diff --git a/dnn/torch/rdovae/README.md b/dnn/torch/rdovae/README.md
new file mode 100644
index 00000000..14359d82
--- /dev/null
+++ b/dnn/torch/rdovae/README.md
@@ -0,0 +1,24 @@
+# Rate-Distortion-Optimized Variational Auto-Encoder
+
+## Setup
+The python code requires python >= 3.6 and has been tested with python 3.6 and python 3.10. To install requirements run
+```
+python -m pip install -r requirements.txt
+```
+
+## Training
+To generate training data use dump date from the main LPCNet repo
+```
+./dump_data -train 16khz_speech_input.s16 features.f32 data.s16
+```
+
+To train the model, simply run
+```
+python train_rdovae.py features.f32 output_folder
+```
+
+To train on CUDA device add `--cuda-visible-devices idx`.
+
+
+## ToDo
+- Upload checkpoints and add URLs
diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py
new file mode 100644
index 00000000..35b43704
--- /dev/null
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -0,0 +1,256 @@
+"""
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('checkpoint', type=str, help='rdovae model checkpoint')
+parser.add_argument('output_dir', type=str, help='output folder')
+parser.add_argument('--format', choices=['C', 'numpy'], help='output format, default: C', default='C')
+
+args = parser.parse_args()
+
+import torch
+import numpy as np
+
+from rdovae import RDOVAE
+from wexchange.torch import dump_torch_weights
+from wexchange.c_export import CWriter, print_vector
+
+
+def dump_statistical_model(writer, qembedding):
+ w = qembedding.weight.detach()
+ levels, dim = w.shape
+ N = dim // 6
+
+ print("printing statistical model")
+ quant_scales = torch.nn.functional.softplus(w[:, : N]).numpy()
+ dead_zone = 0.05 * torch.nn.functional.softplus(w[:, N : 2 * N]).numpy()
+ r = torch.sigmoid(w[:, 5 * N : 6 * N]).numpy()
+ p0 = torch.sigmoid(w[:, 4 * N : 5 * N]).numpy()
+ p0 = 1 - r ** (0.5 + 0.5 * p0)
+
+ quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
+ dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16)
+ r_q15 = np.round(r * 2**15).astype(np.uint16)
+ p0_q15 = np.round(p0 * 2**15).astype(np.uint16)
+
+ print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
+ print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
+ print_vector(writer.source, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False)
+ print_vector(writer.source, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False)
+
+ writer.header.write(
+f"""
+extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
+extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
+extern const opus_uint16 dred_r_q15[{levels * N}];
+extern const opus_uint16 dred_p0_q15[{levels * N}];
+
+"""
+ )
+
+
+def c_export(args, model):
+
+ message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
+
+ enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message)
+ dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message)
+ stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats"), message=message)
+ constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True)
+
+ # some custom includes
+ for writer in [enc_writer, dec_writer, stats_writer]:
+ writer.header.write(
+f"""
+#include "opus_types.h"
+
+#include "dred_rdovae_constants.h"
+
+#include "nnet.h"
+"""
+ )
+
+ # encoder
+ encoder_dense_layers = [
+ ('core_encoder.module.dense_1' , 'enc_dense1', 'TANH'),
+ ('core_encoder.module.dense_2' , 'enc_dense3', 'TANH'),
+ ('core_encoder.module.dense_3' , 'enc_dense5', 'TANH'),
+ ('core_encoder.module.dense_4' , 'enc_dense7', 'TANH'),
+ ('core_encoder.module.dense_5' , 'enc_dense8', 'TANH'),
+ ('core_encoder.module.state_dense_1' , 'gdense1' , 'TANH'),
+ ('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH')
+ ]
+
+ for name, export_name, activation in encoder_dense_layers:
+ layer = model.get_submodule(name)
+ dump_torch_weights(enc_writer, layer, name=export_name, activation=activation, verbose=True)
+
+
+ encoder_gru_layers = [
+ ('core_encoder.module.gru_1' , 'enc_dense2', 'TANH'),
+ ('core_encoder.module.gru_2' , 'enc_dense4', 'TANH'),
+ ('core_encoder.module.gru_3' , 'enc_dense6', 'TANH')
+ ]
+
+ enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_gru_layers])
+
+
+ encoder_conv_layers = [
+ ('core_encoder.module.conv1' , 'bits_dense' , 'LINEAR')
+ ]
+
+ enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_conv_layers])
+
+
+ del enc_writer
+
+ # decoder
+ decoder_dense_layers = [
+ ('core_decoder.module.gru_1_init' , 'state1', 'TANH'),
+ ('core_decoder.module.gru_2_init' , 'state2', 'TANH'),
+ ('core_decoder.module.gru_3_init' , 'state3', 'TANH'),
+ ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH'),
+ ('core_decoder.module.dense_2' , 'dec_dense3', 'TANH'),
+ ('core_decoder.module.dense_3' , 'dec_dense5', 'TANH'),
+ ('core_decoder.module.dense_4' , 'dec_dense7', 'TANH'),
+ ('core_decoder.module.dense_5' , 'dec_dense8', 'TANH'),
+ ('core_decoder.module.output' , 'dec_final', 'LINEAR')
+ ]
+
+ for name, export_name, activation in decoder_dense_layers:
+ layer = model.get_submodule(name)
+ dump_torch_weights(dec_writer, layer, name=export_name, activation=activation, verbose=True)
+
+
+ decoder_gru_layers = [
+ ('core_decoder.module.gru_1' , 'dec_dense2', 'TANH'),
+ ('core_decoder.module.gru_2' , 'dec_dense4', 'TANH'),
+ ('core_decoder.module.gru_3' , 'dec_dense6', 'TANH')
+ ]
+
+ dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in decoder_gru_layers])
+
+ del dec_writer
+
+ # statistical model
+ qembedding = model.statistical_model.quant_embedding
+ dump_statistical_model(stats_writer, qembedding)
+
+ del stats_writer
+
+ # constants
+ constants_writer.header.write(
+f"""
+#define DRED_NUM_FEATURES {model.feature_dim}
+
+#define DRED_LATENT_DIM {model.latent_dim}
+
+#define DRED_STATE_DIME {model.state_dim}
+
+#define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels}
+
+#define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)}
+
+#define DRED_MAX_CONV_INPUTS {enc_max_conv_inputs}
+
+#define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs}
+
+#define DRED_ENC_MAX_CONV_INPUTS {enc_max_conv_inputs}
+
+#define DRED_DEC_MAX_RNN_NEURONS {dec_max_rnn_units}
+
+"""
+ )
+
+ del constants_writer
+
+
+def numpy_export(args, model):
+
+ exchange_name_to_name = {
+ 'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1',
+ 'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2',
+ 'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3',
+ 'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4',
+ 'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5',
+ 'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1',
+ 'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2',
+ 'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1',
+ 'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2',
+ 'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3',
+ 'encoder_stack_layer9_conv' : 'core_encoder.module.conv1',
+ 'statistical_model_embedding' : 'statistical_model.quant_embedding',
+ 'decoder_state1_dense' : 'core_decoder.module.gru_1_init',
+ 'decoder_state2_dense' : 'core_decoder.module.gru_2_init',
+ 'decoder_state3_dense' : 'core_decoder.module.gru_3_init',
+ 'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1',
+ 'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2',
+ 'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3',
+ 'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4',
+ 'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5',
+ 'decoder_stack_layer9_dense' : 'core_decoder.module.output',
+ 'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1',
+ 'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2',
+ 'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3'
+ }
+
+ name_to_exchange_name = {value : key for key, value in exchange_name_to_name.items()}
+
+ for name, exchange_name in name_to_exchange_name.items():
+ print(f"printing layer {name}...")
+ dump_torch_weights(os.path.join(args.output_dir, exchange_name), model.get_submodule(name))
+
+
+if __name__ == "__main__":
+
+
+ os.makedirs(args.output_dir, exist_ok=True)
+
+
+ # load model from checkpoint
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
+ model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+ missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
+
+ if len(missing_keys) > 0:
+ raise ValueError(f"error: missing keys in state dict")
+
+ if len(unmatched_keys) > 0:
+ print(f"warning: the following keys were unmatched {unmatched_keys}")
+
+ if args.format == 'C':
+ c_export(args, model)
+ elif args.format == 'numpy':
+ numpy_export(args, model)
+ else:
+ raise ValueError(f'error: unknown export format {args.format}')
\ No newline at end of file
diff --git a/dnn/torch/rdovae/fec_encoder.py b/dnn/torch/rdovae/fec_encoder.py
new file mode 100644
index 00000000..291c0628
--- /dev/null
+++ b/dnn/torch/rdovae/fec_encoder.py
@@ -0,0 +1,213 @@
+"""
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe and Jean-Marc Valin */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import subprocess
+import argparse
+
+os.environ['CUDA_VISIBLE_DEVICES'] = ""
+
+parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames')
+
+parser.add_argument('input', metavar='', help='audio input (.wav or .raw or .pcm as int16)')
+parser.add_argument('checkpoint', metavar='', help='model checkpoint')
+parser.add_argument('q0', metavar='', type=int, help='quantization level for most recent frame')
+parser.add_argument('q1', metavar='', type=int, help='quantization level for oldest frame')
+parser.add_argument('output', type=str, help='output file (will be extended with .fec)')
+
+parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)')
+parser.add_argument('--num-redundancy-frames', default=52, type=int, help='number of redundancy frames per packet (default 52)')
+parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)")
+parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)')
+parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk')
+
+args = parser.parse_args()
+
+import numpy as np
+from scipy.io import wavfile
+import torch
+
+from rdovae import RDOVAE
+from packets import write_fec_packets
+
+torch.set_num_threads(4)
+
+checkpoint = torch.load(args.checkpoint, map_location="cpu")
+model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+model.load_state_dict(checkpoint['state_dict'], strict=False)
+model.to("cpu")
+
+lpc_order = 16
+
+## prepare input signal
+# SILK frame size is 20ms and LPCNet subframes are 10ms
+subframe_size = 160
+frame_size = 2 * subframe_size
+
+# 91 samples delay to align with SILK decoded frames
+silk_delay = 91
+
+# prepend zeros to have enough history to produce the first package
+zero_history = (args.num_redundancy_frames - 1) * frame_size
+
+# dump data has a (feature) delay of 10ms
+dump_data_delay = 160
+
+total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay
+
+# load signal
+if args.input.endswith('.raw') or args.input.endswith('.pcm'):
+ signal = np.fromfile(args.input, dtype='int16')
+
+elif args.input.endswith('.wav'):
+ fs, signal = wavfile.read(args.input)
+else:
+ raise ValueError(f'unknown input signal format: {args.input}')
+
+# fill up last frame with zeros
+padded_signal_length = len(signal) + total_delay
+tail = padded_signal_length % frame_size
+right_padding = (frame_size - tail) % frame_size
+
+signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16)))
+
+padded_signal_file = os.path.splitext(args.input)[0] + '_padded.raw'
+signal.tofile(padded_signal_file)
+
+# write signal and call dump_data to create features
+
+feature_file = os.path.splitext(args.input)[0] + '_features.f32'
+command = f"{args.dump_data} -test {padded_signal_file} {feature_file}"
+r = subprocess.run(command, shell=True)
+if r.returncode != 0:
+ raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}")
+
+# load features
+nb_features = model.feature_dim + lpc_order
+nb_used_features = model.feature_dim
+
+# load features
+features = np.fromfile(feature_file, dtype='float32')
+num_subframes = len(features) // nb_features
+num_subframes = 2 * (num_subframes // 2)
+num_frames = num_subframes // 2
+
+features = np.reshape(features, (1, -1, nb_features))
+features = features[:, :, :nb_used_features]
+features = features[:, :num_subframes, :]
+
+# quant_ids in reverse decoding order
+quant_ids = torch.round((args.q1 + (args.q0 - args.q1) * torch.arange(args.num_redundancy_frames // 2) / (args.num_redundancy_frames // 2 - 1))).long()
+
+print(f"using quantization levels {quant_ids}...")
+
+# convert input to torch tensors
+features = torch.from_numpy(features)
+
+
+# run encoder
+print("running fec encoder...")
+with torch.no_grad():
+
+ # encoding
+ z, states, state_size = model.encode(features)
+
+
+ # decoder on packet chunks
+ input_length = args.num_redundancy_frames // 2
+ offset = args.num_redundancy_frames - 1
+
+ packets = []
+ packet_sizes = []
+
+ for i in range(offset, num_frames):
+ print(f"processing frame {i - offset}...")
+ # quantize / unquantize latent vectors
+ zi = torch.clone(z[:, i - 2 * input_length + 2: i + 1 : 2, :])
+ zi, rates = model.quantize(zi, quant_ids)
+ zi = model.unquantize(zi, quant_ids)
+
+ features = model.decode(zi, states[:, i : i + 1, :])
+ packets.append(features.squeeze(0).numpy())
+ packet_size = 8 * int((torch.sum(rates) + 7 + state_size) / 8)
+ packet_sizes.append(packet_size)
+
+
+# write packets
+packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output
+write_fec_packets(packet_file, packets, packet_sizes)
+
+
+print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps")
+
+# assemble features according to loss file
+if args.lossfile != None:
+ num_packets = len(packets)
+ loss = np.loadtxt(args.lossfile, dtype='int16')
+ fec_out = np.zeros((num_packets * 2, packets[0].shape[-1]), dtype='float32')
+ foffset = -2
+ ptr = 0
+ count = 2
+ for i in range(num_packets):
+ if (loss[i] == 0) or (i == num_packets - 1):
+
+ fec_out[ptr:ptr+count,:] = packets[i][foffset:, :]
+
+ ptr += count
+ foffset = -2
+ count = 2
+ else:
+ count += 2
+ foffset -= 2
+
+ fec_out_full = np.zeros((fec_out.shape[0], 36), dtype=np.float32)
+ fec_out_full[:, : fec_out.shape[-1]] = fec_out
+
+ fec_out_full.tofile(packet_file[:-4] + f'_fec.f32')
+
+
+if args.debug_output:
+ import itertools
+
+ batches = [4]
+ offsets = [0, 2 * args.num_redundancy_frames - 4]
+
+ # sanity checks
+ # 1. concatenate features at offset 0
+ for batch, offset in itertools.product(batches, offsets):
+
+ stop = packets[0].shape[1] - offset
+ test_features = np.concatenate([packet[stop - batch: stop, :] for packet in packets[::batch//2]], axis=0)
+
+ test_features_full = np.zeros((test_features.shape[0], nb_features), dtype=np.float32)
+ test_features_full[:, :nb_used_features] = test_features[:, :]
+
+ print(f"writing debug output {packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32'}")
+ test_features_full.tofile(packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32')
+
diff --git a/dnn/torch/rdovae/import_rdovae_weights.py b/dnn/torch/rdovae/import_rdovae_weights.py
new file mode 100644
index 00000000..eba05018
--- /dev/null
+++ b/dnn/torch/rdovae/import_rdovae_weights.py
@@ -0,0 +1,143 @@
+"""
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = ""
+
+import argparse
+
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('exchange_folder', type=str, help='exchange folder path')
+parser.add_argument('output', type=str, help='path to output model checkpoint')
+
+model_group = parser.add_argument_group(title="model parameters")
+model_group.add_argument('--num-features', type=int, help="number of features, default: 20", default=20)
+model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
+model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
+model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
+model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
+model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 40", default=40)
+
+args = parser.parse_args()
+
+import torch
+from rdovae import RDOVAE
+from wexchange.torch import load_torch_weights
+
+exchange_name_to_name = {
+ 'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1',
+ 'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2',
+ 'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3',
+ 'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4',
+ 'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5',
+ 'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1',
+ 'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2',
+ 'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1',
+ 'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2',
+ 'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3',
+ 'encoder_stack_layer9_conv' : 'core_encoder.module.conv1',
+ 'statistical_model_embedding' : 'statistical_model.quant_embedding',
+ 'decoder_state1_dense' : 'core_decoder.module.gru_1_init',
+ 'decoder_state2_dense' : 'core_decoder.module.gru_2_init',
+ 'decoder_state3_dense' : 'core_decoder.module.gru_3_init',
+ 'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1',
+ 'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2',
+ 'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3',
+ 'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4',
+ 'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5',
+ 'decoder_stack_layer9_dense' : 'core_decoder.module.output',
+ 'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1',
+ 'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2',
+ 'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3'
+}
+
+if __name__ == "__main__":
+ checkpoint = dict()
+
+ # parameters
+ num_features = args.num_features
+ latent_dim = args.latent_dim
+ quant_levels = args.quant_levels
+ cond_size = args.cond_size
+ cond_size2 = args.cond_size2
+ state_dim = args.state_dim
+
+
+ # model
+ checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
+ checkpoint['model_kwargs'] = {'state_dim': state_dim}
+ model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+
+ dense_layer_names = [
+ 'encoder_stack_layer1_dense',
+ 'encoder_stack_layer3_dense',
+ 'encoder_stack_layer5_dense',
+ 'encoder_stack_layer7_dense',
+ 'encoder_stack_layer8_dense',
+ 'encoder_state_layer1_dense',
+ 'encoder_state_layer2_dense',
+ 'decoder_state1_dense',
+ 'decoder_state2_dense',
+ 'decoder_state3_dense',
+ 'decoder_stack_layer1_dense',
+ 'decoder_stack_layer3_dense',
+ 'decoder_stack_layer5_dense',
+ 'decoder_stack_layer7_dense',
+ 'decoder_stack_layer8_dense',
+ 'decoder_stack_layer9_dense'
+ ]
+
+ gru_layer_names = [
+ 'encoder_stack_layer2_gru',
+ 'encoder_stack_layer4_gru',
+ 'encoder_stack_layer6_gru',
+ 'decoder_stack_layer2_gru',
+ 'decoder_stack_layer4_gru',
+ 'decoder_stack_layer6_gru'
+ ]
+
+ conv1d_layer_names = [
+ 'encoder_stack_layer9_conv'
+ ]
+
+ embedding_layer_names = [
+ 'statistical_model_embedding'
+ ]
+
+ for name in dense_layer_names + gru_layer_names + conv1d_layer_names + embedding_layer_names:
+ print(f"loading weights for layer {exchange_name_to_name[name]}")
+ layer = model.get_submodule(exchange_name_to_name[name])
+ load_torch_weights(os.path.join(args.exchange_folder, name), layer)
+
+ checkpoint['state_dict'] = model.state_dict()
+
+ torch.save(checkpoint, args.output)
\ No newline at end of file
diff --git a/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl b/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl
new file mode 100644
index 00000000..cfeebae5
Binary files /dev/null and b/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl differ
diff --git a/dnn/torch/rdovae/packets/__init__.py b/dnn/torch/rdovae/packets/__init__.py
new file mode 100644
index 00000000..fb71ab3d
--- /dev/null
+++ b/dnn/torch/rdovae/packets/__init__.py
@@ -0,0 +1 @@
+from .fec_packets import write_fec_packets, read_fec_packets
\ No newline at end of file
diff --git a/dnn/torch/rdovae/packets/fec_packets.c b/dnn/torch/rdovae/packets/fec_packets.c
new file mode 100644
index 00000000..376fb4f1
--- /dev/null
+++ b/dnn/torch/rdovae/packets/fec_packets.c
@@ -0,0 +1,142 @@
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+
+#include
+#include
+
+#include "fec_packets.h"
+
+int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index)
+{
+
+ int16_t version;
+ int16_t header_size;
+ int16_t num_packets;
+ int16_t packet_size;
+ int16_t subframe_size;
+ int16_t subframes_per_packet;
+ int16_t num_features;
+ long offset;
+
+ FILE *fid = fopen(filename, "rb");
+
+ /* read header */
+ if (fread(&version, sizeof(version), 1, fid) != 1) goto error;
+ if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error;
+ if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error;
+ if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error;
+ if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error;
+ if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error;
+ if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error;
+
+ /* check if indices are valid */
+ if (packet_index >= num_packets || subframe_index >= subframes_per_packet)
+ {
+ fprintf(stderr, "get_fec_frame: index out of bounds\n");
+ goto error;
+ }
+
+ /* calculate offset in file (+ 2 is for rate) */
+ offset = header_size + packet_index * packet_size + 2 + subframe_index * subframe_size;
+ fseek(fid, offset, SEEK_SET);
+
+ /* read features */
+ if (fread(features, sizeof(*features), num_features, fid) != num_features) goto error;
+
+ fclose(fid);
+ return 0;
+
+error:
+ fclose(fid);
+ return 1;
+}
+
+int get_fec_rate(const char * const filename, int packet_index)
+{
+ int16_t version;
+ int16_t header_size;
+ int16_t num_packets;
+ int16_t packet_size;
+ int16_t subframe_size;
+ int16_t subframes_per_packet;
+ int16_t num_features;
+ long offset;
+ int16_t rate;
+
+ FILE *fid = fopen(filename, "rb");
+
+ /* read header */
+ if (fread(&version, sizeof(version), 1, fid) != 1) goto error;
+ if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error;
+ if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error;
+ if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error;
+ if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error;
+ if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error;
+ if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error;
+
+ /* check if indices are valid */
+ if (packet_index >= num_packets)
+ {
+ fprintf(stderr, "get_fec_rate: index out of bounds\n");
+ goto error;
+ }
+
+ /* calculate offset in file (+ 2 is for rate) */
+ offset = header_size + packet_index * packet_size;
+ fseek(fid, offset, SEEK_SET);
+
+ /* read rate */
+ if (fread(&rate, sizeof(rate), 1, fid) != 1) goto error;
+
+ fclose(fid);
+ return (int) rate;
+
+error:
+ fclose(fid);
+ return -1;
+}
+
+#if 0
+int main()
+{
+ float features[20];
+ int i;
+
+ if (get_fec_frame("../test.fec", &features[0], 0, 127))
+ {
+ return 1;
+ }
+
+ for (i = 0; i < 20; i ++)
+ {
+ printf("%d %f\n", i, features[i]);
+ }
+
+ printf("rate: %d\n", get_fec_rate("../test.fec", 0));
+
+}
+#endif
\ No newline at end of file
diff --git a/dnn/torch/rdovae/packets/fec_packets.h b/dnn/torch/rdovae/packets/fec_packets.h
new file mode 100644
index 00000000..35d35542
--- /dev/null
+++ b/dnn/torch/rdovae/packets/fec_packets.h
@@ -0,0 +1,34 @@
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+
+#ifndef _FEC_PACKETS_H
+#define _FEC_PACKETS_H
+
+int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index);
+int get_fec_rate(const char * const filename, int packet_index);
+
+#endif
\ No newline at end of file
diff --git a/dnn/torch/rdovae/packets/fec_packets.py b/dnn/torch/rdovae/packets/fec_packets.py
new file mode 100644
index 00000000..14bed1f8
--- /dev/null
+++ b/dnn/torch/rdovae/packets/fec_packets.py
@@ -0,0 +1,108 @@
+"""
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import numpy as np
+
+
+
+def write_fec_packets(filename, packets, rates=None):
+ """ writes packets in binary format """
+
+ assert np.dtype(np.float32).itemsize == 4
+ assert np.dtype(np.int16).itemsize == 2
+
+ # derive some sizes
+ num_packets = len(packets)
+ subframes_per_packet = packets[0].shape[-2]
+ num_features = packets[0].shape[-1]
+
+ # size of float is 4
+ subframe_size = num_features * 4
+ packet_size = subframe_size * subframes_per_packet + 2 # two bytes for rate
+
+ version = 1
+ # header size (version, header_size, num_packets, packet_size, subframe_size, subrames_per_packet, num_features)
+ header_size = 14
+
+ with open(filename, 'wb') as f:
+
+ # header
+ f.write(np.int16(version).tobytes())
+ f.write(np.int16(header_size).tobytes())
+ f.write(np.int16(num_packets).tobytes())
+ f.write(np.int16(packet_size).tobytes())
+ f.write(np.int16(subframe_size).tobytes())
+ f.write(np.int16(subframes_per_packet).tobytes())
+ f.write(np.int16(num_features).tobytes())
+
+ # packets
+ for i, packet in enumerate(packets):
+ if type(rates) == type(None):
+ rate = 0
+ else:
+ rate = rates[i]
+
+ f.write(np.int16(rate).tobytes())
+
+ features = np.flip(packet, axis=-2)
+ f.write(features.astype(np.float32).tobytes())
+
+
+def read_fec_packets(filename):
+ """ reads packets from binary format """
+
+ assert np.dtype(np.float32).itemsize == 4
+ assert np.dtype(np.int16).itemsize == 2
+
+ with open(filename, 'rb') as f:
+
+ # header
+ version = np.frombuffer(f.read(2), dtype=np.int16).item()
+ header_size = np.frombuffer(f.read(2), dtype=np.int16).item()
+ num_packets = np.frombuffer(f.read(2), dtype=np.int16).item()
+ packet_size = np.frombuffer(f.read(2), dtype=np.int16).item()
+ subframe_size = np.frombuffer(f.read(2), dtype=np.int16).item()
+ subframes_per_packet = np.frombuffer(f.read(2), dtype=np.int16).item()
+ num_features = np.frombuffer(f.read(2), dtype=np.int16).item()
+
+ dummy_features = np.zeros((subframes_per_packet, num_features), dtype=np.float32)
+
+ # packets
+ rates = []
+ packets = []
+ for i in range(num_packets):
+
+ rate = np.frombuffer(f.read(2), dtype=np.int16).item
+ rates.append(rate)
+
+ features = np.reshape(np.frombuffer(f.read(subframe_size * subframes_per_packet), dtype=np.float32), dummy_features.shape)
+ packet = np.flip(features, axis=-2)
+ packets.append(packet)
+
+ return packets
\ No newline at end of file
diff --git a/dnn/torch/rdovae/rdovae/__init__.py b/dnn/torch/rdovae/rdovae/__init__.py
new file mode 100644
index 00000000..b945adde
--- /dev/null
+++ b/dnn/torch/rdovae/rdovae/__init__.py
@@ -0,0 +1,2 @@
+from .rdovae import RDOVAE, distortion_loss, hard_rate_estimate, soft_rate_estimate
+from .dataset import RDOVAEDataset
diff --git a/dnn/torch/rdovae/rdovae/dataset.py b/dnn/torch/rdovae/rdovae/dataset.py
new file mode 100644
index 00000000..99630d8b
--- /dev/null
+++ b/dnn/torch/rdovae/rdovae/dataset.py
@@ -0,0 +1,68 @@
+"""
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+import numpy as np
+
+class RDOVAEDataset(torch.utils.data.Dataset):
+ def __init__(self,
+ feature_file,
+ sequence_length,
+ num_used_features=20,
+ num_features=36,
+ lambda_min=0.0002,
+ lambda_max=0.0135,
+ quant_levels=16,
+ enc_stride=2):
+
+ self.sequence_length = sequence_length
+ self.lambda_min = lambda_min
+ self.lambda_max = lambda_max
+ self.enc_stride = enc_stride
+ self.quant_levels = quant_levels
+ self.denominator = (quant_levels - 1) / np.log(lambda_max / lambda_min)
+
+ if sequence_length % enc_stride:
+ raise ValueError(f"RDOVAEDataset.__init__: enc_stride {enc_stride} does not divide sequence length {sequence_length}")
+
+ self.features = np.reshape(np.fromfile(feature_file, dtype=np.float32), (-1, num_features))
+ self.features = self.features[:, :num_used_features]
+ self.num_sequences = self.features.shape[0] // sequence_length
+
+ def __len__(self):
+ return self.num_sequences
+
+ def __getitem__(self, index):
+ features = self.features[index * self.sequence_length: (index + 1) * self.sequence_length, :]
+ q_ids = np.random.randint(0, self.quant_levels, (1)).astype(np.int64)
+ q_ids = np.repeat(q_ids, self.sequence_length // self.enc_stride, axis=0)
+ rate_lambda = self.lambda_min * np.exp(q_ids.astype(np.float32) / self.denominator).astype(np.float32)
+
+ return features, rate_lambda, q_ids
+
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py
new file mode 100644
index 00000000..b45d2b8c
--- /dev/null
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -0,0 +1,614 @@
+"""
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+""" Pytorch implementations of rate distortion optimized variational autoencoder """
+
+import math as m
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+# Quantization and rate related utily functions
+
+def soft_pvq(x, k):
+ """ soft pyramid vector quantizer """
+
+ # L2 normalization
+ x_norm2 = x / (1e-15 + torch.norm(x, dim=-1, keepdim=True))
+
+
+ with torch.no_grad():
+ # quantization loop, no need to track gradients here
+ x_norm1 = x / torch.sum(torch.abs(x), dim=-1, keepdim=True)
+
+ # set initial scaling factor to k
+ scale_factor = k
+ x_scaled = scale_factor * x_norm1
+ x_quant = torch.round(x_scaled)
+
+ # we aim for ||x_quant||_L1 = k
+ for _ in range(10):
+ # remove signs and calculate L1 norm
+ abs_x_quant = torch.abs(x_quant)
+ abs_x_scaled = torch.abs(x_scaled)
+ l1_x_quant = torch.sum(abs_x_quant, axis=-1)
+
+ # increase, where target is too small and decrease, where target is too large
+ plus = 1.0001 * torch.min((abs_x_quant + 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
+ minus = 0.9999 * torch.max((abs_x_quant - 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
+ factor = torch.where(l1_x_quant > k, minus, plus)
+ factor = torch.where(l1_x_quant == k, torch.ones_like(factor), factor)
+ scale_factor = scale_factor * factor.unsqueeze(-1)
+
+ # update x
+ x_scaled = scale_factor * x_norm1
+ x_quant = torch.round(x_quant)
+
+ # L2 normalization of quantized x
+ x_quant_norm2 = x_quant / (1e-15 + torch.norm(x_quant, dim=-1, keepdim=True))
+ quantization_error = x_quant_norm2 - x_norm2
+
+ return x_norm2 + quantization_error.detach()
+
+def cache_parameters(func):
+ cache = dict()
+ def cached_func(*args):
+ if args in cache:
+ return cache[args]
+ else:
+ cache[args] = func(*args)
+
+ return cache[args]
+ return cached_func
+
+@cache_parameters
+def pvq_codebook_size(n, k):
+
+ if k == 0:
+ return 1
+
+ if n == 0:
+ return 0
+
+ return pvq_codebook_size(n - 1, k) + pvq_codebook_size(n, k - 1) + pvq_codebook_size(n - 1, k - 1)
+
+
+def soft_rate_estimate(z, r, reduce=True):
+ """ rate approximation with dependent theta Eq. (7)"""
+
+ rate = torch.sum(
+ - torch.log2((1 - r)/(1 + r) * r ** torch.abs(z) + 1e-6),
+ dim=-1
+ )
+
+ if reduce:
+ rate = torch.mean(rate)
+
+ return rate
+
+
+def hard_rate_estimate(z, r, theta, reduce=True):
+ """ hard rate approximation """
+
+ z_q = torch.round(z)
+ p0 = 1 - r ** (0.5 + 0.5 * theta)
+ alpha = torch.relu(1 - torch.abs(z_q)) ** 2
+ rate = - torch.sum(
+ (alpha * torch.log2(p0 * r ** torch.abs(z_q) + 1e-6)
+ + (1 - alpha) * torch.log2(0.5 * (1 - p0) * (1 - r) * r ** (torch.abs(z_q) - 1) + 1e-6)),
+ dim=-1
+ )
+
+ if reduce:
+ rate = torch.mean(rate)
+
+ return rate
+
+
+
+def soft_dead_zone(x, dead_zone):
+ """ approximates application of a dead zone to x """
+ d = dead_zone * 0.05
+ return x - d * torch.tanh(x / (0.1 + d))
+
+
+def hard_quantize(x):
+ """ round with copy gradient trick """
+ return x + (torch.round(x) - x).detach()
+
+
+def noise_quantize(x):
+ """ simulates quantization with addition of random uniform noise """
+ return x + (torch.rand_like(x) - 0.5)
+
+
+# loss functions
+
+
+def distortion_loss(y_true, y_pred, rate_lambda=None):
+ """ custom distortion loss for LPCNet features """
+
+ if y_true.size(-1) != 20:
+ raise ValueError('distortion loss is designed to work with 20 features')
+
+ ceps_error = y_pred[..., :18] - y_true[..., :18]
+ pitch_error = 2 * (y_pred[..., 18:19] - y_true[..., 18:19]) / (2 + y_true[..., 18:19])
+ corr_error = y_pred[..., 19:] - y_true[..., 19:]
+ pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2
+
+ loss = torch.mean(ceps_error ** 2 + (10/18) * torch.abs(pitch_error) * pitch_weight + (1/18) * corr_error ** 2, dim=-1)
+
+ if type(rate_lambda) != type(None):
+ loss = loss / torch.sqrt(rate_lambda)
+
+ loss = torch.mean(loss)
+
+ return loss
+
+
+# sampling functions
+
+import random
+
+
+def random_split(start, stop, num_splits=3, min_len=3):
+ get_min_len = lambda x : min([x[i+1] - x[i] for i in range(len(x) - 1)])
+ candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
+
+ while get_min_len(candidate) < min_len:
+ candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
+
+ return candidate
+
+
+
+# weight initialization and clipping
+def init_weights(module):
+
+ if isinstance(module, nn.GRU):
+ for p in module.named_parameters():
+ if p[0].startswith('weight_hh_'):
+ nn.init.orthogonal_(p[1])
+
+
+def weight_clip_factory(max_value):
+ """ weight clipping function concerning sum of abs values of adjecent weights """
+ def clip_weight_(w):
+ stop = w.size(1)
+ # omit last column if stop is odd
+ if stop % 2:
+ stop -= 1
+ max_values = max_value * torch.ones_like(w[:, :stop])
+ factor = max_value / torch.maximum(max_values,
+ torch.repeat_interleave(
+ torch.abs(w[:, :stop:2]) + torch.abs(w[:, 1:stop:2]),
+ 2,
+ 1))
+ with torch.no_grad():
+ w[:, :stop] *= factor
+
+ def clip_weights(module):
+ if isinstance(module, nn.GRU) or isinstance(module, nn.Linear):
+ for name, w in module.named_parameters():
+ if name.startswith('weight'):
+ clip_weight_(w)
+
+ return clip_weights
+
+# RDOVAE module and submodules
+
+
+class CoreEncoder(nn.Module):
+ STATE_HIDDEN = 128
+ FRAMES_PER_STEP = 2
+ CONV_KERNEL_SIZE = 4
+
+ def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24):
+ """ core encoder for RDOVAE
+
+ Computes latents, initial states, and rate estimates from features and lambda parameter
+
+ """
+
+ super(CoreEncoder, self).__init__()
+
+ # hyper parameters
+ self.feature_dim = feature_dim
+ self.output_dim = output_dim
+ self.cond_size = cond_size
+ self.cond_size2 = cond_size2
+ self.state_size = state_size
+
+ # derived parameters
+ self.input_dim = self.FRAMES_PER_STEP * self.feature_dim
+ self.conv_input_channels = 5 * cond_size + 3 * cond_size2
+
+ # layers
+ self.dense_1 = nn.Linear(self.input_dim, self.cond_size2)
+ self.gru_1 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
+ self.dense_2 = nn.Linear(self.cond_size, self.cond_size2)
+ self.gru_2 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
+ self.dense_3 = nn.Linear(self.cond_size, self.cond_size2)
+ self.gru_3 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
+ self.dense_4 = nn.Linear(self.cond_size, self.cond_size)
+ self.dense_5 = nn.Linear(self.cond_size, self.cond_size)
+ self.conv1 = nn.Conv1d(self.conv_input_channels, self.output_dim, kernel_size=self.CONV_KERNEL_SIZE, padding='valid')
+
+ self.state_dense_1 = nn.Linear(self.conv_input_channels, self.STATE_HIDDEN)
+
+ self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size)
+
+ # initialize weights
+ self.apply(init_weights)
+
+
+ def forward(self, features):
+
+ # reshape features
+ x = torch.reshape(features, (features.size(0), features.size(1) // self.FRAMES_PER_STEP, self.FRAMES_PER_STEP * features.size(2)))
+
+ batch = x.size(0)
+ device = x.device
+
+ # run encoding layer stack
+ x1 = torch.tanh(self.dense_1(x))
+ x2, _ = self.gru_1(x1, torch.zeros((1, batch, self.cond_size)).to(device))
+ x3 = torch.tanh(self.dense_2(x2))
+ x4, _ = self.gru_2(x3, torch.zeros((1, batch, self.cond_size)).to(device))
+ x5 = torch.tanh(self.dense_3(x4))
+ x6, _ = self.gru_3(x5, torch.zeros((1, batch, self.cond_size)).to(device))
+ x7 = torch.tanh(self.dense_4(x6))
+ x8 = torch.tanh(self.dense_5(x7))
+
+ # concatenation of all hidden layer outputs
+ x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
+
+ # init state for decoder
+ states = torch.tanh(self.state_dense_1(x9))
+ states = torch.tanh(self.state_dense_2(states))
+
+ # latent representation via convolution
+ x9 = F.pad(x9.permute(0, 2, 1), [self.CONV_KERNEL_SIZE - 1, 0])
+ z = self.conv1(x9).permute(0, 2, 1)
+
+ return z, states
+
+
+
+
+class CoreDecoder(nn.Module):
+
+ FRAMES_PER_STEP = 4
+
+ def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24):
+ """ core decoder for RDOVAE
+
+ Computes features from latents, initial state, and quantization index
+
+ """
+
+ super(CoreDecoder, self).__init__()
+
+ # hyper parameters
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.cond_size = cond_size
+ self.cond_size2 = cond_size2
+ self.state_size = state_size
+
+ self.input_size = self.input_dim
+
+ self.concat_size = 4 * self.cond_size + 4 * self.cond_size2
+
+ # layers
+ self.dense_1 = nn.Linear(self.input_size, cond_size2)
+ self.gru_1 = nn.GRU(cond_size2, cond_size, batch_first=True)
+ self.dense_2 = nn.Linear(cond_size, cond_size2)
+ self.gru_2 = nn.GRU(cond_size2, cond_size, batch_first=True)
+ self.dense_3 = nn.Linear(cond_size, cond_size2)
+ self.gru_3 = nn.GRU(cond_size2, cond_size, batch_first=True)
+ self.dense_4 = nn.Linear(cond_size, cond_size2)
+ self.dense_5 = nn.Linear(cond_size2, cond_size2)
+
+ self.output = nn.Linear(self.concat_size, self.FRAMES_PER_STEP * self.output_dim)
+
+
+ self.gru_1_init = nn.Linear(self.state_size, self.cond_size)
+ self.gru_2_init = nn.Linear(self.state_size, self.cond_size)
+ self.gru_3_init = nn.Linear(self.state_size, self.cond_size)
+
+ # initialize weights
+ self.apply(init_weights)
+
+ def forward(self, z, initial_state):
+
+ gru_1_state = torch.tanh(self.gru_1_init(initial_state).permute(1, 0, 2))
+ gru_2_state = torch.tanh(self.gru_2_init(initial_state).permute(1, 0, 2))
+ gru_3_state = torch.tanh(self.gru_3_init(initial_state).permute(1, 0, 2))
+
+ # run decoding layer stack
+ x1 = torch.tanh(self.dense_1(z))
+ x2, _ = self.gru_1(x1, gru_1_state)
+ x3 = torch.tanh(self.dense_2(x2))
+ x4, _ = self.gru_2(x3, gru_2_state)
+ x5 = torch.tanh(self.dense_3(x4))
+ x6, _ = self.gru_3(x5, gru_3_state)
+ x7 = torch.tanh(self.dense_4(x6))
+ x8 = torch.tanh(self.dense_5(x7))
+ x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
+
+ # output layer and reshaping
+ x10 = self.output(x9)
+ features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP))
+
+ return features
+
+
+class StatisticalModel(nn.Module):
+ def __init__(self, quant_levels, latent_dim):
+ """ Statistical model for latent space
+
+ Computes scaling, deadzone, r, and theta
+
+ """
+
+ super(StatisticalModel, self).__init__()
+
+ # copy parameters
+ self.latent_dim = latent_dim
+ self.quant_levels = quant_levels
+ self.embedding_dim = 6 * latent_dim
+
+ # quantization embedding
+ self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim)
+
+ # initialize embedding to 0
+ with torch.no_grad():
+ self.quant_embedding.weight[:] = 0
+
+
+ def forward(self, quant_ids):
+ """ takes quant_ids and returns statistical model parameters"""
+
+ x = self.quant_embedding(quant_ids)
+
+ # CAVE: theta_soft is not used anymore. Kick it out?
+ quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim])
+ dead_zone = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim])
+ theta_soft = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim])
+ r_soft = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim])
+ theta_hard = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim])
+ r_hard = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim])
+
+
+ return {
+ 'quant_embedding' : x,
+ 'quant_scale' : quant_scale,
+ 'dead_zone' : dead_zone,
+ 'r_hard' : r_hard,
+ 'theta_hard' : theta_hard,
+ 'r_soft' : r_soft,
+ 'theta_soft' : theta_soft
+ }
+
+
+class RDOVAE(nn.Module):
+ def __init__(self,
+ feature_dim,
+ latent_dim,
+ quant_levels,
+ cond_size,
+ cond_size2,
+ state_dim=24,
+ split_mode='split',
+ clip_weights=True,
+ pvq_num_pulses=82,
+ state_dropout_rate=0):
+
+ super(RDOVAE, self).__init__()
+
+ self.feature_dim = feature_dim
+ self.latent_dim = latent_dim
+ self.quant_levels = quant_levels
+ self.cond_size = cond_size
+ self.cond_size2 = cond_size2
+ self.split_mode = split_mode
+ self.state_dim = state_dim
+ self.pvq_num_pulses = pvq_num_pulses
+ self.state_dropout_rate = state_dropout_rate
+
+ # submodules encoder and decoder share the statistical model
+ self.statistical_model = StatisticalModel(quant_levels, latent_dim)
+ self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
+ self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
+
+ self.enc_stride = CoreEncoder.FRAMES_PER_STEP
+ self.dec_stride = CoreDecoder.FRAMES_PER_STEP
+
+ if clip_weights:
+ self.weight_clip_fn = weight_clip_factory(0.496)
+ else:
+ self.weight_clip_fn = None
+
+ if self.dec_stride % self.enc_stride != 0:
+ raise ValueError(f"get_decoder_chunks_generic: encoder stride does not divide decoder stride")
+
+ def clip_weights(self):
+ if not type(self.weight_clip_fn) == type(None):
+ self.apply(self.weight_clip_fn)
+
+ def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
+
+ enc_stride = self.enc_stride
+ dec_stride = self.dec_stride
+
+ stride = dec_stride // enc_stride
+
+ chunks = []
+
+ for offset in range(stride):
+ # start is the smalles number = offset mod stride that decodes to a valid range
+ start = offset
+ while enc_stride * (start + 1) - dec_stride < 0:
+ start += stride
+
+ # check if start is a valid index
+ if start >= z_frames:
+ raise ValueError("get_decoder_chunks_generic: range too small")
+
+ # stop is the smallest number outside [0, num_enc_frames] that's congruent to offset mod stride
+ stop = z_frames - (z_frames % stride) + offset
+ while stop < z_frames:
+ stop += stride
+
+ # calculate split points
+ length = (stop - start)
+ if mode == 'split':
+ split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop]
+ elif mode == 'random_split':
+ split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)]
+ else:
+ raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}")
+
+
+ for i in range(chunks_per_offset):
+ # (enc_frame_start, enc_frame_stop, enc_frame_stride, stride, feature_frame_start, feature_frame_stop)
+ # encoder range(i, j, stride) maps to feature range(enc_stride * (i + 1) - dec_stride, enc_stride * j)
+ # provided that i - j = 1 mod stride
+ chunks.append({
+ 'z_start' : split_points[i],
+ 'z_stop' : split_points[i + 1] - stride + 1,
+ 'z_stride' : stride,
+ 'features_start' : enc_stride * (split_points[i] + 1) - dec_stride,
+ 'features_stop' : enc_stride * (split_points[i + 1] - stride + 1)
+ })
+
+ return chunks
+
+
+ def forward(self, features, q_id):
+
+ # calculate statistical model from quantization ID
+ statistical_model = self.statistical_model(q_id)
+
+ # run encoder
+ z, states = self.core_encoder(features)
+
+ # scaling, dead-zone and quantization
+ z = z * statistical_model['quant_scale']
+ z = soft_dead_zone(z, statistical_model['dead_zone'])
+
+ # quantization
+ z_q = hard_quantize(z) / statistical_model['quant_scale']
+ z_n = noise_quantize(z) / statistical_model['quant_scale']
+ states_q = soft_pvq(states, self.pvq_num_pulses)
+
+ if self.state_dropout_rate > 0:
+ drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
+ mask = torch.ones_like(states_q)
+ mask[drop] = 0
+ states_q = states_q * mask
+
+ # decoder
+ chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode)
+
+ outputs_hq = []
+ outputs_sq = []
+ for chunk in chunks:
+ # decoder with hard quantized input
+ z_dec_reverse = torch.flip(z_q[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
+ dec_initial_state = states_q[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
+ features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
+ outputs_hq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
+
+
+ # decoder with soft quantized input
+ z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
+ features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
+ outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
+
+ return {
+ 'outputs_hard_quant' : outputs_hq,
+ 'outputs_soft_quant' : outputs_sq,
+ 'z' : z,
+ 'statistical_model' : statistical_model
+ }
+
+ def encode(self, features):
+ """ encoder with quantization and rate estimation """
+
+ z, states = self.core_encoder(features)
+
+ # quantization of initial states
+ states = soft_pvq(states, self.pvq_num_pulses)
+ state_size = m.log2(pvq_codebook_size(self.state_dim, self.pvq_num_pulses))
+
+ return z, states, state_size
+
+ def decode(self, z, initial_state):
+ """ decoder (flips sequences by itself) """
+
+ z_reverse = torch.flip(z, [1])
+ features_reverse = self.core_decoder(z_reverse, initial_state)
+ features = torch.flip(features_reverse, [1])
+
+ return features
+
+ def quantize(self, z, q_ids):
+ """ quantization of latent vectors """
+
+ stats = self.statistical_model(q_ids)
+
+ zq = z * stats['quant_scale']
+ zq = soft_dead_zone(zq, stats['dead_zone'])
+ zq = torch.round(zq)
+
+ sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False)
+
+ return zq, sizes
+
+ def unquantize(self, zq, q_ids):
+ """ re-scaling of latent vector """
+
+ stats = self.statistical_model(q_ids)
+
+ z = zq / stats['quant_scale']
+
+ return z
+
+ def freeze_model(self):
+
+ # freeze all parameters
+ for p in self.parameters():
+ p.requires_grad = False
+
+ for p in self.statistical_model.parameters():
+ p.requires_grad = True
+
diff --git a/dnn/torch/rdovae/requirements.txt b/dnn/torch/rdovae/requirements.txt
new file mode 100644
index 00000000..668c8462
--- /dev/null
+++ b/dnn/torch/rdovae/requirements.txt
@@ -0,0 +1,5 @@
+numpy
+scipy
+torch
+tqdm
+libs/wexchange-1.0-py3-none-any.whl
\ No newline at end of file
diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py
new file mode 100644
index 00000000..68ccf2eb
--- /dev/null
+++ b/dnn/torch/rdovae/train_rdovae.py
@@ -0,0 +1,270 @@
+"""
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+
+import torch
+import tqdm
+
+from rdovae import RDOVAE, RDOVAEDataset, distortion_loss, hard_rate_estimate, soft_rate_estimate
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('features', type=str, help='path to feature file in .f32 format')
+parser.add_argument('output', type=str, help='path to output folder')
+
+parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="")
+
+
+model_group = parser.add_argument_group(title="model parameters")
+model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
+model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
+model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
+model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
+model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 16", default=16)
+model_group.add_argument('--lambda-min', type=float, help="minimal value for rate lambda, default: 0.0002", default=2e-4)
+model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104)
+model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82)
+model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0)
+
+training_group = parser.add_argument_group(title="training parameters")
+training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
+training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4)
+training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100)
+training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by 4, default: 256', default=256)
+training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5)
+training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split')
+training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames')
+training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
+training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only')
+
+args = parser.parse_args()
+
+# set visible devices
+os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
+
+# checkpoints
+checkpoint_dir = os.path.join(args.output, 'checkpoints')
+checkpoint = dict()
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+# training parameters
+batch_size = args.batch_size
+lr = args.lr
+epochs = args.epochs
+sequence_length = args.sequence_length
+lr_decay_factor = args.lr_decay_factor
+split_mode = args.split_mode
+# not exposed
+adam_betas = [0.9, 0.99]
+adam_eps = 1e-8
+
+checkpoint['batch_size'] = batch_size
+checkpoint['lr'] = lr
+checkpoint['lr_decay_factor'] = lr_decay_factor
+checkpoint['split_mode'] = split_mode
+checkpoint['epochs'] = epochs
+checkpoint['sequence_length'] = sequence_length
+checkpoint['adam_betas'] = adam_betas
+
+# logging
+log_interval = 10
+
+# device
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+# model parameters
+cond_size = args.cond_size
+cond_size2 = args.cond_size2
+latent_dim = args.latent_dim
+quant_levels = args.quant_levels
+lambda_min = args.lambda_min
+lambda_max = args.lambda_max
+state_dim = args.state_dim
+# not expsed
+num_features = 20
+
+
+# training data
+feature_file = args.features
+
+# model
+checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
+checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate}
+model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+
+if type(args.initial_checkpoint) != type(None):
+ checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
+
+checkpoint['state_dict'] = model.state_dict()
+
+if args.train_decoder_only:
+ if args.initial_checkpoint is None:
+ print("warning: training decoder only without providing initial checkpoint")
+
+ for p in model.core_encoder.module.parameters():
+ p.requires_grad = False
+
+ for p in model.statistical_model.parameters():
+ p.requires_grad = False
+
+# dataloader
+checkpoint['dataset_args'] = (feature_file, sequence_length, num_features, 36)
+checkpoint['dataset_kwargs'] = {'lambda_min': lambda_min, 'lambda_max': lambda_max, 'enc_stride': model.enc_stride, 'quant_levels': quant_levels}
+dataset = RDOVAEDataset(*checkpoint['dataset_args'], **checkpoint['dataset_kwargs'])
+dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
+
+
+
+# optimizer
+params = [p for p in model.parameters() if p.requires_grad]
+optimizer = torch.optim.Adam(params, lr=lr, betas=adam_betas, eps=adam_eps)
+
+
+# learning rate scheduler
+scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
+
+if __name__ == '__main__':
+
+ # push model to device
+ model.to(device)
+
+ # training loop
+
+ for epoch in range(1, epochs + 1):
+
+ print(f"training epoch {epoch}...")
+
+ # running stats
+ running_rate_loss = 0
+ running_soft_dist_loss = 0
+ running_hard_dist_loss = 0
+ running_hard_rate_loss = 0
+ running_soft_rate_loss = 0
+ running_total_loss = 0
+ running_rate_metric = 0
+ previous_total_loss = 0
+ running_first_frame_loss = 0
+
+ with tqdm.tqdm(dataloader, unit='batch') as tepoch:
+ for i, (features, rate_lambda, q_ids) in enumerate(tepoch):
+
+ # zero out gradients
+ optimizer.zero_grad()
+
+ # push inputs to device
+ features = features.to(device)
+ q_ids = q_ids.to(device)
+ rate_lambda = rate_lambda.to(device)
+
+
+ rate_lambda_upsamp = torch.repeat_interleave(rate_lambda, 2, 1)
+
+ # run model
+ model_output = model(features, q_ids)
+
+ # collect outputs
+ z = model_output['z']
+ outputs_hard_quant = model_output['outputs_hard_quant']
+ outputs_soft_quant = model_output['outputs_soft_quant']
+ statistical_model = model_output['statistical_model']
+
+ # rate loss
+ hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False)
+ soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False)
+ soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate)
+ hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate)
+ rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
+ hard_rate_metric = torch.mean(hard_rate)
+
+ ## distortion losses
+
+ # hard quantized decoder input
+ distortion_loss_hard_quant = torch.zeros_like(rate_loss)
+ for dec_features, start, stop in outputs_hard_quant:
+ distortion_loss_hard_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_hard_quant)
+
+ first_frame_loss = torch.zeros_like(rate_loss)
+ for dec_features, start, stop in outputs_hard_quant:
+ first_frame_loss += distortion_loss(features[..., stop-4 : stop, :], dec_features[..., -4:, :], rate_lambda_upsamp[..., stop - 4 : stop]) / len(outputs_hard_quant)
+
+ # soft quantized decoder input
+ distortion_loss_soft_quant = torch.zeros_like(rate_loss)
+ for dec_features, start, stop in outputs_soft_quant:
+ distortion_loss_soft_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_soft_quant)
+
+ # total loss
+ total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2
+
+ if args.enable_first_frame_loss:
+ total_loss = total_loss + 0.5 * torch.relu(first_frame_loss - distortion_loss_hard_quant)
+
+
+ total_loss.backward()
+
+ optimizer.step()
+
+ model.clip_weights()
+
+ scheduler.step()
+
+ # collect running stats
+ running_hard_dist_loss += float(distortion_loss_hard_quant.detach().cpu())
+ running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu())
+ running_rate_loss += float(rate_loss.detach().cpu())
+ running_rate_metric += float(hard_rate_metric.detach().cpu())
+ running_total_loss += float(total_loss.detach().cpu())
+ running_first_frame_loss += float(first_frame_loss.detach().cpu())
+ running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
+ running_hard_rate_loss += float(hard_rate_loss.detach().cpu())
+
+ if (i + 1) % log_interval == 0:
+ current_loss = (running_total_loss - previous_total_loss) / log_interval
+ tepoch.set_postfix(
+ current_loss=current_loss,
+ total_loss=running_total_loss / (i + 1),
+ dist_hq=running_hard_dist_loss / (i + 1),
+ dist_sq=running_soft_dist_loss / (i + 1),
+ rate_loss=running_rate_loss / (i + 1),
+ rate=running_rate_metric / (i + 1),
+ ffloss=running_first_frame_loss / (i + 1),
+ rateloss_hard=running_hard_rate_loss / (i + 1),
+ rateloss_soft=running_soft_rate_loss / (i + 1)
+ )
+ previous_total_loss = running_total_loss
+
+ # save checkpoint
+ checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
+ checkpoint['state_dict'] = model.state_dict()
+ checkpoint['loss'] = running_total_loss / len(dataloader)
+ checkpoint['epoch'] = epoch
+ torch.save(checkpoint, checkpoint_path)