Update encoder/decoder

This commit is contained in:
Jean-Marc Valin 2022-09-14 17:04:36 -04:00
parent 405aa7cf69
commit ef12c29f14
3 changed files with 22 additions and 17 deletions

View file

@ -72,17 +72,20 @@ sequence_size = args.seq_length
bits = np.memmap(bits_file + "-bits.s16", dtype='int16', mode='r') bits = np.memmap(bits_file + "-bits.s16", dtype='int16', mode='r')
nb_sequences = len(bits)//(20*sequence_size)//batch_size*batch_size nb_sequences = len(bits)//(40*sequence_size)//batch_size*batch_size
bits = bits[:nb_sequences*sequence_size*20] bits = bits[:nb_sequences*sequence_size*40]
bits = np.reshape(bits, (nb_sequences, sequence_size//4, 20*4)) bits = np.reshape(bits, (nb_sequences, sequence_size//2, 20*4))
bits = bits[:,1::2,:]
print(bits.shape) print(bits.shape)
quant = np.memmap(bits_file + "-quant.f32", dtype='float32', mode='r') quant = np.memmap(bits_file + "-quant.f32", dtype='float32', mode='r')
state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r') state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r')
quant = np.reshape(quant, (nb_sequences, sequence_size//4, 6*20*4)) quant = np.reshape(quant, (nb_sequences, sequence_size//2, 6*20*4))
state = np.reshape(state, (nb_sequences, sequence_size//2, 16)) quant = quant[:,1::2,:]
state = np.reshape(state, (nb_sequences, sequence_size//2, 24))
state = state[:,-1,:] state = state[:,-1,:]
print("shapes are:") print("shapes are:")

View file

@ -109,6 +109,6 @@ np.round(bits).astype('int16').tofile(args.output + "-bits.s16")
quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32") quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32")
gru_state_dec = gru_state_dec[:,-1,:] gru_state_dec = gru_state_dec[:,-1,:]
dec_out = decoder([bits, quant_embed_dec, gru_state_dec]) dec_out = decoder([bits[:,1::2,:], quant_embed_dec[:,1::2,:], gru_state_dec])
dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32") dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32")

View file

@ -195,7 +195,7 @@ def var_repeat(x):
nb_state_dim = 24 nb_state_dim = 24
def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256): def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
feat = Input(shape=(None, nb_used_features), batch_size=batch_size) feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
quant_id = Input(shape=(None,), batch_size=batch_size) quant_id = Input(shape=(None,), batch_size=batch_size)
@ -205,12 +205,13 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed)) quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed))
gru = CuDNNGRU if training else GRU
enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1') enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1')
enc_dense2 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2') enc_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3') enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
enc_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4') enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5') enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
enc_dense6 = CuDNNGRU(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6') enc_dense6 = gru(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7') enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8') enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
@ -238,7 +239,7 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder') encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder')
return encoder return encoder
def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256): def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits") bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits")
quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size, name="dec_embed") quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size, name="dec_embed")
gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state") gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state")
@ -247,9 +248,10 @@ def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1') dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1')
dec_dense2 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense2') dec_dense2 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense2')
dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3') dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3')
dec_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4') gru = CuDNNGRU if training else GRU
dec_dense5 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5') dec_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
dec_dense6 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6') dec_dense5 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5')
dec_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
dec_dense7 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense7') dec_dense7 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
dec_dense8 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense8') dec_dense8 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
@ -313,17 +315,17 @@ def tensor_concat(x):
return Concatenate(axis=0)(y) return Concatenate(axis=0)(y)
def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256): def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
feat = Input(shape=(None, nb_used_features), batch_size=batch_size) feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
quant_id = Input(shape=(None,), batch_size=batch_size) quant_id = Input(shape=(None,), batch_size=batch_size)
lambda_val = Input(shape=(None, 1), batch_size=batch_size) lambda_val = Input(shape=(None, 1), batch_size=batch_size)
lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val) lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val)
encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2) encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
ze, quant_embed_dec, gru_state_dec = encoder([feat, quant_id, lambda_val]) ze, quant_embed_dec, gru_state_dec = encoder([feat, quant_id, lambda_val])
decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2) decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
split_decoder = new_split_decoder(decoder) split_decoder = new_split_decoder(decoder)
dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec)) dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))