From 3a4754853683555759f684e20374e9e51058be9c Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Wed, 10 Nov 2021 17:58:51 -0500 Subject: [PATCH] Using KISS99 (taken from Daala) as RNG --- dnn/Makefile.am | 2 ++ dnn/kiss99.c | 68 ++++++++++++++++++++++++++++++++++++++++++++ dnn/kiss99.h | 42 +++++++++++++++++++++++++++ dnn/lpcnet.c | 8 ++++-- dnn/lpcnet_private.h | 2 ++ dnn/nnet.c | 10 +++++-- dnn/nnet.h | 3 +- 7 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 dnn/kiss99.c create mode 100644 dnn/kiss99.h diff --git a/dnn/Makefile.am b/dnn/Makefile.am index 1a0f9bf4..2f48e613 100644 --- a/dnn/Makefile.am +++ b/dnn/Makefile.am @@ -13,6 +13,7 @@ noinst_HEADERS = arch.h \ freq.h \ _kiss_fft_guts.h \ kiss_fft.h \ + kiss99.h \ lpcnet_private.h \ opus_types.h \ nnet_data.h \ @@ -25,6 +26,7 @@ noinst_HEADERS = arch.h \ liblpcnet_la_SOURCES = \ common.c \ + kiss99.c \ lpcnet.c \ lpcnet_dec.c \ lpcnet_enc.c \ diff --git a/dnn/kiss99.c b/dnn/kiss99.c new file mode 100644 index 00000000..a9e4fe55 --- /dev/null +++ b/dnn/kiss99.c @@ -0,0 +1,68 @@ +/*Daala video codec +Copyright (c) 2012 Daala project contributors. All rights reserved. +Author: Timothy B. Terriberry + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.*/ + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "kiss99.h" + +void kiss99_srand(kiss99_ctx *_this,const unsigned char *_data,int _ndata){ + int i; + _this->z=362436069; + _this->w=521288629; + _this->jsr=123456789; + _this->jcong=380116160; + for(i=3;i<_ndata;i+=4){ + _this->z^=_data[i-3]; + _this->w^=_data[i-2]; + _this->jsr^=_data[i-1]; + _this->jcong^=_data[i]; + kiss99_rand(_this); + } + if(i-3<_ndata)_this->z^=_data[i-3]; + if(i-2<_ndata)_this->w^=_data[i-2]; + if(i-1<_ndata)_this->jsr^=_data[i-1]; +} + +uint32_t kiss99_rand(kiss99_ctx *_this){ + uint32_t znew; + uint32_t wnew; + uint32_t mwc; + uint32_t shr3; + uint32_t cong; + znew=36969*(_this->z&0xFFFF)+(_this->z>>16); + wnew=18000*(_this->w&0xFFFF)+(_this->w>>16); + mwc=(znew<<16)+wnew; + shr3=_this->jsr^(_this->jsr<<17); + shr3^=shr3>>13; + shr3^=shr3<<5; + cong=69069*_this->jcong+1234567; + _this->z=znew; + _this->w=wnew; + _this->jsr=shr3; + _this->jcong=cong; + return (mwc^cong)+shr3; +} diff --git a/dnn/kiss99.h b/dnn/kiss99.h new file mode 100644 index 00000000..4fb368c8 --- /dev/null +++ b/dnn/kiss99.h @@ -0,0 +1,42 @@ +/*Daala video codec +Copyright (c) 2012 Daala project contributors. All rights reserved. +Author: Timothy B. Terriberry + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.*/ + +#if !defined(_kiss99_H) +# define _kiss99_H (1) +# include + +typedef struct kiss99_ctx kiss99_ctx; + +struct kiss99_ctx{ + uint32_t z; + uint32_t w; + uint32_t jsr; + uint32_t jcong; +}; + +void kiss99_srand(kiss99_ctx *_this,const unsigned char *_data,int _ndata); +uint32_t kiss99_rand(kiss99_ctx *_this); + +#endif diff --git a/dnn/lpcnet.c b/dnn/lpcnet.c index 520fd67a..5fd451ec 100644 --- a/dnn/lpcnet.c +++ b/dnn/lpcnet.c @@ -116,7 +116,7 @@ void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b if (lpcnet->frame_count < 1000) lpcnet->frame_count++; } -int run_sample_network(NNetState *net, const float *gru_a_condition, const float *gru_b_condition, int last_exc, int last_sig, int pred, const float *sampling_logit_table) +int run_sample_network(NNetState *net, const float *gru_a_condition, const float *gru_b_condition, int last_exc, int last_sig, int pred, const float *sampling_logit_table, kiss99_ctx *rng) { float gru_a_input[3*GRU_A_STATE_SIZE]; float in_b[GRU_A_STATE_SIZE+FEATURE_DENSE2_OUT_SIZE]; @@ -134,7 +134,7 @@ int run_sample_network(NNetState *net, const float *gru_a_condition, const float RNN_COPY(in_b, net->gru_a_state, GRU_A_STATE_SIZE); RNN_COPY(gru_b_input, gru_b_condition, 3*GRU_B_STATE_SIZE); compute_gruB(&gru_b, gru_b_input, net->gru_b_state, in_b); - return sample_mdense(&dual_fc, net->gru_b_state, sampling_logit_table); + return sample_mdense(&dual_fc, net->gru_b_state, sampling_logit_table, rng); } LPCNET_EXPORT int lpcnet_get_size() @@ -145,12 +145,14 @@ LPCNET_EXPORT int lpcnet_get_size() LPCNET_EXPORT int lpcnet_init(LPCNetState *lpcnet) { int i; + const char* rng_string="LPCNet"; memset(lpcnet, 0, lpcnet_get_size()); lpcnet->last_exc = lin2ulaw(0.f); for (i=0;i<256;i++) { float prob = .025+.95*i/255.; lpcnet->sampling_logit_table[i] = -log((1-prob)/prob); } + kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string)); return 0; } @@ -193,7 +195,7 @@ LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features, for (j=0;jlast_sig[j]*lpc[j]; last_sig_ulaw = lin2ulaw(lpcnet->last_sig[0]); pred_ulaw = lin2ulaw(pred); - exc = run_sample_network(&lpcnet->nnet, gru_a_condition, gru_b_condition, lpcnet->last_exc, last_sig_ulaw, pred_ulaw, lpcnet->sampling_logit_table); + exc = run_sample_network(&lpcnet->nnet, gru_a_condition, gru_b_condition, lpcnet->last_exc, last_sig_ulaw, pred_ulaw, lpcnet->sampling_logit_table, &lpcnet->rng); pcm = pred + ulaw2lin(exc); RNN_MOVE(&lpcnet->last_sig[1], &lpcnet->last_sig[0], LPC_ORDER-1); lpcnet->last_sig[0] = pcm; diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index 65a2784e..c7e88f5b 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -7,6 +7,7 @@ #include "lpcnet.h" #include "nnet_data.h" #include "celt_lpc.h" +#include "kiss99.h" #define BITS_PER_CHAR 8 @@ -33,6 +34,7 @@ struct LPCNetState { float sampling_logit_table[256]; int frame_count; float deemph_mem; + kiss99_ctx rng; }; struct LPCNetDecState { diff --git a/dnn/nnet.c b/dnn/nnet.c index 7f4914c4..7a74c00b 100644 --- a/dnn/nnet.c +++ b/dnn/nnet.c @@ -141,7 +141,7 @@ void compute_mdense(const MDenseLayer *layer, float *output, const float *input) compute_activation(output, output, N, layer->activation); } -int sample_mdense(const MDenseLayer *layer, const float *input, const float *sampling_logit_table) +int sample_mdense(const MDenseLayer *layer, const float *input, const float *sampling_logit_table, kiss99_ctx *rng) { int b, j, N, M, C, stride; M = layer->nb_inputs; @@ -156,7 +156,13 @@ int sample_mdense(const MDenseLayer *layer, const float *input, const float *sam /* Computing all the random thresholds in advance. These thresholds are directly based on the logit to avoid computing the sigmoid.*/ - for (b=0;b<8;b++) thresholds[b] = sampling_logit_table[rand()&0xFF]; + for (b=0;b<8;b+=4) { + uint32_t val = kiss99_rand(rng); + thresholds[b] = sampling_logit_table[val&0xFF]; + thresholds[b+1] = sampling_logit_table[(val>>8)&0xFF]; + thresholds[b+2] = sampling_logit_table[(val>>16)&0xFF]; + thresholds[b+3] = sampling_logit_table[(val>>24)&0xFF]; + } for (b=0;b<8;b++) { diff --git a/dnn/nnet.h b/dnn/nnet.h index e0504e53..456c3266 100644 --- a/dnn/nnet.h +++ b/dnn/nnet.h @@ -29,6 +29,7 @@ #define _NNET_H_ #include "vec.h" +#include "kiss99.h" #define ACTIVATION_LINEAR 0 #define ACTIVATION_SIGMOID 1 @@ -98,7 +99,7 @@ void compute_dense(const DenseLayer *layer, float *output, const float *input); void compute_mdense(const MDenseLayer *layer, float *output, const float *input); -int sample_mdense(const MDenseLayer *layer, const float *input, const float *sampling_logit_table); +int sample_mdense(const MDenseLayer *layer, const float *input, const float *sampling_logit_table, kiss99_ctx *rng); void compute_gru(const GRULayer *gru, float *state, const float *input);