lossgen: better training, README.md

This commit is contained in:
Jean-Marc Valin 2023-12-21 18:01:57 -05:00
parent c40add59af
commit b923fd1e28
No known key found for this signature in database
GPG key ID: 531A52533318F00A
2 changed files with 35 additions and 5 deletions

View file

@ -0,0 +1,27 @@
#Packet loss simulator
This code is an attempt at simulating better packet loss scenarios. The most common way of simulating
packet loss is to use a random sequence where each packet loss event is uncorrelated with previous events.
That is a simplistic model since we know that losses often occur in bursts. This model uses real data
to build a generative model for packet loss.
We use the training data provided for the Audio Deep Packet Loss Concealment Challenge, which is available at:
http://plcchallenge2022pub.blob.core.windows.net/plcchallengearchive/test\_train.tar.gz
To create the training data, run:
`./process_data.sh /<path>/test_train/train/lossy_signals/`
That will create an ascii loss\_sorted.txt file with all loss data sorted in increasing packet loss
percentage. Then just run:
`python ./train_lossgen.py`
to train a model
To generate a sequence, run
`python3 ./test_lossgen.py <checkpoint> <percentage> output.txt --length 10000`
where <checkpoint> is the .pth model file and <percentage> is the amount of loss (e.g. 0.2 for 20% loss).

View file

@ -27,9 +27,11 @@ class LossDataset(torch.utils.data.Dataset):
return self.nb_sequences return self.nb_sequences
def __getitem__(self, index): def __getitem__(self, index):
r0 = np.random.normal(scale=.02, size=(1,1)).astype('float32') r0 = np.random.normal(scale=.1, size=(1,1)).astype('float32')
r1 = np.random.normal(scale=.02, size=(self.sequence_length,1)).astype('float32') r1 = np.random.normal(scale=.1, size=(self.sequence_length,1)).astype('float32')
return [self.loss[index, :, :], self.perc[index, :, :]+r0+r1] perc = self.perc[index, :, :]
perc = perc + (r0+r1)*perc*(1-perc)
return [self.loss[index, :, :], perc]
adam_betas = [0.8, 0.98] adam_betas = [0.8, 0.98]
@ -61,7 +63,7 @@ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lam
if __name__ == '__main__': if __name__ == '__main__':
model.to(device) model.to(device)
states = None
for epoch in range(1, epochs + 1): for epoch in range(1, epochs + 1):
running_loss = 0 running_loss = 0
@ -73,7 +75,8 @@ if __name__ == '__main__':
loss = loss.to(device) loss = loss.to(device)
perc = perc.to(device) perc = perc.to(device)
out, _ = model(loss, perc) out, states = model(loss, perc, states=states)
states = [state.detach() for state in states]
out = torch.sigmoid(out[:,:-1,:]) out = torch.sigmoid(out[:,:-1,:])
target = loss[:,1:,:] target = loss[:,1:,:]