/* routines for state objects
 * Copyright (C) 1997 Angelos D. Keromytis.
 * Copyright (C) 1998, 1999  D. Hugh Redelmeier.
 * 
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2 of the License, or (at your
 * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
 * 
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * for more details.
 *
 * RCSID $Id: state.c,v 1.43 1999/04/11 00:44:23 dhr Exp $
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <fcntl.h>

#include <freeswan.h>

#include "constants.h"
#include "defs.h"
#include "connections.h"
#include "state.h"
#include "kernel.h"
#include "log.h"
#include "rnd.h"
#include "timer.h"
#include "whack.h"

/*
 * Global variables: had to go somewhere, might as well be this file.
 */

u_int16_t pluto_port = PORT;	/* Pluto's port */
struct sockaddr_in mask32;	/* 255.255.255.255 */
#ifdef ROAD_WARRIOR_FUDGE
struct sockaddr_in mask0;	/* 0.0.0.0 */
#endif /* ROAD_WARRIOR_FUDGE */

/*
 * This file has the functions that handle the
 * state hash table and the Message ID list.
 */

#define SA_EQUAL(x, y)   ((x).sa_family == (y).sa_family && \
    memcmp(&((x).sa_data), &((y).sa_data), FULL_INET_ADDRESS_SIZE) == 0)

/* Message-IDs
 *
 * Each Phase 2 / Quick mode exchange must have a Message ID.
 * The message ID is used to tell which exchange a message belongs to.
 * Each message ID must be unique per initiator/responder pair.
 * The message ID is picked by the side that initiates the Phase 2.
 * The message ID ought to be random to reduce the probability
 * that both sides initiate simultaneously with the same message ID.
 *
 * A MessageID is a 32 bit number.  We represent the value internally in
 * network order -- they are just blobs to us.  This makes in more
 * convenient when hashing.
 *
 * The following mechanism is used to allocate message IDs.  This
 * requires that we keep track of which numbers have already been used
 * so that we don't allocate one in use.
 *
 * ??? Eventually, we should free each number (and more importantly,
 * the space that its record requires) after the exchange has been
 * dead for a long enough interval.  Tricky point: we'll need
 * usage counts because we allow multiple use.  Actually, before
 * space becomes a problem, the simple-minded linear search will
 * become unreasonable.
 */

struct msgid_desc
{
    struct in_addr       md_peer;
    msgid_t               md_msgid; /* network order */
    struct msgid_desc     *md_next;
};

static struct msgid_desc *msgidlist;

bool
reserve_msgid(struct in_addr peer, msgid_t msgid)
{
    struct msgid_desc *mid;

    for (mid = msgidlist; mid != NULL; mid = mid->md_next)
	if (mid->md_msgid == msgid && mid->md_peer.s_addr == peer.s_addr)
	    return FALSE;	/* cannot be reserved */
    
    mid = alloc_thing(struct msgid_desc, "msgid_desc in reserve_msgid()");
    DBG(DBG_CONTROL, DBG_log("reserving msgid 0x%08lx for %s",
	(unsigned long) msgid, inet_ntoa(peer)));
    mid->md_peer = peer;
    mid->md_msgid = msgid;
    mid->md_next = msgidlist;
    msgidlist = mid;
    return TRUE;
}

msgid_t
generate_msgid(struct in_addr peer)
{
    int timeout = 32;	/* only try so hard for unique msgid */
    msgid_t msgid;

    for (;;)
    {
	get_rnd_bytes((void *) &msgid, sizeof(msgid));
	if (msgid != 0 && reserve_msgid(peer, msgid))
	    break;

	if (--timeout == 0)
	{
	    log("gave up looking for unique msgid; using 0x%08lx for %s",
		(unsigned long) msgid, inet_ntoa(peer));
	    break;
	}
    }
    return msgid;
}


/* state table functions */

#define STATE_TABLE_SIZE 32

static struct state *statetable[STATE_TABLE_SIZE];

