WIP: Adding a constraint

This commit is contained in:
Jean-Marc Valin 2020-12-24 02:50:20 -05:00
parent c045702e51
commit 1657bae024
2 changed files with 24 additions and 5 deletions

View file

@ -29,6 +29,7 @@ import math
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.initializers import Initializer
from tensorflow.keras.callbacks import Callback
from mdense import MDense
@ -115,6 +116,21 @@ class PCMInit(Initializer):
'seed': self.seed
}
class WeightClip(Constraint):
'''Clips the weights incident to each hidden unit to be inside a range
'''
def __init__(self, c=2):
self.c = c
def __call__(self, p):
return K.clip(p, -self.c, self.c)
def get_config(self):
return {'name': self.__class__.__name__,
'c': self.c}
constraint = WeightClip(0.999)
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, adaptation=False):
pcm = Input(shape=(None, 3))
feat = Input(shape=(None, nb_used_features))
@ -142,8 +158,10 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, train
rep = Lambda(lambda x: K.repeat_elements(x, frame_size, 1))
rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a')
rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b')
rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a',
recurrent_constraint = constraint)
rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b',
kernel_constraint=constraint)
rnn_in = Concatenate()([cpcm, rep(cfeat)])
md = MDense(pcm_levels, activation='softmax', name='dual_fc')