/*
 *  Copyright (C) 1993 Michael Tiller
 *
 *  Permission is hereby granted by the author to use, copy, modify,
 *  and distribute this code for any purpose provided this copyright
 *  notice is included in all copies.  This code is provided "as is"
 *  without any expressed or implied warranty.
 *
 */

/* This code is at the heart of the expression manipulation routines */

#include	"expr.h"
#include	<stdio.h>
#include	<string.h>
#include	<math.h>
#include	<malloc.h>
#include	<assert.h>

#define	EXPEPS	(0.0001)

/* Some useful constants */
static	struct	Expr	One = { Const, NULL, NULL, 1.0, NULL };
static	struct	Expr	MinusOne = { Const, NULL, NULL, -1.0, NULL };

/* I would have prefered to do all this stuff in C++ (actually I have),
   but I wanted this to be in C so it would be easy to call from
   anywhere (FORTRAN, TCL, etc...) */


/* Delete an expression...kinda like a destructor in C++ */
void
DeleteExpr(ExprPtr ptr)
{
	if (ptr->type==Id)
		free(ptr->id);

	if (ptr->expr1)
		DeleteExpr(ptr->expr1);
	if (ptr->expr2)
		DeleteExpr(ptr->expr2);

	if (ptr!=&One && ptr!=&MinusOne)
		free(ptr);
}

/* The "copy constructor" for an Expression */
ExprPtr
CopyExpr(ExprPtr ptr)
{
	ExprPtr	ret;
	ret = (ExprPtr)malloc(sizeof(struct Expr));
	*ret = *ptr;
	if (ret->type==Id)
		ret->id = strdup(ptr->id);
	return ret;
}

/* The "constructor" for a Constant Expression */
ExprPtr
MakeConst(double d)
{
	ExprPtr	ret;
	ret = (ExprPtr)malloc(sizeof(struct Expr));
	ret->type = Const;
	ret->val = d;
	ret->expr1 = NULL;
	ret->expr2 = NULL;
	return ret;
}

/* The "constructor" for an Identifier Expression */
ExprPtr
MakeId(char *id)
{
	ExprPtr	ret;
	ret = (ExprPtr)malloc(sizeof(struct Expr));
	ret->type = Id;
	ret->id = id;
	ret->expr1 = NULL;
	ret->expr2 = NULL;
	return ret;
}

/* The "constructor" for an Addition Expression */
ExprPtr
MakeAdd(ExprPtr ptr1, ExprPtr ptr2)
{
	ExprPtr	ret;

	if (ptr1->type == Const && ptr2->type == Const)
		return MakeConst(ptr1->val+ptr2->val);

	ret = (ExprPtr)malloc(sizeof(struct Expr));
	ret->type = Add;
	ret->expr1 = ptr1;
	if (ptr2->type==Neg)
		ret->expr2 = MakeSubExpr(ptr2);
	else
		ret->expr2 = ptr2;
	return ret;
}

/* The "constructor" for an Subtraction Expression */
ExprPtr
MakeSub(ExprPtr ptr1, ExprPtr ptr2)
{
	ExprPtr	ret;

	if (ptr1->type == Const && ptr2->type == Const)
		return MakeConst(ptr1->val-ptr2->val);

	ret = (ExprPtr)malloc(sizeof(struct Expr));
	ret->type = Sub;
	ret->expr1 = ptr1;
	if (ptr2->type==Neg)
		ret->expr2 = MakeSubExpr(ptr2);
	else
		ret->expr2 = ptr2;
	return ret;
}

/* The "constructor" for a Multiplication Expression */
ExprPtr
MakeMult(ExprPtr ptr1, ExprPtr ptr2)
{
	ExprPtr	ret;

	if (ptr1==&One && ptr2==&One)
		return &One;
	if (ptr1==&One && ptr2!=&One)
		return ptr2;
	if (ptr1!=&One && ptr2==&One)
		return ptr1;

	/* Assumes no 'One' in the expression */
	if (ptr1->type == Const && ptr2->type == Const)
		return MakeConst(ptr1->val*ptr2->val);

	ret = (ExprPtr)malloc(sizeof(struct Expr));
	ret->type = Mult;

	if (ptr1->type == Add || ptr1->type == Sub)
		ret->expr1 = MakeSubExpr(ptr1);
	else
		ret->expr1 = ptr1;

	if (ptr2->type == Add || ptr2->type == Sub)
		ret->expr2 = MakeSubExpr(ptr2);
	else
		ret->expr2 = ptr2;
	return ret;
}

/* The "constructor" for a Division Expression */
ExprPtr
MakeDiv(ExprPtr ptr1, ExprPtr ptr2)
{
	ExprPtr	ret;

	if (ptr1==&One && ptr2==&One)
		return &One;
	if (ptr1!=&One && ptr2==&One)
		return ptr1;

	if (ptr1->type == Const && ptr2->type == Const)
		return MakeConst(ptr1->val/ptr2->val);

	ret = (ExprPtr)malloc(sizeof(struct Expr));
	ret->type = Div;

	if (ptr1->type == Add || ptr1->type == Sub)
		ret->expr1 = MakeSubExpr(ptr1);
	else
		ret->expr1 = ptr1;

	if (ptr2->type == Add || ptr2->type == Sub)
		ret->expr2 = MakeSubExpr(ptr2);
	else
		ret->expr2 = ptr2;
	return ret;
}

/* The "constructor" for an Exponent Expression */
ExprPtr
MakeExp(ExprPtr ptr1, ExprPtr ptr2)
{
	ExprPtr	ret;

	if (ptr1==&One)
		return &One;
	if (ptr2==&One)
		return ptr1;
	if (ptr2->type == Const && fabs(ptr2->val-1.0)< EXPEPS)
		return ptr1;

	if (ptr1->type == Const && ptr2->type == Const)
		return MakeConst(pow(ptr1->val,ptr2->val));

	ret = (ExprPtr)malloc(sizeof(struct Expr));
	ret->type = Exp;

	if (ptr1->type == Add || ptr1->type == Sub ||
	    ptr1->type == Mult || ptr1->type == Div || ptr1->type == Neg)
		ret->expr1 = MakeSubExpr(ptr1);
	else
		ret->expr1 = ptr1;

	if (ptr2->type == Add || ptr2->type == Sub ||
	    ptr2->type == Mult || ptr2->type == Div || ptr1->type == Neg)
		ret->expr2 = MakeSubExpr(ptr2);
	else
		ret->expr2 = ptr2;
	return ret;
}

/* The "constructor" for a Unary Minus Expression */
ExprPtr
MakeNeg(ExprPtr ptr)
{
	ExprPtr	ret;

	ret = (ExprPtr)malloc(sizeof(struct Expr));
	ret->type = Neg;
	ret->expr1 = ptr;
	ret->expr2 = NULL;
	return ret;
}

/* The "constructor" for a Sub-Expression */
ExprPtr
MakeSubExpr(ExprPtr ptr)
{
	ExprPtr	ret;
	ret = (ExprPtr)malloc(sizeof(struct Expr));
	ret->type = SubExpr;
	ret->expr1 = ptr;
	ret->expr2 = NULL;
	return ret;
}

/* Returns the derivative of an expression.  This is called many times
   recursively for each expression.  It returns NULL if the expressions
   derivative is zero */

ExprPtr
Derivative(ExprPtr expr, char *var, int total)
{
	ExprPtr	e1, e2;
	switch(expr->type)
	{
		case Add:
			e1 = Derivative(expr->expr1, var, total);
			e2 = Derivative(expr->expr2, var, total);

			if (e1==NULL && e2!=NULL)
				return e2;
			
			if (e2==NULL && e1!=NULL)
				return e1;

			if (e1==NULL && e2==NULL)
				return NULL;

			return MakeAdd(e1,e2);
		case Sub:
			e1 = Derivative(expr->expr1, var, total);
			e2 = Derivative(expr->expr2, var, total);

			if (e1==NULL && e2!=NULL)
				return MakeNeg(e2);
			
			if (e2==NULL && e1!=NULL)
				return e1;

			if (e1==NULL && e2==NULL)
				return NULL;

			return MakeSub(e1,e2);
		case Mult:
			e1 = Derivative(expr->expr1, var, total);
			e2 = Derivative(expr->expr2, var, total);
			if (e1==NULL && e2==NULL)
				return NULL;
			if (e1==NULL && e2!=NULL)				
				return MakeMult(CopyExpr(expr->expr1),e2);
			if (e2==NULL && e1!=NULL)
				return MakeMult(e1,CopyExpr(expr->expr2));

			return MakeAdd(MakeMult(e1,CopyExpr(expr->expr2)),
			               MakeMult(CopyExpr(expr->expr1),e2));
		case Div:
			e1 = Derivative(expr->expr1, var, total);
			e2 = Derivative(expr->expr2, var, total);
			if (e1==NULL && e2==NULL)
				return NULL;
			if (e1==NULL && e2!=NULL)
				return MakeNeg(MakeDiv(MakeMult(CopyExpr(expr->expr1),e2),
				       MakeMult(expr->expr2,expr->expr2)));

			if (e2==NULL && e1!=NULL)
				return MakeDiv(e1,CopyExpr(expr->expr2));

			return MakeSub(MakeDiv(e1,CopyExpr(expr->expr2)),
			       MakeDiv(MakeMult(CopyExpr(expr->expr1),e2),
			       MakeMult(expr->expr2,expr->expr2)));
		case Neg:
			e1 = Derivative(expr->expr1, var, total);
			if (e1==NULL)
				return NULL;
			return MakeNeg(e1);
		case Exp:
			e1 = Derivative(expr->expr1, var, total);
			e2 = Derivative(expr->expr2, var, total);

			if (e2!=NULL)
			{
				fprintf(stderr, "Warning, derivatives of exponents not supported!\n");
				DeleteExpr(e2);
			}

			if (e1==NULL)
				return NULL;
			e2 = &One;

			return MakeMult(MakeMult(MakeExp(CopyExpr(expr->expr1),
			       MakeSub(CopyExpr(expr->expr2),e2)),CopyExpr(expr->expr2)),
			       e1);
		case SubExpr:
			e1 = Derivative(expr->expr1, var, total);
			if (e1==NULL)
				return NULL;

			return MakeSubExpr(e1);
		case Id:
			if (strcmp(expr->id,var)==0)
				return &One;
			else
			{
				if (total)
				{
					char	line[100];

					sprintf(line, "%s,%s", expr->id, var);
					return MakeId(strdup(line));
				}
				return NULL;
			}
			fprintf(stderr, "This shouldn't happen!\n");
			return NULL;
		case Const:
			return NULL;
		default:
			fprintf(stderr, "Unknown expression type %d!\n", expr->type);
			break;			
	}
}

/* Print an expression to standard out */
void
PrintExpr(ExprPtr expr)
{
	switch(expr->type)
	{
		case	Add:
			PrintExpr(expr->expr1);
			printf("+");
			PrintExpr(expr->expr2);
			break;
		case	Sub:
			PrintExpr(expr->expr1);
			printf("-");
			PrintExpr(expr->expr2);
			break;
		case	Mult:
			PrintExpr(expr->expr1);
			printf("*");
			PrintExpr(expr->expr2);
			break;
		case	Div:
			PrintExpr(expr->expr1);
			printf("/");
			PrintExpr(expr->expr2);
			break;
		case	Exp:
			PrintExpr(expr->expr1);
			printf("^");
			PrintExpr(expr->expr2);
			break;
		case	Neg:
			printf("-");
			PrintExpr(expr->expr1);
			break;
		case	Id:
			printf("%s", expr->id);
			break;
		case	Const:
			printf("%g", expr->val);
			break;
		case	SubExpr:
			printf("(");
			PrintExpr(expr->expr1);
			printf(")");
			break;
		default:
			fprintf(stderr, "Unknown expression type %d!\n", expr->type);
			break;
	}
}

/* Print an expression to a String.  This is pretty instense with all
   the calls the malloc and free and it is very possible that it has
   some memory leaks.  This method is also called recursively. */

