Using KISS99 (taken from Daala) as RNG

This commit is contained in:
Jean-Marc Valin 2021-11-10 17:58:51 -05:00
parent 81229a7412
commit 3a47548536
7 changed files with 129 additions and 6 deletions

View file

@ -13,6 +13,7 @@ noinst_HEADERS = arch.h \
freq.h \ freq.h \
_kiss_fft_guts.h \ _kiss_fft_guts.h \
kiss_fft.h \ kiss_fft.h \
kiss99.h \
lpcnet_private.h \ lpcnet_private.h \
opus_types.h \ opus_types.h \
nnet_data.h \ nnet_data.h \
@ -25,6 +26,7 @@ noinst_HEADERS = arch.h \
liblpcnet_la_SOURCES = \ liblpcnet_la_SOURCES = \
common.c \ common.c \
kiss99.c \
lpcnet.c \ lpcnet.c \
lpcnet_dec.c \ lpcnet_dec.c \
lpcnet_enc.c \ lpcnet_enc.c \

68
dnn/kiss99.c Normal file
View file

@ -0,0 +1,68 @@
/*Daala video codec
Copyright (c) 2012 Daala project contributors. All rights reserved.
Author: Timothy B. Terriberry
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
- Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS AS IS
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "kiss99.h"
void kiss99_srand(kiss99_ctx *_this,const unsigned char *_data,int _ndata){
int i;
_this->z=362436069;
_this->w=521288629;
_this->jsr=123456789;
_this->jcong=380116160;
for(i=3;i<_ndata;i+=4){
_this->z^=_data[i-3];
_this->w^=_data[i-2];
_this->jsr^=_data[i-1];
_this->jcong^=_data[i];
kiss99_rand(_this);
}
if(i-3<_ndata)_this->z^=_data[i-3];
if(i-2<_ndata)_this->w^=_data[i-2];
if(i-1<_ndata)_this->jsr^=_data[i-1];
}
uint32_t kiss99_rand(kiss99_ctx *_this){
uint32_t znew;
uint32_t wnew;
uint32_t mwc;
uint32_t shr3;
uint32_t cong;
znew=36969*(_this->z&0xFFFF)+(_this->z>>16);
wnew=18000*(_this->w&0xFFFF)+(_this->w>>16);
mwc=(znew<<16)+wnew;
shr3=_this->jsr^(_this->jsr<<17);
shr3^=shr3>>13;
shr3^=shr3<<5;
cong=69069*_this->jcong+1234567;
_this->z=znew;
_this->w=wnew;
_this->jsr=shr3;
_this->jcong=cong;
return (mwc^cong)+shr3;
}

42
dnn/kiss99.h Normal file
View file

@ -0,0 +1,42 @@
/*Daala video codec
Copyright (c) 2012 Daala project contributors. All rights reserved.
Author: Timothy B. Terriberry
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
- Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS AS IS
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.*/
#if !defined(_kiss99_H)
# define _kiss99_H (1)
# include <stdint.h>
typedef struct kiss99_ctx kiss99_ctx;
struct kiss99_ctx{
uint32_t z;
uint32_t w;
uint32_t jsr;
uint32_t jcong;
};
void kiss99_srand(kiss99_ctx *_this,const unsigned char *_data,int _ndata);
uint32_t kiss99_rand(kiss99_ctx *_this);
#endif

View file

