/* math_2.c 
 * Miscellaneous math functions for RLaB */

/*  This file is a part of RLaB ("Our"-LaB)
   Copyright (C) 1992, 1993, 1994  Ian R. Searle

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

   See the file ./COPYING
   ********************************************************************** */

#include "rlab.h"
#include "bltin.h"
#include "symbol.h"
#include "util.h"
#include "scop1.h"
#include "matop1.h"
#include "matop2.h"
#include "matop3.h"
#include "btree.h"
#include "listnode.h"
#include "r_string.h"
#include "fi_1.h"
#include "mathl.h"

#include <math.h>

extern int matrix_is_symm _PROTO ((Matrix * m));

/* **************************************************************
 * Compute condition number of a general matrix.
 * ************************************************************** */
void
Rcond (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  ListNode *M;
  double d_tmp;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("rcond: 1 argument allowed", 0);

  M = bltin_get_numeric_matrix ("rcond", d_arg, 1);
  d_tmp = matrix_Rcond (e_data (M));
  *return_ptr = (VPTR) scalar_Create (d_tmp);

  remove_tmp_destroy (M);
  return;
}

/* **************************************************************
 * Compute the determinant of a matrix
 * ************************************************************** */
void
Det (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  ListNode *M;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("det: 1 argument allowed", 0);

  M = bltin_get_numeric_matrix ("det", d_arg, 1);
  *return_ptr = (VPTR) matrix_Det (e_data (M));
  remove_tmp_destroy (M);
  return;
}

/* **************************************************************
 * Compute the singular value decomposition of a matrix.
 * ************************************************************** */
void
Svd (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  char *str;
  int aflag;
  Btree *rlist;
  Matrix *rsv, *lsv, *sigma;
  ListNode *M, *TYPE;

  /* Check n_args */
  if (n_args < 1 || n_args > 2)
    error_1 ("svd: 1 or 2 argument(s) allowed", 0);

  M = bltin_get_numeric_matrix ("svd", d_arg, 1);
  aflag = 2;			/* default */

  if (n_args == 2)
  {
    TYPE = bltin_get_string ("svd", d_arg, 2);
    str = string_GetString (e_data (TYPE));
    if (!strncmp ("S", str, 1))
      aflag = 2;
    else if (!strncmp ("s", str, 1))
      aflag = 2;
    else if (!strncmp ("A", str, 1))
      aflag = 1;
    else if (!strncmp ("a", str, 1))
      aflag = 1;
    else if (!strncmp ("N", str, 1))
      aflag = 3;
    else if (!strncmp ("n", str, 1))
      aflag = 3;
    else
      error_1 ("svd: Invalid 2nd argument", 0);
  }

  matrix_Svd (e_data (M), &rsv, &lsv, &sigma, aflag);
  rlist = btree_Create ();
  install (rlist, cpstr ("vt"), MATRIX, rsv);
  install (rlist, cpstr ("u"), MATRIX, lsv);
  install (rlist, cpstr ("sigma"), MATRIX, sigma);
  matrix_SetName (rsv, cpstr ("vt"));
  matrix_SetName (lsv, cpstr ("u"));
  matrix_SetName (sigma, cpstr ("sigma"));

  remove_tmp_destroy (M);
  *return_ptr = (VPTR) rlist;
}

/* **************************************************************
 * Compute the Eigenvalues of [A] (standard eigenvalue problem)
 * or [A], [B] (generalized eigenvalue problem). This function checks
 * for symmetry before deciding which solver to use.
 * ************************************************************** */
void
Eig (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  int sol = 0;
  Btree *rlist;
  ListNode *A, *B;
  Matrix *vec, *val, *lvec;

  /* Check n_args */
  if ((n_args != 1) && (n_args != 2))
    error_1 ("eig: Wrong number of arguments", 0);

  A = B = 0;  /* Initialize */

  if (n_args == 1)
  {
    A = bltin_get_numeric_matrix ("eig", d_arg, 1);
    sol = 1;			/* Standard problem, rlab checks symmetry */
  }
  else if (n_args == 2)
  {
    A = bltin_get_numeric_matrix ("eig", d_arg, 1);
    B = bltin_get_numeric_matrix ("eig", d_arg, 2);
    sol = 2;			/* Generalized problem, rlab checks symmetry */
  }
  else
    error_1 ("eig: invalid number of arguments", 0);

  /* Create list for returning results */
  rlist = btree_Create ();

  if (sol == 1)
  {
    /* Standard problem, check symmetry */
    if (matrix_is_symm (e_data (A)))
    {
      /* Standard Eigenvalue Problem, Symmetric */
      matrix_Eig_SEP (e_data (A), &val, &vec);
    }
    else
    {
      /* Standard Eigenvalue Problem, Non-Symmetric */
      matrix_Eig_NEP (e_data (A), &val, &vec, &lvec, 0);
    }
    install (rlist, cpstr ("val"), MATRIX, val);
    install (rlist, cpstr ("vec"), MATRIX, vec);
    matrix_SetName (val, cpstr ("val"));
    matrix_SetName (vec, cpstr ("vec"));
    remove_tmp_destroy (A);
  }
  else if (sol == 2)
  {
    /* Generalized problem, check symmetry */
    if (matrix_is_symm (e_data (A)))
    {
      matrix_Eig_GSEP (e_data (A), e_data (B),
		       &val, &vec);
    }
    else
    {
      matrix_Eig_GNEP (e_data (A), e_data (B),
		       &val, &vec);
    }
    install (rlist, cpstr ("val"), MATRIX, val);
    install (rlist, cpstr ("vec"), MATRIX, vec);
    matrix_SetName (val, cpstr ("val"));
    matrix_SetName (vec, cpstr ("vec"));
    remove_tmp_destroy (A);
    remove_tmp_destroy (B);
  }
  *return_ptr = (VPTR) rlist;
}

/* **************************************************************
 * Compute various norms of matrices.
 * ************************************************************** */
void
Norm (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  char *s;
  double d, p;
  Matrix *m = 0;
  ListNode *M1, *M2;

  /* Check n_args */
  if ((n_args < 1) || (n_args > 2))
    error_1 ("norm: 1 or 2 args allowed", 0);

  /*
   * norm() takes one arg, and defaults to 1-norm.
   * Otherwise norm() takes two args. The first is always
   * the object, the second (optional) arg denotes the type
   * of norm. "1" = 1-norm, "i" = infinity norm,
   * "f" = frobenius, "m" = max(abs([a]))
   */

  M1 = bltin_get_numeric_matrix ("norm", d_arg, 1);
  if (n_args == 1)
  {
    d = matrix_Norm (e_data (M1), "1");
    *return_ptr = (VPTR) scalar_Create (d);
    remove_tmp_destroy (M1);
  }
  else if (n_args == 2)
  {
    M2 = bltin_get_matrix ("norm", d_arg, 2);
  
    if (MTYPE (e_data (M2)) == STRING)
    {
      s = MATs (e_data (M2), 1, 1);
      d = matrix_Norm (e_data (M1), s);
      *return_ptr = (VPTR) scalar_Create (d);
    }
    else
    {
      /* Compute a P-Norm */
      p = (double) MAT (e_data (M2), 1, 1);
      if (detect_inf_r (MDPTRr (e_data (M2)), 1))
      {
	d = matrix_Norm (e_data (M1), "i");
	remove_tmp_destroy (M1);
	remove_tmp_destroy (M2);
	*return_ptr = (VPTR) scalar_Create (d);
	return;
      }
      m = (Matrix *) e_data (M1);
      if (MNR (m) != 1 && MNC (m) != 1)
      {
	remove_tmp_destroy (M1);
	remove_tmp_destroy (M2);
	error_1 ("norm: cannot compute P-norm of a matrix", 0);
      }
      d = matrix_PNorm (m, p);
      *return_ptr = (VPTR) scalar_Create (d);
    }
    remove_tmp_destroy (M1);
    remove_tmp_destroy (M2);
  }
  return;
}

/* **************************************************************
 * Compute the Cholesky factorization.
 * ************************************************************** */
