C implementation of FARGAN

This commit is contained in:
Jean-Marc Valin 2023-10-10 02:18:21 -04:00
parent 9e76a7bfb8
commit 35cb8d7f66
No known key found for this signature in database
GPG key ID: 531A52533318F00A
11 changed files with 487 additions and 6 deletions

View file

@ -9,7 +9,7 @@ set -e
srcdir=`dirname $0`
test -n "$srcdir" && cd "$srcdir"
dnn/download_model.sh f68e31d
dnn/download_model.sh 9e76a7b
echo "Updating build configuration files, please wait...."

220
dnn/fargan.c Normal file
View file

@ -0,0 +1,220 @@
/* Copyright (c) 2023 Amazon */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "fargan.h"
#include "os_support.h"
#include "freq.h"
#include "fargan_data.h"
#include "lpcnet.h"
#include "pitch.h"
#include "nnet.h"
#include "lpcnet_private.h"
#define FARGAN_FEATURES (NB_FEATURES)
static void compute_fargan_cond(FARGANState *st, float *cond, const float *features, int period)
{
FARGAN *model;
float dense_in[NB_FEATURES+COND_NET_PEMBED_OUT_SIZE];
float conv1_in[COND_NET_FCONV1_IN_SIZE];
float conv2_in[COND_NET_FCONV2_IN_SIZE];
model = &st->model;
celt_assert(FARGAN_FEATURES+COND_NET_PEMBED_OUT_SIZE == model->cond_net_fdense1.nb_inputs);
celt_assert(COND_NET_FCONV1_IN_SIZE == model->cond_net_fdense1.nb_outputs);
celt_assert(COND_NET_FCONV2_IN_SIZE == model->cond_net_fconv1.nb_outputs);
OPUS_COPY(&dense_in[NB_FEATURES], &model->cond_net_pembed.float_weights[IMAX(0,IMIN(period-32, 224))*COND_NET_PEMBED_OUT_SIZE], COND_NET_PEMBED_OUT_SIZE);
OPUS_COPY(dense_in, features, NB_FEATURES);
compute_generic_dense(&model->cond_net_fdense1, conv1_in, dense_in, ACTIVATION_TANH);
compute_generic_conv1d(&model->cond_net_fconv1, conv2_in, st->cond_conv1_state, conv1_in, COND_NET_FCONV1_IN_SIZE, ACTIVATION_TANH);
compute_generic_conv1d(&model->cond_net_fconv2, cond, st->cond_conv2_state, conv2_in, COND_NET_FCONV2_IN_SIZE, ACTIVATION_TANH);
}
static void fargan_deemphasis(float *pcm, float *deemph_mem) {
int i;
for (i=0;i<FARGAN_SUBFRAME_SIZE;i++) {
pcm[i] += FARGAN_DEEMPHASIS * *deemph_mem;
*deemph_mem = pcm[i];
}
}
static void run_fargan_subframe(FARGANState *st, float *pcm, const float *cond, int period)
{
int i, pos;
float fwc0_in[SIG_NET_INPUT_SIZE];
float gru1_in[SIG_NET_FWC0_CONV_OUT_SIZE+2*FARGAN_SUBFRAME_SIZE];
float gru2_in[SIG_NET_GRU1_OUT_SIZE+2*FARGAN_SUBFRAME_SIZE];
float gru3_in[SIG_NET_GRU2_OUT_SIZE+2*FARGAN_SUBFRAME_SIZE];
float pred[FARGAN_SUBFRAME_SIZE+4];
float prev[FARGAN_SUBFRAME_SIZE];
float pitch_gate[4];
float gain;
float gain_1;
float skip_cat[10000];
float skip_out[SIG_NET_SKIP_DENSE_OUT_SIZE];
FARGAN *model;
celt_assert(st->cont_initialized);
model = &st->model;
compute_generic_dense(&model->sig_net_cond_gain_dense, &gain, cond, ACTIVATION_LINEAR);
gain = exp(gain);
gain_1 = 1.f/(1e-5 + gain);
pos = PITCH_MAX_PERIOD-period-2;
for (i=0;i<FARGAN_SUBFRAME_SIZE+4;i++) {
pred[i] = MIN32(1.f, MAX32(-1.f, gain_1*st->pitch_buf[IMAX(0, pos)]));
pos++;
if (pos == PITCH_MAX_PERIOD) pos -= period;
}
for (i=0;i<FARGAN_SUBFRAME_SIZE;i++) prev[i] = MAX32(-1.f, MIN16(1.f, gain_1*st->pitch_buf[PITCH_MAX_PERIOD-FARGAN_SUBFRAME_SIZE+i]));
OPUS_COPY(&fwc0_in[0], &cond[0], FARGAN_COND_SIZE);
OPUS_COPY(&fwc0_in[FARGAN_COND_SIZE], pred, FARGAN_SUBFRAME_SIZE+4);
OPUS_COPY(&fwc0_in[FARGAN_COND_SIZE+FARGAN_SUBFRAME_SIZE+4], prev, FARGAN_SUBFRAME_SIZE);
compute_generic_conv1d(&model->sig_net_fwc0_conv, gru1_in, st->fwc0_mem, fwc0_in, SIG_NET_INPUT_SIZE, ACTIVATION_TANH);
celt_assert(SIG_NET_FWC0_GLU_GATE_OUT_SIZE == model->sig_net_fwc0_glu_gate.nb_outputs);
compute_glu(&model->sig_net_fwc0_glu_gate, gru1_in, gru1_in);
compute_generic_dense(&model->sig_net_gain_dense_out, pitch_gate, gru1_in, ACTIVATION_SIGMOID);
for (i=0;i<FARGAN_SUBFRAME_SIZE;i++) gru1_in[SIG_NET_FWC0_GLU_GATE_OUT_SIZE+i] = pitch_gate[0]*pred[i+2];
OPUS_COPY(&gru1_in[SIG_NET_FWC0_GLU_GATE_OUT_SIZE+FARGAN_SUBFRAME_SIZE], prev, FARGAN_SUBFRAME_SIZE);
compute_generic_gru(&model->sig_net_gru1_input, &model->sig_net_gru1_recurrent, st->gru1_state, gru1_in);
compute_glu(&model->sig_net_gru1_glu_gate, gru2_in, st->gru1_state);
for (i=0;i<FARGAN_SUBFRAME_SIZE;i++) gru2_in[SIG_NET_GRU1_OUT_SIZE+i] = pitch_gate[1]*pred[i+2];
OPUS_COPY(&gru2_in[SIG_NET_GRU1_OUT_SIZE+FARGAN_SUBFRAME_SIZE], prev, FARGAN_SUBFRAME_SIZE);
compute_generic_gru(&model->sig_net_gru2_input, &model->sig_net_gru2_recurrent, st->gru2_state, gru2_in);
compute_glu(&model->sig_net_gru2_glu_gate, gru3_in, st->gru2_state);
for (i=0;i<FARGAN_SUBFRAME_SIZE;i++) gru3_in[SIG_NET_GRU2_OUT_SIZE+i] = pitch_gate[2]*pred[i+2];
OPUS_COPY(&gru3_in[SIG_NET_GRU2_OUT_SIZE+FARGAN_SUBFRAME_SIZE], prev, FARGAN_SUBFRAME_SIZE);
compute_generic_gru(&model->sig_net_gru3_input, &model->sig_net_gru3_recurrent, st->gru3_state, gru3_in);
compute_glu(&model->sig_net_gru3_glu_gate, &skip_cat[SIG_NET_GRU1_OUT_SIZE+SIG_NET_GRU2_OUT_SIZE], st->gru3_state);
OPUS_COPY(skip_cat, gru2_in, SIG_NET_GRU1_OUT_SIZE);
OPUS_COPY(&skip_cat[SIG_NET_GRU1_OUT_SIZE], gru3_in, SIG_NET_GRU2_OUT_SIZE);
OPUS_COPY(&skip_cat[SIG_NET_GRU1_OUT_SIZE+SIG_NET_GRU2_OUT_SIZE+SIG_NET_GRU3_OUT_SIZE], gru1_in, SIG_NET_FWC0_CONV_OUT_SIZE);
for (i=0;i<FARGAN_SUBFRAME_SIZE;i++) skip_cat[SIG_NET_GRU1_OUT_SIZE+SIG_NET_GRU2_OUT_SIZE+SIG_NET_GRU3_OUT_SIZE+SIG_NET_FWC0_CONV_OUT_SIZE+i] = pitch_gate[3]*pred[i+2];
OPUS_COPY(&skip_cat[SIG_NET_GRU1_OUT_SIZE+SIG_NET_GRU2_OUT_SIZE+SIG_NET_GRU3_OUT_SIZE+SIG_NET_FWC0_CONV_OUT_SIZE+FARGAN_SUBFRAME_SIZE], prev, FARGAN_SUBFRAME_SIZE);
compute_generic_dense(&model->sig_net_skip_dense, skip_out, skip_cat, ACTIVATION_TANH);
compute_glu(&model->sig_net_skip_glu_gate, skip_out, skip_out);
compute_generic_dense(&model->sig_net_sig_dense_out, pcm, skip_out, ACTIVATION_TANH);
for (i=0;i<FARGAN_SUBFRAME_SIZE;i++) pcm[i] *= gain;
OPUS_MOVE(st->pitch_buf, &st->pitch_buf[FARGAN_SUBFRAME_SIZE], PITCH_MAX_PERIOD-FARGAN_SUBFRAME_SIZE);
OPUS_COPY(&st->pitch_buf[PITCH_MAX_PERIOD-FARGAN_SUBFRAME_SIZE], pcm, FARGAN_SUBFRAME_SIZE);
fargan_deemphasis(pcm, &st->deemph_mem);
}
void fargan_cont(FARGANState *st, const float *pcm0, const float *features0)
{
int i;
float cond[COND_NET_FCONV2_OUT_SIZE];
float x0[FARGAN_CONT_SAMPLES];
float dummy[FARGAN_SUBFRAME_SIZE];
int period=0;
/* Pre-load features. */
for (i=0;i<5;i++) {
const float *features = &features0[i*NB_FEATURES];
st->last_period = period;
period = (int)floor(.5+256./pow(2.f,((1./60.)*((features[NB_BANDS]+1.5)*60))));
compute_fargan_cond(st, cond, features, period);
}
x0[0] = 0;
for (i=1;i<FARGAN_CONT_SAMPLES;i++) {
x0[i] = pcm0[i] - FARGAN_DEEMPHASIS*pcm0[i-1];
}
OPUS_COPY(&st->pitch_buf[PITCH_MAX_PERIOD-FARGAN_FRAME_SIZE], x0, FARGAN_FRAME_SIZE);
st->cont_initialized = 1;
for (i=0;i<FARGAN_NB_SUBFRAMES;i++) {
run_fargan_subframe(st, dummy, &cond[i*FARGAN_COND_SIZE], st->last_period);
OPUS_COPY(&st->pitch_buf[PITCH_MAX_PERIOD-FARGAN_SUBFRAME_SIZE], &x0[FARGAN_FRAME_SIZE+i*FARGAN_SUBFRAME_SIZE], FARGAN_SUBFRAME_SIZE);
}
st->deemph_mem = pcm0[FARGAN_CONT_SAMPLES-1];
}
void fargan_init(FARGANState *st)
{
int ret;
OPUS_CLEAR(st, 1);
ret = init_fargan(&st->model, fargan_arrays);
celt_assert(ret == 0);
/* FIXME: perform arch detection. */
}
int fargan_load_model(FARGANState *st, const unsigned char *data, int len) {
WeightArray *list;
int ret;
parse_weights(&list, data, len);
ret = init_fargan(&st->model, list);
free(list);
if (ret == 0) return 0;
else return -1;
}
static void fargan_synthesize_impl(FARGANState *st, float *pcm, const float *features)
{
int subframe;
float cond[COND_NET_FCONV2_OUT_SIZE];
int period;
celt_assert(st->cont_initialized);
period = (int)floor(.5+256./pow(2.f,((1./60.)*((features[NB_BANDS]+1.5)*60))));
compute_fargan_cond(st, cond, features, period);
for (subframe=0;subframe<FARGAN_NB_SUBFRAMES;subframe++) {
float *sub_cond;
sub_cond = &cond[subframe*FARGAN_COND_SIZE];
run_fargan_subframe(st, &pcm[subframe*FARGAN_SUBFRAME_SIZE], sub_cond, st->last_period);
}
st->last_period = period;
}
void fargan_synthesize(FARGANState *st, float *pcm, const float *features)
{
fargan_synthesize_impl(st, pcm, features);
}
void fargan_synthesize_int(FARGANState *st, opus_int16 *pcm, const float *features)
{
int i;
float fpcm[FARGAN_FRAME_SIZE];
fargan_synthesize(st, fpcm, features);
for (i=0;i<LPCNET_FRAME_SIZE;i++) pcm[i] = (int)floor(.5 + MIN32(32767, MAX32(-32767, 32768.f*fpcm[i])));
}

