added pytorch implementation of RDOVAE

This commit is contained in:
jbuethe 2022-11-23 11:02:29 +00:00
parent a13aa3a077
commit fdb04d0eef
14 changed files with 1880 additions and 0 deletions

View file

@ -0,0 +1,24 @@
# Rate-Distortion-Optimized Variational Auto-Encoder
## Setup
The python code requires python >= 3.6 and has been tested with python 3.6 and python 3.10. To install requirements run
```
python -m pip install -r requirements.txt
```
## Training
To generate training data use dump date from the main LPCNet repo
```
./dump_data -train 16khz_speech_input.s16 features.f32 data.s16
```
To train the model, simply run
```
python train_rdovae.py features.f32 output_folder
```
To train on CUDA device add `--cuda-visible-devices idx`.
## ToDo
- Upload checkpoints and add URLs

View file

@ -0,0 +1,256 @@
"""
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', type=str, help='rdovae model checkpoint')
parser.add_argument('output_dir', type=str, help='output folder')
parser.add_argument('--format', choices=['C', 'numpy'], help='output format, default: C', default='C')
args = parser.parse_args()
import torch
import numpy as np
from rdovae import RDOVAE
from wexchange.torch import dump_torch_weights
from wexchange.c_export import CWriter, print_vector
def dump_statistical_model(writer, qembedding):
w = qembedding.weight.detach()
levels, dim = w.shape
N = dim // 6
print("printing statistical model")
quant_scales = torch.nn.functional.softplus(w[:, : N]).numpy()
dead_zone = 0.05 * torch.nn.functional.softplus(w[:, N : 2 * N]).numpy()
r = torch.sigmoid(w[:, 5 * N : 6 * N]).numpy()
p0 = torch.sigmoid(w[:, 4 * N : 5 * N]).numpy()
p0 = 1 - r ** (0.5 + 0.5 * p0)
quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16)
r_q15 = np.round(r * 2**15).astype(np.uint16)
p0_q15 = np.round(p0 * 2**15).astype(np.uint16)
print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
print_vector(writer.source, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False)
print_vector(writer.source, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False)
writer.header.write(
f"""
extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
extern const opus_uint16 dred_r_q15[{levels * N}];
extern const opus_uint16 dred_p0_q15[{levels * N}];
"""
)
def c_export(args, model):
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message)
dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message)
stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats"), message=message)
constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True)
# some custom includes
for writer in [enc_writer, dec_writer, stats_writer]:
writer.header.write(
f"""
#include "opus_types.h"
#include "dred_rdovae_constants.h"
#include "nnet.h"
"""
)
# encoder
encoder_dense_layers = [
('core_encoder.module.dense_1' , 'enc_dense1', 'TANH'),
('core_encoder.module.dense_2' , 'enc_dense3', 'TANH'),
('core_encoder.module.dense_3' , 'enc_dense5', 'TANH'),
('core_encoder.module.dense_4' , 'enc_dense7', 'TANH'),
('core_encoder.module.dense_5' , 'enc_dense8', 'TANH'),
('core_encoder.module.state_dense_1' , 'gdense1' , 'TANH'),
('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH')
]
for name, export_name, activation in encoder_dense_layers:
layer = model.get_submodule(name)
dump_torch_weights(enc_writer, layer, name=export_name, activation=activation, verbose=True)
encoder_gru_layers = [
('core_encoder.module.gru_1' , 'enc_dense2', 'TANH'),
('core_encoder.module.gru_2' , 'enc_dense4', 'TANH'),
('core_encoder.module.gru_3' , 'enc_dense6', 'TANH')
]
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_gru_layers])
encoder_conv_layers = [
('core_encoder.module.conv1' , 'bits_dense' , 'LINEAR')
]
enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_conv_layers])
del enc_writer
# decoder
decoder_dense_layers = [
('core_decoder.module.gru_1_init' , 'state1', 'TANH'),
('core_decoder.module.gru_2_init' , 'state2', 'TANH'),
('core_decoder.module.gru_3_init' , 'state3', 'TANH'),
('core_decoder.module.dense_1' , 'dec_dense1', 'TANH'),
('core_decoder.module.dense_2' , 'dec_dense3', 'TANH'),
('core_decoder.module.dense_3' , 'dec_dense5', 'TANH'),
('core_decoder.module.dense_4' , 'dec_dense7', 'TANH'),
('core_decoder.module.dense_5' , 'dec_dense8', 'TANH'),
('core_decoder.module.output' , 'dec_final', 'LINEAR')
]
for name, export_name, activation in decoder_dense_layers:
layer = model.get_submodule(name)
dump_torch_weights(dec_writer, layer, name=export_name, activation=activation, verbose=True)
decoder_gru_layers = [
('core_decoder.module.gru_1' , 'dec_dense2', 'TANH'),
('core_decoder.module.gru_2' , 'dec_dense4', 'TANH'),
('core_decoder.module.gru_3' , 'dec_dense6', 'TANH')
]
dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in decoder_gru_layers])
del dec_writer
# statistical model
qembedding = model.statistical_model.quant_embedding
dump_statistical_model(stats_writer, qembedding)
del stats_writer
# constants
constants_writer.header.write(
f"""
#define DRED_NUM_FEATURES {model.feature_dim}
#define DRED_LATENT_DIM {model.latent_dim}
#define DRED_STATE_DIME {model.state_dim}
#define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels}
#define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)}
#define DRED_MAX_CONV_INPUTS {enc_max_conv_inputs}
#define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs}
#define DRED_ENC_MAX_CONV_INPUTS {enc_max_conv_inputs}
#define DRED_DEC_MAX_RNN_NEURONS {dec_max_rnn_units}
"""
)
del constants_writer
def numpy_export(args, model):
exchange_name_to_name = {
'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1',
'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2',
'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3',
'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4',
'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5',
'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1',
'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2',
'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1',
'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2',
'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3',
'encoder_stack_layer9_conv' : 'core_encoder.module.conv1',
'statistical_model_embedding' : 'statistical_model.quant_embedding',
'decoder_state1_dense' : 'core_decoder.module.gru_1_init',
'decoder_state2_dense' : 'core_decoder.module.gru_2_init',
'decoder_state3_dense' : 'core_decoder.module.gru_3_init',
'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1',
'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2',
'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3',
'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4',
'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5',
'decoder_stack_layer9_dense' : 'core_decoder.module.output',
'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1',
'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2',
'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3'
}
name_to_exchange_name = {value : key for key, value in exchange_name_to_name.items()}
for name, exchange_name in name_to_exchange_name.items():
print(f"printing layer {name}...")
dump_torch_weights(os.path.join(args.output_dir, exchange_name), model.get_submodule(name))
if __name__ == "__main__":
os.makedirs(args.output_dir, exist_ok=True)
# load model from checkpoint
checkpoint = torch.load(args.checkpoint, map_location='cpu')
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
if len(missing_keys) > 0:
raise ValueError(f"error: missing keys in state dict")
if len(unmatched_keys) > 0:
print(f"warning: the following keys were unmatched {unmatched_keys}")
if args.format == 'C':
c_export(args, model)
elif args.format == 'numpy':
numpy_export(args, model)
else:
raise ValueError(f'error: unknown export format {args.format}')

View file

@ -0,0 +1,213 @@
"""
/* Copyright (c) 2022 Amazon
Written by Jan Buethe and Jean-Marc Valin */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import subprocess
import argparse
os.environ['CUDA_VISIBLE_DEVICES'] = ""
parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames')
parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)')
parser.add_argument('checkpoint', metavar='<weights>', help='model checkpoint')
parser.add_argument('q0', metavar='<quant level 0>', type=int, help='quantization level for most recent frame')
parser.add_argument('q1', metavar='<quant level 1>', type=int, help='quantization level for oldest frame')
parser.add_argument('output', type=str, help='output file (will be extended with .fec)')
parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)')
parser.add_argument('--num-redundancy-frames', default=52, type=int, help='number of redundancy frames per packet (default 52)')
parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)")
parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)')
parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk')
args = parser.parse_args()
import numpy as np
from scipy.io import wavfile
import torch
from rdovae import RDOVAE
from packets import write_fec_packets
torch.set_num_threads(4)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.to("cpu")
lpc_order = 16
## prepare input signal
# SILK frame size is 20ms and LPCNet subframes are 10ms
subframe_size = 160
frame_size = 2 * subframe_size
# 91 samples delay to align with SILK decoded frames
silk_delay = 91
# prepend zeros to have enough history to produce the first package
zero_history = (args.num_redundancy_frames - 1) * frame_size
# dump data has a (feature) delay of 10ms
dump_data_delay = 160
total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay
# load signal
if args.input.endswith('.raw') or args.input.endswith('.pcm'):
signal = np.fromfile(args.input, dtype='int16')
elif args.input.endswith('.wav'):
fs, signal = wavfile.read(args.input)
else:
raise ValueError(f'unknown input signal format: {args.input}')
# fill up last frame with zeros
padded_signal_length = len(signal) + total_delay
tail = padded_signal_length % frame_size
right_padding = (frame_size - tail) % frame_size
signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16)))
padded_signal_file = os.path.splitext(args.input)[0] + '_padded.raw'
signal.tofile(padded_signal_file)
# write signal and call dump_data to create features
feature_file = os.path.splitext(args.input)[0] + '_features.f32'
command = f"{args.dump_data} -test {padded_signal_file} {feature_file}"
r = subprocess.run(command, shell=True)
if r.returncode != 0:
raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}")
# load features
nb_features = model.feature_dim + lpc_order
nb_used_features = model.feature_dim
# load features
features = np.fromfile(feature_file, dtype='float32')
num_subframes = len(features) // nb_features
num_subframes = 2 * (num_subframes // 2)
num_frames = num_subframes // 2
features = np.reshape(features, (1, -1, nb_features))
features = features[:, :, :nb_used_features]
features = features[:, :num_subframes, :]
# quant_ids in reverse decoding order
quant_ids = torch.round((args.q1 + (args.q0 - args.q1) * torch.arange(args.num_redundancy_frames // 2) / (args.num_redundancy_frames // 2 - 1))).long()
print(f"using quantization levels {quant_ids}...")
# convert input to torch tensors
features = torch.from_numpy(features)
# run encoder
print("running fec encoder...")
with torch.no_grad():
# encoding
z, states, state_size = model.encode(features)
# decoder on packet chunks
input_length = args.num_redundancy_frames // 2
offset = args.num_redundancy_frames - 1
packets = []
packet_sizes = []
for i in range(offset, num_frames):
print(f"processing frame {i - offset}...")
# quantize / unquantize latent vectors
zi = torch.clone(z[:, i - 2 * input_length + 2: i + 1 : 2, :])
zi, rates = model.quantize(zi, quant_ids)
zi = model.unquantize(zi, quant_ids)
features = model.decode(zi, states[:, i : i + 1, :])
packets.append(features.squeeze(0).numpy())
packet_size = 8 * int((torch.sum(rates) + 7 + state_size) / 8)
packet_sizes.append(packet_size)
# write packets
packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output
write_fec_packets(packet_file, packets, packet_sizes)
print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps")
# assemble features according to loss file
if args.lossfile != None:
num_packets = len(packets)
loss = np.loadtxt(args.lossfile, dtype='int16')
fec_out = np.zeros((num_packets * 2, packets[0].shape[-1]), dtype='float32')
foffset = -2
ptr = 0
count = 2
for i in range(num_packets):
if (loss[i] == 0) or (i == num_packets - 1):
fec_out[ptr:ptr+count,:] = packets[i][foffset:, :]
ptr += count
foffset = -2
count = 2
else:
count += 2
foffset -= 2
fec_out_full = np.zeros((fec_out.shape[0], 36), dtype=np.float32)
fec_out_full[:, : fec_out.shape[-1]] = fec_out
fec_out_full.tofile(packet_file[:-4] + f'_fec.f32')
if args.debug_output:
import itertools
batches = [4]
offsets = [0, 2 * args.num_redundancy_frames - 4]
# sanity checks
# 1. concatenate features at offset 0
for batch, offset in itertools.product(batches, offsets):
stop = packets[0].shape[1] - offset
test_features = np.concatenate([packet[stop - batch: stop, :] for packet in packets[::batch//2]], axis=0)
test_features_full = np.zeros((test_features.shape[0], nb_features), dtype=np.float32)
test_features_full[:, :nb_used_features] = test_features[:, :]
print(f"writing debug output {packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32'}")
test_features_full.tofile(packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32')

View file

@ -0,0 +1,143 @@
"""
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('exchange_folder', type=str, help='exchange folder path')
parser.add_argument('output', type=str, help='path to output model checkpoint')
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--num-features', type=int, help="number of features, default: 20", default=20)
model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 40", default=40)
args = parser.parse_args()
import torch
from rdovae import RDOVAE
from wexchange.torch import load_torch_weights
exchange_name_to_name = {
'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1',
'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2',
'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3',
'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4',
'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5',
'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1',
'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2',
'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1',
'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2',
'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3',
'encoder_stack_layer9_conv' : 'core_encoder.module.conv1',
'statistical_model_embedding' : 'statistical_model.quant_embedding',
'decoder_state1_dense' : 'core_decoder.module.gru_1_init',
'decoder_state2_dense' : 'core_decoder.module.gru_2_init',
'decoder_state3_dense' : 'core_decoder.module.gru_3_init',
'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1',
'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2',
'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3',
'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4',
'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5',
'decoder_stack_layer9_dense' : 'core_decoder.module.output',
'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1',
'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2',
'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3'
}
if __name__ == "__main__":
checkpoint = dict()
# parameters
num_features = args.num_features
latent_dim = args.latent_dim
quant_levels = args.quant_levels
cond_size = args.cond_size
cond_size2 = args.cond_size2
state_dim = args.state_dim
# model
checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
checkpoint['model_kwargs'] = {'state_dim': state_dim}
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
dense_layer_names = [
'encoder_stack_layer1_dense',
'encoder_stack_layer3_dense',
'encoder_stack_layer5_dense',
'encoder_stack_layer7_dense',
'encoder_stack_layer8_dense',
'encoder_state_layer1_dense',
'encoder_state_layer2_dense',
'decoder_state1_dense',
'decoder_state2_dense',
'decoder_state3_dense',
'decoder_stack_layer1_dense',
'decoder_stack_layer3_dense',
'decoder_stack_layer5_dense',
'decoder_stack_layer7_dense',
'decoder_stack_layer8_dense',
'decoder_stack_layer9_dense'
]
gru_layer_names = [
'encoder_stack_layer2_gru',
'encoder_stack_layer4_gru',
'encoder_stack_layer6_gru',
'decoder_stack_layer2_gru',
'decoder_stack_layer4_gru',
'decoder_stack_layer6_gru'
]
conv1d_layer_names = [
'encoder_stack_layer9_conv'
]
embedding_layer_names = [
'statistical_model_embedding'
]
for name in dense_layer_names + gru_layer_names + conv1d_layer_names + embedding_layer_names:
print(f"loading weights for layer {exchange_name_to_name[name]}")
layer = model.get_submodule(exchange_name_to_name[name])
load_torch_weights(os.path.join(args.exchange_folder, name), layer)
checkpoint['state_dict'] = model.state_dict()
torch.save(checkpoint, args.output)

Binary file not shown.

View file

@ -0,0 +1 @@
from .fec_packets import write_fec_packets, read_fec_packets

View file

@ -0,0 +1,142 @@
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <stdio.h>
#include <inttypes.h>
#include "fec_packets.h"
int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index)
{
int16_t version;
int16_t header_size;
int16_t num_packets;
int16_t packet_size;
int16_t subframe_size;
int16_t subframes_per_packet;
int16_t num_features;
long offset;
FILE *fid = fopen(filename, "rb");
/* read header */
if (fread(&version, sizeof(version), 1, fid) != 1) goto error;
if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error;
if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error;
if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error;
if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error;
if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error;
if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error;
/* check if indices are valid */
if (packet_index >= num_packets || subframe_index >= subframes_per_packet)
{
fprintf(stderr, "get_fec_frame: index out of bounds\n");
goto error;
}
/* calculate offset in file (+ 2 is for rate) */
offset = header_size + packet_index * packet_size + 2 + subframe_index * subframe_size;
fseek(fid, offset, SEEK_SET);
/* read features */
if (fread(features, sizeof(*features), num_features, fid) != num_features) goto error;
fclose(fid);
return 0;
error:
fclose(fid);
return 1;
}
int get_fec_rate(const char * const filename, int packet_index)
{
int16_t version;
int16_t header_size;
int16_t num_packets;
int16_t packet_size;
int16_t subframe_size;
int16_t subframes_per_packet;
int16_t num_features;
long offset;
int16_t rate;
FILE *fid = fopen(filename, "rb");
/* read header */
if (fread(&version, sizeof(version), 1, fid) != 1) goto error;
if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error;
if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error;
if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error;
if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error;
if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error;
if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error;
/* check if indices are valid */
if (packet_index >= num_packets)
{
fprintf(stderr, "get_fec_rate: index out of bounds\n");
goto error;
}
/* calculate offset in file (+ 2 is for rate) */
offset = header_size + packet_index * packet_size;
fseek(fid, offset, SEEK_SET);
/* read rate */
if (fread(&rate, sizeof(rate), 1, fid) != 1) goto error;
fclose(fid);
return (int) rate;
error:
fclose(fid);
return -1;
}
#if 0
int main()
{
float features[20];
int i;
if (get_fec_frame("../test.fec", &features[0], 0, 127))
{
return 1;
}
for (i = 0; i < 20; i ++)
{
printf("%d %f\n", i, features[i]);
}
printf("rate: %d\n", get_fec_rate("../test.fec", 0));
}
#endif

View file

@ -0,0 +1,34 @@
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef _FEC_PACKETS_H
#define _FEC_PACKETS_H
int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index);
int get_fec_rate(const char * const filename, int packet_index);
#endif

View file

@ -0,0 +1,108 @@
"""
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numpy as np
def write_fec_packets(filename, packets, rates=None):
""" writes packets in binary format """
assert np.dtype(np.float32).itemsize == 4
assert np.dtype(np.int16).itemsize == 2
# derive some sizes
num_packets = len(packets)
subframes_per_packet = packets[0].shape[-2]
num_features = packets[0].shape[-1]
# size of float is 4
subframe_size = num_features * 4
packet_size = subframe_size * subframes_per_packet + 2 # two bytes for rate
version = 1
# header size (version, header_size, num_packets, packet_size, subframe_size, subrames_per_packet, num_features)
header_size = 14
with open(filename, 'wb') as f:
# header
f.write(np.int16(version).tobytes())
f.write(np.int16(header_size).tobytes())
f.write(np.int16(num_packets).tobytes())
f.write(np.int16(packet_size).tobytes())
f.write(np.int16(subframe_size).tobytes())
f.write(np.int16(subframes_per_packet).tobytes())
f.write(np.int16(num_features).tobytes())
# packets
for i, packet in enumerate(packets):
if type(rates) == type(None):
rate = 0
else:
rate = rates[i]
f.write(np.int16(rate).tobytes())
features = np.flip(packet, axis=-2)
f.write(features.astype(np.float32).tobytes())
def read_fec_packets(filename):
""" reads packets from binary format """
assert np.dtype(np.float32).itemsize == 4
assert np.dtype(np.int16).itemsize == 2
with open(filename, 'rb') as f:
# header
version = np.frombuffer(f.read(2), dtype=np.int16).item()
header_size = np.frombuffer(f.read(2), dtype=np.int16).item()
num_packets = np.frombuffer(f.read(2), dtype=np.int16).item()
packet_size = np.frombuffer(f.read(2), dtype=np.int16).item()
subframe_size = np.frombuffer(f.read(2), dtype=np.int16).item()
subframes_per_packet = np.frombuffer(f.read(2), dtype=np.int16).item()
num_features = np.frombuffer(f.read(2), dtype=np.int16).item()
dummy_features = np.zeros((subframes_per_packet, num_features), dtype=np.float32)
# packets
rates = []
packets = []
for i in range(num_packets):
rate = np.frombuffer(f.read(2), dtype=np.int16).item
rates.append(rate)
features = np.reshape(np.frombuffer(f.read(subframe_size * subframes_per_packet), dtype=np.float32), dummy_features.shape)
packet = np.flip(features, axis=-2)
packets.append(packet)
return packets

View file

@ -0,0 +1,2 @@
from .rdovae import RDOVAE, distortion_loss, hard_rate_estimate, soft_rate_estimate
from .dataset import RDOVAEDataset

View file

@ -0,0 +1,68 @@
"""
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
import numpy as np
class RDOVAEDataset(torch.utils.data.Dataset):
def __init__(self,
feature_file,
sequence_length,
num_used_features=20,
num_features=36,
lambda_min=0.0002,
lambda_max=0.0135,
quant_levels=16,
enc_stride=2):
self.sequence_length = sequence_length
self.lambda_min = lambda_min
self.lambda_max = lambda_max
self.enc_stride = enc_stride
self.quant_levels = quant_levels
self.denominator = (quant_levels - 1) / np.log(lambda_max / lambda_min)
if sequence_length % enc_stride:
raise ValueError(f"RDOVAEDataset.__init__: enc_stride {enc_stride} does not divide sequence length {sequence_length}")
self.features = np.reshape(np.fromfile(feature_file, dtype=np.float32), (-1, num_features))
self.features = self.features[:, :num_used_features]
self.num_sequences = self.features.shape[0] // sequence_length
def __len__(self):
return self.num_sequences
def __getitem__(self, index):
features = self.features[index * self.sequence_length: (index + 1) * self.sequence_length, :]
q_ids = np.random.randint(0, self.quant_levels, (1)).astype(np.int64)
q_ids = np.repeat(q_ids, self.sequence_length // self.enc_stride, axis=0)
rate_lambda = self.lambda_min * np.exp(q_ids.astype(np.float32) / self.denominator).astype(np.float32)
return features, rate_lambda, q_ids

View file

@ -0,0 +1,614 @@
"""
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" Pytorch implementations of rate distortion optimized variational autoencoder """
import math as m
import torch
from torch import nn
import torch.nn.functional as F
# Quantization and rate related utily functions
def soft_pvq(x, k):
""" soft pyramid vector quantizer """
# L2 normalization
x_norm2 = x / (1e-15 + torch.norm(x, dim=-1, keepdim=True))
with torch.no_grad():
# quantization loop, no need to track gradients here
x_norm1 = x / torch.sum(torch.abs(x), dim=-1, keepdim=True)
# set initial scaling factor to k
scale_factor = k
x_scaled = scale_factor * x_norm1
x_quant = torch.round(x_scaled)
# we aim for ||x_quant||_L1 = k
for _ in range(10):
# remove signs and calculate L1 norm
abs_x_quant = torch.abs(x_quant)
abs_x_scaled = torch.abs(x_scaled)
l1_x_quant = torch.sum(abs_x_quant, axis=-1)
# increase, where target is too small and decrease, where target is too large
plus = 1.0001 * torch.min((abs_x_quant + 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
minus = 0.9999 * torch.max((abs_x_quant - 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
factor = torch.where(l1_x_quant > k, minus, plus)
factor = torch.where(l1_x_quant == k, torch.ones_like(factor), factor)
scale_factor = scale_factor * factor.unsqueeze(-1)
# update x
x_scaled = scale_factor * x_norm1
x_quant = torch.round(x_quant)
# L2 normalization of quantized x
x_quant_norm2 = x_quant / (1e-15 + torch.norm(x_quant, dim=-1, keepdim=True))
quantization_error = x_quant_norm2 - x_norm2
return x_norm2 + quantization_error.detach()
def cache_parameters(func):
cache = dict()
def cached_func(*args):
if args in cache:
return cache[args]
else:
cache[args] = func(*args)
return cache[args]
return cached_func
@cache_parameters
def pvq_codebook_size(n, k):
if k == 0:
return 1
if n == 0:
return 0
return pvq_codebook_size(n - 1, k) + pvq_codebook_size(n, k - 1) + pvq_codebook_size(n - 1, k - 1)
def soft_rate_estimate(z, r, reduce=True):
""" rate approximation with dependent theta Eq. (7)"""
rate = torch.sum(
- torch.log2((1 - r)/(1 + r) * r ** torch.abs(z) + 1e-6),
dim=-1
)
if reduce:
rate = torch.mean(rate)
return rate
def hard_rate_estimate(z, r, theta, reduce=True):
""" hard rate approximation """
z_q = torch.round(z)
p0 = 1 - r ** (0.5 + 0.5 * theta)
alpha = torch.relu(1 - torch.abs(z_q)) ** 2
rate = - torch.sum(
(alpha * torch.log2(p0 * r ** torch.abs(z_q) + 1e-6)
+ (1 - alpha) * torch.log2(0.5 * (1 - p0) * (1 - r) * r ** (torch.abs(z_q) - 1) + 1e-6)),
dim=-1
)
if reduce:
rate = torch.mean(rate)
return rate
def soft_dead_zone(x, dead_zone):
""" approximates application of a dead zone to x """
d = dead_zone * 0.05
return x - d * torch.tanh(x / (0.1 + d))
def hard_quantize(x):
""" round with copy gradient trick """
return x + (torch.round(x) - x).detach()
def noise_quantize(x):
""" simulates quantization with addition of random uniform noise """
return x + (torch.rand_like(x) - 0.5)
# loss functions
def distortion_loss(y_true, y_pred, rate_lambda=None):
""" custom distortion loss for LPCNet features """
if y_true.size(-1) != 20:
raise ValueError('distortion loss is designed to work with 20 features')
ceps_error = y_pred[..., :18] - y_true[..., :18]
pitch_error = 2 * (y_pred[..., 18:19] - y_true[..., 18:19]) / (2 + y_true[..., 18:19])
corr_error = y_pred[..., 19:] - y_true[..., 19:]
pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2
loss = torch.mean(ceps_error ** 2 + (10/18) * torch.abs(pitch_error) * pitch_weight + (1/18) * corr_error ** 2, dim=-1)
if type(rate_lambda) != type(None):
loss = loss / torch.sqrt(rate_lambda)
loss = torch.mean(loss)
return loss
# sampling functions
import random
def random_split(start, stop, num_splits=3, min_len=3):
get_min_len = lambda x : min([x[i+1] - x[i] for i in range(len(x) - 1)])
candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
while get_min_len(candidate) < min_len:
candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
return candidate
# weight initialization and clipping
def init_weights(module):
if isinstance(module, nn.GRU):
for p in module.named_parameters():
if p[0].startswith('weight_hh_'):
nn.init.orthogonal_(p[1])
def weight_clip_factory(max_value):
""" weight clipping function concerning sum of abs values of adjecent weights """
def clip_weight_(w):
stop = w.size(1)
# omit last column if stop is odd
if stop % 2:
stop -= 1
max_values = max_value * torch.ones_like(w[:, :stop])
factor = max_value / torch.maximum(max_values,
torch.repeat_interleave(
torch.abs(w[:, :stop:2]) + torch.abs(w[:, 1:stop:2]),
2,
1))
with torch.no_grad():
w[:, :stop] *= factor
def clip_weights(module):
if isinstance(module, nn.GRU) or isinstance(module, nn.Linear):
for name, w in module.named_parameters():
if name.startswith('weight'):
clip_weight_(w)
return clip_weights
# RDOVAE module and submodules
class CoreEncoder(nn.Module):
STATE_HIDDEN = 128
FRAMES_PER_STEP = 2
CONV_KERNEL_SIZE = 4
def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24):
""" core encoder for RDOVAE
Computes latents, initial states, and rate estimates from features and lambda parameter
"""
super(CoreEncoder, self).__init__()
# hyper parameters
self.feature_dim = feature_dim
self.output_dim = output_dim
self.cond_size = cond_size
self.cond_size2 = cond_size2
self.state_size = state_size
# derived parameters
self.input_dim = self.FRAMES_PER_STEP * self.feature_dim
self.conv_input_channels = 5 * cond_size + 3 * cond_size2
# layers
self.dense_1 = nn.Linear(self.input_dim, self.cond_size2)
self.gru_1 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
self.dense_2 = nn.Linear(self.cond_size, self.cond_size2)
self.gru_2 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
self.dense_3 = nn.Linear(self.cond_size, self.cond_size2)
self.gru_3 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
self.dense_4 = nn.Linear(self.cond_size, self.cond_size)
self.dense_5 = nn.Linear(self.cond_size, self.cond_size)
self.conv1 = nn.Conv1d(self.conv_input_channels, self.output_dim, kernel_size=self.CONV_KERNEL_SIZE, padding='valid')
self.state_dense_1 = nn.Linear(self.conv_input_channels, self.STATE_HIDDEN)
self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size)
# initialize weights
self.apply(init_weights)
def forward(self, features):
# reshape features
x = torch.reshape(features, (features.size(0), features.size(1) // self.FRAMES_PER_STEP, self.FRAMES_PER_STEP * features.size(2)))
batch = x.size(0)
device = x.device
# run encoding layer stack
x1 = torch.tanh(self.dense_1(x))
x2, _ = self.gru_1(x1, torch.zeros((1, batch, self.cond_size)).to(device))
x3 = torch.tanh(self.dense_2(x2))
x4, _ = self.gru_2(x3, torch.zeros((1, batch, self.cond_size)).to(device))
x5 = torch.tanh(self.dense_3(x4))
x6, _ = self.gru_3(x5, torch.zeros((1, batch, self.cond_size)).to(device))
x7 = torch.tanh(self.dense_4(x6))
x8 = torch.tanh(self.dense_5(x7))
# concatenation of all hidden layer outputs
x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
# init state for decoder
states = torch.tanh(self.state_dense_1(x9))
states = torch.tanh(self.state_dense_2(states))
# latent representation via convolution
x9 = F.pad(x9.permute(0, 2, 1), [self.CONV_KERNEL_SIZE - 1, 0])
z = self.conv1(x9).permute(0, 2, 1)
return z, states
class CoreDecoder(nn.Module):
FRAMES_PER_STEP = 4
def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24):
""" core decoder for RDOVAE
Computes features from latents, initial state, and quantization index
"""
super(CoreDecoder, self).__init__()
# hyper parameters
self.input_dim = input_dim
self.output_dim = output_dim
self.cond_size = cond_size
self.cond_size2 = cond_size2
self.state_size = state_size
self.input_size = self.input_dim
self.concat_size = 4 * self.cond_size + 4 * self.cond_size2
# layers
self.dense_1 = nn.Linear(self.input_size, cond_size2)
self.gru_1 = nn.GRU(cond_size2, cond_size, batch_first=True)
self.dense_2 = nn.Linear(cond_size, cond_size2)
self.gru_2 = nn.GRU(cond_size2, cond_size, batch_first=True)
self.dense_3 = nn.Linear(cond_size, cond_size2)
self.gru_3 = nn.GRU(cond_size2, cond_size, batch_first=True)
self.dense_4 = nn.Linear(cond_size, cond_size2)
self.dense_5 = nn.Linear(cond_size2, cond_size2)
self.output = nn.Linear(self.concat_size, self.FRAMES_PER_STEP * self.output_dim)
self.gru_1_init = nn.Linear(self.state_size, self.cond_size)
self.gru_2_init = nn.Linear(self.state_size, self.cond_size)
self.gru_3_init = nn.Linear(self.state_size, self.cond_size)
# initialize weights
self.apply(init_weights)
def forward(self, z, initial_state):
gru_1_state = torch.tanh(self.gru_1_init(initial_state).permute(1, 0, 2))
gru_2_state = torch.tanh(self.gru_2_init(initial_state).permute(1, 0, 2))
gru_3_state = torch.tanh(self.gru_3_init(initial_state).permute(1, 0, 2))
# run decoding layer stack
x1 = torch.tanh(self.dense_1(z))
x2, _ = self.gru_1(x1, gru_1_state)
x3 = torch.tanh(self.dense_2(x2))
x4, _ = self.gru_2(x3, gru_2_state)
x5 = torch.tanh(self.dense_3(x4))
x6, _ = self.gru_3(x5, gru_3_state)
x7 = torch.tanh(self.dense_4(x6))
x8 = torch.tanh(self.dense_5(x7))
x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
# output layer and reshaping
x10 = self.output(x9)
features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP))
return features
class StatisticalModel(nn.Module):
def __init__(self, quant_levels, latent_dim):
""" Statistical model for latent space
Computes scaling, deadzone, r, and theta
"""
super(StatisticalModel, self).__init__()
# copy parameters
self.latent_dim = latent_dim
self.quant_levels = quant_levels
self.embedding_dim = 6 * latent_dim
# quantization embedding
self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim)
# initialize embedding to 0
with torch.no_grad():
self.quant_embedding.weight[:] = 0
def forward(self, quant_ids):
""" takes quant_ids and returns statistical model parameters"""
x = self.quant_embedding(quant_ids)
# CAVE: theta_soft is not used anymore. Kick it out?
quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim])
dead_zone = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim])
theta_soft = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim])
r_soft = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim])
theta_hard = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim])
r_hard = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim])
return {
'quant_embedding' : x,
'quant_scale' : quant_scale,
'dead_zone' : dead_zone,
'r_hard' : r_hard,
'theta_hard' : theta_hard,
'r_soft' : r_soft,
'theta_soft' : theta_soft
}
class RDOVAE(nn.Module):
def __init__(self,
feature_dim,
latent_dim,
quant_levels,
cond_size,
cond_size2,
state_dim=24,
split_mode='split',
clip_weights=True,
pvq_num_pulses=82,
state_dropout_rate=0):
super(RDOVAE, self).__init__()
self.feature_dim = feature_dim
self.latent_dim = latent_dim
self.quant_levels = quant_levels
self.cond_size = cond_size
self.cond_size2 = cond_size2
self.split_mode = split_mode
self.state_dim = state_dim
self.pvq_num_pulses = pvq_num_pulses
self.state_dropout_rate = state_dropout_rate
# submodules encoder and decoder share the statistical model
self.statistical_model = StatisticalModel(quant_levels, latent_dim)
self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
self.enc_stride = CoreEncoder.FRAMES_PER_STEP
self.dec_stride = CoreDecoder.FRAMES_PER_STEP
if clip_weights:
self.weight_clip_fn = weight_clip_factory(0.496)
else:
self.weight_clip_fn = None
if self.dec_stride % self.enc_stride != 0:
raise ValueError(f"get_decoder_chunks_generic: encoder stride does not divide decoder stride")
def clip_weights(self):
if not type(self.weight_clip_fn) == type(None):
self.apply(self.weight_clip_fn)
def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
enc_stride = self.enc_stride
dec_stride = self.dec_stride
stride = dec_stride // enc_stride
chunks = []
for offset in range(stride):
# start is the smalles number = offset mod stride that decodes to a valid range
start = offset
while enc_stride * (start + 1) - dec_stride < 0:
start += stride
# check if start is a valid index
if start >= z_frames:
raise ValueError("get_decoder_chunks_generic: range too small")
# stop is the smallest number outside [0, num_enc_frames] that's congruent to offset mod stride
stop = z_frames - (z_frames % stride) + offset
while stop < z_frames:
stop += stride
# calculate split points
length = (stop - start)
if mode == 'split':
split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop]
elif mode == 'random_split':
split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)]
else:
raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}")
for i in range(chunks_per_offset):
# (enc_frame_start, enc_frame_stop, enc_frame_stride, stride, feature_frame_start, feature_frame_stop)
# encoder range(i, j, stride) maps to feature range(enc_stride * (i + 1) - dec_stride, enc_stride * j)
# provided that i - j = 1 mod stride
chunks.append({
'z_start' : split_points[i],
'z_stop' : split_points[i + 1] - stride + 1,
'z_stride' : stride,
'features_start' : enc_stride * (split_points[i] + 1) - dec_stride,
'features_stop' : enc_stride * (split_points[i + 1] - stride + 1)
})
return chunks
def forward(self, features, q_id):
# calculate statistical model from quantization ID
statistical_model = self.statistical_model(q_id)
# run encoder
z, states = self.core_encoder(features)
# scaling, dead-zone and quantization
z = z * statistical_model['quant_scale']
z = soft_dead_zone(z, statistical_model['dead_zone'])
# quantization
z_q = hard_quantize(z) / statistical_model['quant_scale']
z_n = noise_quantize(z) / statistical_model['quant_scale']
states_q = soft_pvq(states, self.pvq_num_pulses)
if self.state_dropout_rate > 0:
drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
mask = torch.ones_like(states_q)
mask[drop] = 0
states_q = states_q * mask
# decoder
chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode)
outputs_hq = []
outputs_sq = []
for chunk in chunks:
# decoder with hard quantized input
z_dec_reverse = torch.flip(z_q[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
dec_initial_state = states_q[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
outputs_hq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
# decoder with soft quantized input
z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
return {
'outputs_hard_quant' : outputs_hq,
'outputs_soft_quant' : outputs_sq,
'z' : z,
'statistical_model' : statistical_model
}
def encode(self, features):
""" encoder with quantization and rate estimation """
z, states = self.core_encoder(features)
# quantization of initial states
states = soft_pvq(states, self.pvq_num_pulses)
state_size = m.log2(pvq_codebook_size(self.state_dim, self.pvq_num_pulses))
return z, states, state_size
def decode(self, z, initial_state):
""" decoder (flips sequences by itself) """
z_reverse = torch.flip(z, [1])
features_reverse = self.core_decoder(z_reverse, initial_state)
features = torch.flip(features_reverse, [1])
return features
def quantize(self, z, q_ids):
""" quantization of latent vectors """
stats = self.statistical_model(q_ids)
zq = z * stats['quant_scale']
zq = soft_dead_zone(zq, stats['dead_zone'])
zq = torch.round(zq)
sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False)
return zq, sizes
def unquantize(self, zq, q_ids):
""" re-scaling of latent vector """
stats = self.statistical_model(q_ids)
z = zq / stats['quant_scale']
return z
def freeze_model(self):
# freeze all parameters
for p in self.parameters():
p.requires_grad = False
for p in self.statistical_model.parameters():
p.requires_grad = True

View file

@ -0,0 +1,5 @@
numpy
scipy
torch
tqdm
libs/wexchange-1.0-py3-none-any.whl

View file

@ -0,0 +1,270 @@
"""
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import torch
import tqdm
from rdovae import RDOVAE, RDOVAEDataset, distortion_loss, hard_rate_estimate, soft_rate_estimate
parser = argparse.ArgumentParser()
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
parser.add_argument('output', type=str, help='path to output folder')
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="")
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 16", default=16)
model_group.add_argument('--lambda-min', type=float, help="minimal value for rate lambda, default: 0.0002", default=2e-4)
model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104)
model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82)
model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0)
training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4)
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100)
training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by 4, default: 256', default=256)
training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5)
training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split')
training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames')
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only')
args = parser.parse_args()
# set visible devices
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
# checkpoints
checkpoint_dir = os.path.join(args.output, 'checkpoints')
checkpoint = dict()
os.makedirs(checkpoint_dir, exist_ok=True)
# training parameters
batch_size = args.batch_size
lr = args.lr
epochs = args.epochs
sequence_length = args.sequence_length
lr_decay_factor = args.lr_decay_factor
split_mode = args.split_mode
# not exposed
adam_betas = [0.9, 0.99]
adam_eps = 1e-8
checkpoint['batch_size'] = batch_size
checkpoint['lr'] = lr
checkpoint['lr_decay_factor'] = lr_decay_factor
checkpoint['split_mode'] = split_mode
checkpoint['epochs'] = epochs
checkpoint['sequence_length'] = sequence_length
checkpoint['adam_betas'] = adam_betas
# logging
log_interval = 10
# device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# model parameters
cond_size = args.cond_size
cond_size2 = args.cond_size2
latent_dim = args.latent_dim
quant_levels = args.quant_levels
lambda_min = args.lambda_min
lambda_max = args.lambda_max
state_dim = args.state_dim
# not expsed
num_features = 20
# training data
feature_file = args.features
# model
checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate}
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
if type(args.initial_checkpoint) != type(None):
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
checkpoint['state_dict'] = model.state_dict()
if args.train_decoder_only:
if args.initial_checkpoint is None:
print("warning: training decoder only without providing initial checkpoint")
for p in model.core_encoder.module.parameters():
p.requires_grad = False
for p in model.statistical_model.parameters():
p.requires_grad = False
# dataloader
checkpoint['dataset_args'] = (feature_file, sequence_length, num_features, 36)
checkpoint['dataset_kwargs'] = {'lambda_min': lambda_min, 'lambda_max': lambda_max, 'enc_stride': model.enc_stride, 'quant_levels': quant_levels}
dataset = RDOVAEDataset(*checkpoint['dataset_args'], **checkpoint['dataset_kwargs'])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
# optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=lr, betas=adam_betas, eps=adam_eps)
# learning rate scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
if __name__ == '__main__':
# push model to device
model.to(device)
# training loop
for epoch in range(1, epochs + 1):
print(f"training epoch {epoch}...")
# running stats
running_rate_loss = 0
running_soft_dist_loss = 0
running_hard_dist_loss = 0
running_hard_rate_loss = 0
running_soft_rate_loss = 0
running_total_loss = 0
running_rate_metric = 0
previous_total_loss = 0
running_first_frame_loss = 0
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (features, rate_lambda, q_ids) in enumerate(tepoch):
# zero out gradients
optimizer.zero_grad()
# push inputs to device
features = features.to(device)
q_ids = q_ids.to(device)
rate_lambda = rate_lambda.to(device)
rate_lambda_upsamp = torch.repeat_interleave(rate_lambda, 2, 1)
# run model
model_output = model(features, q_ids)
# collect outputs
z = model_output['z']
outputs_hard_quant = model_output['outputs_hard_quant']
outputs_soft_quant = model_output['outputs_soft_quant']
statistical_model = model_output['statistical_model']
# rate loss
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False)
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False)
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate)
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate)
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
hard_rate_metric = torch.mean(hard_rate)
## distortion losses
# hard quantized decoder input
distortion_loss_hard_quant = torch.zeros_like(rate_loss)
for dec_features, start, stop in outputs_hard_quant:
distortion_loss_hard_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_hard_quant)
first_frame_loss = torch.zeros_like(rate_loss)
for dec_features, start, stop in outputs_hard_quant:
first_frame_loss += distortion_loss(features[..., stop-4 : stop, :], dec_features[..., -4:, :], rate_lambda_upsamp[..., stop - 4 : stop]) / len(outputs_hard_quant)
# soft quantized decoder input
distortion_loss_soft_quant = torch.zeros_like(rate_loss)
for dec_features, start, stop in outputs_soft_quant:
distortion_loss_soft_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_soft_quant)
# total loss
total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2
if args.enable_first_frame_loss:
total_loss = total_loss + 0.5 * torch.relu(first_frame_loss - distortion_loss_hard_quant)
total_loss.backward()
optimizer.step()
model.clip_weights()
scheduler.step()
# collect running stats
running_hard_dist_loss += float(distortion_loss_hard_quant.detach().cpu())
running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu())
running_rate_loss += float(rate_loss.detach().cpu())
running_rate_metric += float(hard_rate_metric.detach().cpu())
running_total_loss += float(total_loss.detach().cpu())
running_first_frame_loss += float(first_frame_loss.detach().cpu())
running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
running_hard_rate_loss += float(hard_rate_loss.detach().cpu())
if (i + 1) % log_interval == 0:
current_loss = (running_total_loss - previous_total_loss) / log_interval
tepoch.set_postfix(
current_loss=current_loss,
total_loss=running_total_loss / (i + 1),
dist_hq=running_hard_dist_loss / (i + 1),
dist_sq=running_soft_dist_loss / (i + 1),
rate_loss=running_rate_loss / (i + 1),
rate=running_rate_metric / (i + 1),
ffloss=running_first_frame_loss / (i + 1),
rateloss_hard=running_hard_rate_loss / (i + 1),
rateloss_soft=running_soft_rate_loss / (i + 1)
)
previous_total_loss = running_total_loss
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = running_total_loss / len(dataloader)
checkpoint['epoch'] = epoch
torch.save(checkpoint, checkpoint_path)