mirror of
https://github.com/xiph/opus.git
synced 2025-05-25 12:49:12 +00:00
Update encoder/decoder
This commit is contained in:
parent
405aa7cf69
commit
ef12c29f14
3 changed files with 22 additions and 17 deletions
|
@ -72,17 +72,20 @@ sequence_size = args.seq_length
|
||||||
|
|
||||||
|
|
||||||
bits = np.memmap(bits_file + "-bits.s16", dtype='int16', mode='r')
|
bits = np.memmap(bits_file + "-bits.s16", dtype='int16', mode='r')
|
||||||
nb_sequences = len(bits)//(20*sequence_size)//batch_size*batch_size
|
nb_sequences = len(bits)//(40*sequence_size)//batch_size*batch_size
|
||||||
bits = bits[:nb_sequences*sequence_size*20]
|
bits = bits[:nb_sequences*sequence_size*40]
|
||||||
|
|
||||||
bits = np.reshape(bits, (nb_sequences, sequence_size//4, 20*4))
|
bits = np.reshape(bits, (nb_sequences, sequence_size//2, 20*4))
|
||||||
|
bits = bits[:,1::2,:]
|
||||||
print(bits.shape)
|
print(bits.shape)
|
||||||
|
|
||||||
quant = np.memmap(bits_file + "-quant.f32", dtype='float32', mode='r')
|
quant = np.memmap(bits_file + "-quant.f32", dtype='float32', mode='r')
|
||||||
state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r')
|
state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r')
|
||||||
|
|
||||||
quant = np.reshape(quant, (nb_sequences, sequence_size//4, 6*20*4))
|
quant = np.reshape(quant, (nb_sequences, sequence_size//2, 6*20*4))
|
||||||
state = np.reshape(state, (nb_sequences, sequence_size//2, 16))
|
quant = quant[:,1::2,:]
|
||||||
|
|
||||||
|
state = np.reshape(state, (nb_sequences, sequence_size//2, 24))
|
||||||
state = state[:,-1,:]
|
state = state[:,-1,:]
|
||||||
|
|
||||||
print("shapes are:")
|
print("shapes are:")
|
||||||
|
|
|
@ -109,6 +109,6 @@ np.round(bits).astype('int16').tofile(args.output + "-bits.s16")
|
||||||
quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32")
|
quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32")
|
||||||
|
|
||||||
gru_state_dec = gru_state_dec[:,-1,:]
|
gru_state_dec = gru_state_dec[:,-1,:]
|
||||||
dec_out = decoder([bits, quant_embed_dec, gru_state_dec])
|
dec_out = decoder([bits[:,1::2,:], quant_embed_dec[:,1::2,:], gru_state_dec])
|
||||||
|
|
||||||
dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32")
|
dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32")
|
||||||
|
|
|
@ -195,7 +195,7 @@ def var_repeat(x):
|
||||||
|
|
||||||
nb_state_dim = 24
|
nb_state_dim = 24
|
||||||
|
|
||||||
def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
|
def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
|
||||||
feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
|
feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
|
||||||
|
|
||||||
quant_id = Input(shape=(None,), batch_size=batch_size)
|
quant_id = Input(shape=(None,), batch_size=batch_size)
|
||||||
|
@ -205,12 +205,13 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
|
||||||
|
|
||||||
quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed))
|
quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed))
|
||||||
|
|
||||||
|
gru = CuDNNGRU if training else GRU
|
||||||
enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1')
|
enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1')
|
||||||
enc_dense2 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
|
enc_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
|
||||||
enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
|
enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
|
||||||
enc_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
|
enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
|
||||||
enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
|
enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
|
||||||
enc_dense6 = CuDNNGRU(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
|
enc_dense6 = gru(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
|
||||||
enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
|
enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
|
||||||
enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
|
enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
|
||||||
|
|
||||||
|
@ -238,7 +239,7 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
|
||||||
encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder')
|
encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder')
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
|
def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
|
||||||
bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits")
|
bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits")
|
||||||
quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size, name="dec_embed")
|
quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size, name="dec_embed")
|
||||||
gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state")
|
gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state")
|
||||||
|
@ -247,9 +248,10 @@ def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
|
||||||
dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1')
|
dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1')
|
||||||
dec_dense2 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense2')
|
dec_dense2 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense2')
|
||||||
dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3')
|
dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3')
|
||||||
dec_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
|
gru = CuDNNGRU if training else GRU
|
||||||
dec_dense5 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5')
|
dec_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
|
||||||
dec_dense6 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
|
dec_dense5 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5')
|
||||||
|
dec_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
|
||||||
dec_dense7 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
|
dec_dense7 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
|
||||||
dec_dense8 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
|
dec_dense8 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
|
||||||
|
|
||||||
|
@ -313,17 +315,17 @@ def tensor_concat(x):
|
||||||
return Concatenate(axis=0)(y)
|
return Concatenate(axis=0)(y)
|
||||||
|
|
||||||
|
|
||||||
def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
|
def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
|
||||||
|
|
||||||
feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
|
feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
|
||||||
quant_id = Input(shape=(None,), batch_size=batch_size)
|
quant_id = Input(shape=(None,), batch_size=batch_size)
|
||||||
lambda_val = Input(shape=(None, 1), batch_size=batch_size)
|
lambda_val = Input(shape=(None, 1), batch_size=batch_size)
|
||||||
lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val)
|
lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val)
|
||||||
|
|
||||||
encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2)
|
encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
|
||||||
ze, quant_embed_dec, gru_state_dec = encoder([feat, quant_id, lambda_val])
|
ze, quant_embed_dec, gru_state_dec = encoder([feat, quant_id, lambda_val])
|
||||||
|
|
||||||
decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2)
|
decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
|
||||||
split_decoder = new_split_decoder(decoder)
|
split_decoder = new_split_decoder(decoder)
|
||||||
|
|
||||||
dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))
|
dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue