
/* Copyright (C) Gerhard Fuernkranz 1992 */

%{

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <stdarg.h>
#include <netdb.h>
#include <stropts.h>

#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#include "matcher.h"

#define GEN1(a)		gen1(a,#a)
#define GEN2(a,b)	gen2(a,#a,b)
#define GEN3(a,b)	gen3(a,#a,b)
#define GEN4(a,b)	gen4(a,#a,b)
#define GEN5(a,b)	gen5(a,#a,b)

typedef enum set_type {
	port_set, addr_set
} SET_TYPE;

typedef struct sym SYM;			/* forward */

typedef struct fix {			/* for fixing forw. ref. */
	struct fix *next;
	unsigned short where;
} FIX;

typedef struct addr {			/* inet-addr/network w/ mask */
	unsigned long addr;		/* inet address */
	unsigned long mask;		/* netmask */
} ADDR;

typedef struct port {			/* Port protocol & port number */
	unsigned short proto;		/* protocol */
	unsigned short port;		/* port number */
} PPORT;

typedef struct set {			/* specify addres/port set */
	struct set *next;
	SET_TYPE type;			/* which kind of set (port/addr/...) */
	unsigned short count;		/* # of elements */
	void *ptr;			/* data */
	SYM *label;			/* label to represent this set */
	int referenced;			/* set has been referenced */
} SET;

struct sym {				/* symbol table entry */
	struct sym	*next;
	char		*name;
	int		token;		/* symbol's token value */
	int		label_pc;	/* symbol's label value */
	FIX		*fixlist;	/* fwd. refs to fix */
	union {
	    PPORT	port;		/* for port names */
	    ADDR	addr;		/* for host/net names */
	    SET		*set;		/* for addr/port sets */
	} u;
};

#define sym_addr	u.addr
#define sym_port	u.port
#define sym_set		u.set

static int errors;			/* # of errors occured */

static PPORT ports[4096];		/* table to build port sets */
static PPORT *pptr;

static ADDR addrs[4096];		/* table to build addr sets */
static ADDR *aptr;

static unsigned short program[16*1024]; /* The program we produce */
static unsigned short *pc = program;

static SYM *symtab;			/* head of symbol table */
static SET *settab;			/* head of set table */

static void error(const char *, ...);
static void yyerror(const char *s);
static unsigned long netmask(unsigned long);
static SYM *lookup(const char *name);
static SYM *new_label(void);
static SET *new_set(SET_TYPE type);

%}

%union {
	long	intval;
	SYM	*sym;
	SET	*set;
	char	*str;
	ADDR	addr;
	PPORT	port;
}

%token ACCEPT
%token DENY
%token DST
%token GOTO
%token HOST
%token IN
%token MASK
%token NET
%token PORT
%token SRC

%token <str>		STRING
%token <str>		DOTTED

%token <intval>		NUMBER

%token <intval>		PROTOCOL

%token <sym>		LABEL
%token <sym>		NAME
%token <sym>		PORT_NAME
%token <sym>		PORT_SET_NAME
%token <sym>		ADDR_NAME
%token <sym>		ADDR_SET_NAME

%type <port>		port
%type <addr>		addr
%type <addr>		addr_prefix
%type <set>		port_set
%type <set>		addr_set
%type <sym>		anyname

%%

program:  /* empty */
	| program rule
	;

port:	  NUMBER
		{ $$.port = $1; $$.proto = 6; /* TCP */ }
	| PORT NUMBER
		{ $$.port = $2; $$.proto = 6; /* TCP */ }
	| NUMBER '/' PROTOCOL
		{ $$.port = $1; $$.proto = $3; }
	| PORT NUMBER '/' PROTOCOL
		{ $$.port = $2; $$.proto = $4; }
	| PORT STRING
		{
		    char *p;
		    struct servent *sp;
		    struct protoent *pp;
		    if ((p = strchr($2,'/')) == NULL)
			p = "tcp";
		    else
			*p++ = 0;
		    if ((pp = getprotobyname(p)) == NULL) {
			error("Unknown protocol '%s'\n",p);
			$$.proto = 6; /* TCP */
		    }
		    else {
			$$.proto = pp->p_proto;
		    }
		    if ((sp = getservbyname($2,p)) == NULL) {
			error("Unknown service '%s/%s'\n",$2,p);
			$$.port = 0;
		    }
		    else {
			$$.port = ntohs(sp->s_port);
		    }
		}
	| PORT_NAME
		{ $$ = $1->sym_port; }
/* conflict
	| NAME
		{ $$ = 0; error("Name '%s' undefined\n",$1->name); }
*/
	;

addr:	  addr_prefix
	| addr_prefix MASK NUMBER
		{ $$ = $1; $$.addr &= $3; $$.mask = $3; }
	;

addr_prefix:
	  HOST NUMBER
		{ $$.addr = $2; $$.mask = 0xffffffff; }
	| HOST DOTTED
		{
		    $$.addr = ntohl(inet_addr($2));
		    if ($$.addr == -1)
			error("Illegal internet address '%s'\n",$2);
		    $$.mask = 0xffffffff;
		}
	| DOTTED
		{
		    $$.addr = ntohl(inet_addr($1));
		    if ($$.addr == -1)
			error("Illegal internet address '%s'\n",$1);
		    $$.mask = 0xffffffff;
		}
	| HOST STRING
		{
		    struct hostent *hp;
		    if ((hp = gethostbyname($2)) != NULL) {
			$$.addr = ntohl(*(long*)hp->h_addr);
		    }
		    else {
			error("Unknown host '%s'\n",$2);
			$$.addr = 0;
		    }
		    $$.mask = 0xffffffff;
		}
	| NET NUMBER
		{
		    unsigned long net = $2;
		    $$.addr = ntohl(inet_makeaddr(net,0).s_addr);
		    $$.mask = netmask($$.addr);
		}
	| NET DOTTED
		{
		    unsigned long net = inet_network($2);
		    if (net == -1)
			error("Illegal network '%s'\n",$2);
		    $$.addr = ntohl(inet_makeaddr(net,0).s_addr);
		    $$.mask = netmask($$.addr);
		}
	| NET STRING
		{
		    struct netent *np;
		    unsigned long net = 0;
		    if ((np = getnetbyname($2)) != NULL)
			net = np->n_net;
		    else
			error("Unknown network '%s'\n",$2);
		    $$.addr = ntohl(inet_makeaddr(net,0).s_addr);
		    $$.mask = netmask($$.addr);
		}
	| STRING
		{
		    unsigned long net;
		    struct netent *np;
		    struct hostent *hp;
		    if ((hp = gethostbyname($1)) != NULL) {
			$$.addr = ntohl(*(long*)hp->h_addr);
			$$.mask = 0xffffffff;
		    }
		    else if ((np = getnetbyname($1)) != NULL) {
			net = np->n_net;
			$$.addr = ntohl(inet_makeaddr(net,0).s_addr);
			$$.mask = netmask($$.addr);
		    }
		    else {
			$$.addr = 0;
			$$.mask = 0xffffffff;
			error("Unknown host/network '%s'\n",$1);
		    }
		}
	| ADDR_NAME
		{ $$ = $1->sym_addr; }
/* conflict
	| NAME
		{ $$.addr = 0; $$.addr = 0xffffffff;
		  error("Name '%s' undefined\n",$1->name); }
*/
	;

port_set: '{' port_list '}'
		{
		    int n;
		    $$ = new_set(port_set);
		    $$->count = pptr - ports;
		    n = $$->count * sizeof(ports[0]);
		    $$->ptr = memcpy(malloc(n),ports,n);
		}
	| PORT_SET_NAME
		{ $$ = $1->sym_set; }
	;

port_list:
	  port
		{ ports[0] = $1; pptr = ports+1; }
	| PORT_SET_NAME
		{
		    memcpy(ports,$1->sym_set->ptr,
			$1->sym_set->count * sizeof(ports[0]));
		    pptr = ports + $1->sym_set->count;
		}
	| port_list ',' port
		{ *pptr++ = $3; }
	| port_list ',' PORT_SET_NAME
		{
		    memcpy(pptr,$3->sym_set->ptr,
			$3->sym_set->count * sizeof(ports[0]));
		    pptr += $3->sym_set->count;
		}
	;

addr_set: '{' addr_list '}'
		{
		    int n;
		    $$ = new_set(addr_set);
		    $$->count = aptr - addrs;
		    n = $$->count * sizeof(addrs[0]);
		    $$->ptr = memcpy(malloc(n),addrs,n);
		}
	| ADDR_SET_NAME
		{ $$ = $1->sym_set; }
	;

addr_list:
	  addr
		{ addrs[0] = $1; aptr = addrs+1; }
	| ADDR_SET_NAME
		{
		    memcpy(addrs,$1->sym_set->ptr,
			$1->sym_set->count * sizeof(ADDR));
		    aptr = addrs + $1->sym_set->count;
		}
	| addr_list ',' addr
		{ *aptr++ = $3; }
	| addr_list ',' ADDR_SET_NAME
		{
		    memcpy(aptr,$3->sym_set->ptr,
			$3->sym_set->count * sizeof(ADDR));
		    aptr += $3->sym_set->count;
		}
	;

