From b6af21f31c9309e1ee031bb4c0cf51b213cf7f3c Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Mon, 23 Jul 2018 17:05:21 -0400 Subject: [PATCH] wip... --- dnn/gatedconv.py | 7 +++++-- dnn/wavenet.py | 15 ++++++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/dnn/gatedconv.py b/dnn/gatedconv.py index 61f7383e..5d15c806 100644 --- a/dnn/gatedconv.py +++ b/dnn/gatedconv.py @@ -1,6 +1,6 @@ from keras import backend as K from keras.engine.topology import Layer -from keras.layers import activations, initializers, regularizers, constraints, InputSpec, Conv1D +from keras.layers import activations, initializers, regularizers, constraints, InputSpec, Conv1D, Dense import numpy as np class GatedConv(Conv1D): @@ -42,13 +42,16 @@ class GatedConv(Conv1D): self.out_dims = filters self.nongate_activation = activations.get(activation) - def call(self, inputs, memory=None): + def call(self, inputs, cond=None, memory=None): if memory is None: mem = K.zeros((K.shape(inputs)[0], self.mem_size, K.shape(inputs)[-1])) else: mem = K.variable(K.cast_to_floatx(memory)) inputs = K.concatenate([mem, inputs], axis=1) ret = super(GatedConv, self).call(inputs) + if cond is not None: + d = Dense(2*self.out_dims, use_bias=False, activation='linear') + ret = ret + d(cond) ret = self.nongate_activation(ret[:, :, :self.out_dims]) * activations.sigmoid(ret[:, :, self.out_dims:]) if self.return_memory: ret = ret, inputs[:, :self.mem_size, :] diff --git a/dnn/wavenet.py b/dnn/wavenet.py index 96853653..1b53548b 100644 --- a/dnn/wavenet.py +++ b/dnn/wavenet.py @@ -4,6 +4,7 @@ import math from keras.models import Model from keras.layers import Input, LSTM, CuDNNGRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Add, Multiply, Bidirectional, MaxPooling1D, Activation from keras import backend as K +from keras.initializers import VarianceScaling from mdense import MDense import numpy as np import h5py @@ -34,12 +35,20 @@ def new_wavenet_model(fftnet=False): rfeat = rep(cfeat) #tmp = Concatenate()([pcm, rfeat]) tmp = pcm + init = VarianceScaling(scale=1.5,mode='fan_avg',distribution='uniform') for k in range(10): res = tmp - tmp = Concatenate()([tmp, rfeat]) dilation = 9-k if fftnet else k - c = GatedConv(units, 2, dilation_rate=2**dilation, activation='tanh') - tmp = Dense(units, activation='relu')(c(tmp)) + '''#tmp = Concatenate()([tmp, rfeat]) + c = GatedConv(units, 2, dilation_rate=2**dilation, activation='tanh', kernel_initializer=init) + tmp = Dense(units, activation='relu')(c(tmp, cond=rfeat))''' + + tmp = Concatenate()([tmp, rfeat]) + c1 = CausalConv(units, 2, dilation_rate=2**dilation, activation='tanh') + c2 = CausalConv(units, 2, dilation_rate=2**dilation, activation='sigmoid') + tmp = Multiply()([c1(tmp), c2(tmp)]) + tmp = Dense(units, activation='relu')(tmp) + if k != 0: tmp = Add()([tmp, res])