mirror of
https://github.com/xiph/opus.git
synced 2025-05-30 23:27:42 +00:00
Use the encoder state as decoder initial state
Helps reduce the error on the most recent frames
This commit is contained in:
parent
38dda0f950
commit
0a2d6dfcb6
1 changed files with 13 additions and 9 deletions
|
@ -213,7 +213,7 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
|
|||
enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
|
||||
enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
|
||||
enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
|
||||
enc_dense6 = gru(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
|
||||
enc_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
|
||||
enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
|
||||
enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
|
||||
|
||||
|
@ -228,15 +228,16 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
|
|||
d3 = enc_dense3(d2)
|
||||
d4 = enc_dense4(d3)
|
||||
d5 = enc_dense5(d4)
|
||||
d6, gru_state = enc_dense6(d5)
|
||||
d6 = enc_dense6(d5)
|
||||
d7 = enc_dense7(d6)
|
||||
d8 = enc_dense8(d7)
|
||||
enc_out = bits_dense(Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8]))
|
||||
pre_out = Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8])
|
||||
enc_out = bits_dense(pre_out)
|
||||
#enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out)
|
||||
bits = Multiply()([enc_out, quant_scale])
|
||||
global_dense1 = Dense(128, activation='tanh', name='gdense1')
|
||||
global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2')
|
||||
global_bits = global_dense2(global_dense1(d6))
|
||||
global_bits = global_dense2(global_dense1(pre_out))
|
||||
|
||||
encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder')
|
||||
return encoder
|
||||
|
@ -265,15 +266,18 @@ def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
|
|||
quant_scale_dec = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed_dec')(quant_embed_input))
|
||||
#gru_state_rep = RepeatVector(64//bunch)(gru_state_input)
|
||||
|
||||
gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
|
||||
#gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
|
||||
gru_state1 = Dense(cond_size, name="state1", activation='tanh')(gru_state_input)
|
||||
gru_state2 = Dense(cond_size, name="state2", activation='tanh')(gru_state_input)
|
||||
gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input)
|
||||
|
||||
dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input), gru_state_rep])
|
||||
dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input)])
|
||||
dec1 = dec_dense1(time_reverse(dec_inputs))
|
||||
dec2 = dec_dense2(dec1)
|
||||
dec3 = dec_dense3(dec2)
|
||||
dec4 = dec_dense4(dec3)
|
||||
dec5 = dec_dense5(dec4)
|
||||
dec6 = dec_dense6(dec5)
|
||||
dec4 = dec_dense4(dec3, initial_state=gru_state1)
|
||||
dec5 = dec_dense5(dec4, initial_state=gru_state2)
|
||||
dec6 = dec_dense6(dec5, initial_state=gru_state3)
|
||||
dec7 = dec_dense7(dec6)
|
||||
dec8 = dec_dense8(dec7)
|
||||
output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8])))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue