diff --git a/dnn/training_tf2/decode_rdovae.py b/dnn/training_tf2/decode_rdovae.py index 7c5eee2b..8dff0034 100644 --- a/dnn/training_tf2/decode_rdovae.py +++ b/dnn/training_tf2/decode_rdovae.py @@ -56,6 +56,7 @@ import tensorflow.keras.backend as K import h5py import tensorflow as tf +from rdovae import pvq_quantize # Try reducing batch_size if you run out of memory on your GPU batch_size = args.batch_size @@ -87,6 +88,8 @@ quant = quant[:,1::2,:] state = np.reshape(state, (nb_sequences, sequence_size//2, 24)) state = state[:,-1,:] +state = pvq_quantize(state, 30) +#state = state/(1e-15+tf.norm(state, axis=-1,keepdims=True)) print("shapes are:") print(bits.shape)