/*----------------------------------------------------------------------------*/
/* Main Program to Test and Time                                              */
/* Sequential and Multithreaded Matrix Multiply                               */
/*                                                                            */
/* Written by: John Thornley, Computer Science Dept., Caltech.                */
/* Date: October, 1997.                                                       */
/*                                                                            */
/* Usage: COMMAND -test -seq                                                  */
/*        COMMAND -test -multi                                                */
/*        COMMAND -time -seq n seed trials                                    */
/*        COMMAND -time -breaks -seq n seed trials                            */
/*        COMMAND -time -multi n t p seed trials                              */
/*        COMMAND -time -breaks -multi n t p seed trials                      */
/* where: n       = size of matrices (>= 1)                                   */
/*        t       = number of threads (>= 1)                                  */
/*        p       = number of processors (>= 1)                               */
/*        seed    = seed for random generation of data items                  */
/*        trials  = number of timing trials (>= 1)                            */
/*                                                                            */
/* Copyright (c) 1997 by John Thornley.                                       */
/*----------------------------------------------------------------------------*/

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <assert.h>
#include <time.h>
#include <bool.h>
#include <timing.h>
#include <multiprocessing.h>
#include "matrix_multiply.h"

/*----------------------------------------------------------------------------*/
/* Constants                                                                  */
/*----------------------------------------------------------------------------*/

static const float A1[]  = 
    {2.0};
static const float B1[]  = 
    {3.0};
static const float AB1[] = 
    {6.0};

static const float A2[]  = 
    {1.0, 2.0, 
     3.0, 4.0};
static const float B2[]  = 
    {4.0, 3.0,
     2.0, 1.0};
static const float AB2[] = 
    { 8.0,  5.0,
     20.0, 13.0};

static const float A5[]  = 
    {1.0, 2.0, 3.0, 4.0, 5.0,
     0.0, 1.0, 2.0, 3.0, 4.0,
     5.0, 0.0, 1.0, 2.0, 3.0,
     4.0, 5.0, 0.0, 1.0, 2.0,
     3.0, 4.0, 5.0, 0.0, 1.0};
static const float B5[]  = 
    {5.0, 4.0, 3.0, 2.0, 1.0,
     0.0, 5.0, 4.0, 3.0, 2.0,
     1.0, 0.0, 5.0, 4.0, 3.0,
     2.0, 1.0, 0.0, 5.0, 4.0,
     3.0, 2.0, 1.0, 0.0, 5.0};
static const float AB5[] = 
    {31.0, 28.0, 31.0, 40.0, 55.0,
     20.0, 16.0, 18.0, 26.0, 40.0,
     39.0, 28.0, 23.0, 24.0, 31.0,
     28.0, 46.0, 34.0, 28.0, 28.0,
     23.0, 34.0, 51.0, 38.0, 31.0};

static const float A8[]  = 
    {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
     0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
     8.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
     7.0, 8.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
     6.0, 7.0, 8.0, 0.0, 1.0, 2.0, 3.0, 4.0,
     5.0, 6.0, 7.0, 8.0, 0.0, 1.0, 2.0, 3.0,
     4.0, 5.0, 6.0, 7.0, 8.0, 0.0, 1.0, 2.0,
     3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 0.0, 1.0};
static const float B8[]  = 
    {8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0,
     0.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0,
     1.0, 0.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0,
     2.0, 1.0, 0.0, 8.0, 7.0, 6.0, 5.0, 4.0,
     3.0, 2.0, 1.0, 0.0, 8.0, 7.0, 6.0, 5.0,
     4.0, 3.0, 2.0, 1.0, 0.0, 8.0, 7.0, 6.0,
     5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 8.0, 7.0,
     6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 8.0};
