/* odei.c */

/* This file is a part of RLaB ("Our"-LaB)
   Copyright (C) 1994  Ian R. Searle

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

   See the file ./COPYING
   ********************************************************************** */

#include "rlab.h"
#include "symbol.h"
#include "mem.h"
#include "list.h"
#include "btree.h"
#include "bltin.h"
#include "scop1.h"
#include "matop1.h"
#include "matop2.h"
#include "r_string.h"
#include "util.h"
#include "mathl.h"
#include "function.h"
#include "lp.h"
#include "odei.h"

#include <math.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>

#define rabs(x) ((x) >= 0 ? (x) : -(x))

int rks_func _PROTO ((double *t, double *y, double *yp));
int out_func _PROTO ((double *t, double *y, int i));
static double epsilon _PROTO ((void));

static int neq;
static Scalar *stime;
static Matrix *my, *out;
static ListNode *tent, *my_ent;
static char *fname, *out_fname;
static Datum rks_args[2];
static Datum out_args[2];
static int first_out = 1;
static int nout = 0;
static ListNode *tmp5;
static int nstep = 0;

/* **************************************************************
 * Builtin interface to RKSUITE
 * ************************************************************** */
void
odei (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  double dtout, eps, t, t0, tend, tstart;
  int i, j, lenwrk;
  ListNode *FUNC, *YSTART, *DTOUT, *RELERR, *ABSERR, *OUTF;
  F_DOUBLE abserr, relerr;
  F_INT iflag, iwork[5];
  ListNode *tmp1, *tmp2;
  Matrix *work, *y, *ystart;

  eps = epsilon ();

  /*
   * Check arguments.
   */

  if (n_args < 4)
    error_1 ("ode: requires at least 4 arguments", 0);

  dtout = 0.0;         /* Initialize */
  DTOUT = RELERR = OUTF = 0;

  /* Get function ptr */
  FUNC = bltin_get_func ("ode", d_arg, 1);
  fname = e_name (FUNC);

  /* Get tstart */
  tstart = bltin_get_numeric_double ("ode", d_arg, 2);

  /* Get tend */
  tend = bltin_get_numeric_double ("ode", d_arg, 3);

  if (tend == tstart)
    error_1 ("ode: error -- tstart == tend", 0);

  /* Get ystart */
  YSTART = bltin_get_numeric_matrix ("ode", d_arg, 4);
  ystart = (Matrix *) e_data (YSTART);
  if (MTYPE (ystart) != REAL)
    error_1 ("ode: YSTART must be REAL", 0);

  /* Extract neq from ystart */
  if (MNR (ystart) != 1 && MNC (ystart) != 1)
    error_1 ("ode: YSTART must be a row or column vector", 0);

  neq = MNR (ystart) * MNC (ystart);
  tmp2 = install_tmp (MATRIX, y = matrix_Create (neq, 1),
		      matrix_Destroy);

  if (n_args > 4)
  {
    /* Get dtout */
    DTOUT = bltin_get_entity ("ode", d_arg, 5);
    if (e_type (DTOUT) == UNDEF)
    {
      /* Default value */
      dtout = (tend - tstart) / 100;
    }
    else
    {
      if (e_type (DTOUT) == MATRIX && MTYPE (e_data (DTOUT)) == REAL)
      {
	dtout = (double) MAT (e_data (DTOUT), 1, 1);
	remove_tmp_destroy (DTOUT);
      }
      else
      {
	remove_tmp_destroy (FUNC);
	remove_tmp_destroy (YSTART);
	remove_tmp_destroy (DTOUT);
	error_1 ("ode: DTOUT must be numeric-real", 0);
      }
    }
  }
  else
  {
    /* Default value */
    dtout = (tend - tstart) / 100;
  }

  if (dtout == 0)
  {
    remove_tmp_destroy (FUNC);
    remove_tmp_destroy (YSTART);
    remove_tmp_destroy (DTOUT);
    error_1 ("ode: dout must be non-zero", 0);
  }

  if (n_args > 5)
  {
    /* Get relerr */
    RELERR = bltin_get_entity ("ode", d_arg, 6);
    if (e_type (RELERR) == UNDEF)
    {
      /* Default value */
      relerr = (F_DOUBLE) 1.e-6;
    }
    else
    {
      if (e_type (RELERR) == MATRIX && MTYPE (e_data (RELERR)) == REAL)
      {
	relerr = (double) MAT (e_data (RELERR), 1, 1);
	remove_tmp_destroy (RELERR);
      }
      else
      {
	remove_tmp_destroy (FUNC);
	remove_tmp_destroy (YSTART);
	remove_tmp_destroy (DTOUT);
	remove_tmp_destroy (RELERR);
	error_1 ("ode: RELERR must be numeric-real", 0);
      }
    }
  }
  else
  {
    /* Default */
    relerr = (F_DOUBLE) 1.e-6;
  }

  if (n_args > 6)
  {
    /* Get tol */
    ABSERR = bltin_get_entity ("ode", d_arg, 7);
    if (e_type (ABSERR) == UNDEF)
    {
      /* Default value */
      abserr = (F_DOUBLE) 1.0e-6;
    }
    else
    {
      if (e_type (ABSERR) == MATRIX && MTYPE (e_data (ABSERR)) == REAL)
      {
	abserr = (double) MAT (e_data (ABSERR), 1, 1);
	remove_tmp_destroy (ABSERR);
      }
      else
      {
	remove_tmp_destroy (FUNC);
	remove_tmp_destroy (YSTART);
	remove_tmp_destroy (DTOUT);
	remove_tmp_destroy (RELERR);
	remove_tmp_destroy (ABSERR);
	error_1 ("ode: ABSERR must be numeric-real", 0);
      }
    }
  }
  else
  {
    /* Default value */
    abserr = (F_DOUBLE) 1.0e-6;
  }

  nstep = (rabs (tend - tstart) / dtout + .5);

  if (n_args > 7)
  {
    /* Get output function ptr */
    OUTF = bltin_get_func ("ode", d_arg, 8);
    out_fname = e_name (OUTF);
    first_out = 1;
    nout = 0;
  }
  else
  {
    out_fname = 0;
    /* Set up output array */
    tmp5 = install_tmp (MATRIX, out = matrix_Create (nstep + 1, neq + 1),
			matrix_Destroy);
  }

  /*
   * Done with argument processing.
   * Initialize some things...
   */

  lenwrk = 100 + 21 * neq;
  tmp1 = install_tmp (MATRIX, work = matrix_Create (lenwrk, 1),
		      matrix_Destroy);
  iflag = 1;

  /*
   * Call integrator repeatedley.
   */


  /*
   * Set up ENTITIES for user-function.
   */

  tent = listNode_Create ();
  listNode_AttachData (tent, SCALAR, stime = scalar_Create (0.0),
		       scalar_Destroy);
  listNode_SetKey (tent, cpstr ("t"));

  my_ent = listNode_Create ();
  listNode_AttachData (my_ent, MATRIX, my = matrix_Create (0, 0),
		       matrix_Destroy);
  listNode_SetKey (my_ent, cpstr ("y"));

  /*
   * Set these manually so that we can just 
   * copy the pointer later, and not duplicate
   * the space.
   */

  scalar_SetName (stime, cpstr ("t"));
  matrix_SetName (my, cpstr ("y"));
  my->nrow = neq;
  my->ncol = 1;

  rks_args[0].u.ent = tent;
  rks_args[0].type = ENTITY;
  rks_args[1].u.ent = my_ent;
  rks_args[1].type = ENTITY;

  /* Save initial conditions and setup y[] */
  if (out_fname)
  {
    out_args[0].u.ent = tent;
    out_args[0].type = ENTITY;
    out_args[1].u.ent = my_ent;
    out_args[1].type = ENTITY;

    out_func (&tstart, MDPTRr (ystart), 0);
    for (j = 2; j <= neq + 1; j++)
    {
      MATrv1 (y, j - 1) = MATrv1 (ystart, j - 1);
    }
  }
  else
  {
    MAT (out, 1, 1) = tstart;
    for (j = 2; j <= neq + 1; j++)
    {
      MAT (out, 1, j) = MATrv1 (ystart, j - 1);
      MATrv1 (y, j - 1) = MATrv1 (ystart, j - 1);
    }
  }

  /* Now step through output points */
  t0 = tstart;
  for (i = 1; i <= nstep; i++)
  {
    t = tstart + i * dtout;
    if (i == nstep)
      t = tend;

    ODE (rks_func, &neq, MDPTRr (y), &t0, &t,
	 &relerr, &abserr, &iflag, MDPTRr (work), iwork);

    /* Check for errors */
    if (iflag > 3)
    {
      /* Check for different types of failures (later) */
      printf ("ode: iflag = %i\n", (int) iflag);
    }

    /* Reset the time */
    t0 = t;

    /* Save the output */
    if (out_fname)
    {
      out_func (&t, MDPTRr (y), i);
    }
    else
    {
      MAT (out, i + 1, 1) = t;
      for (j = 2; j <= neq + 1; j++)
	MAT (out, i + 1, j) = MATrv1 (y, j - 1);
    }
  }

  /* Clean Up */
  remove_tmp_destroy (tmp1);
  remove_tmp_destroy (tmp2);
  remove_tmp (tmp5);

  /* Clean up time, my */
  listNode_Destroy (tent);
  my->nrow = 0;
  my->ncol = 0;
  my->val.mr = 0;
  listNode_Destroy (my_ent);

  remove_tmp_destroy (FUNC);
  remove_tmp_destroy (YSTART);
  if (n_args == 8)
    remove_tmp_destroy (OUTF);

  first_out = 0;
  nout = 0;

  *return_ptr = (VPTR) out;
}

/*
 * The interface to the user-specified function.
 */

int
rks_func (t, y, yp)
     double *t, *y, *yp;
{
  int i;
  VPTR retval;

  /*
   * Put t, y, and yp into rks_args.
   */

  SVALr (stime) = *t;
  my->val.mr = y;

  /*
   * Call user/builtin function.
   */

  retval = call_rlab_script (fname, rks_args, 2);

  /*
   * Now copy returned entity into yp.
   */

  if ((int) *((int *) retval) == MATRIX)
  {
    if (MNR (retval) * MNC (retval) != neq)
      error_1 ("ode: incorrectly dimensioned derivitive", 0);
    if (MTYPE (retval) != REAL)
      error_1 ("ode: rhs function must return REAL matrix", 0);

    for (i = 0; i < neq; i++)
    {
      yp[i] = MATrv (retval, i);
    }
    if (matrix_GetName (retval) == 0)
      matrix_Destroy ((Matrix *) retval);
  }
  else if ((int) *((int *) retval) == SCALAR)
  {
    if (neq != 1)
      error_1 ("ode: incorrectly dimensioned derivitive", 0);

    yp[0] = SVALr (retval);
    if (scalar_GetName (retval) == 0)
      scalar_Destroy ((Scalar *) retval);
  }
  else
    error_1 ("ode: derivitive function must return a NUMERIC entity", 0);

  return (1);
}

