mirror of
https://github.com/xiph/opus.git
synced 2025-06-02 00:27:43 +00:00
Add loading for LinearLayer
Untested
This commit is contained in:
parent
587c1020fe
commit
9d40e5cb08
2 changed files with 53 additions and 0 deletions
|
@ -113,6 +113,48 @@ static const void *find_idx_check(const WeightArray *arrays, const char *name, i
|
|||
return a->data;
|
||||
}
|
||||
|
||||
int linear_init(LinearLayer *layer, const WeightArray *arrays,
|
||||
const char *bias,
|
||||
const char *subias,
|
||||
const char *weights,
|
||||
const char *float_weights,
|
||||
const char *weights_idx,
|
||||
const char *diag,
|
||||
const char *scale,
|
||||
int nb_inputs,
|
||||
int nb_outputs)
|
||||
{
|
||||
int total_blocks;
|
||||
if ((layer->bias = find_array_check(arrays, bias, nb_outputs*sizeof(layer->bias[0]))) == NULL) return 1;
|
||||
if ((layer->subias = find_array_check(arrays, subias, nb_outputs*sizeof(layer->subias[0]))) == NULL) return 1;
|
||||
layer->weights = NULL;
|
||||
layer->float_weights = NULL;
|
||||
layer->weights_idx = NULL;
|
||||
if (weights_idx != NULL) {
|
||||
if ((layer->weights_idx = find_idx_check(arrays, weights_idx, nb_outputs, nb_inputs, &total_blocks)) == NULL) return 1;
|
||||
}
|
||||
if (weights_idx != NULL) {
|
||||
if (weights != NULL) {
|
||||
if ((layer->weights = find_array_check(arrays, weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->weights[0]))) == NULL) return 1;
|
||||
}
|
||||
if (float_weights != NULL) {
|
||||
if ((layer->float_weights = find_array_check(arrays, float_weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->float_weights[0]))) == NULL) return 1;
|
||||
}
|
||||
} else {
|
||||
if (weights != NULL) {
|
||||
if ((layer->weights = find_array_check(arrays, weights, nb_inputs*nb_outputs*sizeof(layer->weights[0]))) == NULL) return 1;
|
||||
}
|
||||
if (float_weights != NULL) {
|
||||
if ((layer->float_weights = find_array_check(arrays, float_weights, nb_inputs*nb_outputs*sizeof(layer->float_weights[0]))) == NULL) return 1;
|
||||
}
|
||||
}
|
||||
if ((layer->diag = find_array_check(arrays, diag, nb_outputs*sizeof(layer->diag[0]))) == NULL) return 1;
|
||||
if ((layer->scale = find_array_check(arrays, scale, nb_outputs*sizeof(layer->scale[0]))) == NULL) return 1;
|
||||
layer->nb_inputs = nb_inputs;
|
||||
layer->nb_outputs = nb_outputs;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int mdense_init(MDenseLayer *layer, const WeightArray *arrays,
|
||||
const char *bias,
|
||||
const char *input_weights,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue