mirror of
https://github.com/xiph/opus.git
synced 2025-06-01 16:17:42 +00:00
Fix non-128 batch sizes
Avoid hardcoding the batch size in the model
This commit is contained in:
parent
37c9bd8d28
commit
8cdc8081d8
2 changed files with 5 additions and 5 deletions
|
@ -230,10 +230,10 @@ class WeightClip(Constraint):
|
|||
|
||||
constraint = WeightClip(0.992)
|
||||
|
||||
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
|
||||
pcm = Input(shape=(None, 3), batch_size=128)
|
||||
feat = Input(shape=(None, nb_used_features), batch_size=128)
|
||||
pitch = Input(shape=(None, 1), batch_size=128)
|
||||
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_size=128, training=False, adaptation=False, quantize=False, flag_e2e = False):
|
||||
pcm = Input(shape=(None, 3), batch_size=batch_size)
|
||||
feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
|
||||
pitch = Input(shape=(None, 1), batch_size=batch_size)
|
||||
dec_feat = Input(shape=(None, 128))
|
||||
dec_state1 = Input(shape=(rnn_units1,))
|
||||
dec_state2 = Input(shape=(rnn_units2,))
|
||||
|
|
|
@ -121,7 +121,7 @@ opt = Adam(lr, decay=decay, beta_2=0.99)
|
|||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
|
||||
with strategy.scope():
|
||||
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize, flag_e2e = flag_e2e)
|
||||
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, batch_size=batch_size, training=True, quantize=quantize, flag_e2e = flag_e2e)
|
||||
if not flag_e2e:
|
||||
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue