#include <stdlib.h>
#include <time.h>

#include <arch/neuron.h>
#include <nets/sarsa.h>
#include <arch/layer.h>
#include <arch/hidden_layers.h>
#include <utils/transfer.h>

sarsa::sarsa(real _alpha, real _lambda, real _gamma, int inputs, int outputs, 
                                                 int num_hidden_layers, ...) {
    time_t t;
    srand((rand())*(unsigned)time(&t));

     input = new layer("input",  inputs, 1.0);
    output = new layer("output", outputs);
    hidden = new hidden_layers(num_hidden_layers, 
                           (int *)((&num_hidden_layers)+1), SIGMOID, 1.0);

     input->set_transfer_function(SUM);  // input units broadcast ...
    hidden->set_transfer_function(SIGMOID);
    output->set_transfer_function(SIGMOID);

    output->dendrites_touch(hidden->query_head_layer());
    hidden->query_tail_layer()->dendrites_touch(input);

    output->reinitialize_weights_with(0.1, -0.1);
    hidden->reinitialize_weights_with(0.1, -0.1);

    alpha       = _alpha;
    lambda      = _lambda;
    gamma       = _gamma;
    num_outputs = outputs;
    num_inputs  = inputs;

    state         =  0;
    action        =  0;
    actionarg     = -1;
    nextstate     =  0;
    nextaction    =  0;
    nextactionarg = -1;
}

void sarsa::set_transfer_function_for_output(int type) {
    output->set_transfer_function(type);
}

void sarsa::set_transfer_function_for_hidden(int type) {
    hidden->set_transfer_function(type);
}

void sarsa::reinitialize_weights_with(real max, real min) {
    output->reinitialize_weights_with(max, min);
    hidden->reinitialize_weights_with(max, min);
}

void sarsa::set_state(real *state_vector) {
    if(state)  delete state;
    if(action) delete action;

    /* move forward in time: */
    state     = nextstate;
    action    = nextaction;
    actionarg = nextactionarg;

    /* make a copy of our state-vector for next time */
    nextaction = 0;
    nextactionarg = -1;
    nextstate = new real[num_inputs];
    for(int c = 0; c < num_inputs; c++) 
        nextstate[c] = state_vector[c];
    input->set_input(nextstate);
}

void sarsa::set_action(int action_init) {
    nextactionarg = action_init;
}

real *sarsa::query_action_values() {
    return output->query_output();
}

int sarsa::query_action() {
    real *outputs = output->query_output();
    real max      = outputs[0];
    int  mcounter = 0;

    if(nextaction) 
        delete nextaction;

    nextaction = new real[num_outputs];

    for(int i=0; i<num_outputs; i++) {
        nextaction[i] = outputs[i];
        if(nextaction[i] > max) {
            max      = nextaction[i];
            mcounter = i;
        }
    }

    return mcounter;
}

void sarsa::decay_eligibilities(layer *l) {
    neuron *curr;

    l->start_foreach();
    while(l->has_a_neuron()) {
        curr = l->query_current_neuron();
        curr->dendrites->start_foreach();
        while(curr->dendrites->has_an_eligibility()) {
            curr->dendrites->set_current_eligibility_to(
                lambda * gamma * curr->dendrites->query_current_eligibility()
            );
            curr->dendrites->next_eligibility();
        }
        l->next_neuron();
    }
}

void sarsa::decay_eligibilities() {
    decay_eligibilities(output);

    hidden->start_foreach();
    while(hidden->has_a_layer()) {
        decay_eligibilities(hidden->query_current_layer());
        hidden->next_layer();
    }
}

void sarsa::start_new_episode(layer *l) {
    neuron *curr;

    l->start_foreach();
    while(l->has_a_neuron()) {
        curr = l->query_current_neuron();
        curr->dendrites->start_foreach();
        while(curr->dendrites->has_an_eligibility()) {
            curr->dendrites->set_current_eligibility_to(0.0);
            curr->dendrites->next_eligibility();
        }
        l->next_neuron();
    }
}

void sarsa::start_new_episode() {
    last_error = 0;
    start_new_episode(output);

    hidden->start_foreach();
    while(hidden->has_a_layer()) {
        start_new_episode(hidden->query_current_layer());
        hidden->next_layer();
    }
}