def:	  NAME '=' port
		{ $1->token = PORT_NAME; $1->sym_port = $3; }
	| NAME '=' addr
		{ $1->token = ADDR_NAME; $1->sym_addr = $3; }
	| NAME '=' port_set
		{ $1->token = PORT_SET_NAME; $1->sym_set = $3; }
	| NAME '=' addr_set
		{ $1->token = ADDR_SET_NAME; $1->sym_set = $3; }
	;

cond:
	  SRC '=' '=' addr
		{ GEN3(SAEQ,$4); }
	| SRC '!' '=' addr
		{ GEN3(SANE,$4); }
	| SRC IN addr_set
		{ GEN4(SAIN,$3); }
	| SRC '!' IN addr_set
		{ GEN4(SANIN,$4); }

	| DST '=' '=' addr
		{ GEN3(DAEQ,$4); }
	| DST '!' '=' addr
		{ GEN3(DANE,$4); }
	| DST IN addr_set
		{ GEN4(DAIN,$3); }
	| DST '!' IN addr_set
		{ GEN4(DANIN,$4); }

	| SRC '=' '=' port
		{ GEN2(SPEQ,$4); }
	| SRC '!' '=' port
		{ GEN2(SPNE,$4); }
	| SRC '<' port
		{ GEN2(SPLT,$3); }
	| SRC '>' port
		{ GEN2(SPGT,$3); }
	| SRC '<' '=' port
		{ GEN2(SPLE,$4); }
	| SRC '>' '=' port
		{ GEN2(SPGE,$4); }
	| SRC IN port_set
		{ GEN4(SPIN,$3); }
	| SRC '!' IN port_set
		{ GEN4(SPNIN,$4); }

	| DST '=' '=' port
		{ GEN2(DPEQ,$4); }
	| DST '!' '=' port
		{ GEN2(DPNE,$4); }
	| DST '<' port
		{ GEN2(DPLT,$3); }
	| DST '>' port
		{ GEN2(DPGT,$3); }
	| DST '<' '=' port
		{ GEN2(DPLE,$4); }
	| DST '>' '=' port
		{ GEN2(DPGE,$4); }
	| DST IN port_set
		{ GEN4(DPIN,$3); }
	| DST '!' IN port_set
		{ GEN4(DPNIN,$4); }
	;

anyname:  NAME
	| PORT_NAME
	| PORT_SET_NAME
	| ADDR_NAME
	| ADDR_SET_NAME

rule:	  def
	| LABEL ':'
		{ gen_label($1); }
	| cond '-' '>' ACCEPT
		{ GEN1(TACCEPT); }
	| cond '-' '>' DENY
		{ GEN1(TDENY); }
	| cond '-' '>' anyname
		{ GEN5(TJMP,$4); }
	| ACCEPT
		{ GEN1(UACCEPT); }
	| DENY
		{ GEN1(UDENY); }
	| GOTO anyname
		{ GEN5(JMP,$2); }
	| error
	;

%%

#include "lex.yy.c"

/* Print error message */

static void
error(const char *fmt, ...)
{
    va_list ap;
    va_start(ap,fmt);
    fprintf(stderr,"Line %d: ",yylineno);
    vfprintf(stderr,fmt,ap);
    va_end(ap);
    errors++;
}

/* Report yacc error */

static void
yyerror(const char *s)
{
    error("%s before or at '%s'\n",s,yytext);
}

/* Compute the netmask for a given network number.  */

static unsigned long
netmask(unsigned long net)
{
    if (IN_CLASSA(net))
	return IN_CLASSA_NET;
    if (IN_CLASSB(net))
	return IN_CLASSB_NET;
    return IN_CLASSC_NET;
}

/* lookup a symbol in the symbol table */

SYM *
lookup(const char *name)
{
    SYM *p;

    for (p = symtab; p; p = p->next)
	if (strcmp(name,p->name) == 0)
	    return p;
    p = calloc(1,sizeof(*p));
    p->name = strdup(name);
    p->next = symtab;
    p->label_pc = -1;
    symtab = p;
    return p;
}

/* Create a new set structure and link into table */

static SET *
new_set(SET_TYPE type)
{
    SET *s;

    s = calloc(1,sizeof(*s));
    s->type = type;
    s->label = new_label();
    s->next = settab;
    settab = s;
    return s;
}

/* create a new symbol assciated with a label */

static SYM *
new_label(void)
{
    static current_nr = 0;
    int lab = ++current_nr;
    char name[16];
    SYM *sym;

    sprintf(name,"$L%d",lab);
    sym = lookup(name);
    return sym;
}

/* produce assembler listing */