static struct state **
state_hash(const u_char *icookie, const u_char *rcookie, struct in_addr peer)
{
    u_int i = 0, j;

    DBG(DBG_RAW | DBG_CONTROL,
	DBG_dump("ICOOKIE:", icookie, COOKIE_SIZE);
	DBG_dump("RCOOKIE:", rcookie, COOKIE_SIZE);
	DBG_dump("peer:", &peer, sizeof(peer)));

    for (j = 0; j < COOKIE_SIZE; j++)
	i += icookie[j] + rcookie[j];

    i = (i + peer.s_addr) % STATE_TABLE_SIZE;

    DBG(DBG_CONTROL, DBG_log("state hash entry %d", i));

    return &statetable[i];
}

/* Get a state object.
 * Caller must schedule an event for this object so that it doesn't leak.
 * Caller must insert_state().
 */
struct state *
new_state(void)
{
    static so_serial_t serialno = SOS_NOBODY;
    struct state *st;

    st = alloc_thing(struct state, "struct state in new_state()");

    st->st_serialno = ++serialno;
    passert(serialno != SOS_NOBODY);	/* overflow can't happen! */
    DBG(DBG_CONTROL, DBG_log("creating state object #%lu at %p",
	st->st_serialno, (void *) st));

    st->st_whack_sock = NULL_FD;

    /* Note: pointers are not guaranteed to be set to NULL
     * by zeroing their bytes
     */
    st->st_oakley.hasher = NULL;
    st->st_tpacket = NULL;
    st->st_rpacket = NULL;
    st->st_proposal = NULL;
    st->st_p1isa = NULL;
    st->st_ni = NULL;
    st->st_nr = NULL;
    st->st_skeyid = NULL;
    st->st_skeyid_d = NULL;
    st->st_skeyid_a = NULL;
    st->st_skeyid_e = NULL;
    st->st_enc_key = NULL;
    st->st_myidentity = NULL;
    st->st_peeridentity = NULL;
    st->st_ah.our_keymat = NULL;
    st->st_ah.peer_keymat = NULL;
    st->st_esp.our_keymat = NULL;
    st->st_esp.peer_keymat = NULL;
    st->st_event = NULL;
    st->st_hashchain_next = NULL;
    st->st_hashchain_prev = NULL;

    return st;
}

/*
 * Initialize the state table (and mask32).
 */
void
init_states(void)
{
    int i;

    /* initialize mask32 */
    mksin(mask32, inet_addr("255.255.255.255"), 0);
#ifdef ROAD_WARRIOR_FUDGE
    /* initialize mask0 */
    mksin(mask0, inet_addr("0.0.0.0"), 0);
#endif /* ROAD_WARRIOR_FUDGE */

    for (i = 0; i < STATE_TABLE_SIZE; i++)
	statetable[i] = (struct state *) NULL;

    msgidlist = (struct msgid_desc *) NULL;
}

/* Insert a state object in the hash table. The object is inserted
 * at the begining of list.
 * Needs cookies, connection, and msgid.
 */
void
insert_state(struct state *st)
{
    struct state **p = state_hash(st->st_icookie, st->st_rcookie
	, st->st_connection->that.host);

    passert(st->st_hashchain_prev == NULL && st->st_hashchain_next == NULL);

    if (*p != NULL)
    {
	passert((*p)->st_hashchain_prev == NULL);
	(*p)->st_hashchain_prev = st;
    }
    st->st_hashchain_next = *p;
    *p = st;
}

/* unlink a state object from the hash table, but don't free it
 */
void
unhash_state(struct state *st)
{
    /* unlink from forward chain */
    struct state **p = st->st_hashchain_prev == NULL
	? state_hash(st->st_icookie, st->st_rcookie, st->st_connection->that.host)
	: &st->st_hashchain_prev->st_hashchain_next;

    /* unlink from forward chain */
    passert(*p == st);
    *p = st->st_hashchain_next;

    /* unlink from backward chain */
    if (st->st_hashchain_next != NULL)
    {
	passert(st->st_hashchain_next->st_hashchain_prev == st);
	st->st_hashchain_next->st_hashchain_prev = st->st_hashchain_prev;
    }

    st->st_hashchain_next = st->st_hashchain_prev = NULL;
}

/* Free the Whack socket file descriptor.
 * This has the side effect of telling Whack that we're done.
 * It also resets the global whack_log_fd if it is the same.
 */
void
release_whack(struct state *st)
{
    if (st->st_whack_sock != NULL_FD)
    {
	if (whack_log_fd == st->st_whack_sock)
	    whack_log_fd = NULL_FD;
	close(st->st_whack_sock);
	st->st_whack_sock = NULL_FD;
    }
}