68
dnn/fargan.h Normal file
View file

@ -0,0 +1,68 @@
/* Copyright (c) 2023 Amazon */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef FARGAN_H
#define FARGAN_H
#include "freq.h"
#include "fargan_data.h"
#include "pitchdnn.h"
#define FARGAN_CONT_SAMPLES 320
#define FARGAN_NB_SUBFRAMES 4
#define FARGAN_SUBFRAME_SIZE 40
#define FARGAN_FRAME_SIZE (FARGAN_NB_SUBFRAMES*FARGAN_SUBFRAME_SIZE)
#define FARGAN_COND_SIZE (COND_NET_FCONV2_OUT_SIZE/FARGAN_NB_SUBFRAMES)
#define FARGAN_DEEMPHASIS 0.85f
#define SIG_NET_INPUT_SIZE (FARGAN_COND_SIZE+2*FARGAN_SUBFRAME_SIZE+4)
#define SIG_NET_FWC0_STATE_SIZE (2*SIG_NET_INPUT_SIZE)
typedef struct {
FARGAN model;
int arch;
int cont_initialized;
float deemph_mem;
float pitch_buf[PITCH_MAX_PERIOD];
float cond_conv1_state[COND_NET_FCONV1_STATE_SIZE];
float cond_conv2_state[COND_NET_FCONV2_STATE_SIZE];
float fwc0_mem[SIG_NET_FWC0_STATE_SIZE];
float gru1_state[SIG_NET_GRU1_STATE_SIZE];
float gru2_state[SIG_NET_GRU2_STATE_SIZE];
float gru3_state[SIG_NET_GRU3_STATE_SIZE];
int last_period;
} FARGANState;
void fargan_init(FARGANState *st);
int fargan_load_model(FARGANState *st, const unsigned char *data, int len);
void fargan_cont(FARGANState *st, const float *pcm0, const float *features0);
void fargan_synthesize(FARGANState *st, float *pcm, const float *features);
void fargan_synthesize_int(FARGANState *st, opus_int16 *pcm, const float *features);
#endif /* FARGAN_H */

