#include <stdio.h>
#include <math.h>

#include "easy.h"

#define msgtag_ARGS 1

#define msgtag_A 100
#define msgtag_B 200
#define msgtag_C 300

extern void mm(double *C, double *A, double *B, int lsize, int update);
extern void print_mat(char *text, double *C, int lsize);

static void roll(double *B, int blen)
{
  sendrecv(north(), msgtag_B, B, blen,
	   south(), msgtag_B, B, blen);
}

int main(int argc, char **argv)
{
  extern int atoi(char *);
  int lsize, sqrt_nproc;
  double *A=NULL, *B=NULL, *C=NULL, *T=NULL;
  int ME;

  attachproc();
  setstride(1);

  /* pvm_setopt(PvmRoute,PvmRouteDirect); */

  for (;;) {

    if (--argc >= 2) {
      lsize = atoi(argv[1]);
      sqrt_nproc = atoi(argv[2]);
      argc = 0;
    }
    else {
      int args[2];
      
      setdatatype(INTEGER4);
      recv(myhost(),msgtag_ARGS,args,2);

      lsize = args[0];
      sqrt_nproc = args[1];
    }

    if (lsize < 1 || sqrt_nproc < 1) break;
    
    ME = mynode();
    
    A = (double *)malloc(lsize * lsize * sizeof(*A));
    B = (double *)malloc(lsize * lsize * sizeof(*B));
    C = (double *)malloc(lsize * lsize * sizeof(*C));
    
    setdatatype(REAL8);
    
    recv2d(myhost(),msgtag_A,A,lsize,lsize,lsize);
    recv2d(myhost(),msgtag_B,B,lsize,lsize,lsize);
    
    
    if (sqrt_nproc == 1) {
      int update = 0;

      mm(C,A,B,lsize,update);
      
    }
    else {
      int i;
      int mycol, myrow;
      int update = 0;
      int tmp;
      
      getxyz(&mycol, &myrow, NULL);
      tmp = (mycol + myrow - 1) % sqrt_nproc;
      
      T = (double *)malloc(lsize * lsize * sizeof(*T));
      
      for (i=0; i<sqrt_nproc; i++) {
	double *A_mat = NULL;
	
	if ( tmp == 0 ) { /* send() */
	  send(east(), msgtag_A , A, lsize * lsize);
	  A_mat = A;
	}
	else if ( tmp == sqrt_nproc - 1 ) { /* recv() only */
	  recv(west(), msgtag_A, T, lsize * lsize);
	  A_mat = T;
	}
	else { /* storefwd(), i.e. recv() & send forward */
	  storefwd(west(), msgtag_A,
		   east(), msgtag_A,
		   T, lsize * lsize);
	  A_mat = T;
	}
	
	mm(C,A_mat,B,lsize,update++);

	{
	  double xxx=0;
	  if (ME == 0) {
	    broadcast(999,&xxx,1);
	  }
	  else {
	    recv(0,999,&xxx,1);
	  }
	}

	if (i < sqrt_nproc - 1) {
	  roll(B,lsize * lsize);
	  if ( --tmp < 0 ) tmp = sqrt_nproc - 1;
	}
	
      } /* for (i=0; i<sqrt_nproc; i++) */
      
      free(T);
      
    }
    
    free(A);
    free(B);
    
    send2d(myhost(), msgtag_C, C, lsize, lsize, lsize);

    free(C);

  } /* for (;;) */
  
  exit(0);
}
