fixed enable_binary_blob option for CWriter

This commit is contained in:
Jan Buethe 2024-05-06 14:11:59 +02:00
parent 20568812ae
commit 1711e97165
No known key found for this signature in database
GPG key ID: 9E32027A35B36314
3 changed files with 37 additions and 38 deletions

View file

@ -52,7 +52,7 @@ def c_export(args, model):
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}" message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen') writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False)
writer.header.write( writer.header.write(
f""" f"""
#include "opus_types.h" #include "opus_types.h"

View file

@ -120,7 +120,6 @@ f"""
def _finalize_header(self): def _finalize_header(self):
# create model type # create model type
if self.enable_binary_blob:
if self.add_typedef: if self.add_typedef:
self.header.write(f"\ntypedef struct {{") self.header.write(f"\ntypedef struct {{")
else: else:
@ -140,11 +139,11 @@ f"""
def _finalize_source(self): def _finalize_source(self):
if self.enable_binary_blob:
# create weight array # create weight array
if len(set(self.weight_arrays)) != len(self.weight_arrays): if len(set(self.weight_arrays)) != len(self.weight_arrays):
raise ValueError("error: detected duplicates in weight arrays") raise ValueError("error: detected duplicates in weight arrays")
self.source.write("\n#ifndef USE_WEIGHTS_FILE\n") if self.enable_binary_blob: self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")
self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n") self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")
for name in self.weight_arrays: for name in self.weight_arrays:
self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n") self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")
@ -153,17 +152,17 @@ f"""
self.source.write(" {NULL, 0, 0, NULL}\n") self.source.write(" {NULL, 0, 0, NULL}\n")
self.source.write("};\n") self.source.write("};\n")
self.source.write("#endif /* USE_WEIGHTS_FILE */\n") if self.enable_binary_blob: self.source.write("#endif /* USE_WEIGHTS_FILE */\n")
# create init function definition # create init function definition
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)" init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n") if self.enable_binary_blob: self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")
self.source.write(f"{init_prototype} {{\n") self.source.write(f"{init_prototype} {{\n")
for name, data in self.layer_dict.items(): for name, data in self.layer_dict.items():
self.source.write(f" if ({data[1]}) return 1;\n") self.source.write(f" if ({data[1]}) return 1;\n")
self.source.write(" return 0;\n") self.source.write(" return 0;\n")
self.source.write("}\n") self.source.write("}\n")
self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n") if self.enable_binary_blob:self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")
def close(self): def close(self):