Use the encoder state as decoder initial state

Helps reduce the error on the most recent frames
This commit is contained in:
Jean-Marc Valin 2022-09-28 15:34:02 -04:00
parent 38dda0f950
commit 0a2d6dfcb6

View file

@ -213,7 +213,7 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
enc_dense6 = gru(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
enc_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
@ -228,15 +228,16 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
d3 = enc_dense3(d2)
d4 = enc_dense4(d3)
d5 = enc_dense5(d4)
d6, gru_state = enc_dense6(d5)
d6 = enc_dense6(d5)
d7 = enc_dense7(d6)
d8 = enc_dense8(d7)
enc_out = bits_dense(Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8]))
pre_out = Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8])
enc_out = bits_dense(pre_out)
#enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out)
bits = Multiply()([enc_out, quant_scale])
global_dense1 = Dense(128, activation='tanh', name='gdense1')
global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2')
global_bits = global_dense2(global_dense1(d6))
global_bits = global_dense2(global_dense1(pre_out))
encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder')
return encoder
@ -265,15 +266,18 @@ def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
quant_scale_dec = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed_dec')(quant_embed_input))
#gru_state_rep = RepeatVector(64//bunch)(gru_state_input)
gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
#gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
gru_state1 = Dense(cond_size, name="state1", activation='tanh')(gru_state_input)
gru_state2 = Dense(cond_size, name="state2", activation='tanh')(gru_state_input)
gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input)
dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input), gru_state_rep])
dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input)])
dec1 = dec_dense1(time_reverse(dec_inputs))
dec2 = dec_dense2(dec1)
dec3 = dec_dense3(dec2)
dec4 = dec_dense4(dec3)
dec5 = dec_dense5(dec4)
dec6 = dec_dense6(dec5)
dec4 = dec_dense4(dec3, initial_state=gru_state1)
dec5 = dec_dense5(dec4, initial_state=gru_state2)
dec6 = dec_dense6(dec5, initial_state=gru_state3)
dec7 = dec_dense7(dec6)
dec8 = dec_dense8(dec7)
output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8])))