/*
 * delete a state object
 */
void
delete_state(struct state *st)
{
    /* Check that no timer event is left dangling.
     * We could actually delete it here, but in most
     * cases this is a "can't happen".
     */
    struct state *old_cur_state = cur_state;

    cur_state = st;
    passert(st->st_event == NULL);

    /* effectively, this deletes any ISAKMP SA that this state represents */
    unhash_state(st);

    /* tell kernel to delete any IPSEC SA
     * ??? we ought to tell peer to delete IPSEC SAs
     */
    switch (st->st_state)
    {
    case STATE_QUICK_I2:
	delete_ipsec_sa(st, TRUE);
	break;
    case STATE_QUICK_R2:
	delete_ipsec_sa(st, FALSE);
	break;
    }

    if (st->st_connection->newest_ipsec_sa == st->st_serialno)
	st->st_connection->newest_ipsec_sa = SOS_NOBODY;

    if (st->st_connection->newest_isakmp_sa == st->st_serialno)
	st->st_connection->newest_isakmp_sa = SOS_NOBODY;

#ifdef ROAD_WARRIOR_FUDGE
    /* If we're the only state using a rw_instance connection,
     * we must free it.  Must be careful to avoid circularity:
     * if we are being deleted because the connection is being
     * deleted, hands off!
     */
    if (st->st_connection->rw_instance)
    {
	/* are there any states still using it? */
	struct connection *c = st->st_connection;
	struct state *ost = NULL;
	int i;

	for (i = 0; ost == NULL && i < STATE_TABLE_SIZE; i++)
	    for (ost = statetable[i]
	    ; ost != NULL && ost->st_connection != c
	    ; ost = ost->st_hashchain_next)
		;
	if (ost == NULL)
	{
	    delete_connection(c);
	    st->st_connection = NULL;	/* redundant, but careful */
	}
    }
#endif /* ROAD_WARRIOR_FUDGE */

    cur_state = old_cur_state;

    release_whack(st);

    /* from here on we are just freeing RAM */

    if (st->st_gi_in_use)
	mpz_clear(&(st->st_gi));

    if (st->st_gr_in_use)
	mpz_clear(&(st->st_gr));

    if (st->st_sec_in_use)
	mpz_clear(&(st->st_sec));

    if (st->st_shared_in_use)
	mpz_clear(&(st->st_shared));

    pfreeany(st->st_tpacket);
    pfreeany(st->st_rpacket);
    pfreeany(st->st_proposal);
    pfreeany(st->st_p1isa);
    pfreeany(st->st_ni);
    pfreeany(st->st_nr);
    pfreeany(st->st_skeyid);
    pfreeany(st->st_skeyid_d);
    pfreeany(st->st_skeyid_a);
    pfreeany(st->st_skeyid_e);
    pfreeany(st->st_enc_key);
    pfreeany(st->st_myidentity);
    pfreeany(st->st_peeridentity);
    pfreeany(st->st_ah.our_keymat);
    pfreeany(st->st_ah.peer_keymat);
    pfreeany(st->st_esp.our_keymat);
    pfreeany(st->st_esp.peer_keymat);

    pfree(st);
}

void
delete_states_by_connection(struct connection *c)
{
    int i;

    for (i = 0; i < STATE_TABLE_SIZE; i++)
    {
	struct state *st;

	for (st = statetable[i]; st != NULL; )
	{
	    struct state *this = st;

	    st = st->st_hashchain_next;	/* before this is deleted */

	    if (this->st_connection == c)
	    {
		struct state *old_cur_state = cur_state;

		cur_state = this;
		log("deleting state (%s)"
		    , enum_show(&state_names, this->st_state));
		passert(this->st_event != NULL);
		delete_event(this);
		delete_state(this);
		cur_state = old_cur_state;
	    }
	}
    }
}

/* Duplicate a Phase 1 state object, to create a Phase 2 object.
 * Caller must schedule an event for this object so that it doesn't leak.
 * Caller must insert_state().
 */
