#include "displays/d_walker.h"
#include <math.h>
#include <stdlib.h>
#include <nets/sarsa.h>
#include <utils/transfer.h>

#define max_episodes  100
#define max_epochs    1000
#define states        26

#define hidden_layers 1, /* <-- # of | sizes --> */ states

#define init_weights_min -0.1
#define init_weights_max  0.1

#define epsilon_policy(X) (irand(0,1))

#define sqr(X) (X*X)

#define epsilon_test pure_random
#define pure_random (irand(0, epsilon) == epsilon)
#define confidence (rrand(0,1)  <  \
                     1.0 + theta - exp(-(sqr(walker->query_last_error()))/rho))

#define alpha    0.07
#define epsilon  states //only used for the pure_random epsilon_test
#define gamma    1.0
#define lambda   0.8
#define rho      0.5   //only used for the confidence epsilon_test
#define theta    0.25  //only used for the confidence epsilon_test

#define average_depth 10 //how many elements to keep in the average.

#define good_avg_return         0.85
#define reward_for_wandering   -0.01
#define reward_for_good_side    1.0
#define reward_for_bad_side    -1.0
#define reward_for_max_epochs reward_for_bad_side

#define f_state_c   'a'
#define f_state_c_h 'A'  // the header version
#define l_state_c   (f_state_c+(states-1))

#define irand(low,high) ((rand()%(high-low+1)+low))
#define rrand(low,high) (((real)rand()/(real)RAND_MAX)*\
                                             ((real)high-(real)low)+(real)low)

#define right 1
#define left  0

walker_display *disp;
real *state = new real[states];

void calc_state_vector_for(char name) {
    if(name>l_state_c || name<f_state_c) {
        disp->show_str("State vector calculation failed.");
        exit(1);
    }

    for(int i=0; i<states; i++)
        state[i] = ((f_state_c+i) == name) ? 1:0;
}

char pick_state_at_random() {
    return f_state_c + irand(1, states-3);
}

void insert_avg_element(real e, real *avg_elements) {
    for(int i=0; i<9; i++)
        avg_elements[i] = avg_elements[i+1];
    avg_elements[9] = e;
}

real calc_avg_from_elements(real *avg_elements) {
    real a = 0;
    for(int i=0; i<average_depth; i++)
        a += avg_elements[i];
    return (real)a/(real)average_depth;
}

void clear_avg_elements(real *avg_elements) {
    for(int i=0; i<average_depth; i++) avg_elements[i] = 0.0;
}

void cleanup() {
    delete disp;
    printf("\n");
}

void main() {
    sarsa *walker = new sarsa(alpha, lambda, gamma, states, 2, hidden_layers);
    disp          = new walker_display(f_state_c_h, states);
    char *msg     = new char[80];
    char curr;
    int  action;
    int  finished;
    real reward;
    real t_reward;
    real *avg_elements = new real[average_depth];

    atexit(*cleanup);   // needed in case of error exits,
                        // otherwise your display get's hosed.

    walker->reinitialize_weights_with(init_weights_max, init_weights_min);
    walker->set_transfer_function_for_output(JBIPOLR); // a jet invention
    walker->set_transfer_function_for_hidden(BIPOLAR);

    clear_avg_elements(avg_elements);

    disp->show_str("Note:  Avg calc is for the last average_depth returns.");
    disp->show_str("Numbers smaller if the curr episode is < average_depth.");

    for(int episode=0; episode<max_episodes; episode++) {
        if(calc_avg_from_elements(avg_elements) >= good_avg_return) {
            break;
        }
          disp->start_new_episode();
        walker->start_new_episode();
        t_reward = 0;
        finished = 0;
        curr = pick_state_at_random(); 
        calc_state_vector_for(curr);
        walker->set_state(state);
        disp->show_state(f_state_c, curr);
        action = walker->query_action();
        walker->set_action(action);
        disp->show_epsilon(0);
        disp->show_action(action);
        curr += (action) ? 1 : -1; 

        int epoch=0;
        while(1) {
            epoch++;
            calc_state_vector_for(curr);
            walker->set_state(state);
            disp->show_state(f_state_c, curr);
            action = walker->query_action();
            if(epsilon_test) {
                action = epsilon_policy(episode);
                disp->show_epsilon(1);
            } else {
                disp->show_epsilon(0);
            }
            walker->set_action(action);
            disp->show_action(action);
            curr += (action) ? 1 : -1; 

                 if(curr < f_state_c) reward = reward_for_bad_side;
            else if(curr > l_state_c) reward = reward_for_good_side;
            else if(epoch>max_epochs) reward = reward_for_max_epochs;
            else                      reward = reward_for_wandering;

            finished = (curr < f_state_c || curr > l_state_c);

            if(finished) walker->learn_from_final(reward);
            else         walker->learn_from(reward);

            t_reward += reward;

            if(epoch>max_epochs) {
                insert_avg_element(t_reward, avg_elements);
                sprintf(msg,
                    "Max Epochs reached.  Return was %5.1f (avg %1.1f).", 
                    t_reward,
                    calc_avg_from_elements(avg_elements)
                );
                disp->show_str(msg);
                break;
            }

            if(curr < f_state_c || curr > l_state_c) {
                insert_avg_element(t_reward, avg_elements);
                sprintf(msg,
                    "Return for episode %2i was %5.1f (avg %5.1f).", 
                    episode,
                    t_reward,
                    calc_avg_from_elements(avg_elements)
                );
                disp->show_str(msg);
                break;
            }
        }
    }
}