static const float AB8[] = 
    {141.0, 123.0, 114.0, 114.0, 123.0, 141.0, 168.0, 204.0,
     112.0,  93.0,  83.0,  82.0,  90.0, 107.0, 133.0, 168.0,
     155.0, 126.0, 106.0,  95.0,  93.0, 100.0, 116.0, 141.0,
     126.0, 168.0, 138.0, 117.0, 105.0, 102.0, 108.0, 123.0,
     106.0, 138.0, 179.0, 148.0, 126.0, 113.0, 109.0, 114.0,
      95.0, 117.0, 148.0, 188.0, 156.0, 133.0, 119.0, 114.0,
      93.0, 105.0, 126.0, 156.0, 195.0, 162.0, 138.0, 123.0,
     100.0, 102.0, 113.0, 133.0, 162.0, 200.0, 166.0, 141.0};

/*----------------------------------------------------------------------------*/
/* Miscelleanous matrix functions                                             */
/*----------------------------------------------------------------------------*/

static void assign_matrix(int n, float m[N][N], const float elements[])
{
    int i, j;

    assert(1 <= n && n <= N);

    for (i = 0; i < n; i++)
        for (j = 0; j < n; j++)
            m[i][j] = elements[i*n + j];
}

/*----------------------------------------------------------------------------*/

static void print_matrix(int n, const float m[N][N])
{
    int i, j;

    assert(1 <= n && n <= N);

    printf("{");
    for (i = 0; i < n; i++) {
        printf("{");
        for (j = 0; j < n; j++) {
            printf("%3.1f", m[i][j]);
            if (j < n - 1) printf(", ");
        } 
        printf("}");
        if (i < n - 1) printf(",\n ");
    }
    printf("}");
}

/*----------------------------------------------------------------------------*/

static bool equal_matrices(
        int n, const float left[N][N], const float right[N][N])
{
    bool equal;
    int i, j;

    assert(1 <= n && n <= N);

    equal = true;
    for (i = 0; equal && i < n; i++)
        for (j = 0; equal && j < n; j++)
            equal = (left[i][j] == right[i][j]);
    return equal;
}

/*----------------------------------------------------------------------------*/
/* Declaration of matrix variables                                            */
/*----------------------------------------------------------------------------*/

static float left[N][N], right[N][N], product[N][N], expected[N][N];

/*----------------------------------------------------------------------------*/
/* Test sequential matrix multiply                                            */
/*----------------------------------------------------------------------------*/

static void test_seq(
        int n, const float left_elements[], const float right_elements[],
        const float expected_elements[], bool *passed)
{
    assert(1 <= n && n <= N);

    assign_matrix(n, left, left_elements);
    assign_matrix(n, right, right_elements);
    assign_matrix(n, expected, expected_elements);
    printf("Multiplying %ix%i matrices:\n", n, n);
    printf("Left =\n");
    print_matrix(n, left); printf("\n");
    printf("Right =\n");
    print_matrix(n, right); printf("\n");
    seq_matrix_multiply(n, left, right, product);
    printf("Product =\n");
    print_matrix(n, product); printf("\n");
    if (equal_matrices(n, product, expected)) {
        *passed = true;
        printf("Passed test\n");
    } else {
        *passed = false;
        printf("Expected =\n");
        print_matrix(n, expected); printf("\n");
        printf(">>>>> FAILED TEST <<<<<\n");
    }
    printf("\n");
}

/*----------------------------------------------------------------------------*/

static void test_seq_matrix_multiply(void)
{
    bool passed, passed_all;

    printf("TESTING SEQUENTIAL MATRIX MULTIPLY\n");
    printf("\n");

    passed_all = true;

    test_seq(1, A1, B1, AB1, &passed);
    passed_all = passed_all && passed;

    test_seq(2, A2, B2, AB2, &passed);
    passed_all = passed_all && passed;

    test_seq(5, A5, B5, AB5, &passed);
    passed_all = passed_all && passed;

    test_seq(8, A8, B8, AB8, &passed);
    passed_all = passed_all && passed;

    if (passed_all)
        printf("PASSED ALL TESTS.\n");
    else
        printf("FAILED SOME TESTS.\n");
}

/*----------------------------------------------------------------------------*/
/* Test multithreaded matrix multiply                                         */
/*----------------------------------------------------------------------------*/

static void test_multi(int n,
        const float left_elements[], const float right_elements[],
        const float expected_elements[], bool *passed)
{
    int t;

    assert(1 <= n && n <= N);

    assign_matrix(n, left, left_elements);
    assign_matrix(n, right, right_elements);
    assign_matrix(n, expected, expected_elements);
    *passed = true;
    for (t = 1; t <= n + 1; t = (t < 4 ? t + 1 : 2*t)) {
        printf("Multiplying %ix%i matrices (t = %i):\n", n, n, t);
        printf("Left =\n");
        print_matrix(n, left); printf("\n");
        printf("Right =\n");
        print_matrix(n, right); printf("\n");
        multi_matrix_multiply(n, left, right, product, t);
        printf("Product =\n");
        print_matrix(n, product); printf("\n");
        if (equal_matrices(n, product, expected))
            printf("Passed test\n");
        else {
            *passed = false;
            printf("Expected =\n");
            print_matrix(n, expected); printf("\n");
            printf(">>>>> FAILED TEST <<<<<\n");
        }
        printf("\n");
    }
}
/*----------------------------------------------------------------------------*/

static void test_multi_matrix_multiply(void)
{
    bool passed, passed_all;

    printf("TESTING MULTITHREADED MATRIX MULTIPLY\n");
    printf("\n");

    passed_all = true;

    test_multi(1, A1, B1, AB1, &passed);
    passed_all = passed_all && passed;

    test_multi(2, A2, B2, AB2, &passed);
    passed_all = passed_all && passed;

    test_multi(5, A5, B5, AB5, &passed);
    passed_all = passed_all && passed;

    test_multi(8, A8, B8, AB8, &passed);
    passed_all = passed_all && passed;

    if (passed_all)
        printf("PASSED ALL TESTS.\n");
    else
        printf("FAILED SOME TESTS.\n");
}

/*----------------------------------------------------------------------------*/
/* Miscellaneous functions used in timing                                     */
/*----------------------------------------------------------------------------*/

static void print_current_time(void)
{
    time_t calendar_time;

    calendar_time = time(NULL);
    printf("%s", asctime(localtime(&calendar_time)));
}

/*----------------------------------------------------------------------------*/

static void sort(int n, double data[])
{
    int i, j, k;
    double temp;

    assert(n >= 0);

    for (i = 1; i < n; i++) {
        temp = data[i];
        k = 0;
        for (j = i - 1; j >= 0; j--)
            if (temp < data[j])
                data[j + 1] = data[j];
            else {
                k = j + 1;
                break;
            }
        data[k] = temp;
    }
}

/*----------------------------------------------------------------------------*/
/* Time sequential matrix multiply                                            */
/*----------------------------------------------------------------------------*/

static void time_seq_matrix_multiply(
        bool breaks, int n, int seed, int trials)
{
    int trial;
    int i, j;
    double *elapsed;
    int quarter;
    double total;

    assert(1 <= n && n <= N);
    assert(trials >= 1);

    elapsed = (double *) malloc(trials*sizeof(double));
    printf("TIMING SEQUENTIAL MATRIX MULTIPLY ");
    printf("(N = %i, SEED = %i)\n", n, seed);
    srand(seed);
    for (i = 0; i < n; i++) 
        for (j = 0; j < n; j++) {
            left[i][j] = (float) (rand()%101 - 50);
            right[i][j] = (float) (rand()%101 - 50);
        }
    print_current_time();
    for (trial = 0; trial < trials; trial++) {
        clear_processor_caches();
        if (breaks) {
            printf("Starting trial %i (press Enter to continue) ...", trial);
            getchar();
        }
        start_timing();
        seq_matrix_multiply(n, left, right, product);
        finish_timing(&elapsed[trial]);
        if (breaks) {
            printf("Finished trial %i (press Enter to continue) ...", trial);
            getchar();
        }
       printf("Trial %i took %.2f seconds\n", trial, elapsed[trial]);
    }
    print_current_time();
    sort(trials, elapsed);
    quarter = trials/4;
    total = 0.0;
    for (trial = quarter; trial < trials - quarter; trial++) 
        total = total + elapsed[trial];
    printf("STATISTICS FOR MEDIAN %i TRIALS:\n", trials - 2*quarter);
    printf("LOWEST = %.2f seconds, ", elapsed[quarter]);
    printf("HIGHEST = %.2f seconds, ", elapsed[trials - quarter - 1]);
    printf("AVERAGE = %.2f seconds.\n", total/(trials - 2*quarter));
    free(elapsed);
}

/*----------------------------------------------------------------------------*/
/* Time multithreaded matrix multiply                                         */
/*----------------------------------------------------------------------------*/

static void time_multi_matrix_multiply(
        bool breaks, int n, int t, int p, int seed, int trials)
{
    int trial;
    int i, j;
    double *elapsed;
    int quarter;
    double total;

    assert(1 <= n && n <= N);
    assert(t >= 1);
    assert(p >= 1);
    assert(trials >= 1);

    elapsed = (double *) malloc(trials*sizeof(double));
    printf("TIMING MULTITHREADED MATRIX MULTIPLY ");
    printf("(N = %i, T = %i, P = %i, SEED = %i)\n", n, t, p, seed);
    set_processor_usage(p);
    srand(seed);
    for (i = 0; i < n; i++) 
        for (j = 0; j < n; j++) {
            left[i][j] = (float) (rand()%101 - 50);
            right[i][j] = (float) (rand()%101 - 50);
        }
    set_processor_usage(p);
    print_current_time();
    for (trial = 0; trial < trials; trial++) {
        clear_processor_caches();
        if (breaks) {
            printf("Starting trial %i (press Enter to continue) ...", trial);
            getchar();
        }
        start_timing();
        multi_matrix_multiply(n, left, right, product, t);
        finish_timing(&elapsed[trial]);
        if (breaks) {
            printf("Finished trial %i (press Enter to continue) ...", trial);
            getchar();
        }
        printf("Trial %i took %.2f seconds\n", trial, elapsed[trial]);
    }
    print_current_time();
    {
        seq_matrix_multiply(n, left, right, expected);
        if (!equal_matrices(n, product, expected)) 
            printf(">>>>> MULTITHREADED DIFFERENT FROM SEQUENTIAL <<<<<\n");
    }
    sort(trials, elapsed);
    quarter = trials/4;
    total = 0.0;
    for (trial = quarter; trial < trials - quarter; trial++) 
        total = total + elapsed[trial];
    printf("STATISTICS FOR MEDIAN %i TRIALS:\n", trials - 2*quarter);
    printf("LOWEST = %.2f seconds, ", elapsed[quarter]);
    printf("HIGHEST = %.2f seconds, ", elapsed[trials - quarter - 1]);
    printf("AVERAGE = %.2f seconds.\n", total/(trials - 2*quarter));
    free(elapsed);
}

/*----------------------------------------------------------------------------*/
/* Print program usage                                                        */
/*----------------------------------------------------------------------------*/

static void print_program_usage(const char command[])
{
    fprintf(stderr, "Usage: %s -test -seq\n", command);
    fprintf(stderr, "       %s -test -multi\n", command);
    fprintf(stderr, "       %s -time -seq n seed trials\n", command);
    fprintf(stderr, "       %s -time -breaks -seq n seed trials\n", command);
    fprintf(stderr, "       %s -time -multi n t p seed trials\n", command); 
    fprintf(stderr, "       %s -time -breaks -multi n t p seed trials\n", 
            command); 
    fprintf(stderr, "where: n       = size of matrices (>= 1)\n");
    fprintf(stderr, "       t       = number of threads (>= 1)\n");
    fprintf(stderr, "       p       = number of processors (>= 1)\n");
    fprintf(stderr, "       seed    = seed for random generation of items\n");
    fprintf(stderr, "       trials  = number of timing trials (>= 1)\n");
}

/*----------------------------------------------------------------------------*/
/* Main program                                                               */
/*----------------------------------------------------------------------------*/

void main(int argc, char *argv[])
{
    int n, t, p, seed, trials;
    bool args_ok;
    int result;

    switch (argc) {
    case 3:
        if (strcmp(argv[1], "-test") == 0 && 
            strcmp(argv[2], "-seq") == 0)
            test_seq_matrix_multiply();
        else if (strcmp(argv[1], "-test") == 0 && 
            strcmp(argv[2], "-multi") == 0)
            test_multi_matrix_multiply();
        else
            print_program_usage(argv[0]);
        break;
    case 6:
        if (strcmp(argv[1], "-time") == 0 && 
            strcmp(argv[2], "-seq") == 0) {
            args_ok = true;
            result = sscanf(argv[3], "%i", &n);
            args_ok = args_ok && result == 1 && n >= 0;
            result = sscanf(argv[4], "%i", &seed);
            args_ok = args_ok && result == 1;
            result = sscanf(argv[5], "%i", &trials);
            args_ok = args_ok && result == 1 && trials >= 1;
            if (args_ok)
                time_seq_matrix_multiply(false, n, seed, trials);
            else
                print_program_usage(argv[0]);
        } else 
            print_program_usage(argv[0]);
        break;
    case 7:
        if (strcmp(argv[1], "-time") == 0 && 
            strcmp(argv[2], "-breaks") == 0 &&
            strcmp(argv[3], "-seq") == 0) {
            args_ok = true;
            result = sscanf(argv[4], "%i", &n);
            args_ok = args_ok && result == 1 && n >= 0;
            result = sscanf(argv[5], "%i", &seed);
            args_ok = args_ok && result == 1;
            result = sscanf(argv[6], "%i", &trials);
            args_ok = args_ok && result == 1 && trials >= 1;
            if (args_ok)
                time_seq_matrix_multiply(true, n, seed, trials);
            else
                print_program_usage(argv[0]);
        } else 
            print_program_usage(argv[0]);
        break;
    case 8:
        if (strcmp(argv[1], "-time") == 0 && 
            strcmp(argv[2], "-multi") == 0) {
            args_ok = true;
            result = sscanf(argv[3], "%i", &n);
            args_ok = args_ok && result == 1 && n >= 0;
            result = sscanf(argv[4], "%i", &t);
            args_ok = args_ok && result == 1 && t >= 1;
            result = sscanf(argv[5], "%i", &p);
            args_ok = args_ok && result == 1 && p >= 1;
            result = sscanf(argv[6], "%i", &seed);
            args_ok = args_ok && result == 1;
            result = sscanf(argv[7], "%i", &trials);
            args_ok = args_ok && result == 1 && trials >= 1;
            if (args_ok)
                time_multi_matrix_multiply(false, n, t, p, seed, trials);
            else
                print_program_usage(argv[0]);
        
        } else
            print_program_usage(argv[0]);
        break;
    case 9:
        if (strcmp(argv[1], "-time") == 0 && 
            strcmp(argv[2], "-breaks") == 0 &&
            strcmp(argv[3], "-multi") == 0) {
            args_ok = true;
            result = sscanf(argv[4], "%i", &n);
            args_ok = args_ok && result == 1 && n >= 0;
            result = sscanf(argv[5], "%i", &t);
            args_ok = args_ok && result == 1 && t >= 1;
            result = sscanf(argv[6], "%i", &p);
            args_ok = args_ok && result == 1 && p >= 1;
            result = sscanf(argv[7], "%i", &seed);
            args_ok = args_ok && result == 1;
            result = sscanf(argv[8], "%i", &trials);
            args_ok = args_ok && result == 1 && trials >= 1;
            if (args_ok)
                time_multi_matrix_multiply(true, n, t, p, seed, trials);
            else
                print_program_usage(argv[0]);
        
        } else
            print_program_usage(argv[0]);
        break;
    default:
        print_program_usage(argv[0]);
        break;
    }

}

/*----------------------------------------------------------------------------*/
