#include <arch/neuron.h>
#include <arch/hidden_layers.h>
#include <arch/layer.h>

class sarsa {
    private:
        int   num_outputs;
        int   num_inputs;
        real  alpha;
        real  lambda;
        real  gamma;

        real  last_error;

        layer         *input;
        layer         *output; 
        hidden_layers *hidden;

        neuron *chosen_neuron;

        real *state;
        real *action;
        int   actionarg; // actionarg is the offset of action
        real *nextstate;
        real *nextaction;
        int   nextactionarg;

        void decay_eligibilities(layer *l);
        void decay_eligibilities();

        void start_new_episode(layer *l);

        void calc_eligibilities_for_dendrites_of(neuron *this_neuron);
        void calc_eligibilities();

        void calc_deltas();
        void calc_deltas_and_eligibilities();

        void update_weights_for_dendrites_of(neuron *n);
        void update_weights(real reward);
        void update_weights_final(real reward);

    public:
        sarsa(real _alpha, real _lambda, real _gamma, int inputs, int outputs, 
                                                 int num_hidden_layers, ...);
        int   query_action();
        real *query_action_values();
        real  query_last_error();

        void reinitialize_weights_with(real max, real min);

        void set_state(real *state_vector);
        void set_action(int action_init);
        void learn_from(real reward);
        void learn_from_final(real reward);
        void start_new_episode();

        void    save_net(char *filename);
        void restore_net(char *filename);

        void  set_transfer_function_for_output(int type);
        void  set_transfer_function_for_hidden(int type);
};
