Add input embedding

This commit is contained in:
Jean-Marc Valin 2018-07-27 16:33:01 -04:00
parent 1837dad072
commit 2aba2a9c49
2 changed files with 30 additions and 3 deletions

View file

@ -31,7 +31,7 @@ feature_chunk_size = 15
pcm_chunk_size = frame_size*feature_chunk_size pcm_chunk_size = frame_size*feature_chunk_size
data = np.fromfile(pcmfile, dtype='int16') data = np.fromfile(pcmfile, dtype='int16')
data = np.minimum(127, lin2ulaw(data[160:]/32768.)) data = np.minimum(127, lin2ulaw(data[80:]/32768.))
nb_frames = len(data)//pcm_chunk_size nb_frames = len(data)//pcm_chunk_size
features = np.fromfile(feature_file, dtype='float32') features = np.fromfile(feature_file, dtype='float32')
@ -39,7 +39,7 @@ features = np.fromfile(feature_file, dtype='float32')
data = data[:nb_frames*pcm_chunk_size] data = data[:nb_frames*pcm_chunk_size]
features = features[:nb_frames*feature_chunk_size*nb_features] features = features[:nb_frames*feature_chunk_size*nb_features]
in_data = np.concatenate([data[0:1], data[:-1]])/16.; in_data = np.concatenate([data[0:1], data[:-1]]);
features = np.reshape(features, (nb_frames*feature_chunk_size, nb_features)) features = np.reshape(features, (nb_frames*feature_chunk_size, nb_features))
pitch = 1.*data pitch = 1.*data
@ -51,6 +51,7 @@ for i in range(2, nb_frames*feature_chunk_size):
in_pitch = np.reshape(pitch/16., (nb_frames, pcm_chunk_size, 1)) in_pitch = np.reshape(pitch/16., (nb_frames, pcm_chunk_size, 1))
in_data = np.reshape(in_data, (nb_frames, pcm_chunk_size, 1)) in_data = np.reshape(in_data, (nb_frames, pcm_chunk_size, 1))
in_data = (in_data.astype('int16')+128).astype('uint8')
out_data = np.reshape(data, (nb_frames, pcm_chunk_size, 1)) out_data = np.reshape(data, (nb_frames, pcm_chunk_size, 1))
out_data = (out_data.astype('int16')+128).astype('uint8') out_data = (out_data.astype('int16')+128).astype('uint8')
features = np.reshape(features, (nb_frames, feature_chunk_size, nb_features)) features = np.reshape(features, (nb_frames, feature_chunk_size, nb_features))

View file

@ -4,6 +4,7 @@ import math
from keras.models import Model from keras.models import Model
from keras.layers import Input, LSTM, CuDNNGRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Add, Multiply, Bidirectional, MaxPooling1D, Activation from keras.layers import Input, LSTM, CuDNNGRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Add, Multiply, Bidirectional, MaxPooling1D, Activation
from keras import backend as K from keras import backend as K
from keras.initializers import Initializer
from keras.initializers import VarianceScaling from keras.initializers import VarianceScaling
from mdense import MDense from mdense import MDense
import numpy as np import numpy as np
@ -17,6 +18,30 @@ pcm_bits = 8
pcm_levels = 2**pcm_bits pcm_levels = 2**pcm_bits
nb_used_features = 38 nb_used_features = 38
class PCMInit(Initializer):
def __init__(self, gain=.1, seed=None):
self.gain = gain
self.seed = seed
def __call__(self, shape, dtype=None):
num_rows = 1
for dim in shape[:-1]:
num_rows *= dim
num_cols = shape[-1]
flat_shape = (num_rows, num_cols)
if self.seed is not None:
np.random.seed(self.seed)
a = np.random.uniform(-1.7321, 1.7321, flat_shape)
#a[:,0] = math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows
#a[:,1] = .5*a[:,0]*a[:,0]*a[:,0]
a = a + np.reshape(math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows, (num_rows, 1))
return self.gain * a
def get_config(self):
return {
'gain': self.gain,
'seed': self.seed
}
def new_wavenet_model(fftnet=False): def new_wavenet_model(fftnet=False):
pcm = Input(shape=(None, 1)) pcm = Input(shape=(None, 1))
@ -34,7 +59,8 @@ def new_wavenet_model(fftnet=False):
activation='tanh' activation='tanh'
rfeat = rep(cfeat) rfeat = rep(cfeat)
#tmp = Concatenate()([pcm, rfeat]) #tmp = Concatenate()([pcm, rfeat])
tmp = pcm embed = Embedding(256, units, embeddings_initializer=PCMInit())
tmp = Reshape((-1, units))(embed(pcm))
init = VarianceScaling(scale=1.5,mode='fan_avg',distribution='uniform') init = VarianceScaling(scale=1.5,mode='fan_avg',distribution='uniform')
for k in range(10): for k in range(10):
res = tmp res = tmp