Fix non-128 batch sizes

Avoid hardcoding the batch size in the model
This commit is contained in:
Jean-Marc Valin 2021-10-09 03:20:22 -04:00
parent 37c9bd8d28
commit 8cdc8081d8
2 changed files with 5 additions and 5 deletions

View file

@ -230,10 +230,10 @@ class WeightClip(Constraint):
constraint = WeightClip(0.992)
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
pcm = Input(shape=(None, 3), batch_size=128)
feat = Input(shape=(None, nb_used_features), batch_size=128)
pitch = Input(shape=(None, 1), batch_size=128)
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_size=128, training=False, adaptation=False, quantize=False, flag_e2e = False):
pcm = Input(shape=(None, 3), batch_size=batch_size)
feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
pitch = Input(shape=(None, 1), batch_size=batch_size)
dec_feat = Input(shape=(None, 128))
dec_state1 = Input(shape=(rnn_units1,))
dec_state2 = Input(shape=(rnn_units2,))