Oops, forgot to run PVQ quantization for the state

This commit is contained in:
Jean-Marc Valin 2022-09-28 15:33:20 -04:00
parent b43f077ba8
commit 38dda0f950

View file

@ -56,6 +56,7 @@ import tensorflow.keras.backend as K
import h5py import h5py
import tensorflow as tf import tensorflow as tf
from rdovae import pvq_quantize
# Try reducing batch_size if you run out of memory on your GPU # Try reducing batch_size if you run out of memory on your GPU
batch_size = args.batch_size batch_size = args.batch_size
@ -87,6 +88,8 @@ quant = quant[:,1::2,:]
state = np.reshape(state, (nb_sequences, sequence_size//2, 24)) state = np.reshape(state, (nb_sequences, sequence_size//2, 24))
state = state[:,-1,:] state = state[:,-1,:]
state = pvq_quantize(state, 30)
#state = state/(1e-15+tf.norm(state, axis=-1,keepdims=True))
print("shapes are:") print("shapes are:")
print(bits.shape) print(bits.shape)