changed checkpoint format

This commit is contained in:
Jan Buethe 2023-09-26 14:35:36 +02:00 committed by Jean-Marc Valin
parent 733a095ba2
commit 41a4c9515d
No known key found for this signature in database
GPG key ID: 531A52533318F00A
4 changed files with 29 additions and 132 deletions

View file

@ -120,31 +120,9 @@ def rpa(model,device = 'cpu',data_format = 'if'):
cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int')
# if (model == 'penn'):
# model_frequency, _ = penn.from_audio(
# torch.from_numpy(audio).unsqueeze(0).float(),
# 16000,
# hopsize=0.01,
# fmin=(16000.0/256),
# fmax=500,
# checkpoint=penn.DEFAULT_CHECKPOINT,
# batch_size=32,
# pad=True,
# interp_unvoiced_at=0.065,
# gpu=0)
# model_frequency = model_frequency.cpu().detach().squeeze().numpy()
# model_cents = 1200*np.log2(model_frequency/(16000/256))
# elif (model == 'crepe'):
# _, model_frequency, _, _ = crepe.predict(audio, 16000, viterbi=vflag,center=True,verbose=0)
# lpcnet_file_name = '/home/ubuntu/Code/Datasets/SPEECH_DATA/lpcnet_f0_16k_residual/' + file_name + '_f0.f32'
# period_lpcnet = np.fromfile(lpcnet_file_name, dtype='float32')
# model_frequency = 16000/(period_lpcnet + 1.0e-6)
# model_cents = 1200*np.log2(model_frequency/(16000/256))
# else:
model_cents = model(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device))
model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy()
# model_cents = np.roll(model_cents,-1*3)
num_frames = min(cent.shape[0],model_cents.shape[0])
pitch = pitch[:num_frames]
@ -158,131 +136,62 @@ def rpa(model,device = 'cpu',data_format = 'if'):
voicing_all[force_out_of_pitch] = 0
C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
# list_rca_model_all.append(sweep_rca(cent,model_cents,voicing_all,thresh,[0]))
list_rca_model_all.append(rca(cent,model_cents,voicing_all,thresh))
# list_rca_model_all.append(np.count_nonzero(np.where(np.abs(cent - model_cents))))
if "mic_M" in audio_file:
# list_rca_male_all.append(sweep_rca(cent,model_cents,voicing_all,thresh,[0]))
list_rca_male_all.append(rca(cent,model_cents,voicing_all,thresh))
C_all_m = C_all_m + np.where(voicing_all != 0)[0].shape[0]
else:
# list_rca_female_all.append(sweep_rca(cent,model_cents,voicing_all,thresh,[0]))
list_rca_female_all.append(rca(cent,model_cents,voicing_all,thresh))
C_all_f = C_all_f + np.where(voicing_all != 0)[0].shape[0]
"""
# Low pitch estimation
voicing_lp = np.copy(voicing)
force_out_of_pitch = np.where(np.logical_or(pitch < 65,pitch > 125)==True)
voicing_lp[force_out_of_pitch] = 0
C_lp = C_lp + np.where(voicing_lp != 0)[0].shape[0]
# list_rca_model_lp.append(sweep_rca(cent,model_cents,voicing_lp,thresh,[0]))
list_rca_model_lp.append(rca(cent,model_cents,voicing_lp,thresh))
if "mic_M" in audio_file:
# list_rca_male_lp.append(sweep_rca(cent,model_cents,voicing_lp,thresh,[0]))
list_rca_male_lp.append(rca(cent,model_cents,voicing_lp,thresh))
C_lp_m = C_lp_m + np.where(voicing_lp != 0)[0].shape[0]
else:
# list_rca_female_lp.append(sweep_rca(cent,model_cents,voicing_lp,thresh,[0]))
list_rca_female_lp.append(rca(cent,model_cents,voicing_lp,thresh))
C_lp_f = C_lp_f + np.where(voicing_lp != 0)[0].shape[0]
# High pitch estimation
voicing_hp = np.copy(voicing)
force_out_of_pitch = np.where(np.logical_or(pitch < 125,pitch > 500)==True)
voicing_hp[force_out_of_pitch] = 0
C_hp = C_hp + np.where(voicing_hp != 0)[0].shape[0]
# list_rca_model_hp.append(sweep_rca(cent,model_cents,voicing_hp,thresh,[0]))
list_rca_model_hp.append(rca(cent,model_cents,voicing_hp,thresh))
if "mic_M" in audio_file:
# list_rca_male_hp.append(sweep_rca(cent,model_cents,voicing_hp,thresh,[0]))
list_rca_male_hp.append(rca(cent,model_cents,voicing_hp,thresh))
C_hp_m = C_hp_m + np.where(voicing_hp != 0)[0].shape[0]
else:
# list_rca_female_hp.append(sweep_rca(cent,model_cents,voicing_hp,thresh,[0]))
list_rca_female_hp.append(rca(cent,model_cents,voicing_hp,thresh))
C_hp_f = C_hp_f + np.where(voicing_hp != 0)[0].shape[0]
# list_rca_model.append(acc_model)
# list_rca_crepe.append(acc_crepe)
# list_rca_lpcnet.append(acc_lpcnet)
# list_rca_penn.append(acc_penn)
"""
# list_rca_crepe = np.array(list_rca_crepe)
# list_rca_model_lp = np.array(list_rca_model_lp)
# list_rca_male_lp = np.array(list_rca_male_lp)
# list_rca_female_lp = np.array(list_rca_female_lp)
# list_rca_model_hp = np.array(list_rca_model_hp)
# list_rca_male_hp = np.array(list_rca_male_hp)
# list_rca_female_hp = np.array(list_rca_female_hp)
list_rca_model_all = np.array(list_rca_model_all)
list_rca_male_all = np.array(list_rca_male_all)
list_rca_female_all = np.array(list_rca_female_all)
# list_rca_lpcnet = np.array(list_rca_lpcnet)
# list_rca_penn = np.array(list_rca_penn)
x = PrettyTable()
x.field_names = ["Experiment", "Mean RPA"]
x.add_row(["Both all pitches", np.sum(list_rca_model_all)/C_all])
# x.add_row(["Both low pitches", np.sum(list_rca_model_lp)/C_lp])
# x.add_row(["Both high pitches", np.sum(list_rca_model_hp)/C_hp])
x.add_row(["Male all pitches", np.sum(list_rca_male_all)/C_all_m])
# x.add_row(["Male low pitches", np.sum(list_rca_male_lp)/C_lp_m])
# x.add_row(["Male high pitches", np.sum(list_rca_male_hp)/C_hp_m])
x.add_row(["Female all pitches", np.sum(list_rca_female_all)/C_all_f])
# x.add_row(["Female low pitches", np.sum(list_rca_female_lp)/C_lp_f])
# x.add_row(["Female high pitches", np.sum(list_rca_female_hp)/C_hp_f])
print(x)
return None
def cycle_eval(list_files_pth, noise_type = 'synthetic', noise_dataset = None, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = None,fraction = 0.1,thresh = 50):
def cycle_eval(checkpoint_list, noise_type = 'synthetic', noise_dataset = None, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = None,fraction = 0.1,thresh = 50):
"""
Cycle through SNR evaluation for list of .pth files
Cycle through SNR evaluation for list of checkpoints
"""
# list_files = glob.glob('/home/ubuntu/Code/Datasets/SPEECH DATA/combined_mic_16k_raw/*.raw')
# dir_f0 = '/home/ubuntu/Code/Datasets/SPEECH DATA/combine_f0_ptdb/'
# random_shuffle = list(np.random.permutation(len(list_files)))
list_files = glob.glob(ptdb_dataset_path + 'combined_mic_16k/*.raw')
dir_f0 = ptdb_dataset_path + 'combined_reference_f0/'
random.shuffle(list_files)
list_files = list_files[:(int)(fraction*len(list_files))]
# list_nfiles = ['DKITCHEN','NFIELD','OHALLWAY','PCAFETER','SPSQUARE','TCAR','DLIVING','NPARK','OMEETING','PRESTO','STRAFFIC','TMETRO','DWASHING','NRIVER','OOFFICE','PSTATION','TBUS']
dict_models = {}
list_snr.append(np.inf)
# thresh = 50
for f in list_files_pth:
for f in checkpoint_list:
if (f!='crepe') and (f!='lpcnet'):
fname = os.path.basename(f).split('_')[0] + '_' + os.path.basename(f).split('_')[-1][:-4]
config_path = os.path.dirname(f) + '/' + os.path.basename(f).split('_')[0] + '_' + 'config_' + os.path.basename(f).split('_')[-1][:-4] + '.json'
with open(config_path) as json_file:
dict_params = json.load(json_file)
checkpoint = torch.load(f, map_location='cpu')
dict_params = checkpoint['config']
if dict_params['data_format'] == 'if':
from models import large_if_ccode as model
pitch_nn = model(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim']).to(device)
pitch_nn = model(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim'])
elif dict_params['data_format'] == 'xcorr':
from models import large_xcorr as model
pitch_nn = model(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim']).to(device)
pitch_nn = model(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
else:
from models import large_joint as model
pitch_nn = model(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim']).to(device)
pitch_nn = model(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
pitch_nn.load_state_dict(torch.load(f))
pitch_nn.load_state_dict(checkpoint['state_dict'])
N = dict_params['window_size']
H = dict_params['hop_factor']
@ -356,15 +265,8 @@ def cycle_eval(list_files_pth, noise_type = 'synthetic', noise_dataset = None, l
cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int')
# if os.path.basename(f) == 'crepe':
# elif (model == 'crepe'):
# _, model_frequency, _, _ = crepe.predict(np.concatenate([np.zeros(80),audio]), 16000, viterbi=True,center=True,verbose=0)
# model_cents = 1200*np.log2(model_frequency/(16000/256))
# else:
# else:
model_cents = pitch_nn(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device))
model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy()
# model_cents = np.roll(model_cents,-1*3)
num_frames = min(cent.shape[0],model_cents.shape[0])
pitch = pitch[:num_frames]
@ -378,9 +280,7 @@ def cycle_eval(list_files_pth, noise_type = 'synthetic', noise_dataset = None, l
voicing_all[force_out_of_pitch] = 0
C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
# list_rca_model_all.append(sweep_rca(cent,model_cents,voicing_all,thresh,[0]))
C_correct = C_correct + rca(cent,model_cents,voicing_all,thresh)
# list_rca_model_all.append(np.count_nonzero(np.where(np.abs(cent - model_cents))))
list_mean.append(C_correct/C_all)
else:
fname = f
@ -453,9 +353,7 @@ def cycle_eval(list_files_pth, noise_type = 'synthetic', noise_dataset = None, l
voicing_all[force_out_of_pitch] = 0
C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
# list_rca_model_all.append(sweep_rca(cent,model_cents,voicing_all,thresh,[0]))
C_correct = C_correct + rca(cent,model_cents,voicing_all,thresh)
# list_rca_model_all.append(np.count_nonzero(np.where(np.abs(cent - model_cents))))
list_mean.append(C_correct/C_all)
dict_models[fname] = {}
dict_models[fname]['list_SNR'] = list_mean[:-1]