diff --git a/dnn/training_tf2/decode_rdovae.py b/dnn/training_tf2/decode_rdovae.py index db2ba3b9..7c5eee2b 100644 --- a/dnn/training_tf2/decode_rdovae.py +++ b/dnn/training_tf2/decode_rdovae.py @@ -72,17 +72,20 @@ sequence_size = args.seq_length bits = np.memmap(bits_file + "-bits.s16", dtype='int16', mode='r') -nb_sequences = len(bits)//(20*sequence_size)//batch_size*batch_size -bits = bits[:nb_sequences*sequence_size*20] +nb_sequences = len(bits)//(40*sequence_size)//batch_size*batch_size +bits = bits[:nb_sequences*sequence_size*40] -bits = np.reshape(bits, (nb_sequences, sequence_size//4, 20*4)) +bits = np.reshape(bits, (nb_sequences, sequence_size//2, 20*4)) +bits = bits[:,1::2,:] print(bits.shape) quant = np.memmap(bits_file + "-quant.f32", dtype='float32', mode='r') state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r') -quant = np.reshape(quant, (nb_sequences, sequence_size//4, 6*20*4)) -state = np.reshape(state, (nb_sequences, sequence_size//2, 16)) +quant = np.reshape(quant, (nb_sequences, sequence_size//2, 6*20*4)) +quant = quant[:,1::2,:] + +state = np.reshape(state, (nb_sequences, sequence_size//2, 24)) state = state[:,-1,:] print("shapes are:") diff --git a/dnn/training_tf2/encode_rdovae.py b/dnn/training_tf2/encode_rdovae.py index 429d80c3..00a44305 100644 --- a/dnn/training_tf2/encode_rdovae.py +++ b/dnn/training_tf2/encode_rdovae.py @@ -109,6 +109,6 @@ np.round(bits).astype('int16').tofile(args.output + "-bits.s16") quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32") gru_state_dec = gru_state_dec[:,-1,:] -dec_out = decoder([bits, quant_embed_dec, gru_state_dec]) +dec_out = decoder([bits[:,1::2,:], quant_embed_dec[:,1::2,:], gru_state_dec]) dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32") diff --git a/dnn/training_tf2/rdovae.py b/dnn/training_tf2/rdovae.py index 164c2391..d33422be 100644 --- a/dnn/training_tf2/rdovae.py +++ b/dnn/training_tf2/rdovae.py @@ -195,7 +195,7 @@ def var_repeat(x): nb_state_dim = 24 -def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256): +def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False): feat = Input(shape=(None, nb_used_features), batch_size=batch_size) quant_id = Input(shape=(None,), batch_size=batch_size) @@ -205,12 +205,13 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed)) + gru = CuDNNGRU if training else GRU enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1') - enc_dense2 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2') + enc_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2') enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3') - enc_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4') + enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4') enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5') - enc_dense6 = CuDNNGRU(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6') + enc_dense6 = gru(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6') enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7') enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8') @@ -238,7 +239,7 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder') return encoder -def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256): +def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False): bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits") quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size, name="dec_embed") gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state") @@ -247,9 +248,10 @@ def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1') dec_dense2 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense2') dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3') - dec_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4') - dec_dense5 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5') - dec_dense6 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6') + gru = CuDNNGRU if training else GRU + dec_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4') + dec_dense5 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5') + dec_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6') dec_dense7 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense7') dec_dense8 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense8') @@ -313,17 +315,17 @@ def tensor_concat(x): return Concatenate(axis=0)(y) -def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256): +def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False): feat = Input(shape=(None, nb_used_features), batch_size=batch_size) quant_id = Input(shape=(None,), batch_size=batch_size) lambda_val = Input(shape=(None, 1), batch_size=batch_size) lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val) - encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2) + encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training) ze, quant_embed_dec, gru_state_dec = encoder([feat, quant_id, lambda_val]) - decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2) + decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training) split_decoder = new_split_decoder(decoder) dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))