mirror of
https://github.com/xiph/opus.git
synced 2025-05-25 12:49:12 +00:00
added missing dead-zone to encode_rdovae.py
This commit is contained in:
parent
be42c3b514
commit
01baf1a0fc
1 changed files with 7 additions and 0 deletions
|
@ -48,6 +48,8 @@ args = parser.parse_args()
|
||||||
import importlib
|
import importlib
|
||||||
rdovae = importlib.import_module(args.model)
|
rdovae = importlib.import_module(args.model)
|
||||||
|
|
||||||
|
from rdovae import apply_dead_zone
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.keras.optimizers import Adam
|
from tensorflow.keras.optimizers import Adam
|
||||||
|
@ -105,10 +107,15 @@ print(gru_state_dec.shape)
|
||||||
|
|
||||||
features.astype('float32').tofile(args.output + "-input.f32")
|
features.astype('float32').tofile(args.output + "-input.f32")
|
||||||
#quant_out.astype('float32').tofile(args.output + "-enc_dec.f32")
|
#quant_out.astype('float32').tofile(args.output + "-enc_dec.f32")
|
||||||
|
nbits=80
|
||||||
|
dead_zone = tf.math.softplus(quant_embed_dec[:, :, nbits : 2 * nbits])
|
||||||
|
symbols = apply_dead_zone([bits, dead_zone]).numpy()
|
||||||
np.round(bits).astype('int16').tofile(args.output + "-bits.s16")
|
np.round(bits).astype('int16').tofile(args.output + "-bits.s16")
|
||||||
quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32")
|
quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32")
|
||||||
|
|
||||||
gru_state_dec = gru_state_dec[:,-1,:]
|
gru_state_dec = gru_state_dec[:,-1,:]
|
||||||
dec_out = decoder([bits[:,1::2,:], quant_embed_dec[:,1::2,:], gru_state_dec])
|
dec_out = decoder([bits[:,1::2,:], quant_embed_dec[:,1::2,:], gru_state_dec])
|
||||||
|
|
||||||
|
print(dec_out.shape)
|
||||||
|
|
||||||
dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32")
|
dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue