/*----------------------------------------------------------------------------*/
/* Sequential and Multithreaded Matrix Multiply Function Definitions          */
/*                                                                            */
/* Written by: John Thornley, Computer Science Dept., Caltech.                */
/* Date: October, 1997.                                                       */
/*                                                                            */
/* Algorithm:                                                                 */
/* - Horizontal stripes with j-i-k loop ordering within each stripe.          */
/* - Stripe size chosen to fit within L2 cache.                               */
/* - Stripes assigned to threads with blocked mapping.                        */
/*                                                                            */
/* Copyright (c) 1997 by John Thornley.                                       */
/*----------------------------------------------------------------------------*/

#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <sthreads.h>
#include "matrix_multiply.h"

/*----------------------------------------------------------------------------*/
/* L2 cache size (used to decide matrix stripe size)                          */
/*----------------------------------------------------------------------------*/

#define CACHE_SIZE 524288

/*----------------------------------------------------------------------------*/
/* Handle errors from Sthreads function calls                                 */
/*----------------------------------------------------------------------------*/

void handle_sthread_error(int error_code)
{
    char error_message[100];

    if (error_code != STHREAD_ERROR_NONE) {
        sthread_sprint_error(error_message, error_code);
        fprintf(stderr, "Sthread error: %s\n", error_message);
        exit(EXIT_FAILURE);
    }
}

/*----------------------------------------------------------------------------*/
/* Miscellaneous utilities                                                    */
/*----------------------------------------------------------------------------*/

#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define MAX(x, y) ((x) > (y) ? (x) : (y))

/*----------------------------------------------------------------------------*/
/* Split 0 .. n - 1 into even-sized parts and return first index of part p    */ 
/*----------------------------------------------------------------------------*/

static int split_range(int n, int num_parts, int p)
{
    return (n/num_parts)*p + MIN(p, n%num_parts);
}

/*----------------------------------------------------------------------------*/
/* Sequential matrix multiply                                                 */
/*----------------------------------------------------------------------------*/

void seq_matrix_multiply(
        int n, 
        const float left[N][N], const float right[N][N], 
        float product[N][N])
{
    int rows_in_cache, stripe_size;
    int s, num_stripes;
    int first_i, last_i;
    int i, j, k;
    float sum;

    assert(1 <= n && n <= N);

    rows_in_cache = (CACHE_SIZE/sizeof(float))/(n*sizeof(float));
    stripe_size = MIN(n, MAX(1, rows_in_cache - 1));
    num_stripes = (n + stripe_size - 1)/stripe_size;
    for (s = 0; s < num_stripes; s++) {
        first_i = split_range(n, num_stripes, s);
        last_i = split_range(n, num_stripes, s + 1) - 1;
        for (j = 0; j < n; j++)
            for (i = first_i; i <= last_i; i++) {
                sum = 0.0;
                for (k = 0; k < n; k++)
                    sum = sum + left[i][k]*right[k][j];
                product[i][j] = sum;
            }
    }
}
            
/*----------------------------------------------------------------------------*/
/* Multithreaded matrix multiply (pragma version)                             */
/*----------------------------------------------------------------------------*/
/*
void multi_matrix_multiply(
        int n, 
        const float left[N][N], const float right[N][N], 
        float product[N][N],
        int t)
{
    int rows_in_cache, stripe_size;
    int s, num_stripes;

    assert(1 <= n && n <= N);
    assert(t >= 1);

    rows_in_cache = (CACHE_SIZE/sizeof(float))/(n*sizeof(float));
    stripe_size = MIN(n, MAX(1, rows_in_cache - 1));
    num_stripes = MAX(t, (((n + stripe_size - 1)/stripe_size + t - 1)/t)*t);
    #pragma multithreadable mapping(blocked(t))
    for (s = 0; s < num_stripes; s++) {
        int first_i, last_i;
        int i, j, k;
        float sum;

        first_i = split_range(n, num_stripes, s);
        last_i = split_range(n, num_stripes, s + 1) - 1;
        for (j = 0; j < n; j++)
            for (i = first_i; i <= last_i; i++) {
                sum = 0.0;
                for (k = 0; k < n; k++)
                    sum = sum + left[i][k]*right[k][j];
                product[i][j] = sum;
            }
    }
}
*/
/*----------------------------------------------------------------------------*/
/* Multithreaded matrix multiply (Sthreads version)                           */
/*----------------------------------------------------------------------------*/

typedef struct {
    int n;
    const float (*left)[N], (*right)[N];
    float (*product)[N];
    int num_stripes;
} loop_args;

static void inner_loops(int s, int last, int step, loop_args *args)
{
    int first_i, last_i;
    int i, j, k;
    float sum;

    assert(last == s && step == 1);

    first_i = split_range(args->n, args->num_stripes, s);
    last_i = split_range(args->n, args->num_stripes, s + 1) - 1;
    for (j = 0; j < args->n; j++)
        for (i = first_i; i <= last_i; i++) {
            sum = 0.0;
            for (k = 0; k < args->n; k++)
                sum = sum + args->left[i][k]*args->right[k][j];
            args->product[i][j] = sum;
        }
}

void multi_matrix_multiply(
        int n, 
        const float left[N][N], const float right[N][N], 
        float product[N][N],
        int t)
{
    int rows_in_cache, stripe_size;
    int num_stripes;
    loop_args args;
    int error_code;

    assert(1 <= n && n <= N);
    assert(t >= 1);

    rows_in_cache = (CACHE_SIZE/sizeof(float))/(n*sizeof(float));
    stripe_size = MIN(n, MAX(1, rows_in_cache - 1));
    num_stripes = MAX(t, (((n + stripe_size - 1)/stripe_size + t - 1)/t)*t);
    args.n = n; 
    args.left = left; args.right = right; args.product = product; 
    args.num_stripes = num_stripes;
    error_code = sthread_regular_for_loop(
        (void (*)(int, int, int, void *)) inner_loops, (void *) &args,
        0, STHREAD_CONDITION_LT, num_stripes, 1,
        1, STHREAD_MAPPING_BLOCKED, t,
        STHREAD_PRIORITY_PARENT, STHREAD_STACK_SIZE_DEFAULT);
    handle_sthread_error(error_code);
}

/*----------------------------------------------------------------------------*/