struct state *
duplicate_state(const struct state *st)
{
    struct state *nst;

    DBG(DBG_CONTROL, DBG_log("duplicating state object #%lu",
	st->st_serialno));

    nst = new_state();

    memcpy(nst->st_icookie, st->st_icookie, COOKIE_SIZE);
    memcpy(nst->st_rcookie, st->st_rcookie, COOKIE_SIZE);
    nst->st_connection = st->st_connection;

    nst->st_doi = st->st_doi;
    nst->st_situation = st->st_situation;

#   define clone_fld(fld, fld_len, name) { \
	nst->fld = clone_bytes(st->fld, st->fld_len, name); \
	nst->fld_len = st->fld_len; \
    }

    clone_fld(st_skeyid_d, st_skeyid_d_len, "st_skeyid_d in duplicate_state");
    clone_fld(st_skeyid_e, st_skeyid_e_len, "st_skeyid_e in duplicate_state");
    clone_fld(st_skeyid_a, st_skeyid_a_len, "st_skeyid_a in duplicate_state");
    clone_fld(st_enc_key, st_enc_key_len, "st_enc_key in duplicate_state");
    clone_fld(st_myidentity, st_myidentity_len, "st_myidentity in duplicate_state");
    clone_fld(st_peeridentity, st_peeridentity_len, "st_peeridentity in duplicate_state");

#   undef clone_fld

    nst->st_myidentity_type = st->st_myidentity_type;
    nst->st_peeridentity_type = st->st_peeridentity_type;

    nst->st_oakley = st->st_oakley;

    nst->st_pfs_group = OAKLEY_GROUP_UNSET;

    return nst;
}

/*
 * Find a state object.
 */
struct state *
find_state(const u_char *icookie, const u_char *rcookie,
		const struct in_addr peer, msgid_t /*network order*/  msgid)
{
    struct state *st = *state_hash(icookie, rcookie, peer);

    while (st != (struct state *) NULL)
	if (peer.s_addr == st->st_connection->that.host.s_addr
	&& memcmp(icookie, st->st_icookie, COOKIE_SIZE) == 0
	&& memcmp(rcookie, st->st_rcookie, COOKIE_SIZE) == 0
	&& msgid == st->st_msgid)
	    break;
	else
	    st = st->st_hashchain_next;

    DBG(DBG_CONTROL,
	if (st == NULL)
	    DBG_log("state object not found");
	else
	    DBG_log("state object #%lu found, in %s",
		st->st_serialno,
		enum_show(&state_names, st->st_state)));

    return st;
}

/*
 * Find an ISAKMP SA state object.
 */
struct state *
find_phase1_state(const struct in_addr peer)
{
    struct
	state *st,
	*best = NULL;
    int i;

    for (i = 0; i < STATE_TABLE_SIZE; i++)
	for (st = statetable[i]; st != NULL; st = st->st_hashchain_next)
	    if (peer.s_addr == st->st_connection->that.host.s_addr      /* Host we want */
	    && st->st_msgid == 0	        /* ISAKMP SA */
	    && (st->st_state == STATE_MAIN_I4 || st->st_state == STATE_MAIN_R3)
	    && (best == NULL || best->st_serialno < st->st_serialno))
		best = st;

    return best;
}

void
show_states_status(int wfd)
{
    time_t now = time((time_t *) NULL);
    int i;

    for (i = 0; i < STATE_TABLE_SIZE; i++)
    {
	struct state *st;

	for (st = statetable[i]; st != NULL; st = st->st_hashchain_next)
	{
	    /* what the heck is interesting about a state? */
	    long delta = st->st_event->ev_time >= now
		? (long)(st->st_event->ev_time - now)
		: -(long)(now - st->st_event->ev_time);
	    const char *np1 = st->st_connection->newest_isakmp_sa == st->st_serialno
		? "; newest ISAKMP" : "";
	    const char *np2 = st->st_connection->newest_ipsec_sa == st->st_serialno
		? "; newest IPSEC" : "";
	    const char *eo = st->st_connection->eroute_owner == st->st_serialno
		? "; eroute owner" : "";

	    passert(st->st_event != 0);

	    whack_log(wfd, RC_COMMENT
		, "#%lu: \"%s\" %s (%s); %s in %lds%s%s%s"
		, st->st_serialno
		, st->st_connection->name
		, enum_name(&state_names, st->st_state)
		, state_story[st->st_state - STATE_MAIN_R0]
		, enum_name(&timer_event_names, st->st_event->ev_type)
		, delta
		, np1, np2, eo);
	}
    }
}
