diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py index 450cc7b4..5fb94bec 100644 --- a/dnn/torch/rdovae/rdovae/rdovae.py +++ b/dnn/torch/rdovae/rdovae/rdovae.py @@ -551,6 +551,7 @@ class RDOVAE(nn.Module): cond_size2, state_dim=24, split_mode='split', + chunks_per_offset=4, clip_weights=False, pvq_num_pulses=82, state_dropout_rate=0, @@ -564,6 +565,7 @@ class RDOVAE(nn.Module): self.cond_size = cond_size self.cond_size2 = cond_size2 self.split_mode = split_mode + self.chunks_per_offset = chunks_per_offset self.state_dim = state_dim self.pvq_num_pulses = pvq_num_pulses self.state_dropout_rate = state_dropout_rate @@ -670,7 +672,7 @@ class RDOVAE(nn.Module): states_q = states_q * mask # decoder - chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode) + chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode, chunks_per_offset=self.chunks_per_offset) outputs_hq = [] outputs_sq = [] diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py index 543e326f..71c8a656 100644 --- a/dnn/torch/rdovae/train_rdovae.py +++ b/dnn/torch/rdovae/train_rdovae.py @@ -60,7 +60,8 @@ training_group = parser.add_argument_group(title="training parameters") training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32) training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4) training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100) -training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by 4, default: 256', default=256) +training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by chunks_per_offset, default: 400', default=400) +training_group.add_argument('--chunks-per-offset', type=int, help='chunks per offset', default=4) training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5) training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split') training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames') @@ -120,7 +121,7 @@ feature_file = args.features # model checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2) -checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate, 'softquant': softquant} +checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate, 'softquant': softquant, 'chunks_per_offset': args.chunks_per_offset} model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) if type(args.initial_checkpoint) != type(None):