diff --git a/dnn/training_tf2/decode_rdovae.py b/dnn/training_tf2/decode_rdovae.py new file mode 100644 index 00000000..db2ba3b9 --- /dev/null +++ b/dnn/training_tf2/decode_rdovae.py @@ -0,0 +1,95 @@ +#!/usr/bin/python3 +'''Copyright (c) 2021-2022 Amazon + Copyright (c) 2018-2019 Mozilla + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR + CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' + +# Train an LPCNet model + +import argparse +#from plc_loader import PLCLoader + +parser = argparse.ArgumentParser(description='Train a PLC model') + +parser.add_argument('bits', metavar='', help='binary features file (int16)') +parser.add_argument('output', metavar='', help='output features') +parser.add_argument('--model', metavar='', default='rdovae', help='PLC model python definition (without .py)') +group1 = parser.add_mutually_exclusive_group() +group1.add_argument('--weights', metavar='', help='model weights') +parser.add_argument('--cond-size', metavar='', default=1024, type=int, help='number of units in conditioning network (default 1024)') +parser.add_argument('--batch-size', metavar='', default=1, type=int, help='batch size to use (default 128)') +parser.add_argument('--seq-length', metavar='', default=1000, type=int, help='sequence length to use (default 1000)') + + +args = parser.parse_args() + +import importlib +rdovae = importlib.import_module(args.model) + +import sys +import numpy as np +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger +import tensorflow.keras.backend as K +import h5py + +import tensorflow as tf + +# Try reducing batch_size if you run out of memory on your GPU +batch_size = args.batch_size + +model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size) +model.load_weights(args.weights) + +lpc_order = 16 + +bits_file = args.bits +sequence_size = args.seq_length + +# u for unquantised, load 16 bit PCM samples and convert to mu-law + + +bits = np.memmap(bits_file + "-bits.s16", dtype='int16', mode='r') +nb_sequences = len(bits)//(20*sequence_size)//batch_size*batch_size +bits = bits[:nb_sequences*sequence_size*20] + +bits = np.reshape(bits, (nb_sequences, sequence_size//4, 20*4)) +print(bits.shape) + +quant = np.memmap(bits_file + "-quant.f32", dtype='float32', mode='r') +state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r') + +quant = np.reshape(quant, (nb_sequences, sequence_size//4, 6*20*4)) +state = np.reshape(state, (nb_sequences, sequence_size//2, 16)) +state = state[:,-1,:] + +print("shapes are:") +print(bits.shape) +print(quant.shape) +print(state.shape) + +features = decoder.predict([bits, quant, state], batch_size=batch_size) + +features.astype('float32').tofile(args.output) diff --git a/dnn/training_tf2/encode_rdovae.py b/dnn/training_tf2/encode_rdovae.py new file mode 100644 index 00000000..429d80c3 --- /dev/null +++ b/dnn/training_tf2/encode_rdovae.py @@ -0,0 +1,114 @@ +#!/usr/bin/python3 +'''Copyright (c) 2021-2022 Amazon + Copyright (c) 2018-2019 Mozilla + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR + CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' + +# Train an LPCNet model + +import argparse +#from plc_loader import PLCLoader + +parser = argparse.ArgumentParser(description='Train a PLC model') + +parser.add_argument('features', metavar='', help='binary features file (float32)') +parser.add_argument('output', metavar='', help='trained model file (.h5)') +parser.add_argument('--model', metavar='', default='rdovae', help='PLC model python definition (without .py)') +group1 = parser.add_mutually_exclusive_group() +group1.add_argument('--weights', metavar='', help='model weights') +parser.add_argument('--cond-size', metavar='', default=1024, type=int, help='number of units in conditioning network (default 1024)') +parser.add_argument('--batch-size', metavar='', default=1, type=int, help='batch size to use (default 128)') +parser.add_argument('--seq-length', metavar='', default=1000, type=int, help='sequence length to use (default 1000)') + + +args = parser.parse_args() + +import importlib +rdovae = importlib.import_module(args.model) + +import sys +import numpy as np +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger +import tensorflow.keras.backend as K +import h5py + +import tensorflow as tf + +# Try reducing batch_size if you run out of memory on your GPU +batch_size = args.batch_size + +model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size) +model.load_weights(args.weights) + +lpc_order = 16 + +feature_file = args.features +nb_features = model.nb_used_features + lpc_order +nb_used_features = model.nb_used_features +sequence_size = args.seq_length + +# u for unquantised, load 16 bit PCM samples and convert to mu-law + + +features = np.memmap(feature_file, dtype='float32', mode='r') +nb_sequences = len(features)//(nb_features*sequence_size)//batch_size*batch_size +features = features[:nb_sequences*sequence_size*nb_features] + +features = np.reshape(features, (nb_sequences, sequence_size, nb_features)) +print(features.shape) +features = features[:, :, :nb_used_features] +#features = np.random.randn(73600, 1000, 17) + +lambda_val = 0.001 * np.ones((nb_sequences, sequence_size//2, 1)) +quant_id = np.round(10*np.log(lambda_val/.0007)).astype('int16') +quant_id = quant_id[:,:,0] + + +bits, quant_embed_dec, gru_state_dec = encoder.predict([features, quant_id, lambda_val], batch_size=batch_size) +(gru_state_dec).astype('float32').tofile(args.output + "-state.f32") + + +#quant_out, _, _, model_bits, _ = model.predict([features, quant_id, lambda_val], batch_size=batch_size) + +#dist = rdovae.feat_dist_loss(features, quant_out) +#rate = rdovae.sq1_rate_loss(features, model_bits) +#rate2 = rdovae.sq_rate_metric(features, model_bits) +#print(dist, rate, rate2) + +print("shapes are:") +print(bits.shape) +print(quant_embed_dec.shape) +print(gru_state_dec.shape) + +features.astype('float32').tofile(args.output + "-input.f32") +#quant_out.astype('float32').tofile(args.output + "-enc_dec.f32") +np.round(bits).astype('int16').tofile(args.output + "-bits.s16") +quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32") + +gru_state_dec = gru_state_dec[:,-1,:] +dec_out = decoder([bits, quant_embed_dec, gru_state_dec]) + +dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32") diff --git a/dnn/training_tf2/rdovae.py b/dnn/training_tf2/rdovae.py new file mode 100644 index 00000000..0d620cd8 --- /dev/null +++ b/dnn/training_tf2/rdovae.py @@ -0,0 +1,340 @@ +#!/usr/bin/python3 +'''Copyright (c) 2022 Amazon + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR + CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' + +import math +import tensorflow as tf +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise, AveragePooling1D, RepeatVector +from tensorflow.compat.v1.keras.layers import CuDNNGRU +from tensorflow.keras import backend as K +from tensorflow.keras.constraints import Constraint +from tensorflow.keras.initializers import Initializer +from tensorflow.keras.callbacks import Callback +from tensorflow.keras.regularizers import l1 +import numpy as np +import h5py +from uniform_noise import UniformNoise + +class WeightClip(Constraint): + '''Clips the weights incident to each hidden unit to be inside a range + ''' + def __init__(self, c=2): + self.c = c + + def __call__(self, p): + # Ensure that abs of adjacent weights don't sum to more than 127. Otherwise there's a risk of + # saturation when implementing dot products with SSSE3 or AVX2. + return self.c*p/tf.maximum(self.c, tf.repeat(tf.abs(p[:, 1::2])+tf.abs(p[:, 0::2]), 2, axis=1)) + #return K.clip(p, -self.c, self.c) + + def get_config(self): + return {'name': self.__class__.__name__, + 'c': self.c} + +constraint = WeightClip(0.496) + +def soft_quantize(x): + #x = 4*x + #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x) + #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x) + #x = x - (.25/np.math.pi)*tf.math.sin(2*np.math.pi*x) + return x + +def noise_quantize(x): + return soft_quantize(x + (K.random_uniform((128, 16, 80))-.5) ) + +def hard_quantize(x): + x = soft_quantize(x) + quantized = tf.round(x) + return x + tf.stop_gradient(quantized - x) + +def apply_dead_zone(x): + d = x[1]*.05 + x = x[0] + y = x - d*tf.math.tanh(x/(.1+d)) + return y + +def rate_loss(y_true,y_pred): + log2_e = 1.4427 + n = y_pred.shape[-1] + C = n - log2_e*np.math.log(np.math.gamma(n)) + k = K.sum(K.abs(y_pred), axis=-1) + p = 1.5 + #rate = C + (n-1)*log2_e*tf.math.log((k**p + (n/5)**p)**(1/p)) + rate = C + (n-1)*log2_e*tf.math.log(k + .112*n**2/(n/1.8+k) ) + return K.mean(rate) + +eps=1e-6 +def safelog2(x): + log2_e = 1.4427 + return log2_e*tf.math.log(eps+x) + +def feat_dist_loss(y_true,y_pred): + ceps = y_pred[:,:,:18] - y_true[:,:,:18] + pitch = 2*(y_pred[:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2) + corr = y_pred[:,:,19:] - y_true[:,:,19:] + pitch_weight = K.square(K.maximum(0., y_true[:,:,19:]+.5)) + return K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr)) + +def sq1_rate_loss(y_true,y_pred): + lambda_val = y_pred[:,:,-1] + y_pred = y_pred[:,:,:-1] + log2_e = 1.4427 + n = y_pred.shape[-1]//3 + r = (y_pred[:,:,2*n:]) + p0 = (y_pred[:,:,n:2*n]) + p0 = 1-r**(.5+.5*p0) + y_pred = y_pred[:,:,:n] + y_pred = soft_quantize(y_pred) + + y0 = K.maximum(0., 1. - K.abs(y_pred))**2 + rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1)) + rate = -safelog2(-.5*tf.math.log(r)*r**K.abs(y_pred)) + rate = -safelog2((1-r)/(1+r)*r**K.abs(y_pred)) + #rate = -safelog2(- tf.math.sinh(.5*tf.math.log(r))* r**K.abs(y_pred) - tf.math.cosh(K.maximum(0., .5 - K.abs(y_pred))*tf.math.log(r)) + 1) + rate = lambda_val*K.sum(rate, axis=-1) + return K.mean(rate) + +def sq2_rate_loss(y_true,y_pred): + lambda_val = y_pred[:,:,-1] + y_pred = y_pred[:,:,:-1] + log2_e = 1.4427 + n = y_pred.shape[-1]//3 + r = y_pred[:,:,2*n:] + p0 = y_pred[:,:,n:2*n] + p0 = 1-r**(.5+.5*p0) + #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r)) + #p0 = 1-r**theta + y_pred = tf.round(y_pred[:,:,:n]) + y0 = K.maximum(0., 1. - K.abs(y_pred))**2 + rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1)) + rate = lambda_val*K.sum(rate, axis=-1) + return K.mean(rate) + +def sq_rate_metric(y_true,y_pred): + lambda_val = y_pred[:,:,-1] + y_pred = y_pred[:,:,:-1] + log2_e = 1.4427 + n = y_pred.shape[-1]//3 + r = y_pred[:,:,2*n:] + p0 = y_pred[:,:,n:2*n] + p0 = 1-r**(.5+.5*p0) + #theta = K.minimum(1., .5 + 0*p0 - 0.04*tf.math.log(r)) + #p0 = 1-r**theta + y_pred = tf.round(y_pred[:,:,:n]) + y0 = K.maximum(0., 1. - K.abs(y_pred))**2 + rate = -y0*safelog2(p0*r**K.abs(y_pred)) - (1-y0)*safelog2(.5*(1-p0)*(1-r)*r**(K.abs(y_pred)-1)) + rate = K.sum(rate, axis=-1) + return K.mean(rate) + +def pvq_quant_search(x, k): + x = x/tf.reduce_sum(tf.abs(x), axis=-1, keepdims=True) + kx = k*x + y = tf.round(kx) + newk = k + + for j in range(10): + #print("y = ", y) + #print("iteration ", j) + abs_y = tf.abs(y) + abs_kx = tf.abs(kx) + kk=tf.reduce_sum(abs_y, axis=-1) + #print("sums = ", kk) + plus = 1.0001*tf.reduce_min((abs_y+.5)/(abs_kx+1e-15), axis=-1) + minus = .9999*tf.reduce_max((abs_y-.5)/(abs_kx+1e-15), axis=-1) + #print("plus = ", plus) + #print("minus = ", minus) + factor = tf.where(kk>k, minus, plus) + factor = tf.where(kk==k, tf.ones_like(factor), factor) + #print("scale = ", factor) + factor = tf.expand_dims(factor, axis=-1) + #newk = newk * (k/kk)**.2 + newk = newk*factor + kx = newk*x + #print("newk = ", newk) + #print("unquantized = ", newk*x) + y = tf.round(kx) + + #print(y) + + return y + +def pvq_quantize(x, k): + x = x/(1e-15+tf.norm(x, axis=-1,keepdims=True)) + quantized = pvq_quant_search(x, k) + quantized = quantized/(1e-15+tf.norm(quantized, axis=-1,keepdims=True)) + return x + tf.stop_gradient(quantized - x) + + +def var_repeat(x): + return RepeatVector(K.shape(x[1])[1])(x[0]) + +nb_state_dim = 24 + +def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256): + feat = Input(shape=(None, nb_used_features), batch_size=batch_size) + + quant_id = Input(shape=(None,), batch_size=batch_size) + lambda_val = Input(shape=(None, 1), batch_size=batch_size) + qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros') + quant_embed = qembedding(quant_id) + quant_embed_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(quant_embed) + + quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed_bunched)) + + enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1') + enc_dense2 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2') + enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3') + enc_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4') + enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5') + enc_dense6 = CuDNNGRU(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6') + enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7') + enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8') + + #bits_dense = Dense(nb_bits, activation='linear', name='bits_dense') + bits_dense = Conv1D(nb_bits, 4, padding='causal', activation='linear', name='bits_dense') + + zero_out = Lambda(lambda x: 0*x) + inputs = Concatenate()([Reshape((-1, 2*nb_used_features))(feat), tf.stop_gradient(quant_embed), lambda_val]) + #inputs = Concatenate()([feat, tf.stop_gradient(quant_embed), lambda_val]) + d1 = enc_dense1(inputs) + d2 = enc_dense2(d1) + d3 = enc_dense3(d2) + d4 = enc_dense4(d3) + d5 = enc_dense5(d4) + d6, gru_state = enc_dense6(d5) + d7 = enc_dense7(d6) + d8 = enc_dense8(d7) + enc_out = bits_dense(Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8])) + enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out) + bits = Multiply()([enc_out, quant_scale]) + global_dense1 = Dense(128, activation='tanh', name='gdense1') + global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2') + global_bits = global_dense2(global_dense1(d6)) + + encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed_bunched, global_bits], name='encoder') + return encoder + +def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256): + bits_input = Input(shape=(None, nb_bits), batch_size=batch_size) + quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size) + gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size) + + + dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1') + dec_dense2 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense2') + dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3') + dec_dense4 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4') + dec_dense5 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5') + dec_dense6 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6') + dec_dense7 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense7') + dec_dense8 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense8') + + dec_final = Dense(bunch*nb_used_features, activation='linear', name='dec_final') + + div = Lambda(lambda x: x[0]/x[1]) + time_reverse = Lambda(lambda x: K.reverse(x, 1)) + #time_reverse = Lambda(lambda x: x) + quant_scale_dec = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed_dec')(quant_embed_input)) + #gru_state_rep = RepeatVector(64//bunch)(gru_state_input) + + gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input]) + + dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input), gru_state_rep]) + dec1 = dec_dense1(time_reverse(dec_inputs)) + dec2 = dec_dense2(dec1) + dec3 = dec_dense3(dec2) + dec4 = dec_dense4(dec3) + dec5 = dec_dense5(dec4) + dec6 = dec_dense6(dec5) + dec7 = dec_dense7(dec6) + dec8 = dec_dense8(dec7) + output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8]))) + decoder = Model([bits_input, quant_embed_input, gru_state_input], time_reverse(output), name='decoder') + decoder.nb_bits = nb_bits + decoder.bunch = bunch + return decoder + +def new_split_decoder(decoder): + nb_bits = decoder.nb_bits + bunch = decoder.bunch + bits_input = Input(shape=(None, nb_bits)) + quant_embed_input = Input(shape=(None, 6*nb_bits)) + gru_state_input = Input(shape=(None,nb_state_dim)) + + range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:]) + elem_select = Lambda(lambda x: x[0][:,x[1],:]) + points = [0, 64, 128, 192, 256] + outputs = [] + for i in range(len(points)-1): + begin = points[i]//bunch + end = points[i+1]//bunch + state = elem_select([gru_state_input, 2*end-1]) + bits = range_select([bits_input, begin, end]) + embed = range_select([quant_embed_input, begin, end]) + outputs.append(decoder([bits, embed, state])) + output = Concatenate(axis=1)(outputs) + split = Model([bits_input, quant_embed_input, gru_state_input], output, name="split") + return split + + +def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256): + + feat = Input(shape=(None, nb_used_features), batch_size=batch_size) + quant_id = Input(shape=(None,), batch_size=batch_size) + lambda_val = Input(shape=(None, 1), batch_size=batch_size) + lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val) + + encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2) + ze, quant_embed_dec, gru_state_dec = encoder([feat, quant_id, lambda_val]) + + decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2) + split_decoder = new_split_decoder(decoder) + + dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec)) + soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec)) + hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec)) + + noisequant = UniformNoise() + hardquant = Lambda(hard_quantize) + dzone = Lambda(apply_dead_zone) + dze = dzone([ze,dead_zone]) + gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec) + combined_output = split_decoder([hardquant(dze), tf.stop_gradient(quant_embed_dec), gru_state_dec]) + ndze = noisequant(dze) + unquantized_output = split_decoder([ndze, quant_embed_dec, gru_state_dec]) + unquantized_output_dec = split_decoder([tf.stop_gradient(ndze), tf.stop_gradient(quant_embed_dec), gru_state_dec]) + + e2 = Concatenate(name="hard_bits")([dze, hard_distr_embed, lambda_bunched]) + e = Concatenate(name="soft_bits")([dze, soft_distr_embed, lambda_bunched]) + + + model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, unquantized_output_dec, e, e2], name="end2end") + model.nb_used_features = nb_used_features + + return model, encoder, decoder + diff --git a/dnn/training_tf2/train_rdovae.py b/dnn/training_tf2/train_rdovae.py new file mode 100644 index 00000000..33a45e3e --- /dev/null +++ b/dnn/training_tf2/train_rdovae.py @@ -0,0 +1,150 @@ +#!/usr/bin/python3 +'''Copyright (c) 2021-2022 Amazon + Copyright (c) 2018-2019 Mozilla + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR + CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' + +# Train an LPCNet model +import tensorflow as tf +strategy = tf.distribute.MultiWorkerMirroredStrategy() + + +import argparse +#from plc_loader import PLCLoader + +parser = argparse.ArgumentParser(description='Train a quantization model') + +parser.add_argument('features', metavar='', help='binary features file (float32)') +parser.add_argument('output', metavar='', help='trained model file (.h5)') +parser.add_argument('--model', metavar='', default='rdovae', help='PLC model python definition (without .py)') +group1 = parser.add_mutually_exclusive_group() +group1.add_argument('--quantize', metavar='', help='quantize model') +group1.add_argument('--retrain', metavar='', help='continue training model') +parser.add_argument('--cond-size', metavar='', default=1024, type=int, help='number of units in conditioning network (default 1024)') +parser.add_argument('--epochs', metavar='', default=120, type=int, help='number of epochs to train for (default 120)') +parser.add_argument('--batch-size', metavar='', default=128, type=int, help='batch size to use (default 128)') +parser.add_argument('--seq-length', metavar='', default=1000, type=int, help='sequence length to use (default 1000)') +parser.add_argument('--lr', metavar='', type=float, help='learning rate') +parser.add_argument('--decay', metavar='', type=float, help='learning rate decay') +parser.add_argument('--logdir', metavar='', help='directory for tensorboard log files') + + +args = parser.parse_args() + +import importlib +rdovae = importlib.import_module(args.model) + +import sys +import numpy as np +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger +import tensorflow.keras.backend as K +import h5py + +#gpus = tf.config.experimental.list_physical_devices('GPU') +#if gpus: +# try: +# tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)]) +# except RuntimeError as e: +# print(e) + +nb_epochs = args.epochs + +# Try reducing batch_size if you run out of memory on your GPU +batch_size = args.batch_size + +quantize = args.quantize is not None +retrain = args.retrain is not None + +if quantize: + lr = 0.00003 + decay = 0 + input_model = args.quantize +else: + lr = 0.001 + decay = 2.5e-5 + +if args.lr is not None: + lr = args.lr + +if args.decay is not None: + decay = args.decay + +if retrain: + input_model = args.retrain + + +opt = Adam(lr, decay=decay, beta_2=0.99) + +with strategy.scope(): + model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size) + model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[0.5, 0.5, 0., 1., .1], metrics={'split':'mse', 'hard_bits':rdovae.sq_rate_metric}) + model.summary() + +lpc_order = 16 + +feature_file = args.features +nb_features = model.nb_used_features + lpc_order +nb_used_features = model.nb_used_features +sequence_size = args.seq_length + +# u for unquantised, load 16 bit PCM samples and convert to mu-law + + +features = np.memmap(feature_file, dtype='float32', mode='r') +nb_sequences = len(features)//(nb_features*sequence_size)//batch_size*batch_size +features = features[:nb_sequences*sequence_size*nb_features] + +features = np.reshape(features, (nb_sequences, sequence_size, nb_features)) +print(features.shape) +features = features[:, :, :nb_used_features] + +#lambda_val = np.random.uniform(.0007, .002, (features.shape[0], features.shape[1], 1)) +lambda_val = np.repeat(np.random.uniform(.0007, .002, (features.shape[0], 1, 1)), features.shape[1]//2, axis=1) +#lambda_val = 0*lambda_val + .001 +quant_id = np.round(10*np.log(lambda_val/.0007)).astype('int16') +quant_id = quant_id[:,:,0] + +# dump models to disk as we go +checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.cond_size, '{epoch:02d}')) + +if args.retrain is not None: + model.load_weights(args.retrain) + +if quantize or retrain: + #Adapting from an existing model + model.load_weights(input_model) + +model.save_weights('{}_{}_initial.h5'.format(args.output, args.cond_size)) + +callbacks = [checkpoint] +#callbacks = [] + +if args.logdir is not None: + logdir = '{}/{}_{}_logs'.format(args.logdir, args.output, args.cond_size) + tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) + callbacks.append(tensorboard_callback) + +model.fit([features, quant_id, lambda_val], [features, features, features, features, features], batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks) diff --git a/dnn/training_tf2/uniform_noise.py b/dnn/training_tf2/uniform_noise.py new file mode 100644 index 00000000..6197dd5f --- /dev/null +++ b/dnn/training_tf2/uniform_noise.py @@ -0,0 +1,78 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains the UniformNoise layer.""" + + +import tensorflow.compat.v2 as tf + +from tensorflow.keras import backend + +from tensorflow.keras.layers import Layer + +class UniformNoise(Layer): + """Apply additive zero-centered uniform noise. + + This is useful to mitigate overfitting + (you could see it as a form of random data augmentation). + Gaussian Noise (GS) is a natural choice as corruption process + for real valued inputs. + + As it is a regularization layer, it is only active at training time. + + Args: + stddev: Float, standard deviation of the noise distribution. + seed: Integer, optional random seed to enable deterministic behavior. + + Call arguments: + inputs: Input tensor (of any rank). + training: Python boolean indicating whether the layer should behave in + training mode (adding noise) or in inference mode (doing nothing). + + Input shape: + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape: + Same shape as input. + """ + + + + + def __init__(self, stddev=0.5, seed=None, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + self.stddev = stddev + + + def call(self, inputs, training=None): + def noised(): + return inputs + backend.random_uniform( + shape=tf.shape(inputs), + minval=-self.stddev, + maxval=self.stddev, + dtype=inputs.dtype, + ) + + return backend.in_train_phase(noised, inputs, training=training) + + def get_config(self): + config = {"stddev": self.stddev} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + return input_shape