fixed enable_binary_blob option for CWriter
This commit is contained in:
parent
20568812ae
commit
1711e97165
3 changed files with 37 additions and 38 deletions
|
@ -52,7 +52,7 @@ def c_export(args, model):
|
||||||
|
|
||||||
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
|
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
|
||||||
|
|
||||||
writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen')
|
writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False)
|
||||||
writer.header.write(
|
writer.header.write(
|
||||||
f"""
|
f"""
|
||||||
#include "opus_types.h"
|
#include "opus_types.h"
|
||||||
|
|
|
@ -120,50 +120,49 @@ f"""
|
||||||
def _finalize_header(self):
|
def _finalize_header(self):
|
||||||
|
|
||||||
# create model type
|
# create model type
|
||||||
if self.enable_binary_blob:
|
if self.add_typedef:
|
||||||
if self.add_typedef:
|
self.header.write(f"\ntypedef struct {{")
|
||||||
self.header.write(f"\ntypedef struct {{")
|
else:
|
||||||
else:
|
self.header.write(f"\nstruct {self.model_struct_name} {{")
|
||||||
self.header.write(f"\nstruct {self.model_struct_name} {{")
|
for name, data in self.layer_dict.items():
|
||||||
for name, data in self.layer_dict.items():
|
layer_type = data[0]
|
||||||
layer_type = data[0]
|
self.header.write(f"\n {layer_type} {name};")
|
||||||
self.header.write(f"\n {layer_type} {name};")
|
if self.add_typedef:
|
||||||
if self.add_typedef:
|
self.header.write(f"\n}} {self.model_struct_name};\n")
|
||||||
self.header.write(f"\n}} {self.model_struct_name};\n")
|
else:
|
||||||
else:
|
self.header.write(f"\n}};\n")
|
||||||
self.header.write(f"\n}};\n")
|
|
||||||
|
|
||||||
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
|
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
|
||||||
self.header.write(f"\n{init_prototype};\n")
|
self.header.write(f"\n{init_prototype};\n")
|
||||||
|
|
||||||
self.header.write(f"\n#endif /* {self.header_guard} */\n")
|
self.header.write(f"\n#endif /* {self.header_guard} */\n")
|
||||||
|
|
||||||
def _finalize_source(self):
|
def _finalize_source(self):
|
||||||
|
|
||||||
if self.enable_binary_blob:
|
|
||||||
# create weight array
|
|
||||||
if len(set(self.weight_arrays)) != len(self.weight_arrays):
|
|
||||||
raise ValueError("error: detected duplicates in weight arrays")
|
|
||||||
self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")
|
|
||||||
self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")
|
|
||||||
for name in self.weight_arrays:
|
|
||||||
self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")
|
|
||||||
self.source.write(f' {{"{name}", WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n')
|
|
||||||
self.source.write(f"#endif\n")
|
|
||||||
self.source.write(" {NULL, 0, 0, NULL}\n")
|
|
||||||
self.source.write("};\n")
|
|
||||||
|
|
||||||
self.source.write("#endif /* USE_WEIGHTS_FILE */\n")
|
# create weight array
|
||||||
|
if len(set(self.weight_arrays)) != len(self.weight_arrays):
|
||||||
|
raise ValueError("error: detected duplicates in weight arrays")
|
||||||
|
if self.enable_binary_blob: self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")
|
||||||
|
self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")
|
||||||
|
for name in self.weight_arrays:
|
||||||
|
self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")
|
||||||
|
self.source.write(f' {{"{name}", WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n')
|
||||||
|
self.source.write(f"#endif\n")
|
||||||
|
self.source.write(" {NULL, 0, 0, NULL}\n")
|
||||||
|
self.source.write("};\n")
|
||||||
|
|
||||||
# create init function definition
|
if self.enable_binary_blob: self.source.write("#endif /* USE_WEIGHTS_FILE */\n")
|
||||||
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
|
|
||||||
self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")
|
# create init function definition
|
||||||
self.source.write(f"{init_prototype} {{\n")
|
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
|
||||||
for name, data in self.layer_dict.items():
|
if self.enable_binary_blob: self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")
|
||||||
self.source.write(f" if ({data[1]}) return 1;\n")
|
self.source.write(f"{init_prototype} {{\n")
|
||||||
self.source.write(" return 0;\n")
|
for name, data in self.layer_dict.items():
|
||||||
self.source.write("}\n")
|
self.source.write(f" if ({data[1]}) return 1;\n")
|
||||||
self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")
|
self.source.write(" return 0;\n")
|
||||||
|
self.source.write("}\n")
|
||||||
|
if self.enable_binary_blob:self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|
|
@ -54,7 +54,7 @@ f'''
|
||||||
#ifndef USE_WEIGHTS_FILE
|
#ifndef USE_WEIGHTS_FILE
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
writer.weight_arrays.append(name)
|
writer.weight_arrays.append(name)
|
||||||
|
|
||||||
if reshape_8x4:
|
if reshape_8x4:
|
||||||
vector = vector.reshape((vector.shape[0]//4, 4, vector.shape[1]//8, 8))
|
vector = vector.reshape((vector.shape[0]//4, 4, vector.shape[1]//8, 8))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue