// version 0.3 modifed by ian a. mason
// may 21 @ u.n.e
//
// added real argument arrays and got rid of the silly 
// broadcasting of the things that should be in the argument arrays
// added three file arguments for you know what
// got rid of the goto
// added VERSION prior to making it automatic
// put everything in a library

#include <stdlib.h>
#include <pvm3.h>
#include <stdio.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/types.h>
#include <time.h>
#include "mmlib.h"

#define VERSION  "mm.0.3"
#define ENCODING PvmDataRaw
#define MAXNTIDS    100
#define MAXROW      100
#define ATAG        2
#define BTAG        3
#define DIMTAG      5
#define VERBOSE  0
#define PERMS 0666

int parse_args(int, char*[], int*, int*, int *, int []);

int main(int argc, char* argv[]){
    int ntask = 2;
    int info;
    int mytid, mygid;
    int child[MAXNTIDS-1];
    int i, m, blksize;
    int myrow[MAXROW];
    int *a, *b, *c, *atmp;
    int row, col, up, down;
    int crow, ccol;
    int fd[3];   /* A = fd[0] B = fd[1] C = fd[2] */
    //real time for the time being
    int start_time = time(NULL); 
    
    if(parse_args(argc, argv, &m, &blksize, &ntask, fd) < 0) exit(0);
    
    mytid = pvm_mytid();
    pvm_setopt(PvmRoute, PvmRouteDirect);

    if (mytid < 0){
      pvm_perror(argv[0]);
      return -1; }

    mygid = pvm_joingroup("mmult");

    if (mygid < 0){
      pvm_perror(argv[0]); 
      pvm_exit(); 
      return -1; }

    if ((mygid == 0) && (ntask > 1)) {
      info = pvm_spawn(VERSION, 
                       &argv[1], 
                       PvmTaskDefault, 
                       (char*)0,
                       ntask-1, 
                       child);

      if(info != ntask-1){
        pvm_lvgroup("mmult"); 
        pvm_exit(); 
        return -1; }
    }

    // barrier

    info = pvm_barrier("mmult",ntask);
    if (info < 0) pvm_perror(argv[0]);

    for (i = 0; i < m; i++)
      myrow[i] = pvm_gettid("mmult", (mygid/m)*m + i);

    a = (int*)malloc(sizeof(int)*blksize*blksize);
    b = (int*)malloc(sizeof(int)*blksize*blksize);
    c = (int*)malloc(sizeof(int)*blksize*blksize);
    atmp = (int*)malloc(sizeof(int)*blksize*blksize);
    if (!(a && b && c && atmp)) {
      fprintf(stderr, 
              "%s: out of memory!\n", 
              argv[0]);
      free(a); 
      free(b); 
      free(c); 
      free(atmp);
      pvm_lvgroup("mmult"); 
      pvm_exit();
      return -1; }

    row = mygid/m; 
    col = mygid % m;

    up = pvm_gettid("mmult", 
                    ((row)?
                     (row-1):
                     (m-1))*m+col);
    down = pvm_gettid("mmult", 
                      ((row == (m-1))?
                       col:
                       (row+1)*m+col));

    init_block(a, b, c, blksize, row, col);

    for (i = 0; i < m; i++) {
      if (col == (row + i)%m) {
        pvm_initsend(ENCODING);
        pvm_pkint(a, blksize*blksize, 1);
        pvm_mcast(myrow, m, (i+1)*ATAG);
        block_mult(c,a,b,blksize);  
      } else {
        pvm_recv(pvm_gettid("mmult", row*m + (row +i)%m), (i+1)*ATAG);
        pvm_upkint(atmp, blksize*blksize, 1);
        block_mult(c,atmp,b,blksize); 
      }
      pvm_initsend(ENCODING);
      pvm_pkint(b, blksize*blksize, 1);
      pvm_send(up, (i+1)*BTAG);
      pvm_recv(down, (i+1)*BTAG);
      pvm_upkint(b, blksize*blksize, 1);  
    }

    info = pvm_barrier("mmult",ntask);
    if (info < 0) pvm_perror(argv[0]);

    for (i = 0 ; i < blksize*blksize; i++)
      if (a[i] != c[i])
        printf("Error a[%d] (%d) != c[%d] (%d) \n", i, a[i],i,c[i]);

    if(VERBOSE){
    printf("Block C at [%d,%d] managed by task %d in:\n", row, col, mytid);

    for(crow = 0; crow < blksize; crow++){
      for(ccol = 0; ccol < blksize; ccol++) 
        printf("%5d ", c[(crow*blksize) + ccol]); 
      printf("\n");  }

    printf("Block A at [%d,%d] managed by task %d in:\n", row, col, mytid);

    for(crow = 0; crow < blksize; crow++){
      for(ccol = 0; ccol < blksize; ccol++) 
        printf("%5d ", a[(crow*blksize) + ccol]); 
      printf("\n");  }
    }

    printf("%s %d %d task %d done sucessfully after %ld seconds.\n", 
           VERSION, 
           m, 
           blksize, 
           mytid, 
           time(NULL) - start_time);
    free(a); 
    free(b); 
    free(c); 
    free(atmp);
    pvm_lvgroup("mmult");
    pvm_exit();
    return 0; 
}


int parse_args(int argc, char *argv[], int *m, int *blksize, int *ntask, int fd[]){
  if ((argc != 6) || 
      ((fd[0] = open(argv[1], O_RDONLY)) == -1) ||
      ((fd[1] = open(argv[2], O_RDONLY)) == -1) ||
      ((fd[2] = creat(argv[3], PERMS)) == -1)   ||
      ((*m   = atoi(argv[4])) <= 0)             ||
      ((*blksize   = atoi(argv[5])) <= 0)){
    fprintf(stderr, 
            "Usage: %s matrixA matrixB matrixC m blk\n", argv[0]);
    return(-1); };
  if(*m > MAXROW){
    fprintf(stderr, "m = %d not valid.\n", *m);
    return(-1); };
  *ntask = (*m)*(*m);
  if ((*ntask < 1) || (*ntask >= MAXNTIDS)) {
    fprintf(stderr, "ntask in parse_args = %d not valid.\n", *ntask);
    pvm_exit();
    return(-1); };
  return(0); 
}