/************************************************************************************************************
/*							  
/*							    Last Update: 1/31/2003	
/*								  
/*							  Author :  Jiang Li  and Dr. Manry
/*						
/*								EE Department, 
/*						Image Processing and Neural Network Lab
/*							University of Texas at Arlington
/*								
/*
/*			This program is the test program for the MLP Network, which is used for classification. 
/*
/*
/***************************************************************************************************************/

#include "math.h"
#include "stdio.h"
#include "stdlib.h"
#include "string.h"

#define RESULT "BP_OR_Testing_Result.txt"
#define FILENAMELENGTH 100

main()
{
	FILE *paper;
	FILE *fpResult, *fpData, *fpWeights;                       /* File pointer for testing file, result file and weights file */
	double *x, *new_xa, **Wo, **W,*y, *net, percentage, Max_y, /* x   --> Input vector */ 
		   *mean_Input;										   /*new_xa --> Input argumented by hidden units
															   Wo     --> All the weights connecting to outputs
															   W      --> Weights connecting to hidden units
															   y      --> Actual output
															   net    --> Hidden units net function
															   percentage --> Classification error percentage
															   Max_y      --> Max of the actual output
	                                                           mean_Input --> Mean of the Input */

	int N, Nout, Nh,Nv, ClassId, Class_obtained=1, Out_Flag,   
	    *ErrorForClass;                                        /* N     --> No. of input
														       Nout     --> No. of output
															   Nh       --> No. of hidden unit
															   Nv       --> No. of patterns
															   ClassId  --> Desired Class membership Id for each pattern
															   Class_obtained   --> Actual Class membership Id 
															   Out_Flag --> Flag Id for the desired output 
		                                                       ErrorForClass  --> Error for each class */
															   
	int i, j, k, err_pattern, Check;                           /*Check  --> To check the Weights file */
	char str[FILENAMELENGTH], *FileName, *WeightsFileName,
		TrainingAlgorithm[FILENAMELENGTH], temp[FILENAMELENGTH];
															   /* str   --> A string used for reading data file
															   temp     --> Store the file name of training data
															   FileName --> A string to store testing file name
															   WeightsFileName  --> A string to store Weights file Name
															   TrainingAlgorithm --> A string to store What algorithm used
															   for the trained weights */
	
	char * get_string(char *);                                 /* Input string */
	int get_int(char*, int,int);                               /* Input int varible between some range specified by two integer */
	
	/* Getting relevant information from the user */
	FileName = get_string("Please input Testing File Name: ");

	paper = fopen("testresult.txt","a+");
//	FileName = argv[1];
	fpData = fopen(FileName,"r");
	if(fpData == NULL){ perror(FileName);exit(1);}

	/* Does this testing file has desired output? */
	printf("\n Does the testing file has desired output? \n");
	printf("\n Choose (0) for NO \n");
	printf("\n        (1) for YES : ");
	Out_Flag = get_int("",0,1);

//	Out_Flag = atoi(argv[2]);

	N = get_int("Please input Number of inputs: ",1,100);
	Nout = get_int("Please input Number of classes: ",1,100);

//	N = atoi(argv[3]);
//	Nout = atoi(argv[4]);

	/* Getting weights file name */
	do
	{
		WeightsFileName = get_string("Please input weights file name: ");
		fpWeights = fopen(WeightsFileName,"r");
	}
	while(fpWeights == NULL);
	
//	WeightsFileName = argv[5];
	fpWeights = fopen(WeightsFileName,"r");

	fscanf(fpWeights,"%s",temp);
	fscanf(fpWeights,"%d",&Check);
	if(Check != N)
	{
		printf(" The input file's input number does not match weights file!\n");
		exit(1);
	}
	fscanf(fpWeights,"%d",&Check);
	if(Check != Nout)
	{
		printf(" The input file's classed number does not match weights file!\n");
		exit(1);
	}
	fscanf(fpWeights,"%d",&Nh);

	/*initilinize the weights matrixs*/
	x = (double *)malloc(sizeof(double)*(N+1));
	mean_Input = (double *)malloc(sizeof(double)*N);
	new_xa = (double *)malloc(sizeof(double)*(N+Nh+1));
	net = (double *)malloc(sizeof(double)*Nh);
	y = (double *)malloc(sizeof(double)*Nout);
	ErrorForClass = (int*)malloc(sizeof(int)*Nout);
	Wo = (double **)malloc(sizeof(double *)*Nout);
	W = (double**)malloc(sizeof(double *)*Nh);
	for(i = 0 ; i < Nout; i ++)
	{
		Wo[i] = (double *)malloc(sizeof(double)*(N+Nh+1));
	}
	for(i = 0; i < Nh; i ++)
	{
		W[i] = (double *)malloc(sizeof(double)*(N+1));
	}
	for( i = 0; i < Nout ; i ++)
	{
		ErrorForClass[i] = 0;
	}
	/* Read from the weights file */
	fscanf(fpWeights,"%s",TrainingAlgorithm);
	
	/* read the mean of the input */
/*	for(i = 0; i < N ; i ++){
		fscanf(fpWeights, "%s",str);
		mean_Input[i] = atof(str);
	}
*/	
	/* read the outputs weights */
	for(i = 0 ; i < N + Nh + 1; i ++)
	{
		for( j = 0 ; j < Nout; j ++)
		{
			fscanf(fpWeights,"%s",str);
			Wo[j][i] = atof(str);
		}
	}
	
	/* read the hidden units weights */
	for(i = 0; i < Nh; i ++)
	{
		for( j = 0; j < N + 1; j ++)
		{
			fscanf(fpWeights,"%s",str);
			W[i][j] = atof(str);
		}
	}

	/*Count the no of patterns */
	Nv = 0;
	while(!feof(fpData))
	{
		Nv ++;
		for(i = 0; i < N ; i++)
		{
			fscanf(fpData,"%s",str);
			x[i] = atof(str);
			new_xa[i] = x[i];
		}
		
		if(Out_Flag){
			fscanf(fpData,"%d",&ClassId);
		}
				
		new_xa[N] = 1.0;
		
	}
	Nv--;

	/* output to the file */
	fpResult = fopen(RESULT,"w");
	if(fpResult == NULL)
	{
		perror(RESULT);
		exit(1);
	}
	fprintf(fpResult,"\n\n\tThe testing data file name is: \t%s",FileName);
	fprintf(fpResult,"\n\n\tThe testing weights file name is: \t%s",WeightsFileName);
	fprintf(fpResult,"\n\n\tNo. of inputs: \t%d",N);
	fprintf(fpResult,"\n\n\tNo. of Hidden units: \t%d",Nh);
	fprintf(fpResult,"\n\n\tNo. of Classed: \t%d",Nout);
	fprintf(fpResult,"\n\n\t%s\n\n",TrainingAlgorithm);
	fprintf(fpResult,"Index\t");
	for(i = 0; i < N; i ++){
		fprintf(fpResult,"Input[%d] ",i+1);
	}
	fprintf(fpResult,"Actual_Id  Desired_ClassId ");

	rewind(fpData);
	err_pattern = 0;

	/* Start testing */
	for(k=0;k<Nv;k++)	
	{
	
		for(i = 0; i < N ; i++)
		{
			fscanf(fpData,"%s",str);
			x[i] = atof(str);
//			x[i] = x[i] - mean_Input[i];
			new_xa[i] = x[i];
		}
		
		if(Out_Flag){
			fscanf(fpData,"%d",&ClassId);
		}
		
		new_xa[N] = 1.0;
		
		for(i = 0 ; i < Nh; i ++)
		{
			net[i] = 0.0;
			for(j = 0; j < N+1; j ++)
			{
				net[i] += W[i][j]*new_xa[j];
			}
			new_xa[i+N+1] = 1.0/(1.0+exp(-net[i]));
		}

		for(i = 0; i < Nout; i ++)
		{
			y[i] = 0.0;
			for(j = 0;j < N+Nh+1; j ++)
			{
				y[i] += Wo[i][j]*new_xa[j];
			}
		}

		fprintf(fpResult,"\n"); fprintf(fpResult,"%d\t",k+1);
		for(i = 0; i < N; i ++){
			fprintf(fpResult,"%f ",x[i]);//+mean_Input[i]);
		}
		Max_y= y[0];
	    Class_obtained=1;
		for (i=1;i<Nout;i++)
		{
			if(y[i]>Max_y) 
			{
				Max_y= y[i];
				Class_obtained= i+1;
			}
		}

		fprintf(fpResult,"%d   ",Class_obtained);
		
		if(Out_Flag){
			if ( ClassId != Class_obtained ) {
				ErrorForClass[ClassId-1] ++;
				err_pattern += 1;
			}
			fprintf(fpResult,"%d   ",ClassId);
		}
		else
		{
			fprintf(fpResult,"N/A   ");
		}
		
	}


	/* Testing finished, calculate the error percentage and write to file*/
	printf("\n\n There are %d classes.\n",Nout);
	fprintf(fpResult,"\n\n There are %d classes.\n",Nout);
		
	if(Out_Flag){
	
		for(i = 0; i < Nout; i ++)
		{
			printf(" Class %d has %d error patterns.\n", i,ErrorForClass[i]);
			fprintf(fpResult, " Class %d has %d error patterns.\n", i,ErrorForClass[i]);
		}
	
		percentage = (double)err_pattern/(double)Nv*100.0;
		
		fprintf(fpResult,"\n\nThe number of patterns are in error:\t%d",err_pattern);
		fprintf(fpResult,"\nThe error percentage:\t%f",percentage);

		printf("\n The number of patterns are in error:\t%d",err_pattern);
		printf("\n\n The error percentage:\t%f\n",percentage);
	}

	fprintf(paper,"%f\n",percentage);
	
	fcloseall();

}

/*************************************************************************/

int get_int(char *title_string,int low_limit, int up_limit)
{
	 int i,error_flag;
	 char *get_string();             /* get string routine */
	 char *cp,*endcp;                /* char pointer */
	 char *stemp;                    /* temp string */

/* check for limit error, low may equal high but not greater */
	 if(low_limit > up_limit) {
		  printf("\nLimit error, lower > upper\n");
		  exit(1);
	 }

/* make prompt string */
	 stemp = (char *) malloc(strlen(title_string) + 60);
	 if(!stemp) {
		  printf("\nString allocation error in get_int\n");
		  exit(1);
	 }
	 sprintf(stemp,"%s [%d...%d]",title_string,low_limit,up_limit);

/* get the string and make sure i is in range and valid */
	 do {
		  cp = get_string(stemp);
		  i = (int) strtol(cp,&endcp,10);
		  error_flag = (cp == endcp) || (*endcp != '\0'); /* detect errors */
		  free(cp);                                   /* free string space */
	 } while(i < low_limit || i > up_limit || error_flag);

/* free temp string and return result */
	 free(stemp);
	 return(i);
}

/*****************************************************************************/
char *get_string(char *title_string)
{
	 char *alpha;                            /* result string pointer */

	 alpha = (char *) malloc(80);
	 if(!alpha) {
		  printf("\nString allocation error in get_string\n");
		  exit(1);
	 }
	 printf(" %s ",title_string);
	 gets(alpha);

	 return(alpha);
}

/****************************************************************************/