From 0e564fdfafc2d4fe8261af7bd691d486947ab534 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Sat, 18 May 2024 02:40:53 -0400 Subject: [PATCH] More fixes for the non-blob weight export --- dnn/torch/lossgen/export_lossgen.py | 2 +- dnn/torch/weight-exchange/wexchange/c_export/common.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/dnn/torch/lossgen/export_lossgen.py b/dnn/torch/lossgen/export_lossgen.py index da63118f..15a65c36 100644 --- a/dnn/torch/lossgen/export_lossgen.py +++ b/dnn/torch/lossgen/export_lossgen.py @@ -52,7 +52,7 @@ def c_export(args, model): message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}" - writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False) + writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False, add_typedef=True) writer.header.write( f""" #include "opus_types.h" diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py index b96e0d6c..8d2cbcf9 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/common.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py @@ -64,8 +64,7 @@ f''' if debug_float: f.write('#ifndef DISABLE_DEBUG_FLOAT\n') - if binary_blob: - f.write( + f.write( f''' #define WEIGHTS_{name}_DEFINED #define WEIGHTS_{name}_TYPE WEIGHT_TYPE_{dtype_suffix[dtype]} @@ -384,4 +383,4 @@ def print_tconv1d_layer(writer : CWriter, writer.header.write(f"\n#define {name.upper()}_KERNEL_SIZE {kernel_size}\n") writer.header.write(f"\n#define {name.upper()}_STRIDE {stride}\n") writer.header.write(f"\n#define {name.upper()}_IN_CHANNELS {in_channels}\n") - writer.header.write(f"\n#define {name.upper()}_OUT_CHANNELS {out_channels}\n") \ No newline at end of file + writer.header.write(f"\n#define {name.upper()}_OUT_CHANNELS {out_channels}\n")