mirror of
https://github.com/xiph/opus.git
synced 2025-06-01 08:07:41 +00:00
added pytorch implementation of RDOVAE
This commit is contained in:
parent
a13aa3a077
commit
fdb04d0eef
14 changed files with 1880 additions and 0 deletions
24
dnn/torch/rdovae/README.md
Normal file
24
dnn/torch/rdovae/README.md
Normal file
|
@ -0,0 +1,24 @@
|
|||
# Rate-Distortion-Optimized Variational Auto-Encoder
|
||||
|
||||
## Setup
|
||||
The python code requires python >= 3.6 and has been tested with python 3.6 and python 3.10. To install requirements run
|
||||
```
|
||||
python -m pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Training
|
||||
To generate training data use dump date from the main LPCNet repo
|
||||
```
|
||||
./dump_data -train 16khz_speech_input.s16 features.f32 data.s16
|
||||
```
|
||||
|
||||
To train the model, simply run
|
||||
```
|
||||
python train_rdovae.py features.f32 output_folder
|
||||
```
|
||||
|
||||
To train on CUDA device add `--cuda-visible-devices idx`.
|
||||
|
||||
|
||||
## ToDo
|
||||
- Upload checkpoints and add URLs
|
256
dnn/torch/rdovae/export_rdovae_weights.py
Normal file
256
dnn/torch/rdovae/export_rdovae_weights.py
Normal file
|
@ -0,0 +1,256 @@
|
|||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('checkpoint', type=str, help='rdovae model checkpoint')
|
||||
parser.add_argument('output_dir', type=str, help='output folder')
|
||||
parser.add_argument('--format', choices=['C', 'numpy'], help='output format, default: C', default='C')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from rdovae import RDOVAE
|
||||
from wexchange.torch import dump_torch_weights
|
||||
from wexchange.c_export import CWriter, print_vector
|
||||
|
||||
|
||||
def dump_statistical_model(writer, qembedding):
|
||||
w = qembedding.weight.detach()
|
||||
levels, dim = w.shape
|
||||
N = dim // 6
|
||||
|
||||
print("printing statistical model")
|
||||
quant_scales = torch.nn.functional.softplus(w[:, : N]).numpy()
|
||||
dead_zone = 0.05 * torch.nn.functional.softplus(w[:, N : 2 * N]).numpy()
|
||||
r = torch.sigmoid(w[:, 5 * N : 6 * N]).numpy()
|
||||
p0 = torch.sigmoid(w[:, 4 * N : 5 * N]).numpy()
|
||||
p0 = 1 - r ** (0.5 + 0.5 * p0)
|
||||
|
||||
quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
|
||||
dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16)
|
||||
r_q15 = np.round(r * 2**15).astype(np.uint16)
|
||||
p0_q15 = np.round(p0 * 2**15).astype(np.uint16)
|
||||
|
||||
print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
|
||||
print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
|
||||
print_vector(writer.source, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False)
|
||||
print_vector(writer.source, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False)
|
||||
|
||||
writer.header.write(
|
||||
f"""
|
||||
extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
|
||||
extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
|
||||
extern const opus_uint16 dred_r_q15[{levels * N}];
|
||||
extern const opus_uint16 dred_p0_q15[{levels * N}];
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def c_export(args, model):
|
||||
|
||||
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
|
||||
|
||||
enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message)
|
||||
dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message)
|
||||
stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats"), message=message)
|
||||
constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True)
|
||||
|
||||
# some custom includes
|
||||
for writer in [enc_writer, dec_writer, stats_writer]:
|
||||
writer.header.write(
|
||||
f"""
|
||||
#include "opus_types.h"
|
||||
|
||||
#include "dred_rdovae_constants.h"
|
||||
|
||||
#include "nnet.h"
|
||||
"""
|
||||
)
|
||||
|
||||
# encoder
|
||||
encoder_dense_layers = [
|
||||
('core_encoder.module.dense_1' , 'enc_dense1', 'TANH'),
|
||||
('core_encoder.module.dense_2' , 'enc_dense3', 'TANH'),
|
||||
('core_encoder.module.dense_3' , 'enc_dense5', 'TANH'),
|
||||
('core_encoder.module.dense_4' , 'enc_dense7', 'TANH'),
|
||||
('core_encoder.module.dense_5' , 'enc_dense8', 'TANH'),
|
||||
('core_encoder.module.state_dense_1' , 'gdense1' , 'TANH'),
|
||||
('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH')
|
||||
]
|
||||
|
||||
for name, export_name, activation in encoder_dense_layers:
|
||||
layer = model.get_submodule(name)
|
||||
dump_torch_weights(enc_writer, layer, name=export_name, activation=activation, verbose=True)
|
||||
|
||||
|
||||
encoder_gru_layers = [
|
||||
('core_encoder.module.gru_1' , 'enc_dense2', 'TANH'),
|
||||
('core_encoder.module.gru_2' , 'enc_dense4', 'TANH'),
|
||||
('core_encoder.module.gru_3' , 'enc_dense6', 'TANH')
|
||||
]
|
||||
|
||||
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_gru_layers])
|
||||
|
||||
|
||||
encoder_conv_layers = [
|
||||
('core_encoder.module.conv1' , 'bits_dense' , 'LINEAR')
|
||||
]
|
||||
|
||||
enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_conv_layers])
|
||||
|
||||
|
||||
del enc_writer
|
||||
|
||||
# decoder
|
||||
decoder_dense_layers = [
|
||||
('core_decoder.module.gru_1_init' , 'state1', 'TANH'),
|
||||
('core_decoder.module.gru_2_init' , 'state2', 'TANH'),
|
||||
('core_decoder.module.gru_3_init' , 'state3', 'TANH'),
|
||||
('core_decoder.module.dense_1' , 'dec_dense1', 'TANH'),
|
||||
('core_decoder.module.dense_2' , 'dec_dense3', 'TANH'),
|
||||
('core_decoder.module.dense_3' , 'dec_dense5', 'TANH'),
|
||||
('core_decoder.module.dense_4' , 'dec_dense7', 'TANH'),
|
||||
('core_decoder.module.dense_5' , 'dec_dense8', 'TANH'),
|
||||
('core_decoder.module.output' , 'dec_final', 'LINEAR')
|
||||
]
|
||||
|
||||
for name, export_name, activation in decoder_dense_layers:
|
||||
layer = model.get_submodule(name)
|
||||
dump_torch_weights(dec_writer, layer, name=export_name, activation=activation, verbose=True)
|
||||
|
||||
|
||||
decoder_gru_layers = [
|
||||
('core_decoder.module.gru_1' , 'dec_dense2', 'TANH'),
|
||||
('core_decoder.module.gru_2' , 'dec_dense4', 'TANH'),
|
||||
('core_decoder.module.gru_3' , 'dec_dense6', 'TANH')
|
||||
]
|
||||
|
||||
dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in decoder_gru_layers])
|
||||
|
||||
del dec_writer
|
||||
|
||||
# statistical model
|
||||
qembedding = model.statistical_model.quant_embedding
|
||||
dump_statistical_model(stats_writer, qembedding)
|
||||
|
||||
del stats_writer
|
||||
|
||||
# constants
|
||||
constants_writer.header.write(
|
||||
f"""
|
||||
#define DRED_NUM_FEATURES {model.feature_dim}
|
||||
|
||||
#define DRED_LATENT_DIM {model.latent_dim}
|
||||
|
||||
#define DRED_STATE_DIME {model.state_dim}
|
||||
|
||||
#define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels}
|
||||
|
||||
#define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)}
|
||||
|
||||
#define DRED_MAX_CONV_INPUTS {enc_max_conv_inputs}
|
||||
|
||||
#define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs}
|
||||
|
||||
#define DRED_ENC_MAX_CONV_INPUTS {enc_max_conv_inputs}
|
||||
|
||||
#define DRED_DEC_MAX_RNN_NEURONS {dec_max_rnn_units}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
del constants_writer
|
||||
|
||||
|
||||
def numpy_export(args, model):
|
||||
|
||||
exchange_name_to_name = {
|
||||
'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1',
|
||||
'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2',
|
||||
'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3',
|
||||
'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4',
|
||||
'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5',
|
||||
'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1',
|
||||
'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2',
|
||||
'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1',
|
||||
'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2',
|
||||
'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3',
|
||||
'encoder_stack_layer9_conv' : 'core_encoder.module.conv1',
|
||||
'statistical_model_embedding' : 'statistical_model.quant_embedding',
|
||||
'decoder_state1_dense' : 'core_decoder.module.gru_1_init',
|
||||
'decoder_state2_dense' : 'core_decoder.module.gru_2_init',
|
||||
'decoder_state3_dense' : 'core_decoder.module.gru_3_init',
|
||||
'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1',
|
||||
'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2',
|
||||
'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3',
|
||||
'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4',
|
||||
'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5',
|
||||
'decoder_stack_layer9_dense' : 'core_decoder.module.output',
|
||||
'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1',
|
||||
'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2',
|
||||
'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3'
|
||||
}
|
||||
|
||||
name_to_exchange_name = {value : key for key, value in exchange_name_to_name.items()}
|
||||
|
||||
for name, exchange_name in name_to_exchange_name.items():
|
||||
print(f"printing layer {name}...")
|
||||
dump_torch_weights(os.path.join(args.output_dir, exchange_name), model.get_submodule(name))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
|
||||
# load model from checkpoint
|
||||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(f"error: missing keys in state dict")
|
||||
|
||||
if len(unmatched_keys) > 0:
|
||||
print(f"warning: the following keys were unmatched {unmatched_keys}")
|
||||
|
||||
if args.format == 'C':
|
||||
c_export(args, model)
|
||||
elif args.format == 'numpy':
|
||||
numpy_export(args, model)
|
||||
else:
|
||||
raise ValueError(f'error: unknown export format {args.format}')
|
213
dnn/torch/rdovae/fec_encoder.py
Normal file
213
dnn/torch/rdovae/fec_encoder.py
Normal file
|
@ -0,0 +1,213 @@
|
|||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe and Jean-Marc Valin */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ""
|
||||
|
||||
parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames')
|
||||
|
||||
parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)')
|
||||
parser.add_argument('checkpoint', metavar='<weights>', help='model checkpoint')
|
||||
parser.add_argument('q0', metavar='<quant level 0>', type=int, help='quantization level for most recent frame')
|
||||
parser.add_argument('q1', metavar='<quant level 1>', type=int, help='quantization level for oldest frame')
|
||||
parser.add_argument('output', type=str, help='output file (will be extended with .fec)')
|
||||
|
||||
parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)')
|
||||
parser.add_argument('--num-redundancy-frames', default=52, type=int, help='number of redundancy frames per packet (default 52)')
|
||||
parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)")
|
||||
parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)')
|
||||
parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
import torch
|
||||
|
||||
from rdovae import RDOVAE
|
||||
from packets import write_fec_packets
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
model.to("cpu")
|
||||
|
||||
lpc_order = 16
|
||||
|
||||
## prepare input signal
|
||||
# SILK frame size is 20ms and LPCNet subframes are 10ms
|
||||
subframe_size = 160
|
||||
frame_size = 2 * subframe_size
|
||||
|
||||
# 91 samples delay to align with SILK decoded frames
|
||||
silk_delay = 91
|
||||
|
||||
# prepend zeros to have enough history to produce the first package
|
||||
zero_history = (args.num_redundancy_frames - 1) * frame_size
|
||||
|
||||
# dump data has a (feature) delay of 10ms
|
||||
dump_data_delay = 160
|
||||
|
||||
total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay
|
||||
|
||||
# load signal
|
||||
if args.input.endswith('.raw') or args.input.endswith('.pcm'):
|
||||
signal = np.fromfile(args.input, dtype='int16')
|
||||
|
||||
elif args.input.endswith('.wav'):
|
||||
fs, signal = wavfile.read(args.input)
|
||||
else:
|
||||
raise ValueError(f'unknown input signal format: {args.input}')
|
||||
|
||||
# fill up last frame with zeros
|
||||
padded_signal_length = len(signal) + total_delay
|
||||
tail = padded_signal_length % frame_size
|
||||
right_padding = (frame_size - tail) % frame_size
|
||||
|
||||
signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16)))
|
||||
|
||||
padded_signal_file = os.path.splitext(args.input)[0] + '_padded.raw'
|
||||
signal.tofile(padded_signal_file)
|
||||
|
||||
# write signal and call dump_data to create features
|
||||
|
||||
feature_file = os.path.splitext(args.input)[0] + '_features.f32'
|
||||
command = f"{args.dump_data} -test {padded_signal_file} {feature_file}"
|
||||
r = subprocess.run(command, shell=True)
|
||||
if r.returncode != 0:
|
||||
raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}")
|
||||
|
||||
# load features
|
||||
nb_features = model.feature_dim + lpc_order
|
||||
nb_used_features = model.feature_dim
|
||||
|
||||
# load features
|
||||
features = np.fromfile(feature_file, dtype='float32')
|
||||
num_subframes = len(features) // nb_features
|
||||
num_subframes = 2 * (num_subframes // 2)
|
||||
num_frames = num_subframes // 2
|
||||
|
||||
features = np.reshape(features, (1, -1, nb_features))
|
||||
features = features[:, :, :nb_used_features]
|
||||
features = features[:, :num_subframes, :]
|
||||
|
||||
# quant_ids in reverse decoding order
|
||||
quant_ids = torch.round((args.q1 + (args.q0 - args.q1) * torch.arange(args.num_redundancy_frames // 2) / (args.num_redundancy_frames // 2 - 1))).long()
|
||||
|
||||
print(f"using quantization levels {quant_ids}...")
|
||||
|
||||
# convert input to torch tensors
|
||||
features = torch.from_numpy(features)
|
||||
|
||||
|
||||
# run encoder
|
||||
print("running fec encoder...")
|
||||
with torch.no_grad():
|
||||
|
||||
# encoding
|
||||
z, states, state_size = model.encode(features)
|
||||
|
||||
|
||||
# decoder on packet chunks
|
||||
input_length = args.num_redundancy_frames // 2
|
||||
offset = args.num_redundancy_frames - 1
|
||||
|
||||
packets = []
|
||||
packet_sizes = []
|
||||
|
||||
for i in range(offset, num_frames):
|
||||
print(f"processing frame {i - offset}...")
|
||||
# quantize / unquantize latent vectors
|
||||
zi = torch.clone(z[:, i - 2 * input_length + 2: i + 1 : 2, :])
|
||||
zi, rates = model.quantize(zi, quant_ids)
|
||||
zi = model.unquantize(zi, quant_ids)
|
||||
|
||||
features = model.decode(zi, states[:, i : i + 1, :])
|
||||
packets.append(features.squeeze(0).numpy())
|
||||
packet_size = 8 * int((torch.sum(rates) + 7 + state_size) / 8)
|
||||
packet_sizes.append(packet_size)
|
||||
|
||||
|
||||
# write packets
|
||||
packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output
|
||||
write_fec_packets(packet_file, packets, packet_sizes)
|
||||
|
||||
|
||||
print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps")
|
||||
|
||||
# assemble features according to loss file
|
||||
if args.lossfile != None:
|
||||
num_packets = len(packets)
|
||||
loss = np.loadtxt(args.lossfile, dtype='int16')
|
||||
fec_out = np.zeros((num_packets * 2, packets[0].shape[-1]), dtype='float32')
|
||||
foffset = -2
|
||||
ptr = 0
|
||||
count = 2
|
||||
for i in range(num_packets):
|
||||
if (loss[i] == 0) or (i == num_packets - 1):
|
||||
|
||||
fec_out[ptr:ptr+count,:] = packets[i][foffset:, :]
|
||||
|
||||
ptr += count
|
||||
foffset = -2
|
||||
count = 2
|
||||
else:
|
||||
count += 2
|
||||
foffset -= 2
|
||||
|
||||
fec_out_full = np.zeros((fec_out.shape[0], 36), dtype=np.float32)
|
||||
fec_out_full[:, : fec_out.shape[-1]] = fec_out
|
||||
|
||||
fec_out_full.tofile(packet_file[:-4] + f'_fec.f32')
|
||||
|
||||
|
||||
if args.debug_output:
|
||||
import itertools
|
||||
|
||||
batches = [4]
|
||||
offsets = [0, 2 * args.num_redundancy_frames - 4]
|
||||
|
||||
# sanity checks
|
||||
# 1. concatenate features at offset 0
|
||||
for batch, offset in itertools.product(batches, offsets):
|
||||
|
||||
stop = packets[0].shape[1] - offset
|
||||
test_features = np.concatenate([packet[stop - batch: stop, :] for packet in packets[::batch//2]], axis=0)
|
||||
|
||||
test_features_full = np.zeros((test_features.shape[0], nb_features), dtype=np.float32)
|
||||
test_features_full[:, :nb_used_features] = test_features[:, :]
|
||||
|
||||
print(f"writing debug output {packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32'}")
|
||||
test_features_full.tofile(packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32')
|
||||
|
143
dnn/torch/rdovae/import_rdovae_weights.py
Normal file
143
dnn/torch/rdovae/import_rdovae_weights.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ""
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('exchange_folder', type=str, help='exchange folder path')
|
||||
parser.add_argument('output', type=str, help='path to output model checkpoint')
|
||||
|
||||
model_group = parser.add_argument_group(title="model parameters")
|
||||
model_group.add_argument('--num-features', type=int, help="number of features, default: 20", default=20)
|
||||
model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
|
||||
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
|
||||
model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
|
||||
model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
|
||||
model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 40", default=40)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import torch
|
||||
from rdovae import RDOVAE
|
||||
from wexchange.torch import load_torch_weights
|
||||
|
||||
exchange_name_to_name = {
|
||||
'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1',
|
||||
'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2',
|
||||
'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3',
|
||||
'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4',
|
||||
'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5',
|
||||
'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1',
|
||||
'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2',
|
||||
'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1',
|
||||
'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2',
|
||||
'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3',
|
||||
'encoder_stack_layer9_conv' : 'core_encoder.module.conv1',
|
||||
'statistical_model_embedding' : 'statistical_model.quant_embedding',
|
||||
'decoder_state1_dense' : 'core_decoder.module.gru_1_init',
|
||||
'decoder_state2_dense' : 'core_decoder.module.gru_2_init',
|
||||
'decoder_state3_dense' : 'core_decoder.module.gru_3_init',
|
||||
'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1',
|
||||
'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2',
|
||||
'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3',
|
||||
'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4',
|
||||
'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5',
|
||||
'decoder_stack_layer9_dense' : 'core_decoder.module.output',
|
||||
'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1',
|
||||
'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2',
|
||||
'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3'
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
checkpoint = dict()
|
||||
|
||||
# parameters
|
||||
num_features = args.num_features
|
||||
latent_dim = args.latent_dim
|
||||
quant_levels = args.quant_levels
|
||||
cond_size = args.cond_size
|
||||
cond_size2 = args.cond_size2
|
||||
state_dim = args.state_dim
|
||||
|
||||
|
||||
# model
|
||||
checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
|
||||
checkpoint['model_kwargs'] = {'state_dim': state_dim}
|
||||
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
|
||||
dense_layer_names = [
|
||||
'encoder_stack_layer1_dense',
|
||||
'encoder_stack_layer3_dense',
|
||||
'encoder_stack_layer5_dense',
|
||||
'encoder_stack_layer7_dense',
|
||||
'encoder_stack_layer8_dense',
|
||||
'encoder_state_layer1_dense',
|
||||
'encoder_state_layer2_dense',
|
||||
'decoder_state1_dense',
|
||||
'decoder_state2_dense',
|
||||
'decoder_state3_dense',
|
||||
'decoder_stack_layer1_dense',
|
||||
'decoder_stack_layer3_dense',
|
||||
'decoder_stack_layer5_dense',
|
||||
'decoder_stack_layer7_dense',
|
||||
'decoder_stack_layer8_dense',
|
||||
'decoder_stack_layer9_dense'
|
||||
]
|
||||
|
||||
gru_layer_names = [
|
||||
'encoder_stack_layer2_gru',
|
||||
'encoder_stack_layer4_gru',
|
||||
'encoder_stack_layer6_gru',
|
||||
'decoder_stack_layer2_gru',
|
||||
'decoder_stack_layer4_gru',
|
||||
'decoder_stack_layer6_gru'
|
||||
]
|
||||
|
||||
conv1d_layer_names = [
|
||||
'encoder_stack_layer9_conv'
|
||||
]
|
||||
|
||||
embedding_layer_names = [
|
||||
'statistical_model_embedding'
|
||||
]
|
||||
|
||||
for name in dense_layer_names + gru_layer_names + conv1d_layer_names + embedding_layer_names:
|
||||
print(f"loading weights for layer {exchange_name_to_name[name]}")
|
||||
layer = model.get_submodule(exchange_name_to_name[name])
|
||||
load_torch_weights(os.path.join(args.exchange_folder, name), layer)
|
||||
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
|
||||
torch.save(checkpoint, args.output)
|
BIN
dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl
Normal file
BIN
dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl
Normal file
Binary file not shown.
1
dnn/torch/rdovae/packets/__init__.py
Normal file
1
dnn/torch/rdovae/packets/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .fec_packets import write_fec_packets, read_fec_packets
|
142
dnn/torch/rdovae/packets/fec_packets.c
Normal file
142
dnn/torch/rdovae/packets/fec_packets.c
Normal file
|
@ -0,0 +1,142 @@
|
|||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <inttypes.h>
|
||||
|
||||
#include "fec_packets.h"
|
||||
|
||||
int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index)
|
||||
{
|
||||
|
||||
int16_t version;
|
||||
int16_t header_size;
|
||||
int16_t num_packets;
|
||||
int16_t packet_size;
|
||||
int16_t subframe_size;
|
||||
int16_t subframes_per_packet;
|
||||
int16_t num_features;
|
||||
long offset;
|
||||
|
||||
FILE *fid = fopen(filename, "rb");
|
||||
|
||||
/* read header */
|
||||
if (fread(&version, sizeof(version), 1, fid) != 1) goto error;
|
||||
if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error;
|
||||
if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error;
|
||||
if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error;
|
||||
if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error;
|
||||
if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error;
|
||||
if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error;
|
||||
|
||||
/* check if indices are valid */
|
||||
if (packet_index >= num_packets || subframe_index >= subframes_per_packet)
|
||||
{
|
||||
fprintf(stderr, "get_fec_frame: index out of bounds\n");
|
||||
goto error;
|
||||
}
|
||||
|
||||
/* calculate offset in file (+ 2 is for rate) */
|
||||
offset = header_size + packet_index * packet_size + 2 + subframe_index * subframe_size;
|
||||
fseek(fid, offset, SEEK_SET);
|
||||
|
||||
/* read features */
|
||||
if (fread(features, sizeof(*features), num_features, fid) != num_features) goto error;
|
||||
|
||||
fclose(fid);
|
||||
return 0;
|
||||
|
||||
error:
|
||||
fclose(fid);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int get_fec_rate(const char * const filename, int packet_index)
|
||||
{
|
||||
int16_t version;
|
||||
int16_t header_size;
|
||||
int16_t num_packets;
|
||||
int16_t packet_size;
|
||||
int16_t subframe_size;
|
||||
int16_t subframes_per_packet;
|
||||
int16_t num_features;
|
||||
long offset;
|
||||
int16_t rate;
|
||||
|
||||
FILE *fid = fopen(filename, "rb");
|
||||
|
||||
/* read header */
|
||||
if (fread(&version, sizeof(version), 1, fid) != 1) goto error;
|
||||
if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error;
|
||||
if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error;
|
||||
if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error;
|
||||
if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error;
|
||||
if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error;
|
||||
if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error;
|
||||
|
||||
/* check if indices are valid */
|
||||
if (packet_index >= num_packets)
|
||||
{
|
||||
fprintf(stderr, "get_fec_rate: index out of bounds\n");
|
||||
goto error;
|
||||
}
|
||||
|
||||
/* calculate offset in file (+ 2 is for rate) */
|
||||
offset = header_size + packet_index * packet_size;
|
||||
fseek(fid, offset, SEEK_SET);
|
||||
|
||||
/* read rate */
|
||||
if (fread(&rate, sizeof(rate), 1, fid) != 1) goto error;
|
||||
|
||||
fclose(fid);
|
||||
return (int) rate;
|
||||
|
||||
error:
|
||||
fclose(fid);
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if 0
|
||||
int main()
|
||||
{
|
||||
float features[20];
|
||||
int i;
|
||||
|
||||
if (get_fec_frame("../test.fec", &features[0], 0, 127))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (i = 0; i < 20; i ++)
|
||||
{
|
||||
printf("%d %f\n", i, features[i]);
|
||||
}
|
||||
|
||||
printf("rate: %d\n", get_fec_rate("../test.fec", 0));
|
||||
|
||||
}
|
||||
#endif
|
34
dnn/torch/rdovae/packets/fec_packets.h
Normal file
34
dnn/torch/rdovae/packets/fec_packets.h
Normal file
|
@ -0,0 +1,34 @@
|
|||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
#ifndef _FEC_PACKETS_H
|
||||
#define _FEC_PACKETS_H
|
||||
|
||||
int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index);
|
||||
int get_fec_rate(const char * const filename, int packet_index);
|
||||
|
||||
#endif
|
108
dnn/torch/rdovae/packets/fec_packets.py
Normal file
108
dnn/torch/rdovae/packets/fec_packets.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def write_fec_packets(filename, packets, rates=None):
|
||||
""" writes packets in binary format """
|
||||
|
||||
assert np.dtype(np.float32).itemsize == 4
|
||||
assert np.dtype(np.int16).itemsize == 2
|
||||
|
||||
# derive some sizes
|
||||
num_packets = len(packets)
|
||||
subframes_per_packet = packets[0].shape[-2]
|
||||
num_features = packets[0].shape[-1]
|
||||
|
||||
# size of float is 4
|
||||
subframe_size = num_features * 4
|
||||
packet_size = subframe_size * subframes_per_packet + 2 # two bytes for rate
|
||||
|
||||
version = 1
|
||||
# header size (version, header_size, num_packets, packet_size, subframe_size, subrames_per_packet, num_features)
|
||||
header_size = 14
|
||||
|
||||
with open(filename, 'wb') as f:
|
||||
|
||||
# header
|
||||
f.write(np.int16(version).tobytes())
|
||||
f.write(np.int16(header_size).tobytes())
|
||||
f.write(np.int16(num_packets).tobytes())
|
||||
f.write(np.int16(packet_size).tobytes())
|
||||
f.write(np.int16(subframe_size).tobytes())
|
||||
f.write(np.int16(subframes_per_packet).tobytes())
|
||||
f.write(np.int16(num_features).tobytes())
|
||||
|
||||
# packets
|
||||
for i, packet in enumerate(packets):
|
||||
if type(rates) == type(None):
|
||||
rate = 0
|
||||
else:
|
||||
rate = rates[i]
|
||||
|
||||
f.write(np.int16(rate).tobytes())
|
||||
|
||||
features = np.flip(packet, axis=-2)
|
||||
f.write(features.astype(np.float32).tobytes())
|
||||
|
||||
|
||||
def read_fec_packets(filename):
|
||||
""" reads packets from binary format """
|
||||
|
||||
assert np.dtype(np.float32).itemsize == 4
|
||||
assert np.dtype(np.int16).itemsize == 2
|
||||
|
||||
with open(filename, 'rb') as f:
|
||||
|
||||
# header
|
||||
version = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
header_size = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
num_packets = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
packet_size = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
subframe_size = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
subframes_per_packet = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
num_features = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
|
||||
dummy_features = np.zeros((subframes_per_packet, num_features), dtype=np.float32)
|
||||
|
||||
# packets
|
||||
rates = []
|
||||
packets = []
|
||||
for i in range(num_packets):
|
||||
|
||||
rate = np.frombuffer(f.read(2), dtype=np.int16).item
|
||||
rates.append(rate)
|
||||
|
||||
features = np.reshape(np.frombuffer(f.read(subframe_size * subframes_per_packet), dtype=np.float32), dummy_features.shape)
|
||||
packet = np.flip(features, axis=-2)
|
||||
packets.append(packet)
|
||||
|
||||
return packets
|
2
dnn/torch/rdovae/rdovae/__init__.py
Normal file
2
dnn/torch/rdovae/rdovae/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .rdovae import RDOVAE, distortion_loss, hard_rate_estimate, soft_rate_estimate
|
||||
from .dataset import RDOVAEDataset
|
68
dnn/torch/rdovae/rdovae/dataset.py
Normal file
68
dnn/torch/rdovae/rdovae/dataset.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class RDOVAEDataset(torch.utils.data.Dataset):
|
||||
def __init__(self,
|
||||
feature_file,
|
||||
sequence_length,
|
||||
num_used_features=20,
|
||||
num_features=36,
|
||||
lambda_min=0.0002,
|
||||
lambda_max=0.0135,
|
||||
quant_levels=16,
|
||||
enc_stride=2):
|
||||
|
||||
self.sequence_length = sequence_length
|
||||
self.lambda_min = lambda_min
|
||||
self.lambda_max = lambda_max
|
||||
self.enc_stride = enc_stride
|
||||
self.quant_levels = quant_levels
|
||||
self.denominator = (quant_levels - 1) / np.log(lambda_max / lambda_min)
|
||||
|
||||
if sequence_length % enc_stride:
|
||||
raise ValueError(f"RDOVAEDataset.__init__: enc_stride {enc_stride} does not divide sequence length {sequence_length}")
|
||||
|
||||
self.features = np.reshape(np.fromfile(feature_file, dtype=np.float32), (-1, num_features))
|
||||
self.features = self.features[:, :num_used_features]
|
||||
self.num_sequences = self.features.shape[0] // sequence_length
|
||||
|
||||
def __len__(self):
|
||||
return self.num_sequences
|
||||
|
||||
def __getitem__(self, index):
|
||||
features = self.features[index * self.sequence_length: (index + 1) * self.sequence_length, :]
|
||||
q_ids = np.random.randint(0, self.quant_levels, (1)).astype(np.int64)
|
||||
q_ids = np.repeat(q_ids, self.sequence_length // self.enc_stride, axis=0)
|
||||
rate_lambda = self.lambda_min * np.exp(q_ids.astype(np.float32) / self.denominator).astype(np.float32)
|
||||
|
||||
return features, rate_lambda, q_ids
|
||||
|
614
dnn/torch/rdovae/rdovae/rdovae.py
Normal file
614
dnn/torch/rdovae/rdovae/rdovae.py
Normal file
|
@ -0,0 +1,614 @@
|
|||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" Pytorch implementations of rate distortion optimized variational autoencoder """
|
||||
|
||||
import math as m
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Quantization and rate related utily functions
|
||||
|
||||
def soft_pvq(x, k):
|
||||
""" soft pyramid vector quantizer """
|
||||
|
||||
# L2 normalization
|
||||
x_norm2 = x / (1e-15 + torch.norm(x, dim=-1, keepdim=True))
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
# quantization loop, no need to track gradients here
|
||||
x_norm1 = x / torch.sum(torch.abs(x), dim=-1, keepdim=True)
|
||||
|
||||
# set initial scaling factor to k
|
||||
scale_factor = k
|
||||
x_scaled = scale_factor * x_norm1
|
||||
x_quant = torch.round(x_scaled)
|
||||
|
||||
# we aim for ||x_quant||_L1 = k
|
||||
for _ in range(10):
|
||||
# remove signs and calculate L1 norm
|
||||
abs_x_quant = torch.abs(x_quant)
|
||||
abs_x_scaled = torch.abs(x_scaled)
|
||||
l1_x_quant = torch.sum(abs_x_quant, axis=-1)
|
||||
|
||||
# increase, where target is too small and decrease, where target is too large
|
||||
plus = 1.0001 * torch.min((abs_x_quant + 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
|
||||
minus = 0.9999 * torch.max((abs_x_quant - 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
|
||||
factor = torch.where(l1_x_quant > k, minus, plus)
|
||||
factor = torch.where(l1_x_quant == k, torch.ones_like(factor), factor)
|
||||
scale_factor = scale_factor * factor.unsqueeze(-1)
|
||||
|
||||
# update x
|
||||
x_scaled = scale_factor * x_norm1
|
||||
x_quant = torch.round(x_quant)
|
||||
|
||||
# L2 normalization of quantized x
|
||||
x_quant_norm2 = x_quant / (1e-15 + torch.norm(x_quant, dim=-1, keepdim=True))
|
||||
quantization_error = x_quant_norm2 - x_norm2
|
||||
|
||||
return x_norm2 + quantization_error.detach()
|
||||
|
||||
def cache_parameters(func):
|
||||
cache = dict()
|
||||
def cached_func(*args):
|
||||
if args in cache:
|
||||
return cache[args]
|
||||
else:
|
||||
cache[args] = func(*args)
|
||||
|
||||
return cache[args]
|
||||
return cached_func
|
||||
|
||||
@cache_parameters
|
||||
def pvq_codebook_size(n, k):
|
||||
|
||||
if k == 0:
|
||||
return 1
|
||||
|
||||
if n == 0:
|
||||
return 0
|
||||
|
||||
return pvq_codebook_size(n - 1, k) + pvq_codebook_size(n, k - 1) + pvq_codebook_size(n - 1, k - 1)
|
||||
|
||||
|
||||
def soft_rate_estimate(z, r, reduce=True):
|
||||
""" rate approximation with dependent theta Eq. (7)"""
|
||||
|
||||
rate = torch.sum(
|
||||
- torch.log2((1 - r)/(1 + r) * r ** torch.abs(z) + 1e-6),
|
||||
dim=-1
|
||||
)
|
||||
|
||||
if reduce:
|
||||
rate = torch.mean(rate)
|
||||
|
||||
return rate
|
||||
|
||||
|
||||
def hard_rate_estimate(z, r, theta, reduce=True):
|
||||
""" hard rate approximation """
|
||||
|
||||
z_q = torch.round(z)
|
||||
p0 = 1 - r ** (0.5 + 0.5 * theta)
|
||||
alpha = torch.relu(1 - torch.abs(z_q)) ** 2
|
||||
rate = - torch.sum(
|
||||
(alpha * torch.log2(p0 * r ** torch.abs(z_q) + 1e-6)
|
||||
+ (1 - alpha) * torch.log2(0.5 * (1 - p0) * (1 - r) * r ** (torch.abs(z_q) - 1) + 1e-6)),
|
||||
dim=-1
|
||||
)
|
||||
|
||||
if reduce:
|
||||
rate = torch.mean(rate)
|
||||
|
||||
return rate
|
||||
|
||||
|
||||
|
||||
def soft_dead_zone(x, dead_zone):
|
||||
""" approximates application of a dead zone to x """
|
||||
d = dead_zone * 0.05
|
||||
return x - d * torch.tanh(x / (0.1 + d))
|
||||
|
||||
|
||||
def hard_quantize(x):
|
||||
""" round with copy gradient trick """
|
||||
return x + (torch.round(x) - x).detach()
|
||||
|
||||
|
||||
def noise_quantize(x):
|
||||
""" simulates quantization with addition of random uniform noise """
|
||||
return x + (torch.rand_like(x) - 0.5)
|
||||
|
||||
|
||||
# loss functions
|
||||
|
||||
|
||||
def distortion_loss(y_true, y_pred, rate_lambda=None):
|
||||
""" custom distortion loss for LPCNet features """
|
||||
|
||||
if y_true.size(-1) != 20:
|
||||
raise ValueError('distortion loss is designed to work with 20 features')
|
||||
|
||||
ceps_error = y_pred[..., :18] - y_true[..., :18]
|
||||
pitch_error = 2 * (y_pred[..., 18:19] - y_true[..., 18:19]) / (2 + y_true[..., 18:19])
|
||||
corr_error = y_pred[..., 19:] - y_true[..., 19:]
|
||||
pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2
|
||||
|
||||
loss = torch.mean(ceps_error ** 2 + (10/18) * torch.abs(pitch_error) * pitch_weight + (1/18) * corr_error ** 2, dim=-1)
|
||||
|
||||
if type(rate_lambda) != type(None):
|
||||
loss = loss / torch.sqrt(rate_lambda)
|
||||
|
||||
loss = torch.mean(loss)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# sampling functions
|
||||
|
||||
import random
|
||||
|
||||
|
||||
def random_split(start, stop, num_splits=3, min_len=3):
|
||||
get_min_len = lambda x : min([x[i+1] - x[i] for i in range(len(x) - 1)])
|
||||
candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
|
||||
|
||||
while get_min_len(candidate) < min_len:
|
||||
candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
|
||||
|
||||
return candidate
|
||||
|
||||
|
||||
|
||||
# weight initialization and clipping
|
||||
def init_weights(module):
|
||||
|
||||
if isinstance(module, nn.GRU):
|
||||
for p in module.named_parameters():
|
||||
if p[0].startswith('weight_hh_'):
|
||||
nn.init.orthogonal_(p[1])
|
||||
|
||||
|
||||
def weight_clip_factory(max_value):
|
||||
""" weight clipping function concerning sum of abs values of adjecent weights """
|
||||
def clip_weight_(w):
|
||||
stop = w.size(1)
|
||||
# omit last column if stop is odd
|
||||
if stop % 2:
|
||||
stop -= 1
|
||||
max_values = max_value * torch.ones_like(w[:, :stop])
|
||||
factor = max_value / torch.maximum(max_values,
|
||||
torch.repeat_interleave(
|
||||
torch.abs(w[:, :stop:2]) + torch.abs(w[:, 1:stop:2]),
|
||||
2,
|
||||
1))
|
||||
with torch.no_grad():
|
||||
w[:, :stop] *= factor
|
||||
|
||||
def clip_weights(module):
|
||||
if isinstance(module, nn.GRU) or isinstance(module, nn.Linear):
|
||||
for name, w in module.named_parameters():
|
||||
if name.startswith('weight'):
|
||||
clip_weight_(w)
|
||||
|
||||
return clip_weights
|
||||
|
||||
# RDOVAE module and submodules
|
||||
|
||||
|
||||
class CoreEncoder(nn.Module):
|
||||
STATE_HIDDEN = 128
|
||||
FRAMES_PER_STEP = 2
|
||||
CONV_KERNEL_SIZE = 4
|
||||
|
||||
def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24):
|
||||
""" core encoder for RDOVAE
|
||||
|
||||
Computes latents, initial states, and rate estimates from features and lambda parameter
|
||||
|
||||
"""
|
||||
|
||||
super(CoreEncoder, self).__init__()
|
||||
|
||||
# hyper parameters
|
||||
self.feature_dim = feature_dim
|
||||
self.output_dim = output_dim
|
||||
self.cond_size = cond_size
|
||||
self.cond_size2 = cond_size2
|
||||
self.state_size = state_size
|
||||
|
||||
# derived parameters
|
||||
self.input_dim = self.FRAMES_PER_STEP * self.feature_dim
|
||||
self.conv_input_channels = 5 * cond_size + 3 * cond_size2
|
||||
|
||||
# layers
|
||||
self.dense_1 = nn.Linear(self.input_dim, self.cond_size2)
|
||||
self.gru_1 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
|
||||
self.dense_2 = nn.Linear(self.cond_size, self.cond_size2)
|
||||
self.gru_2 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
|
||||
self.dense_3 = nn.Linear(self.cond_size, self.cond_size2)
|
||||
self.gru_3 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
|
||||
self.dense_4 = nn.Linear(self.cond_size, self.cond_size)
|
||||
self.dense_5 = nn.Linear(self.cond_size, self.cond_size)
|
||||
self.conv1 = nn.Conv1d(self.conv_input_channels, self.output_dim, kernel_size=self.CONV_KERNEL_SIZE, padding='valid')
|
||||
|
||||
self.state_dense_1 = nn.Linear(self.conv_input_channels, self.STATE_HIDDEN)
|
||||
|
||||
self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
|
||||
def forward(self, features):
|
||||
|
||||
# reshape features
|
||||
x = torch.reshape(features, (features.size(0), features.size(1) // self.FRAMES_PER_STEP, self.FRAMES_PER_STEP * features.size(2)))
|
||||
|
||||
batch = x.size(0)
|
||||
device = x.device
|
||||
|
||||
# run encoding layer stack
|
||||
x1 = torch.tanh(self.dense_1(x))
|
||||
x2, _ = self.gru_1(x1, torch.zeros((1, batch, self.cond_size)).to(device))
|
||||
x3 = torch.tanh(self.dense_2(x2))
|
||||
x4, _ = self.gru_2(x3, torch.zeros((1, batch, self.cond_size)).to(device))
|
||||
x5 = torch.tanh(self.dense_3(x4))
|
||||
x6, _ = self.gru_3(x5, torch.zeros((1, batch, self.cond_size)).to(device))
|
||||
x7 = torch.tanh(self.dense_4(x6))
|
||||
x8 = torch.tanh(self.dense_5(x7))
|
||||
|
||||
# concatenation of all hidden layer outputs
|
||||
x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
|
||||
|
||||
# init state for decoder
|
||||
states = torch.tanh(self.state_dense_1(x9))
|
||||
states = torch.tanh(self.state_dense_2(states))
|
||||
|
||||
# latent representation via convolution
|
||||
x9 = F.pad(x9.permute(0, 2, 1), [self.CONV_KERNEL_SIZE - 1, 0])
|
||||
z = self.conv1(x9).permute(0, 2, 1)
|
||||
|
||||
return z, states
|
||||
|
||||
|
||||
|
||||
|
||||
class CoreDecoder(nn.Module):
|
||||
|
||||
FRAMES_PER_STEP = 4
|
||||
|
||||
def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24):
|
||||
""" core decoder for RDOVAE
|
||||
|
||||
Computes features from latents, initial state, and quantization index
|
||||
|
||||
"""
|
||||
|
||||
super(CoreDecoder, self).__init__()
|
||||
|
||||
# hyper parameters
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.cond_size = cond_size
|
||||
self.cond_size2 = cond_size2
|
||||
self.state_size = state_size
|
||||
|
||||
self.input_size = self.input_dim
|
||||
|
||||
self.concat_size = 4 * self.cond_size + 4 * self.cond_size2
|
||||
|
||||
# layers
|
||||
self.dense_1 = nn.Linear(self.input_size, cond_size2)
|
||||
self.gru_1 = nn.GRU(cond_size2, cond_size, batch_first=True)
|
||||
self.dense_2 = nn.Linear(cond_size, cond_size2)
|
||||
self.gru_2 = nn.GRU(cond_size2, cond_size, batch_first=True)
|
||||
self.dense_3 = nn.Linear(cond_size, cond_size2)
|
||||
self.gru_3 = nn.GRU(cond_size2, cond_size, batch_first=True)
|
||||
self.dense_4 = nn.Linear(cond_size, cond_size2)
|
||||
self.dense_5 = nn.Linear(cond_size2, cond_size2)
|
||||
|
||||
self.output = nn.Linear(self.concat_size, self.FRAMES_PER_STEP * self.output_dim)
|
||||
|
||||
|
||||
self.gru_1_init = nn.Linear(self.state_size, self.cond_size)
|
||||
self.gru_2_init = nn.Linear(self.state_size, self.cond_size)
|
||||
self.gru_3_init = nn.Linear(self.state_size, self.cond_size)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, z, initial_state):
|
||||
|
||||
gru_1_state = torch.tanh(self.gru_1_init(initial_state).permute(1, 0, 2))
|
||||
gru_2_state = torch.tanh(self.gru_2_init(initial_state).permute(1, 0, 2))
|
||||
gru_3_state = torch.tanh(self.gru_3_init(initial_state).permute(1, 0, 2))
|
||||
|
||||
# run decoding layer stack
|
||||
x1 = torch.tanh(self.dense_1(z))
|
||||
x2, _ = self.gru_1(x1, gru_1_state)
|
||||
x3 = torch.tanh(self.dense_2(x2))
|
||||
x4, _ = self.gru_2(x3, gru_2_state)
|
||||
x5 = torch.tanh(self.dense_3(x4))
|
||||
x6, _ = self.gru_3(x5, gru_3_state)
|
||||
x7 = torch.tanh(self.dense_4(x6))
|
||||
x8 = torch.tanh(self.dense_5(x7))
|
||||
x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
|
||||
|
||||
# output layer and reshaping
|
||||
x10 = self.output(x9)
|
||||
features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP))
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class StatisticalModel(nn.Module):
|
||||
def __init__(self, quant_levels, latent_dim):
|
||||
""" Statistical model for latent space
|
||||
|
||||
Computes scaling, deadzone, r, and theta
|
||||
|
||||
"""
|
||||
|
||||
super(StatisticalModel, self).__init__()
|
||||
|
||||
# copy parameters
|
||||
self.latent_dim = latent_dim
|
||||
self.quant_levels = quant_levels
|
||||
self.embedding_dim = 6 * latent_dim
|
||||
|
||||
# quantization embedding
|
||||
self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim)
|
||||
|
||||
# initialize embedding to 0
|
||||
with torch.no_grad():
|
||||
self.quant_embedding.weight[:] = 0
|
||||
|
||||
|
||||
def forward(self, quant_ids):
|
||||
""" takes quant_ids and returns statistical model parameters"""
|
||||
|
||||
x = self.quant_embedding(quant_ids)
|
||||
|
||||
# CAVE: theta_soft is not used anymore. Kick it out?
|
||||
quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim])
|
||||
dead_zone = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim])
|
||||
theta_soft = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim])
|
||||
r_soft = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim])
|
||||
theta_hard = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim])
|
||||
r_hard = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim])
|
||||
|
||||
|
||||
return {
|
||||
'quant_embedding' : x,
|
||||
'quant_scale' : quant_scale,
|
||||
'dead_zone' : dead_zone,
|
||||
'r_hard' : r_hard,
|
||||
'theta_hard' : theta_hard,
|
||||
'r_soft' : r_soft,
|
||||
'theta_soft' : theta_soft
|
||||
}
|
||||
|
||||
|
||||
class RDOVAE(nn.Module):
|
||||
def __init__(self,
|
||||
feature_dim,
|
||||
latent_dim,
|
||||
quant_levels,
|
||||
cond_size,
|
||||
cond_size2,
|
||||
state_dim=24,
|
||||
split_mode='split',
|
||||
clip_weights=True,
|
||||
pvq_num_pulses=82,
|
||||
state_dropout_rate=0):
|
||||
|
||||
super(RDOVAE, self).__init__()
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.latent_dim = latent_dim
|
||||
self.quant_levels = quant_levels
|
||||
self.cond_size = cond_size
|
||||
self.cond_size2 = cond_size2
|
||||
self.split_mode = split_mode
|
||||
self.state_dim = state_dim
|
||||
self.pvq_num_pulses = pvq_num_pulses
|
||||
self.state_dropout_rate = state_dropout_rate
|
||||
|
||||
# submodules encoder and decoder share the statistical model
|
||||
self.statistical_model = StatisticalModel(quant_levels, latent_dim)
|
||||
self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
|
||||
self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
|
||||
|
||||
self.enc_stride = CoreEncoder.FRAMES_PER_STEP
|
||||
self.dec_stride = CoreDecoder.FRAMES_PER_STEP
|
||||
|
||||
if clip_weights:
|
||||
self.weight_clip_fn = weight_clip_factory(0.496)
|
||||
else:
|
||||
self.weight_clip_fn = None
|
||||
|
||||
if self.dec_stride % self.enc_stride != 0:
|
||||
raise ValueError(f"get_decoder_chunks_generic: encoder stride does not divide decoder stride")
|
||||
|
||||
def clip_weights(self):
|
||||
if not type(self.weight_clip_fn) == type(None):
|
||||
self.apply(self.weight_clip_fn)
|
||||
|
||||
def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
|
||||
|
||||
enc_stride = self.enc_stride
|
||||
dec_stride = self.dec_stride
|
||||
|
||||
stride = dec_stride // enc_stride
|
||||
|
||||
chunks = []
|
||||
|
||||
for offset in range(stride):
|
||||
# start is the smalles number = offset mod stride that decodes to a valid range
|
||||
start = offset
|
||||
while enc_stride * (start + 1) - dec_stride < 0:
|
||||
start += stride
|
||||
|
||||
# check if start is a valid index
|
||||
if start >= z_frames:
|
||||
raise ValueError("get_decoder_chunks_generic: range too small")
|
||||
|
||||
# stop is the smallest number outside [0, num_enc_frames] that's congruent to offset mod stride
|
||||
stop = z_frames - (z_frames % stride) + offset
|
||||
while stop < z_frames:
|
||||
stop += stride
|
||||
|
||||
# calculate split points
|
||||
length = (stop - start)
|
||||
if mode == 'split':
|
||||
split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop]
|
||||
elif mode == 'random_split':
|
||||
split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)]
|
||||
else:
|
||||
raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}")
|
||||
|
||||
|
||||
for i in range(chunks_per_offset):
|
||||
# (enc_frame_start, enc_frame_stop, enc_frame_stride, stride, feature_frame_start, feature_frame_stop)
|
||||
# encoder range(i, j, stride) maps to feature range(enc_stride * (i + 1) - dec_stride, enc_stride * j)
|
||||
# provided that i - j = 1 mod stride
|
||||
chunks.append({
|
||||
'z_start' : split_points[i],
|
||||
'z_stop' : split_points[i + 1] - stride + 1,
|
||||
'z_stride' : stride,
|
||||
'features_start' : enc_stride * (split_points[i] + 1) - dec_stride,
|
||||
'features_stop' : enc_stride * (split_points[i + 1] - stride + 1)
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def forward(self, features, q_id):
|
||||
|
||||
# calculate statistical model from quantization ID
|
||||
statistical_model = self.statistical_model(q_id)
|
||||
|
||||
# run encoder
|
||||
z, states = self.core_encoder(features)
|
||||
|
||||
# scaling, dead-zone and quantization
|
||||
z = z * statistical_model['quant_scale']
|
||||
z = soft_dead_zone(z, statistical_model['dead_zone'])
|
||||
|
||||
# quantization
|
||||
z_q = hard_quantize(z) / statistical_model['quant_scale']
|
||||
z_n = noise_quantize(z) / statistical_model['quant_scale']
|
||||
states_q = soft_pvq(states, self.pvq_num_pulses)
|
||||
|
||||
if self.state_dropout_rate > 0:
|
||||
drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
|
||||
mask = torch.ones_like(states_q)
|
||||
mask[drop] = 0
|
||||
states_q = states_q * mask
|
||||
|
||||
# decoder
|
||||
chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode)
|
||||
|
||||
outputs_hq = []
|
||||
outputs_sq = []
|
||||
for chunk in chunks:
|
||||
# decoder with hard quantized input
|
||||
z_dec_reverse = torch.flip(z_q[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
|
||||
dec_initial_state = states_q[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
|
||||
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
|
||||
outputs_hq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
|
||||
|
||||
|
||||
# decoder with soft quantized input
|
||||
z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
|
||||
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
|
||||
outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
|
||||
|
||||
return {
|
||||
'outputs_hard_quant' : outputs_hq,
|
||||
'outputs_soft_quant' : outputs_sq,
|
||||
'z' : z,
|
||||
'statistical_model' : statistical_model
|
||||
}
|
||||
|
||||
def encode(self, features):
|
||||
""" encoder with quantization and rate estimation """
|
||||
|
||||
z, states = self.core_encoder(features)
|
||||
|
||||
# quantization of initial states
|
||||
states = soft_pvq(states, self.pvq_num_pulses)
|
||||
state_size = m.log2(pvq_codebook_size(self.state_dim, self.pvq_num_pulses))
|
||||
|
||||
return z, states, state_size
|
||||
|
||||
def decode(self, z, initial_state):
|
||||
""" decoder (flips sequences by itself) """
|
||||
|
||||
z_reverse = torch.flip(z, [1])
|
||||
features_reverse = self.core_decoder(z_reverse, initial_state)
|
||||
features = torch.flip(features_reverse, [1])
|
||||
|
||||
return features
|
||||
|
||||
def quantize(self, z, q_ids):
|
||||
""" quantization of latent vectors """
|
||||
|
||||
stats = self.statistical_model(q_ids)
|
||||
|
||||
zq = z * stats['quant_scale']
|
||||
zq = soft_dead_zone(zq, stats['dead_zone'])
|
||||
zq = torch.round(zq)
|
||||
|
||||
sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False)
|
||||
|
||||
return zq, sizes
|
||||
|
||||
def unquantize(self, zq, q_ids):
|
||||
""" re-scaling of latent vector """
|
||||
|
||||
stats = self.statistical_model(q_ids)
|
||||
|
||||
z = zq / stats['quant_scale']
|
||||
|
||||
return z
|
||||
|
||||
def freeze_model(self):
|
||||
|
||||
# freeze all parameters
|
||||
for p in self.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
for p in self.statistical_model.parameters():
|
||||
p.requires_grad = True
|
||||
|
5
dnn/torch/rdovae/requirements.txt
Normal file
5
dnn/torch/rdovae/requirements.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
numpy
|
||||
scipy
|
||||
torch
|
||||
tqdm
|
||||
libs/wexchange-1.0-py3-none-any.whl
|
270
dnn/torch/rdovae/train_rdovae.py
Normal file
270
dnn/torch/rdovae/train_rdovae.py
Normal file
|
@ -0,0 +1,270 @@
|
|||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from rdovae import RDOVAE, RDOVAEDataset, distortion_loss, hard_rate_estimate, soft_rate_estimate
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
|
||||
parser.add_argument('output', type=str, help='path to output folder')
|
||||
|
||||
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="")
|
||||
|
||||
|
||||
model_group = parser.add_argument_group(title="model parameters")
|
||||
model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
|
||||
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
|
||||
model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
|
||||
model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
|
||||
model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 16", default=16)
|
||||
model_group.add_argument('--lambda-min', type=float, help="minimal value for rate lambda, default: 0.0002", default=2e-4)
|
||||
model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104)
|
||||
model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82)
|
||||
model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0)
|
||||
|
||||
training_group = parser.add_argument_group(title="training parameters")
|
||||
training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
|
||||
training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4)
|
||||
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100)
|
||||
training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by 4, default: 256', default=256)
|
||||
training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5)
|
||||
training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split')
|
||||
training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames')
|
||||
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
|
||||
training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# set visible devices
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
|
||||
|
||||
# checkpoints
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
checkpoint = dict()
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# training parameters
|
||||
batch_size = args.batch_size
|
||||
lr = args.lr
|
||||
epochs = args.epochs
|
||||
sequence_length = args.sequence_length
|
||||
lr_decay_factor = args.lr_decay_factor
|
||||
split_mode = args.split_mode
|
||||
# not exposed
|
||||
adam_betas = [0.9, 0.99]
|
||||
adam_eps = 1e-8
|
||||
|
||||
checkpoint['batch_size'] = batch_size
|
||||
checkpoint['lr'] = lr
|
||||
checkpoint['lr_decay_factor'] = lr_decay_factor
|
||||
checkpoint['split_mode'] = split_mode
|
||||
checkpoint['epochs'] = epochs
|
||||
checkpoint['sequence_length'] = sequence_length
|
||||
checkpoint['adam_betas'] = adam_betas
|
||||
|
||||
# logging
|
||||
log_interval = 10
|
||||
|
||||
# device
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
# model parameters
|
||||
cond_size = args.cond_size
|
||||
cond_size2 = args.cond_size2
|
||||
latent_dim = args.latent_dim
|
||||
quant_levels = args.quant_levels
|
||||
lambda_min = args.lambda_min
|
||||
lambda_max = args.lambda_max
|
||||
state_dim = args.state_dim
|
||||
# not expsed
|
||||
num_features = 20
|
||||
|
||||
|
||||
# training data
|
||||
feature_file = args.features
|
||||
|
||||
# model
|
||||
checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
|
||||
checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate}
|
||||
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
|
||||
if type(args.initial_checkpoint) != type(None):
|
||||
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
|
||||
if args.train_decoder_only:
|
||||
if args.initial_checkpoint is None:
|
||||
print("warning: training decoder only without providing initial checkpoint")
|
||||
|
||||
for p in model.core_encoder.module.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
for p in model.statistical_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# dataloader
|
||||
checkpoint['dataset_args'] = (feature_file, sequence_length, num_features, 36)
|
||||
checkpoint['dataset_kwargs'] = {'lambda_min': lambda_min, 'lambda_max': lambda_max, 'enc_stride': model.enc_stride, 'quant_levels': quant_levels}
|
||||
dataset = RDOVAEDataset(*checkpoint['dataset_args'], **checkpoint['dataset_kwargs'])
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
|
||||
|
||||
|
||||
|
||||
# optimizer
|
||||
params = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.Adam(params, lr=lr, betas=adam_betas, eps=adam_eps)
|
||||
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# push model to device
|
||||
model.to(device)
|
||||
|
||||
# training loop
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
|
||||
print(f"training epoch {epoch}...")
|
||||
|
||||
# running stats
|
||||
running_rate_loss = 0
|
||||
running_soft_dist_loss = 0
|
||||
running_hard_dist_loss = 0
|
||||
running_hard_rate_loss = 0
|
||||
running_soft_rate_loss = 0
|
||||
running_total_loss = 0
|
||||
running_rate_metric = 0
|
||||
previous_total_loss = 0
|
||||
running_first_frame_loss = 0
|
||||
|
||||
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
|
||||
for i, (features, rate_lambda, q_ids) in enumerate(tepoch):
|
||||
|
||||
# zero out gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# push inputs to device
|
||||
features = features.to(device)
|
||||
q_ids = q_ids.to(device)
|
||||
rate_lambda = rate_lambda.to(device)
|
||||
|
||||
|
||||
rate_lambda_upsamp = torch.repeat_interleave(rate_lambda, 2, 1)
|
||||
|
||||
# run model
|
||||
model_output = model(features, q_ids)
|
||||
|
||||
# collect outputs
|
||||
z = model_output['z']
|
||||
outputs_hard_quant = model_output['outputs_hard_quant']
|
||||
outputs_soft_quant = model_output['outputs_soft_quant']
|
||||
statistical_model = model_output['statistical_model']
|
||||
|
||||
# rate loss
|
||||
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False)
|
||||
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False)
|
||||
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate)
|
||||
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate)
|
||||
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
|
||||
hard_rate_metric = torch.mean(hard_rate)
|
||||
|
||||
## distortion losses
|
||||
|
||||
# hard quantized decoder input
|
||||
distortion_loss_hard_quant = torch.zeros_like(rate_loss)
|
||||
for dec_features, start, stop in outputs_hard_quant:
|
||||
distortion_loss_hard_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_hard_quant)
|
||||
|
||||
first_frame_loss = torch.zeros_like(rate_loss)
|
||||
for dec_features, start, stop in outputs_hard_quant:
|
||||
first_frame_loss += distortion_loss(features[..., stop-4 : stop, :], dec_features[..., -4:, :], rate_lambda_upsamp[..., stop - 4 : stop]) / len(outputs_hard_quant)
|
||||
|
||||
# soft quantized decoder input
|
||||
distortion_loss_soft_quant = torch.zeros_like(rate_loss)
|
||||
for dec_features, start, stop in outputs_soft_quant:
|
||||
distortion_loss_soft_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_soft_quant)
|
||||
|
||||
# total loss
|
||||
total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2
|
||||
|
||||
if args.enable_first_frame_loss:
|
||||
total_loss = total_loss + 0.5 * torch.relu(first_frame_loss - distortion_loss_hard_quant)
|
||||
|
||||
|
||||
total_loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
model.clip_weights()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
# collect running stats
|
||||
running_hard_dist_loss += float(distortion_loss_hard_quant.detach().cpu())
|
||||
running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu())
|
||||
running_rate_loss += float(rate_loss.detach().cpu())
|
||||
running_rate_metric += float(hard_rate_metric.detach().cpu())
|
||||
running_total_loss += float(total_loss.detach().cpu())
|
||||
running_first_frame_loss += float(first_frame_loss.detach().cpu())
|
||||
running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
|
||||
running_hard_rate_loss += float(hard_rate_loss.detach().cpu())
|
||||
|
||||
if (i + 1) % log_interval == 0:
|
||||
current_loss = (running_total_loss - previous_total_loss) / log_interval
|
||||
tepoch.set_postfix(
|
||||
current_loss=current_loss,
|
||||
total_loss=running_total_loss / (i + 1),
|
||||
dist_hq=running_hard_dist_loss / (i + 1),
|
||||
dist_sq=running_soft_dist_loss / (i + 1),
|
||||
rate_loss=running_rate_loss / (i + 1),
|
||||
rate=running_rate_metric / (i + 1),
|
||||
ffloss=running_first_frame_loss / (i + 1),
|
||||
rateloss_hard=running_hard_rate_loss / (i + 1),
|
||||
rateloss_soft=running_soft_rate_loss / (i + 1)
|
||||
)
|
||||
previous_total_loss = running_total_loss
|
||||
|
||||
# save checkpoint
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['loss'] = running_total_loss / len(dataloader)
|
||||
checkpoint['epoch'] = epoch
|
||||
torch.save(checkpoint, checkpoint_path)
|
Loading…
Add table
Add a link
Reference in a new issue