void
Chol (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  ListNode *M;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("chol: 1 argument allowed", 0);

  M = bltin_get_numeric_matrix ("chol", d_arg, 1);
  *return_ptr = (VPTR) matrix_Chol (e_data (M));
  remove_tmp_destroy (M);

  return;
}

/* **************************************************************
 * Compute the QR decomposition of the input matrix.
 * ************************************************************** */
void
QR (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  int pflag;
  Btree *btree;
  Matrix *q, *r, *p;
  char *str;
  ListNode *M, *TYPE;

  /* Check n_args */
  if (n_args != 1 && n_args != 2)
    error_1 ("qr: 1 or 2 arguments allowed", 0);
  pflag = 0;

  M = bltin_get_numeric_matrix ("qr", d_arg, 1);

  if (n_args == 2)
  {
    TYPE = bltin_get_string ("any", d_arg, 1);
    str = string_GetString (e_data (TYPE));
    if (strcmp ("p", str) && strcmp ("P", str))
    {
      remove_tmp_destroy (M);
      remove_tmp_destroy (TYPE);
      error_1 ("qr: 2nd arg must be \"p\"", 0);
    }
    pflag = 1;
    remove_tmp_destroy (TYPE);
  }

  btree = btree_Create ();
  if (pflag == 0)
  {
    matrix_Qr (e_data (M), &q, &r);
    install (btree, cpstr ("q"), MATRIX, q);
    install (btree, cpstr ("r"), MATRIX, r);
    matrix_SetName (q, cpstr ("q"));
    matrix_SetName (r, cpstr ("r"));
  }
  else
  {
    matrix_QrP (e_data (M), &q, &r, &p);
    install (btree, cpstr ("q"), MATRIX, q);
    install (btree, cpstr ("r"), MATRIX, r);
    install (btree, cpstr ("p"), MATRIX, p);
    matrix_SetName (q, cpstr ("q"));
    matrix_SetName (r, cpstr ("r"));
    matrix_SetName (p, cpstr ("p"));
  }
  remove_tmp_destroy (M);
  *return_ptr = (VPTR) btree;
}

/* **************************************************************
 * Compute the Hessenberg form of the input matrix.
 * ************************************************************** */
void
Hess (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Btree *btree;
  Matrix *p, *h;
  ListNode *M;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("hess: 1 argument allowed", 0);

  /* get arg from list */
  M = bltin_get_numeric_matrix ("hess", d_arg, 1);
  matrix_Hess (e_data (M), &p, &h);
  btree = btree_Create ();
  install (btree, cpstr ("p"), MATRIX, p);
  install (btree, cpstr ("h"), MATRIX, h);
  matrix_SetName (p, cpstr ("p"));
  matrix_SetName (h, cpstr ("h"));

  *return_ptr = (VPTR) btree;
  remove_tmp_destroy (M);
}

/* **************************************************************
 * Balance a matrix.
 * ************************************************************** */
void
Balance (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  ListNode *M;
  Btree *btree;
  Matrix *Ab, *t;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("balance: 1 argument allowed", 0);

  M = bltin_get_numeric_matrix ("balance", d_arg, 1);
  matrix_Balance (e_data (M), &Ab, &t);
  btree = btree_Create ();
  install (btree, cpstr ("ab"), MATRIX, Ab);
  install (btree, cpstr ("t"), MATRIX, t);
  matrix_SetName (Ab, cpstr ("ab"));
  matrix_SetName (t, cpstr ("t"));
  *return_ptr = (VPTR) btree;

  remove_tmp_destroy (M);
  return;
}

/* **************************************************************
 * max function
 * ************************************************************** */

void
Max (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  ListNode *M1, *M2;

  /* Check n_args */
  if (n_args == 1)
  {
    M1 = bltin_get_numeric_matrix ("max", d_arg, 1);
    *return_ptr = (VPTR) matrix_Max (e_data (M1));
    remove_tmp_destroy (M1);
    return;
  }
  else if (n_args == 2)
  {
    M1 = bltin_get_numeric_matrix ("max", d_arg, 1);
    M2 = bltin_get_numeric_matrix ("max", d_arg, 2);
    *return_ptr = (VPTR) matrix_2_max (e_data (M1), e_data (M2));
    remove_tmp_destroy (M1);
    remove_tmp_destroy (M2);
    return;
  }
  else
    error_1 ("max: 1 or 2 arguments allowed", 0);
}

