mirror of
https://github.com/xiph/opus.git
synced 2025-05-25 20:59:13 +00:00

Model now stores LPC gamma, look-ahead, and end-to-end. Parameters aren't quite reliable yet, YMMV
29 lines
1.1 KiB
Python
29 lines
1.1 KiB
Python
""" module for handling extra model parameters for tf.keras models """
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
def set_parameter(model, parameter_name, parameter_value, dtype='float32'):
|
|
""" stores parameter_value as non-trainable weight with name parameter_name:0 """
|
|
|
|
weights = [weight for weight in model.weights if weight.name == (parameter_name + ":0")]
|
|
|
|
if len(weights) == 0:
|
|
model.add_weight(parameter_name, trainable=False, initializer=tf.keras.initializers.Constant(parameter_value), dtype=dtype)
|
|
elif len(weights) == 1:
|
|
weights[0].assign(parameter_value)
|
|
else:
|
|
raise ValueError(f"more than one weight starting with {parameter_name}:0 in model")
|
|
|
|
|
|
def get_parameter(model, parameter_name, default=None):
|
|
""" returns parameter value if parameter is present in model and otherwise default """
|
|
|
|
weights = [weight for weight in model.weights if weight.name == (parameter_name + ":0")]
|
|
|
|
if len(weights) == 0:
|
|
return default
|
|
elif len(weights) > 1:
|
|
raise ValueError(f"more than one weight starting with {parameter_name}:0 in model")
|
|
else:
|
|
return weights[0].numpy().item()
|