Add loading for LinearLayer

Untested
This commit is contained in:
Jean-Marc Valin 2023-07-23 14:21:21 -04:00
parent 587c1020fe
commit 9d40e5cb08
No known key found for this signature in database
GPG key ID: 531A52533318F00A
2 changed files with 53 additions and 0 deletions

View file

@ -113,6 +113,48 @@ static const void *find_idx_check(const WeightArray *arrays, const char *name, i
return a->data;
}
int linear_init(LinearLayer *layer, const WeightArray *arrays,
const char *bias,
const char *subias,
const char *weights,
const char *float_weights,
const char *weights_idx,
const char *diag,
const char *scale,
int nb_inputs,
int nb_outputs)
{
int total_blocks;
if ((layer->bias = find_array_check(arrays, bias, nb_outputs*sizeof(layer->bias[0]))) == NULL) return 1;
if ((layer->subias = find_array_check(arrays, subias, nb_outputs*sizeof(layer->subias[0]))) == NULL) return 1;
layer->weights = NULL;
layer->float_weights = NULL;
layer->weights_idx = NULL;
if (weights_idx != NULL) {
if ((layer->weights_idx = find_idx_check(arrays, weights_idx, nb_outputs, nb_inputs, &total_blocks)) == NULL) return 1;
}
if (weights_idx != NULL) {
if (weights != NULL) {
if ((layer->weights = find_array_check(arrays, weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->weights[0]))) == NULL) return 1;
}
if (float_weights != NULL) {
if ((layer->float_weights = find_array_check(arrays, float_weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->float_weights[0]))) == NULL) return 1;
}
} else {
if (weights != NULL) {
if ((layer->weights = find_array_check(arrays, weights, nb_inputs*nb_outputs*sizeof(layer->weights[0]))) == NULL) return 1;
}
if (float_weights != NULL) {
if ((layer->float_weights = find_array_check(arrays, float_weights, nb_inputs*nb_outputs*sizeof(layer->float_weights[0]))) == NULL) return 1;
}
}
if ((layer->diag = find_array_check(arrays, diag, nb_outputs*sizeof(layer->diag[0]))) == NULL) return 1;
if ((layer->scale = find_array_check(arrays, scale, nb_outputs*sizeof(layer->scale[0]))) == NULL) return 1;
layer->nb_inputs = nb_inputs;
layer->nb_outputs = nb_outputs;
return 0;
}
int mdense_init(MDenseLayer *layer, const WeightArray *arrays,
const char *bias,
const char *input_weights,