/* **************************************************************
 * max function that returns the corresponding index value.
 * ************************************************************** */
void
MaxI (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  ListNode *M;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("maxi: 1 argument allowed", 0);

  M = bltin_get_numeric_matrix ("maxi", d_arg, 1);
  *return_ptr = (VPTR) matrix_Maxi (e_data (M));
  remove_tmp_destroy (M);
  return;
}

/* **************************************************************
 * min function
 * ************************************************************** */
void
Min (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  ListNode *M1, *M2;

  /* Check n_args */
  if (n_args == 1)
  {
    M1 = bltin_get_numeric_matrix ("min", d_arg, 1);
    *return_ptr = (VPTR) matrix_Min (e_data (M1));
    remove_tmp_destroy (M1);
    return;
  }
  else if (n_args == 2)
  {
    M1 = bltin_get_numeric_matrix ("min", d_arg, 1);
    M2 = bltin_get_numeric_matrix ("min", d_arg, 2);
    *return_ptr = (VPTR) matrix_2_min (e_data (M1), e_data (M2));
    remove_tmp_destroy (M1);
    remove_tmp_destroy (M2);
    return;
  }
  else
    error_1 ("min: 1 or 2 arguments allowed", 0);
}

/* **************************************************************
 * min function that returns the corresponding index value.
 * ************************************************************** */
void
MinI (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  ListNode *M;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("mini: 1 argument allowed", 0);

  M = bltin_get_numeric_matrix ("mini", d_arg, 1);
  *return_ptr = (VPTR) matrix_Mini (e_data (M));
  remove_tmp_destroy (M);
  return;
}

/* **************************************************************
 * Vector sort function. Sort the input vector. Return the sorted
 * vector, and a vector of the sorted indices.
 * ************************************************************** */

static void r_qsort _PROTO ((double *v, int left, int right, double *ind));
static void csort _PROTO ((char *v[], int left, int right, double *ind));

void
Sort (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  int i, j, n;
  Btree *btree;
  ListNode *M;
  Matrix *sind, *m, *mcopy;

  sind = 0;
  mcopy = 0;			/* Initialize */

  /* Check n_args */
  if (n_args != 1)
    error_1 ("sort: 1 argument allowed", 0);

  M = bltin_get_matrix ("sort", d_arg, 1);
  m = (Matrix *) e_data (M);
  switch (MTYPE (m))
  {
  case REAL:
    if (MNR (m) == 1 || MNC (m) == 1)
    {
      /* Vector sort */
      n = max (MNR (m), MNC (m));
      sind = matrix_CreateFill (1.0, (double) n, 1.0, 0);
      mcopy = matrix_Copy (m);
      r_qsort ((double *) MDPTRr (mcopy), 0, n - 1,
	       (double *) MDPTRr (sind));
    }
    else
    {
      /* Matrix sort (column-wise) */
      n = MNR (m);
      sind = matrix_CreateFillSind (MNR (m), MNC (m));
      mcopy = matrix_Copy (m);
      for (i = 0; i < MNC (m); i++)
	r_qsort ((double *) (mcopy->val.mr + (i * n)), 0, n - 1,
		 (double *) (sind->val.mr + (i * n)));
    }
    break;
  case COMPLEX:
    if (MNR (m) == 1 || MNC (m) == 1)
    {
      int size = MNR (m) * MNC (m);
      n = max (MNR (m), MNC (m));
      sind = matrix_CreateFill (1.0, (double) n, 1.0, 0);
      mcopy = matrix_Abs (m);
      r_qsort ((double *) MDPTRr (mcopy), 0, n - 1,
	       (double *) MDPTRr (sind));
      
      /* Now sort [m] according to [sind] */
      matrix_Destroy (mcopy);
      mcopy = matrix_CreateC (MNR (m), MNC (m));
      for (i = 1; i <= size; i++)
      {
	MATcvr1 (mcopy, i) = MATcvr1 (m, ((int) MATrv1 (sind, i)));
	MATcvi1 (mcopy, i) = MATcvi1 (m, ((int) MATrv1 (sind, i)));
      }
    }
    else
    {
      /* Matrix sort (column-wise) */
      n = MNR (m);
      sind = matrix_CreateFillSind (MNR (m), MNC (m));
      mcopy = matrix_Abs (m);
      for (i = 0; i < MNC (m); i++)
	r_qsort ((double *) (mcopy->val.mr + (i * n)), 0, n - 1,
		 (double *) (sind->val.mr + (i * n)));
      
      /* Now sort [m] according to [sind] */
      matrix_Destroy (mcopy);
      mcopy = matrix_CreateC (MNR (m), MNC (m));
      for (i = 1; i <= MNC (m); i++)
	for (j = 1; j <= MNR (m); j++)
	{
	  MATr (mcopy, j, i) = MATr (m, ((int) MAT (sind, j, i)), i);
	  MATi (mcopy, j, i) = MATi (m, ((int) MAT (sind, j, i)), i);
	}
    }
    break;
  case STRING:
    if (MNR (m) == 1 || MNC (m) == 1)
    {
      /* Vector sort */
      n = max (MNR (m), MNC (m));
      sind = matrix_CreateFill (1.0, (double) n, 1.0, 0);
      mcopy = matrix_Copy (m);
      csort ((char **) MDPTRs (mcopy), 0, n - 1, (double *) MDPTRr (sind));
    }
    else
    {
      /* Matrix sort (column-wise) */
      n = MNR (m);
      sind = matrix_CreateFillSind (MNR (m), MNC (m));
      mcopy = matrix_Copy (m);
      for (i = 0; i < MNC (m); i++)
	csort ((char **) (mcopy->val.ms + (i * n)), 0, n - 1,
	       (double *) (sind->val.mr + (i * n)));
    }
    break;
  }
  
  btree = btree_Create ();
  install (btree, cpstr ("val"), MATRIX, mcopy);
  install (btree, cpstr ("ind"), MATRIX, sind);
  *return_ptr = (VPTR) btree;
  remove_tmp_destroy (M);
  return;
}

