From 0a2d6dfcb656108eaace165310245e7f74a5b360 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Wed, 28 Sep 2022 15:34:02 -0400 Subject: [PATCH] Use the encoder state as decoder initial state Helps reduce the error on the most recent frames --- dnn/training_tf2/rdovae.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/dnn/training_tf2/rdovae.py b/dnn/training_tf2/rdovae.py index 81e958f7..883d59e3 100644 --- a/dnn/training_tf2/rdovae.py +++ b/dnn/training_tf2/rdovae.py @@ -213,7 +213,7 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3') enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4') enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5') - enc_dense6 = gru(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6') + enc_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6') enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7') enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8') @@ -228,15 +228,16 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba d3 = enc_dense3(d2) d4 = enc_dense4(d3) d5 = enc_dense5(d4) - d6, gru_state = enc_dense6(d5) + d6 = enc_dense6(d5) d7 = enc_dense7(d6) d8 = enc_dense8(d7) - enc_out = bits_dense(Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8])) + pre_out = Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8]) + enc_out = bits_dense(pre_out) #enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out) bits = Multiply()([enc_out, quant_scale]) global_dense1 = Dense(128, activation='tanh', name='gdense1') global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2') - global_bits = global_dense2(global_dense1(d6)) + global_bits = global_dense2(global_dense1(pre_out)) encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder') return encoder @@ -265,15 +266,18 @@ def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba quant_scale_dec = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed_dec')(quant_embed_input)) #gru_state_rep = RepeatVector(64//bunch)(gru_state_input) - gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input]) + #gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input]) + gru_state1 = Dense(cond_size, name="state1", activation='tanh')(gru_state_input) + gru_state2 = Dense(cond_size, name="state2", activation='tanh')(gru_state_input) + gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input) - dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input), gru_state_rep]) + dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input)]) dec1 = dec_dense1(time_reverse(dec_inputs)) dec2 = dec_dense2(dec1) dec3 = dec_dense3(dec2) - dec4 = dec_dense4(dec3) - dec5 = dec_dense5(dec4) - dec6 = dec_dense6(dec5) + dec4 = dec_dense4(dec3, initial_state=gru_state1) + dec5 = dec_dense5(dec4, initial_state=gru_state2) + dec6 = dec_dense6(dec5, initial_state=gru_state3) dec7 = dec_dense7(dec6) dec8 = dec_dense8(dec7) output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8])))