More fixes for the non-blob weight export

This commit is contained in:
Jean-Marc Valin 2024-05-18 02:40:53 -04:00
parent 1711e97165
commit 0e564fdfaf
No known key found for this signature in database
GPG key ID: 5E5DD9A36F9189C8
2 changed files with 3 additions and 4 deletions

View file

@ -52,7 +52,7 @@ def c_export(args, model):
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}" message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False) writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False, add_typedef=True)
writer.header.write( writer.header.write(
f""" f"""
#include "opus_types.h" #include "opus_types.h"

View file

@ -64,8 +64,7 @@ f'''
if debug_float: if debug_float:
f.write('#ifndef DISABLE_DEBUG_FLOAT\n') f.write('#ifndef DISABLE_DEBUG_FLOAT\n')
if binary_blob: f.write(
f.write(
f''' f'''
#define WEIGHTS_{name}_DEFINED #define WEIGHTS_{name}_DEFINED
#define WEIGHTS_{name}_TYPE WEIGHT_TYPE_{dtype_suffix[dtype]} #define WEIGHTS_{name}_TYPE WEIGHT_TYPE_{dtype_suffix[dtype]}
@ -384,4 +383,4 @@ def print_tconv1d_layer(writer : CWriter,
writer.header.write(f"\n#define {name.upper()}_KERNEL_SIZE {kernel_size}\n") writer.header.write(f"\n#define {name.upper()}_KERNEL_SIZE {kernel_size}\n")
writer.header.write(f"\n#define {name.upper()}_STRIDE {stride}\n") writer.header.write(f"\n#define {name.upper()}_STRIDE {stride}\n")
writer.header.write(f"\n#define {name.upper()}_IN_CHANNELS {in_channels}\n") writer.header.write(f"\n#define {name.upper()}_IN_CHANNELS {in_channels}\n")
writer.header.write(f"\n#define {name.upper()}_OUT_CHANNELS {out_channels}\n") writer.header.write(f"\n#define {name.upper()}_OUT_CHANNELS {out_channels}\n")