/* mm.c */

/*
 * This file (mm.c) contains the low-level matrix multiplication
 * functions. These functions were donated by Fook Fah Yap. I patterned
 * the mixed mode multiply functions after the contributed functions.
 *
 * These functions are anywhere from 50% to 20% faster than the BLAS
 * counterparts (for matrix sizes between 100-350). However, if you
 * have a special set of optimized BLAS libraries you should test
 * their speed yourself.
 */

/*  This file is a part of RLaB ("Our"-LaB)
   Copyright (C) 1993  Ian R. Searle

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

   See the file ./COPYING
   ********************************************************************** */

#include "rlab.h"
#include "complex.h"

/*
 * Multiply an m (rows) by k (columns) real matrix A
 * by a k (rows) by n (columns) real matrix B
 * to produce an m by n matrix C.
 */

void
rmmpy (m, k, n, A, B, C)
     int m, k, n;
     double *A, *B, *C;
{
  register double s;
  double *pA, *pB, *pC1, *pC2, *pt;

  pC1 = C;
  for (pA = A; pA < (A + m); pA++, pC1++)
  {
    for (pC2 = pC1, pB = B; pC2 < (pC1 + n * m); pC2 += m)
    {
      for (pt = pA, s = 0.0; pt < (pA + k * m); pt += m)
      {
	s += (*pt) * (*pB++);
      }
      *pC2 = s;
    }
  }
}

/*
 * Multiply an m (rows) by k (columns) complex matrix A
 * by a k (rows) by n (columns) complex matrix B
 * to produce an m by n matrix C.
 */

void
cmmpy (m, k, n, A, B, C)
     int m, k, n;
     Complex *A, *B, *C;
{
  register double s1, s2;
  Complex *pA, *pB, *pC1, *pC2, *pt;

  pC1 = C;
  for (pA = A; pA < (A + m); pA++, pC1++)
  {
    for (pC2 = pC1, pB = B; pC2 < (pC1 + n * m); pC2 += m)
    {
      for (pt = pA, s1 = 0.0, s2 = 0.0; pt < (pA + k * m); pt += m)
      {
	s1 += (*pt).r * (*pB).r - (*pt).i * (*pB).i;
	s2 += (*pt).r * (*pB).i + (*pt).i * (*pB).r;
	pB++;
      }
      (*pC2).r = s1;
      (*pC2).i = s2;
    }
  }
}

/*
 * Real - Complex Matrix Multiply
 */

void
rcmmpy (m, k, n, A, B, C)
     int m, k, n;
     double *A;
     Complex *B, *C;
{
  register double s1, s2;
  double *pA, *pt;
  Complex *pB, *pC1, *pC2;

  pC1 = C;
  for (pA = A; pA < (A + m); pA++, pC1++)
  {
    for (pC2 = pC1, pB = B; pC2 < (pC1 + n * m); pC2 += m)
    {
      for (pt = pA, s1 = 0.0, s2 = 0.0; pt < (pA + k * m); pt += m)
      {
	s1 += (*pt) * (*pB).r;
	s2 += (*pt) * (*pB).i;
	pB++;
      }
      (*pC2).r = s1;
      (*pC2).i = s2;
    }
  }
}

/*
 * Complex - Real Multiply
 */

void
crmmpy (m, k, n, A, B, C)
     int m, k, n;
     Complex *A, *C;
     double *B;
{
  register double s1, s2;
  Complex *pA, *pC1, *pC2, *pt;
  double *pB;

  pC1 = C;
  for (pA = A; pA < (A + m); pA++, pC1++)
  {
    for (pC2 = pC1, pB = B; pC2 < (pC1 + n * m); pC2 += m)
    {
      for (pt = pA, s1 = 0.0, s2 = 0.0; pt < (pA + k * m); pt += m)
      {
	s1 += (*pt).r * (*pB);
	s2 += (*pt).i * (*pB);
	pB++;
      }
      (*pC2).r = s1;
      (*pC2).i = s2;
    }
  }
}

/* **************************************************************
 * The following are some recursive matrix-mulitply functions.
 * ************************************************************** */

/*
 * Multiply an m (rows) by k (columns) real matrix A on the left
 * by a k (rows) by n (columns) real matrix B on the right
 * to produce an m by n matrix C.
 */

void
rmmpyr (M, m, k, n, A, B, C)
     int M, m, k, n;

/*
 * M is the number of rows of matrix A when rmmpy is first called from
 * matrix_Multiply. The call from matrix_Multiply should look like this :
 * rmmpy( MNR(A), MNR(A), MNC(A), MNC(B), A, B, C)
 */

     double *A, *B, *C;
{
  register double s;
  double *pA, *pB;

  if (n == 1)
  {
    if (m == 1)
    {
      pB = B;
      for (pA = A, s = 0.0; pA < (A + k * M); pA += M)
      {
	s += *pA * *pB++;
      }
      *C = s;
    }
    else
    {
      rmmpyr (M, m / 2, k, 1, A, B, C);
      rmmpyr (M, m - m / 2, k, 1, A + m / 2, B, C + m / 2);
    }
  }
  else
  {
    rmmpyr (M, m, k, n / 2, A, B, C);
    rmmpyr (M, m, k, n - n / 2, A, B + k * (n / 2), C + m * (n / 2));
  }
}

/*
 * Multiply an m (rows) by k (columns) complex matrix A on the left
 * by a k (rows) by n (columns) complex matrix B on the right
 * to produce an m by n matrix C.
 */

void
cmmpyr (M, m, k, n, A, B, C)
     int M, m, k, n;
     Complex *A, *B, *C;
{
  register double s1, s2;
  Complex *pA, *pB;

  if (n == 1)
  {
    if (m == 1)
    {
      pB = B;
      for (pA = A, s1 = 0.0, s2 = 0.0; pA < (A + k * M); pA += M)
      {
	s1 += (*pA).r * (*pB).r - (*pA).i * (*pB).i;
	s2 += (*pA).r * (*pB).i + (*pA).i * (*pB++).r;
      }
      (*C).r = s1;
      (*C).i = s2;
    }
    else
    {
      cmmpyr (M, m / 2, k, 1, A, B, C);
      cmmpyr (M, m - m / 2, k, 1, A + m / 2, B, C + m / 2);
    }
  }
  else
  {
    cmmpyr (M, m, k, n / 2, A, B, C);
    cmmpyr (M, m, k, n - n / 2, A, B + k * (n / 2), C + m * (n / 2));
  }
}
