From 9791b22b2c83980f6b4386c870cad58557c78007 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Thu, 22 Nov 2018 14:06:34 -0500 Subject: Refactoring: Isolating the matrix-vector product in gemm_accum() --- src/mlp.c | 75 +++++++++++++++++++++++++++++++++------------------------------ 1 file changed, 39 insertions(+), 36 deletions(-) (limited to 'src') diff --git a/src/mlp.c b/src/mlp.c index f43a704e..964c6a98 100644 --- a/src/mlp.c +++ b/src/mlp.c @@ -69,22 +69,29 @@ static OPUS_INLINE float sigmoid_approx(float x) return .5f + .5f*tansig_approx(.5f*x); } -void compute_dense(const DenseLayer *layer, float *output, const float *input) +static void gemm_accum(float *out, const opus_int8 *weights, int rows, int cols, int col_stride, const float *x) { int i, j; + for (i=0;inb_inputs; N = layer->nb_neurons; stride = N; for (i=0;ibias[i]; - for (j=0;jinput_weights[j*stride + i]*input[j]; - output[i] = WEIGHTS_SCALE*sum; - } + output[i] = layer->bias[i]; + gemm_accum(output, layer->input_weights, N, M, stride, input); + for (i=0;isigmoid) { for (i=0;inb_inputs; N = gru->nb_neurons; stride = 3*N; + /* Compute update gate. */ for (i=0;ibias[i]; - for (j=0;jinput_weights[j*stride + i]*input[j]; - for (j=0;jrecurrent_weights[j*stride + i]*state[j]; - z[i] = sigmoid_approx(WEIGHTS_SCALE*sum); - } + z[i] = gru->bias[i]; + gemm_accum(z, gru->input_weights, N, M, stride, input); + gemm_accum(z, gru->recurrent_weights, N, N, stride, state); for (i=0;ibias[N + i]; - for (j=0;jinput_weights[N + j*stride + i]*input[j]; - for (j=0;jrecurrent_weights[N + j*stride + i]*state[j]; - r[i] = sigmoid_approx(WEIGHTS_SCALE*sum); - } + z[i] = sigmoid_approx(WEIGHTS_SCALE*z[i]); + + /* Compute reset gate. */ for (i=0;ibias[2*N + i]; - for (j=0;jinput_weights[2*N + j*stride + i]*input[j]; - for (j=0;jrecurrent_weights[2*N + j*stride + i]*state[j]*r[j]; - h[i] = z[i]*state[i] + (1-z[i])*tansig_approx(WEIGHTS_SCALE*sum); - } + r[i] = gru->bias[N + i]; + gemm_accum(r, &gru->input_weights[N], N, M, stride, input); + gemm_accum(r, &gru->recurrent_weights[N], N, N, stride, state); + for (i=0;ibias[2*N + i]; + for (i=0;iinput_weights[2*N], N, M, stride, input); + gemm_accum(h, &gru->recurrent_weights[2*N], N, N, stride, tmp); + for (i=0;i