/* **************************************************************
 * Sort() support functions (they do the real work).
 * ************************************************************** */

static void qswap _PROTO ((double *v, int i, int j, double *ind));

static void
r_qsort (v, left, right, ind)
     double *v, *ind;
     int left, right;
{
  int i, last;

  if (left >= right)		/* Do nothing if array contains */
    return;			/* fewer than two elements */

  qswap (v, left, (left + right) / 2, ind);	/* Move partitiion element */
  last = left;			/* to v[0] */

  for (i = left + 1; i <= right; i++)	/* Partition */
    if (v[i] < v[left])
      qswap (v, ++last, i, ind);

  qswap (v, left, last, ind);	/* Restore partition element */
  r_qsort (v, left, last - 1, ind);
  r_qsort (v, last + 1, right, ind);
}

static void
qswap (v, i, j, ind)
     double *v, *ind;
     int i, j;
{
  double tmp;

  tmp = v[i];			/* swap values */
  v[i] = v[j];
  v[j] = tmp;

  tmp = ind[i];			/* swap indices */
  ind[i] = ind[j];
  ind[j] = tmp;
}

/*
 * Simple character qsort.
 */

static void cswap _PROTO ((char *v[], int i, int j, double *ind));

static void
csort (v, left, right, ind)
     char *v[];
     int left, right;
     double *ind;
{
  int i, last;

  if (left >= right)
    return;
  cswap (v, left, (left + right) / 2, ind);
  last = left;
  for (i = left + 1; i <= right; i++)
    if (strcmp (v[i], v[left]) < 0)
      cswap (v, ++last, i, ind);
  cswap (v, left, last, ind);
  csort (v, left, last - 1, ind);
  csort (v, last + 1, right, ind);
}

/*
 * Interchange v[i] and v[j]
 */

static void
cswap (v, i, j, ind)
     char *v[];
     int i, j;
     double *ind;
{
  char *temp;
  double tmp;

  temp = v[i];
  v[i] = v[j];
  v[j] = temp;

  tmp = ind[i];
  ind[i] = ind[j];
  ind[j] = tmp;
}