char *
StringExpr(ExprPtr expr)
{
	char	*ret, *e1, *e2;
	switch(expr->type)
	{
		case	Add:
			e1 = StringExpr(expr->expr1);
			e2 = StringExpr(expr->expr2);
			ret = malloc(strlen(e1)+strlen(e2)+2);
			sprintf(ret, "%s+%s", e1, e2);
			free(e1);
			free(e2);
			return ret;
		case	Sub:
			e1 = StringExpr(expr->expr1);
			e2 = StringExpr(expr->expr2);
			ret = malloc(strlen(e1)+strlen(e2)+2);
			sprintf(ret, "%s-%s", e1, e2);
			free(e1);
			free(e2);
			return ret;
		case	Mult:
			e1 = StringExpr(expr->expr1);
			e2 = StringExpr(expr->expr2);
			ret = malloc(strlen(e1)+strlen(e2)+2);
			sprintf(ret, "%s*%s", e1, e2);
			free(e1);
			free(e2);
			return ret;
		case	Div:
			e1 = StringExpr(expr->expr1);
			e2 = StringExpr(expr->expr2);
			ret = malloc(strlen(e1)+strlen(e2)+2);
			sprintf(ret, "%s/%s", e1, e2);
			free(e1);
			free(e2);
			return ret;
		case	Exp:
			e1 = StringExpr(expr->expr1);
			e2 = StringExpr(expr->expr2);
			ret = malloc(strlen(e1)+strlen(e2)+2);
			sprintf(ret, "%s^%s", e1, e2);
			free(e1);
			free(e2);
			return ret;
		case	Neg:
			e1 = StringExpr(expr->expr1);
			ret = malloc(strlen(e1)+2);
			sprintf(ret, "-%s", e1);
			free(e1);
			return ret;
		case	Id:
		{
			char	temp[100];
			sprintf(temp, "%s", expr->id);
			return strdup(temp);
		}
		case	Const:
		{
			char	temp[100];
			sprintf(temp, "%g", expr->val);
			return strdup(temp);
		}
		case	SubExpr:
			e1 = StringExpr(expr->expr1);
			ret = malloc(strlen(e1)+3);
			sprintf(ret, "(%s)", e1);
			free(e1);
			return ret;
		default:
			fprintf(stderr, "Unknown expression type %d!\n", expr->type);
			break;
	}
}

ExprPtr
EvalExpr(ExprPtr expr)
{
	ExprPtr	e1, e2, ret;
	switch(expr->type)
	{
		case	Add:
			e1 = EvalExpr(expr->expr1);
			e2 = EvalExpr(expr->expr2);
			if (e1==NULL || e2==NULL)
				return NULL;
			assert(e1->type==Const && e2->type==Const);
			ret = MakeConst(e1->val+e2->val);
			DeleteExpr(e1);
			DeleteExpr(e2);
			return ret;
		case	Sub:
			e1 = EvalExpr(expr->expr1);
			e2 = EvalExpr(expr->expr2);
			if (e1==NULL || e2==NULL)
				return NULL;
			assert(e1->type==Const && e2->type==Const);
			ret = MakeConst(e1->val-e2->val);
			DeleteExpr(e1);
			DeleteExpr(e2);
			return ret;
		case	Mult:
			e1 = EvalExpr(expr->expr1);
			e2 = EvalExpr(expr->expr2);
			if (e1==NULL || e2==NULL)
				return NULL;
			assert(e1->type==Const && e2->type==Const);
			ret = MakeConst(e1->val*e2->val);
			DeleteExpr(e1);
			DeleteExpr(e2);
			return ret;
		case	Div:
			e1 = EvalExpr(expr->expr1);
			e2 = EvalExpr(expr->expr2);
			if (e1==NULL || e2==NULL)
				return NULL;
			assert(e1->type==Const && e2->type==Const);
			ret = MakeConst(e1->val/e2->val);
			DeleteExpr(e1);
			DeleteExpr(e2);
			return ret;
		case	Exp:
			e1 = EvalExpr(expr->expr1);
			e2 = EvalExpr(expr->expr2);
			if (e1==NULL || e2==NULL)
				return NULL;
			assert(e1->type==Const && e2->type==Const);
			ret = MakeConst(pow(e1->val,e2->val));
			DeleteExpr(e1);
			DeleteExpr(e2);
			return ret;
		case	Neg:
			e1 = EvalExpr(expr->expr1);
			if (e1==NULL)
				return NULL;
			assert(e1->type==Const);
			ret = MakeConst(-e1->val);
			DeleteExpr(e1);
			return ret;
		case	Id:
			return NULL;		/* Error...can't evaluate */
		case	Const:
			return MakeConst(expr->val);	/* Makes a copy */
		case	SubExpr:
			e1 = EvalExpr(expr->expr1);
			if (e1==NULL)
				return NULL;
			assert(e1->type==Const);
			ret = MakeConst(e1->val);
			DeleteExpr(e1);
			return ret;
		default:
			fprintf(stderr, "Unknown expression type %d!\n", expr->type);
			break;
	}
}
