diff --git a/dnn/gatedconv.py b/dnn/gatedconv.py new file mode 100644 index 00000000..61f7383e --- /dev/null +++ b/dnn/gatedconv.py @@ -0,0 +1,62 @@ +from keras import backend as K +from keras.engine.topology import Layer +from keras.layers import activations, initializers, regularizers, constraints, InputSpec, Conv1D +import numpy as np + +class GatedConv(Conv1D): + + def __init__(self, filters, + kernel_size, + dilation_rate=1, + activation='tanh', + use_bias=True, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + return_memory=False, + **kwargs): + + super(GatedConv, self).__init__( + filters=2*filters, + kernel_size=kernel_size, + strides=1, + padding='valid', + data_format='channels_last', + dilation_rate=dilation_rate, + activation='linear', + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + **kwargs) + self.mem_size = dilation_rate*(kernel_size-1) + self.return_memory = return_memory + self.out_dims = filters + self.nongate_activation = activations.get(activation) + + def call(self, inputs, memory=None): + if memory is None: + mem = K.zeros((K.shape(inputs)[0], self.mem_size, K.shape(inputs)[-1])) + else: + mem = K.variable(K.cast_to_floatx(memory)) + inputs = K.concatenate([mem, inputs], axis=1) + ret = super(GatedConv, self).call(inputs) + ret = self.nongate_activation(ret[:, :, :self.out_dims]) * activations.sigmoid(ret[:, :, self.out_dims:]) + if self.return_memory: + ret = ret, inputs[:, :self.mem_size, :] + return ret + + def compute_output_shape(self, input_shape): + assert input_shape and len(input_shape) >= 2 + assert input_shape[-1] + output_shape = list(input_shape) + output_shape[-1] = self.out_dims + return tuple(output_shape) diff --git a/dnn/wavenet.py b/dnn/wavenet.py index 20a387a5..96853653 100644 --- a/dnn/wavenet.py +++ b/dnn/wavenet.py @@ -9,8 +9,9 @@ import numpy as np import h5py import sys from causalconv import CausalConv +from gatedconv import GatedConv -units=256 +units=128 pcm_bits = 8 pcm_levels = 2**pcm_bits nb_used_features = 38 @@ -37,10 +38,8 @@ def new_wavenet_model(fftnet=False): res = tmp tmp = Concatenate()([tmp, rfeat]) dilation = 9-k if fftnet else k - c1 = CausalConv(units, 2, dilation_rate=2**dilation, activation='tanh') - c2 = CausalConv(units, 2, dilation_rate=2**dilation, activation='sigmoid') - tmp = Multiply()([c1(tmp), c2(tmp)]) - tmp = Dense(units, activation='relu')(tmp) + c = GatedConv(units, 2, dilation_rate=2**dilation, activation='tanh') + tmp = Dense(units, activation='relu')(c(tmp)) if k != 0: tmp = Add()([tmp, res])