static void
asmlist(int how, ...)
{
    int i = 0;
    va_list ap;
    const char *op;
    const char *fmt;
    static unsigned short *save_pc;

    if (how == 1) {
	    save_pc = pc;
	    return;
    }
    if (how == 3) {
	va_start(ap,i);
	op = va_arg(ap,const char *);
	va_end(ap);
	printf("%5d:\t\t%s:\n",pc-program,op);
	return;
    }
    va_start(ap,i);
    op = va_arg(ap,const char *);
    fmt = va_arg(ap,const char *);
    i += printf("%5d: ",save_pc - program);
    while (save_pc < pc)
	i += printf("%04x ",*save_pc++);
    if (how == 4) {
	printf("\n");
	return;
    }
    if (i >= 24) {
	printf("\n\t\t\t");
	i = 24;
    }
    while (i < 24) {
	putchar(' ');
	i++;
    }
    printf("%-7s ",op);
    vprintf(fmt,ap);
    printf("\n");
    va_end(ap);
}

/* generate a long word */

static void
gen_long(unsigned long l)
{
    *(unsigned long *) pc = l;
    pc += sizeof(l) / sizeof(*pc);
}

static void
gen_label(SYM *label)
{
    if (label->label_pc >= 0)
	error("Label '%s' multiply defined\n",label->name);
    else
	label->label_pc = pc - program;
    asmlist(3,label->name);
}

static void
gen_label_ref(SYM *label)
{
    FIX *f;

    if (label->label_pc >= 0) {
	*pc = label->label_pc - (pc - program);
	pc++;
	return;
    }
    f = malloc(sizeof(*f));
    f->next = label->fixlist;
    label->fixlist = f;
    f->where = pc - program;
    *pc++ = - (pc - program);
}

static void
gen1(OPCODE op, const char *opstr)
{
    asmlist(1);
    *pc++ = op;
    asmlist(2,opstr,"");
}

static void
gen2(OPCODE op, const char *opstr, PPORT port)
{
    asmlist(1);
    *pc++ = op;
    *pc++ = port.proto;
    *pc++ = port.port;
    asmlist(2,opstr,"%d,%d",(int)port.proto,(int)port.port);
}

static void
gen3(OPCODE op, const char *opstr, ADDR addr)
{
    asmlist(1);
    *pc++ = op;
    gen_long(htonl(addr.addr));
    gen_long(htonl(addr.mask));
    asmlist(2,opstr,"0x%08lx,0x%08lx",addr.addr,addr.mask);
}

static void
gen4(OPCODE op, const char *opstr, SET *set)
{
    asmlist(1);
    *pc++ = op;
    gen_label_ref(set->label);
    *pc++ = set->count;
    asmlist(2,opstr,"%s,%d",set->label->name,(int)set->count);
    set->referenced = 1;
}

static void
gen5(OPCODE op, const char *opstr, SYM *label)
{
    asmlist(1);
    *pc++ = op;
    gen_label_ref(label);
    asmlist(2,opstr,"%s",label->name);
}

static void
gen_sets(void)
{
    int i;
    SET *s;
    ADDR *pa = NULL;
    PPORT *pp = NULL;

    for (s = settab; s; s = s->next) {
	if (!s->referenced)
	    continue;
	gen_label(s->label);
	pp = s->ptr;
	pa = s->ptr;
	for (i = 0; i < s->count; i++) {
	    if (s->type == port_set) {
		asmlist(1);
		*pc++ = pp->proto;
		*pc++ = pp->port;
		pp++;
		asmlist(4);
	    }
	    else {
		asmlist(1);
		gen_long(htonl(pa->addr));
		gen_long(htonl(pa->mask));
		pa++;
		asmlist(4);
	    }
	}
    }
}

/* Fix forward refs */

static void
fix_forward(void)
{
    SYM *s;
    FIX *f;

    printf("\nFixing forward references\n\n");
    for (s = symtab; s; s = s->next) {
	if (s->fixlist && s->label_pc < 0)
	    error("Undefined label '%s'\n",s->name);
	else {
	    for (f = s->fixlist; f; f = f->next) {
		program[f->where] += s->label_pc;
		printf("%5d: %04x\n",(int)f->where,(int)program[f->where]);
	    }
	}
    }
}

/* Main program */

void
main(void)
{
    int fd;
    struct strioctl io;

    yyparse();
    GEN1(UDENY);
    gen_sets();
    fix_forward();
    if (errors != 0) {
	error("Found %d errors - aborting\n",errors);
	exit(1);
    }
    if ((fd = open("/dev/ip",2)) == -1) {
	perror("/dev/ip");
	exit(1);
    }
    io.ic_cmd = 'I' << 24 | 'P' << 16 | 'A' << 8 | 0;
    io.ic_timout = -1;
    io.ic_len = (char*)pc - (char*)program;
    io.ic_dp = (void*)program;
    if (ioctl(fd,I_STR,&io) == -1) {
	perror("ioctl");
	exit(1);
    }
}

