diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py index 5fb94bec..a3f3e999 100644 --- a/dnn/torch/rdovae/rdovae/rdovae.py +++ b/dnn/torch/rdovae/rdovae/rdovae.py @@ -624,6 +624,8 @@ class RDOVAE(nn.Module): split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop] elif mode == 'random_split': split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)] + elif mode == 'skewed_split': + split_points = [start + stride * int(i * length / 4 / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop] else: raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}") diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py index 0e899713..444c5738 100644 --- a/dnn/torch/rdovae/train_rdovae.py +++ b/dnn/torch/rdovae/train_rdovae.py @@ -63,7 +63,7 @@ training_group.add_argument('--epochs', type=int, help='number of training epoch training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by chunks_per_offset, default: 400', default=400) training_group.add_argument('--chunks-per-offset', type=int, help='chunks per offset', default=4) training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5) -training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split') +training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split', 'skewed_split'], help='splitting mode for decoder input, default: split', default='split') training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames') training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None) training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only')