Gated convolution

This commit is contained in:
Jean-Marc Valin 2018-07-13 17:10:03 -04:00
parent 0fa7150454
commit 211435f5d3
2 changed files with 66 additions and 5 deletions

View file

@ -9,8 +9,9 @@ import numpy as np
import h5py
import sys
from causalconv import CausalConv
from gatedconv import GatedConv
units=256
units=128
pcm_bits = 8
pcm_levels = 2**pcm_bits
nb_used_features = 38
@ -37,10 +38,8 @@ def new_wavenet_model(fftnet=False):
res = tmp
tmp = Concatenate()([tmp, rfeat])
dilation = 9-k if fftnet else k
c1 = CausalConv(units, 2, dilation_rate=2**dilation, activation='tanh')
c2 = CausalConv(units, 2, dilation_rate=2**dilation, activation='sigmoid')
tmp = Multiply()([c1(tmp), c2(tmp)])
tmp = Dense(units, activation='relu')(tmp)
c = GatedConv(units, 2, dilation_rate=2**dilation, activation='tanh')
tmp = Dense(units, activation='relu')(c(tmp))
if k != 0:
tmp = Add()([tmp, res])