View file

@ -37,6 +37,7 @@
#include "freq.h"
#include "os_support.h"
#include "fwgan.h"
#include "fargan.h"
#ifdef USE_WEIGHTS_FILE
# if __unix__
@ -86,6 +87,7 @@ void free_blob(unsigned char *blob, int len) {
#define MODE_PLC 4
#define MODE_ADDLPC 5
#define MODE_FWGAN_SYNTHESIS 6
#define MODE_FARGAN_SYNTHESIS 7
void usage(void) {
fprintf(stderr, "usage: lpcnet_demo -features <input.pcm> <features.f32>\n");
@ -115,6 +117,7 @@ int main(int argc, char **argv) {
if (strcmp(argv[1], "-features") == 0) mode=MODE_FEATURES;
else if (strcmp(argv[1], "-synthesis") == 0) mode=MODE_SYNTHESIS;
else if (strcmp(argv[1], "-fwgan-synthesis") == 0) mode=MODE_FWGAN_SYNTHESIS;
else if (strcmp(argv[1], "-fargan-synthesis") == 0) mode=MODE_FARGAN_SYNTHESIS;
else if (strcmp(argv[1], "-plc") == 0) {
mode=MODE_PLC;
plc_options = argv[2];
@ -210,6 +213,32 @@ int main(int argc, char **argv) {
for (i=0;i<LPCNET_FRAME_SIZE;i++) pcm[i] = (int)floor(.5 + MIN32(32767, MAX32(-32767, 32768.f*fpcm[i])));
fwrite(pcm, sizeof(pcm[0]), LPCNET_FRAME_SIZE, fout);
}
} else if (mode == MODE_FARGAN_SYNTHESIS) {
FARGANState fargan;
size_t ret, i;
float in_features[5*NB_TOTAL_FEATURES];
float zeros[320] = {0};
fargan_init(&fargan);
#ifdef USE_WEIGHTS_FILE
fargan_load_model(fwgan, data, len);
#endif
/* uncomment the following to align with Python code */
/*ret = fread(&in_features[0], sizeof(in_features[0]), NB_TOTAL_FEATURES, fin);*/
for (i=0;i<5;i++) {
ret = fread(&in_features[i*NB_FEATURES], sizeof(in_features[0]), NB_TOTAL_FEATURES, fin);
}
fargan_cont(&fargan, zeros, in_features);
while (1) {
float features[NB_FEATURES];
float fpcm[LPCNET_FRAME_SIZE];
opus_int16 pcm[LPCNET_FRAME_SIZE];
ret = fread(in_features, sizeof(features[0]), NB_TOTAL_FEATURES, fin);
if (feof(fin) || ret != NB_TOTAL_FEATURES) break;
OPUS_COPY(features, in_features, NB_FEATURES);
fargan_synthesize(&fargan, fpcm, features);
for (i=0;i<LPCNET_FRAME_SIZE;i++) pcm[i] = (int)floor(.5 + MIN32(32767, MAX32(-32767, 32768.f*fpcm[i])));
fwrite(pcm, sizeof(pcm[0]), LPCNET_FRAME_SIZE, fout);
}
} else if (mode == MODE_PLC) {
opus_int16 pcm[FRAME_SIZE];
int count=0;

View file

@ -143,6 +143,16 @@ void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *re
state[i] = h[i];
}
void compute_glu(const LinearLayer *layer, float *output, const float *input)
{
int i;
float act2[MAX_INPUTS];
celt_assert(layer->nb_inputs == layer->nb_outputs);
compute_linear(layer, act2, input);
compute_activation(act2, act2, layer->nb_outputs, ACTIVATION_SIGMOID);
for (i=0;i<layer->nb_outputs;i++) output[i] = input[i]*act2[i];
}
void compute_gated_activation(const LinearLayer *layer, float *output, const float *input, int activation)
{
int i;

View file

@ -146,6 +146,7 @@ void compute_generic_dense(const LinearLayer *layer, float *output, const float
void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in);
void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int activation);
void compute_generic_conv1d_dilation(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int dilation, int activation);
void compute_glu(const LinearLayer *layer, float *output, const float *input);
void compute_gated_activation(const LinearLayer *layer, float *output, const float *input, int activation);
void compute_activation(float *output, const float *input, int N, int activation);
@ -176,6 +177,7 @@ extern const WeightArray lpcnet_plc_arrays[];
extern const WeightArray rdovaeenc_arrays[];
extern const WeightArray rdovaedec_arrays[];
extern const WeightArray fwgan_arrays[];
extern const WeightArray fargan_arrays[];
extern const WeightArray pitchdnn_arrays[];
int linear_init(LinearLayer *layer, const WeightArray *arrays,

View file

@ -0,0 +1,112 @@
import os
import sys
import argparse
import torch
from torch import nn
sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange'))
import wexchange.torch
import fargan
#from models import model_dict
unquantized = [ 'cond_net.pembed', 'cond_net.fdense1', 'sig_net.cond_gain_dense', 'sig_net.gain_dense_out' ]
unquantized2 = [
'cond_net.pembed',
'cond_net.fdense1',
'cond_net.fconv1',
'cond_net.fconv2',
'cont_net.0',
'sig_net.cond_gain_dense',
'sig_net.fwc0.conv',
'sig_net.fwc0.glu.gate',
'sig_net.dense1_glu.gate',
'sig_net.gru1_glu.gate',
'sig_net.gru2_glu.gate',
'sig_net.gru3_glu.gate',
'sig_net.skip_glu.gate',
'sig_net.skip_dense',
'sig_net.sig_dense_out',
'sig_net.gain_dense_out'
]
description=f"""
This is an unsafe dumping script for FARGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer
and will fail to export any other weights.
Furthermore, the quanitze option relies on the following explicit list of layers to be excluded:
{unquantized}.
Modify this script manually if adjustments are needed.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument('weightfile', type=str, help='weight file path')
parser.add_argument('export_folder', type=str)
parser.add_argument('--export-filename', type=str, default='fargan_data', help='filename for source and header file (.c and .h will be added), defaults to fargan_data')
parser.add_argument('--struct-name', type=str, default='FARGAN', help='name for C struct, defaults to FARGAN')
parser.add_argument('--quantize', action='store_true', help='apply quantization')
if __name__ == "__main__":
args = parser.parse_args()
print(f"loading weights from {args.weightfile}...")
saved_gen= torch.load(args.weightfile, map_location='cpu')
saved_gen['model_args'] = ()
saved_gen['model_kwargs'] = {'cond_size': 256, 'gamma': 0.9}
model = fargan.FARGAN(*saved_gen['model_args'], **saved_gen['model_kwargs'])
model.load_state_dict(saved_gen['state_dict'], strict=False)
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
model.apply(_remove_weight_norm)
print("dumping model...")
quantize_model=args.quantize
output_folder = args.export_folder
os.makedirs(output_folder, exist_ok=True)
writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name)
for name, module in model.named_modules():
if quantize_model:
quantize=name not in unquantized
scale = None if quantize else 1/128
else:
quantize=False
scale=1/128
if isinstance(module, nn.Linear):
print(f"dumping linear layer {name}...")
wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
elif isinstance(module, nn.Conv1d):
print(f"dumping conv1d layer {name}...")
wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
elif isinstance(module, nn.GRU):
print(f"dumping GRU layer {name}...")
wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
elif isinstance(module, nn.GRUCell):
print(f"dumping GRUCell layer {name}...")
wexchange.torch.dump_torch_grucell_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
elif isinstance(module, nn.Embedding):
print(f"dumping Embedding layer {name}...")
wexchange.torch.dump_torch_embedding_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
#wexchange.torch.dump_torch_embedding_weights(writer, module)
else:
print(f"Ignoring layer {name}...")
writer.close()

View file

@ -31,5 +31,6 @@ from .torch import dump_torch_conv1d_weights, load_torch_conv1d_weights
from .torch import dump_torch_conv2d_weights, load_torch_conv2d_weights
from .torch import dump_torch_dense_weights, load_torch_dense_weights
from .torch import dump_torch_gru_weights, load_torch_gru_weights
from .torch import dump_torch_grucell_weights
from .torch import dump_torch_embedding_weights, load_torch_embedding_weights
from .torch import dump_torch_weights, load_torch_weights

View file

@ -61,6 +61,30 @@ def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent
np.save(os.path.join(where, 'bias_hh_rzn.npy'), b_hh)
def dump_torch_grucell_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128):
w_ih = gru.weight_ih.detach().cpu().numpy().copy()
w_hh = gru.weight_hh.detach().cpu().numpy().copy()
if hasattr(gru, 'bias_ih') and gru.bias_ih is not None:
b_ih = gru.bias_ih.detach().cpu().numpy().copy()
else:
b_ih = None
if hasattr(gru, 'bias_hh') and gru.bias_hh is not None:
b_hh = gru.bias_hh.detach().cpu().numpy().copy()
else:
b_hh = None
if isinstance(where, CWriter):
return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, format='torch', input_sparse=input_sparse, recurrent_sparse=recurrent_sparse, quantize=quantize, scale=scale, recurrent_scale=recurrent_scale)
else:
os.makedirs(where, exist_ok=True)
np.save(os.path.join(where, 'weight_ih_rzn.npy'), w_ih)
np.save(os.path.join(where, 'weight_hh_rzn.npy'), w_hh)
np.save(os.path.join(where, 'bias_ih_rzn.npy'), b_ih)
np.save(os.path.join(where, 'bias_hh_rzn.npy'), b_hh)
def load_torch_gru_weights(where, gru):
@ -165,11 +189,20 @@ def load_torch_conv2d_weights(where, conv):
conv.bias.set_(torch.from_numpy(b))
def dump_torch_embedding_weights(where, emb):
os.makedirs(where, exist_ok=True)
def dump_torch_embedding_weights(where, embed, name='embed', scale=1/128, sparse=False, diagonal=False, quantize=False):
w = emb.weight.detach().cpu().numpy().copy()
np.save(os.path.join(where, 'weight.npy'), w)
print("quantize = ", quantize)
w = embed.weight.detach().cpu().numpy().copy().transpose()
b = np.zeros(w.shape[0], dtype=w.dtype)
if isinstance(where, CWriter):
return print_dense_layer(where, name, w, b, scale=scale, format='torch', sparse=sparse, diagonal=diagonal, quantize=quantize)
else:
os.makedirs(where, exist_ok=True)
np.save(os.path.join(where, 'weight.npy'), w)
np.save(os.path.join(where, 'bias.npy'), b)
def load_torch_embedding_weights(where, emb):
@ -187,6 +220,8 @@ def dump_torch_weights(where, module, name=None, verbose=False, **kwargs):
return dump_torch_dense_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.GRU):
return dump_torch_gru_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.GRUCell):
return dump_torch_grucell_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.Conv1d):
return dump_torch_conv1d_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.Conv2d):
@ -209,4 +244,4 @@ def load_torch_weights(where, module):
elif isinstance(module, torch.nn.Embedding):
load_torch_embedding_weights(where, module)
else:
raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported')
raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported')

View file

@ -8,6 +8,8 @@ dnn/lpcnet.h \
dnn/burg.h \
dnn/common.h \
dnn/freq.h \
dnn/fargan.h \
dnn/fargan_data.h \
dnn/fwgan.h \
dnn/fwgan_data.h \
dnn/kiss99.h \

View file

@ -1,6 +1,8 @@
LPCNET_SOURCES = \
dnn/burg.c \
dnn/freq.c \
dnn/fargan.c \
dnn/fargan_data.c \
dnn/fwgan.c \
dnn/fwgan_data.c \
dnn/kiss99.c \