void sarsa::calc_deltas() {
    int counter = 0;
    layer  *l;
    neuron *n;

    output->start_foreach();
    while(output->has_a_neuron()) {
        if(counter != actionarg) {
            output->query_current_neuron()->set_delta(0);
        } else {
            output->query_current_neuron()->set_delta(
                output->query_current_neuron()->query_output_dot()
            );
            chosen_neuron = output->query_current_neuron();
        }
        output->next_neuron(); 
        counter++;
    }

    hidden->start_foreach();
    while(hidden->has_a_layer()) {
        l = hidden->query_current_layer();
        l->start_foreach();
        while(l->has_a_neuron()) {
            n = l->query_current_neuron();
            n->set_delta(
                  (n->axon_fires_at->weighted_sum_of_delta_from_above())
                * (n->query_output_dot())
            );
            l->next_neuron();
        }
        hidden->next_layer();
    }
}

void sarsa::calc_eligibilities_for_dendrites_of(neuron *n) {
    n->dendrites->start_foreach();
    while(n->dendrites->has_an_eligibility()) {
        n->dendrites->set_current_eligibility_to(
              (n->query_delta()) * 
              (n->dendrites->query_current_fireing_output())
        );
        n->dendrites->next_eligibility();
    }
}

void sarsa::calc_eligibilities() {
    layer *l;

    calc_eligibilities_for_dendrites_of(chosen_neuron);

    hidden->start_foreach();
    while(hidden->has_a_layer()) {
        l = hidden->query_current_layer();
        l->start_foreach();
        while(l->has_a_neuron()) {
            calc_eligibilities_for_dendrites_of( l->query_current_neuron() );
            l->next_neuron();
        }
        hidden->next_layer();
    }
}

void sarsa::calc_deltas_and_eligibilities() {
    input->set_input(state);

    decay_eligibilities();
    calc_deltas();  // sets the chosen_neuron from the action_curr
    calc_eligibilities();
}

void sarsa::update_weights_for_dendrites_of(neuron *n) {
    n->dendrites->start_foreach();
    while(n->dendrites->has_a_weight()) {
        n->dendrites->change_current_weight_by(
            alpha * last_error * n->dendrites->query_current_eligibility()
                  // the eligibilities and the weights are both stepped
                  // during a foreach of either...
        );
        n->dendrites->next_weight();
    }
}

real sarsa::query_last_error() {
    return last_error;  // this is only set in update_weights(real).
}

void sarsa::update_weights(real reward) {
    layer  *l;
    if(actionarg < 0) {
        fprintf(stderr, "Action arg was never set!\n");
        exit(1);
    }

    if(actionarg > num_outputs-1) {
        fprintf(stderr, "Action arg is too big!\n");
        exit(1);
    }

    if(nextactionarg < 0) {
        fprintf(stderr, "NextAction arg was never set!\n");
        exit(1);
    }

    if(nextactionarg > num_outputs-1) {
        fprintf(stderr, "NextAction arg is too big!\n");
        exit(1);
    }

    last_error = 
        reward + gamma * (nextaction[nextactionarg]) - action[actionarg];

    update_weights_for_dendrites_of(chosen_neuron);

    hidden->start_foreach();
    while(hidden->has_a_layer()) {
        l = hidden->query_current_layer();
        l->start_foreach();
        while(l->has_a_neuron()) {
            update_weights_for_dendrites_of(l->query_current_neuron());
            l->next_neuron();
        }
        hidden->next_layer();
    }
}

void sarsa::update_weights_final(real reward) {
    layer  *l;
    if(actionarg < 0 || actionarg > num_outputs-1) {
        fprintf(stderr, "Action arg is bad!\n");
        exit(1);
    }

    last_error = 
        reward - action[actionarg];

    update_weights_for_dendrites_of(chosen_neuron);

    hidden->start_foreach();
    while(hidden->has_a_layer()) {
        l = hidden->query_current_layer();
        l->start_foreach();
        while(l->has_a_neuron()) {
            update_weights_for_dendrites_of(l->query_current_neuron());
            l->next_neuron();
        }
        hidden->next_layer();
    }
}

void sarsa::learn_from(real reward) {
    calc_deltas_and_eligibilities();
    update_weights(reward);
}

void sarsa::learn_from_final(real reward) {
    calc_deltas_and_eligibilities();
    update_weights_final(reward);
}

void sarsa::save_net(char *filename) {
    FILE *F = fopen(filename, "w");

    hidden->save_weights(F);
    output->save_weights(F);

    fclose(F);
}

void sarsa::restore_net(char *filename) {
    FILE *F = fopen(filename, "r");

    hidden->restore_weights(F);
    output->restore_weights(F);

    fclose(F);
}