static double
epsilon ()
{
  double eps;
  eps = 1.0;
  while ((1.0 + eps) != 1.0)
  {
    eps = eps / 2.0;
  }
  return (eps);
}

/*
 * The interface to the user-specified OUTPUT function.
 */

int
out_func (t, y, I)
     double *t, *y;
     int I;
{
  int i;
  VPTR retval;

  /*
   * Put t, and y into out_args.
   */

  SVALr (stime) = *t;
  my->val.mr = y;
  
  /*
   * Call user/builtin function.
   */

  retval = call_rlab_script (out_fname, out_args, 2);

  /*
   * Now copy returned entity into OUT.
   */

  if ((int) *((int *) retval) == MATRIX)
  {
    if (first_out)
    { 
      /* Set the output size */
      nout = MNR (retval) * MNC (retval);
      first_out = 0;

      /* Set up output array */
      tmp5 = install_tmp (MATRIX, out = matrix_Create (nstep + 1, nout),
			  matrix_Destroy);
    }
    else
    {
      if (MNR (retval) * MNC (retval) != nout)
      {
	first_out = 1;
	nout = 0;
	error_1 ("ode: iconsistent dimension from output function", 0);
      }
    }
    
    /* Now copy the output */
    for (i = 0; i < nout; i++)
    {
      MAT (out, I+1, i+1) = MATrv (retval, i);
    }
    if (matrix_GetName (retval) == 0)
      matrix_Destroy ((Matrix *) retval);
  }
  else if ((int) *((int *) retval) == SCALAR)
  {
    if (first_out)
    { 
      /* Set the output size */
      nout = 1;
      first_out = 0;

      /* Set up output array */
      tmp5 = install_tmp (MATRIX, out = matrix_Create (nstep + 1, nout),
			  matrix_Destroy);
    }
    else
    {
      if (nout != 1)
      {
	first_out = 1;
	nout = 0;
	error_1 ("ode: iconsistent dimension from output function", 0);
      }
    }

    /* Now set copy the output */
    MAT (out, I+1, 1) = SVALr (retval);
    if (scalar_GetName (retval) == 0)
      scalar_Destroy ((Scalar *) retval);
  }
  else
    error_1 ("ode: output function must return a NUMERIC entity", 0);

  return (1);
}