@ -116,7 +116,7 @@ void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b
if (lpcnet->frame_count < 1000) lpcnet->frame_count++; if (lpcnet->frame_count < 1000) lpcnet->frame_count++;
} }
int run_sample_network(NNetState *net, const float *gru_a_condition, const float *gru_b_condition, int last_exc, int last_sig, int pred, const float *sampling_logit_table) int run_sample_network(NNetState *net, const float *gru_a_condition, const float *gru_b_condition, int last_exc, int last_sig, int pred, const float *sampling_logit_table, kiss99_ctx *rng)
{ {
float gru_a_input[3*GRU_A_STATE_SIZE]; float gru_a_input[3*GRU_A_STATE_SIZE];
float in_b[GRU_A_STATE_SIZE+FEATURE_DENSE2_OUT_SIZE]; float in_b[GRU_A_STATE_SIZE+FEATURE_DENSE2_OUT_SIZE];
@ -134,7 +134,7 @@ int run_sample_network(NNetState *net, const float *gru_a_condition, const float
RNN_COPY(in_b, net->gru_a_state, GRU_A_STATE_SIZE); RNN_COPY(in_b, net->gru_a_state, GRU_A_STATE_SIZE);
RNN_COPY(gru_b_input, gru_b_condition, 3*GRU_B_STATE_SIZE); RNN_COPY(gru_b_input, gru_b_condition, 3*GRU_B_STATE_SIZE);
compute_gruB(&gru_b, gru_b_input, net->gru_b_state, in_b); compute_gruB(&gru_b, gru_b_input, net->gru_b_state, in_b);
return sample_mdense(&dual_fc, net->gru_b_state, sampling_logit_table); return sample_mdense(&dual_fc, net->gru_b_state, sampling_logit_table, rng);
} }
LPCNET_EXPORT int lpcnet_get_size() LPCNET_EXPORT int lpcnet_get_size()
@ -145,12 +145,14 @@ LPCNET_EXPORT int lpcnet_get_size()
LPCNET_EXPORT int lpcnet_init(LPCNetState *lpcnet) LPCNET_EXPORT int lpcnet_init(LPCNetState *lpcnet)
{ {
int i; int i;
const char* rng_string="LPCNet";
memset(lpcnet, 0, lpcnet_get_size()); memset(lpcnet, 0, lpcnet_get_size());
lpcnet->last_exc = lin2ulaw(0.f); lpcnet->last_exc = lin2ulaw(0.f);
for (i=0;i<256;i++) { for (i=0;i<256;i++) {
float prob = .025+.95*i/255.; float prob = .025+.95*i/255.;
lpcnet->sampling_logit_table[i] = -log((1-prob)/prob); lpcnet->sampling_logit_table[i] = -log((1-prob)/prob);
} }
kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string));
return 0; return 0;
} }
@ -193,7 +195,7 @@ LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features,
for (j=0;j<LPC_ORDER;j++) pred -= lpcnet->last_sig[j]*lpc[j]; for (j=0;j<LPC_ORDER;j++) pred -= lpcnet->last_sig[j]*lpc[j];
last_sig_ulaw = lin2ulaw(lpcnet->last_sig[0]); last_sig_ulaw = lin2ulaw(lpcnet->last_sig[0]);
pred_ulaw = lin2ulaw(pred); pred_ulaw = lin2ulaw(pred);
exc = run_sample_network(&lpcnet->nnet, gru_a_condition, gru_b_condition, lpcnet->last_exc, last_sig_ulaw, pred_ulaw, lpcnet->sampling_logit_table); exc = run_sample_network(&lpcnet->nnet, gru_a_condition, gru_b_condition, lpcnet->last_exc, last_sig_ulaw, pred_ulaw, lpcnet->sampling_logit_table, &lpcnet->rng);
pcm = pred + ulaw2lin(exc); pcm = pred + ulaw2lin(exc);
RNN_MOVE(&lpcnet->last_sig[1], &lpcnet->last_sig[0], LPC_ORDER-1); RNN_MOVE(&lpcnet->last_sig[1], &lpcnet->last_sig[0], LPC_ORDER-1);
lpcnet->last_sig[0] = pcm; lpcnet->last_sig[0] = pcm;

View file

@ -7,6 +7,7 @@
#include "lpcnet.h" #include "lpcnet.h"
#include "nnet_data.h" #include "nnet_data.h"
#include "celt_lpc.h" #include "celt_lpc.h"
#include "kiss99.h"
#define BITS_PER_CHAR 8 #define BITS_PER_CHAR 8
@ -33,6 +34,7 @@ struct LPCNetState {
float sampling_logit_table[256]; float sampling_logit_table[256];
int frame_count; int frame_count;
float deemph_mem; float deemph_mem;
kiss99_ctx rng;
}; };
struct LPCNetDecState { struct LPCNetDecState {

View file

@ -141,7 +141,7 @@ void compute_mdense(const MDenseLayer *layer, float *output, const float *input)
compute_activation(output, output, N, layer->activation); compute_activation(output, output, N, layer->activation);
} }
int sample_mdense(const MDenseLayer *layer, const float *input, const float *sampling_logit_table) int sample_mdense(const MDenseLayer *layer, const float *input, const float *sampling_logit_table, kiss99_ctx *rng)
{ {
int b, j, N, M, C, stride; int b, j, N, M, C, stride;
M = layer->nb_inputs; M = layer->nb_inputs;
@ -156,7 +156,13 @@ int sample_mdense(const MDenseLayer *layer, const float *input, const float *sam
/* Computing all the random thresholds in advance. These thresholds are directly /* Computing all the random thresholds in advance. These thresholds are directly
based on the logit to avoid computing the sigmoid.*/ based on the logit to avoid computing the sigmoid.*/
for (b=0;b<8;b++) thresholds[b] = sampling_logit_table[rand()&0xFF]; for (b=0;b<8;b+=4) {
uint32_t val = kiss99_rand(rng);
thresholds[b] = sampling_logit_table[val&0xFF];
thresholds[b+1] = sampling_logit_table[(val>>8)&0xFF];
thresholds[b+2] = sampling_logit_table[(val>>16)&0xFF];
thresholds[b+3] = sampling_logit_table[(val>>24)&0xFF];
}
for (b=0;b<8;b++) for (b=0;b<8;b++)
{ {

View file

@ -29,6 +29,7 @@
#define _NNET_H_ #define _NNET_H_
#include "vec.h" #include "vec.h"
#include "kiss99.h"
#define ACTIVATION_LINEAR 0 #define ACTIVATION_LINEAR 0
#define ACTIVATION_SIGMOID 1 #define ACTIVATION_SIGMOID 1
@ -98,7 +99,7 @@ void compute_dense(const DenseLayer *layer, float *output, const float *input);
void compute_mdense(const MDenseLayer *layer, float *output, const float *input); void compute_mdense(const MDenseLayer *layer, float *output, const float *input);
int sample_mdense(const MDenseLayer *layer, const float *input, const float *sampling_logit_table); int sample_mdense(const MDenseLayer *layer, const float *input, const float *sampling_logit_table, kiss99_ctx *rng);
void compute_gru(const GRULayer *gru, float *state, const float *input); void compute_gru(const GRULayer *gru, float *state, const float *input);