/*
 * socktest.c  --  Socket testing program.  Useful for first-cut tests on
 *                 AF_{INET,INET6}.
 *
 *
 * Copyright 1995 by Dan McDonald, Bao Phan, and Randall Atkinson,
 *	All Rights Reserved.  
 *      All Rights under this copyright have been assigned to NRL.
 */

/*----------------------------------------------------------------------
#       @(#)COPYRIGHT   1.1a (NRL) 17 August 1995

COPYRIGHT NOTICE

All of the documentation and software included in this software
distribution from the US Naval Research Laboratory (NRL) are
copyrighted by their respective developers.

This software and documentation were developed at NRL by various
people.  Those developers have each copyrighted the portions that they
developed at NRL and have assigned All Rights for those portions to
NRL.  Outside the USA, NRL also has copyright on the software
developed at NRL. The affected files all contain specific copyright
notices and those notices must be retained in any derived work.

NRL LICENSE

NRL grants permission for redistribution and use in source and binary
forms, with or without modification, of the software and documentation
created at NRL provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright
   notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
   notice, this list of conditions and the following disclaimer in the
   documentation and/or other materials provided with the distribution.
3. All advertising materials mentioning features or use of this software
   must display the following acknowledgement:

        This product includes software developed at the Information
        Technology Division, US Naval Research Laboratory.

4. Neither the name of the NRL nor the names of its contributors
   may be used to endorse or promote products derived from this software
   without specific prior written permission.

THE SOFTWARE PROVIDED BY NRL IS PROVIDED BY NRL AND CONTRIBUTORS ``AS
IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL NRL OR
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

The views and conclusions contained in the software and documentation
are those of the authors and should not be interpreted as representing
official policies, either expressed or implied, of the US Naval
Research Laboratory (NRL).

----------------------------------------------------------------------*/

#include <stdio.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <unistd.h>
#include <stdio.h>
#include <ctype.h>
#include <errno.h>
#include <string.h>
#include <stdarg.h>
#include <stdlib.h>

#if FASTCTO
#include <signal.h>
#include <setjmp.h>
#endif /* FASTCTO */

#include "support.h"

extern char *optarg;
extern int errno, optind;

struct nrl_nametonum socktypes[] = {
  { SOCK_STREAM, "stream", 0 },
  { SOCK_DGRAM, "dgram", 0 },
  { 0, NULL, 0 }
};

int port = 7777;
char *portname = NULL;
char *hostname = NULL;

int server = 0, type = SOCK_DGRAM;
int s = 0, s2 = 0, datasize = 0;
int af = 0;
int patternlen = 0;

int nflag = 0;

char addrbuf[24], addrbuf2[24];
struct sockaddr *addr;
int addrlen;

char *datablock;
int datablocklen = 65535;
char mynamestr[256];

char *requeststr = NULL;

#ifdef ADDRFORM
int aform = 0;
#endif /* ADDRFORM */

#ifdef FASTCTO
static jmp_buf timeout_env;

static void timeout_handler(int i)
{
  longjmp(timeout_env, i);
}
#endif /* FASTCTO */

void say(char *message, ...)
{
  char a2abuf[64];
  va_list ap;

  fputs(mynamestr, stdout);
  va_start(ap, message);
  vprintf(message, ap);
  va_end(ap);
  putchar('\n');
}

#ifdef ADDRFORM
addrform(form)
int form;
{
  int len, old, new;

  len = sizeof(int);
  if (getsockopt(s, IPPROTO_IPV6, IPV6_ADDRFORM, &old, &len) < 0) {
    perror("socktest: getsockopt(IPV6_ADDRFORM...)");
    exit(1);
  }

  len = sizeof(int);
  if (setsockopt(s, IPPROTO_IPV6, IPV6_ADDRFORM, &form, len) < 0) {
    perror("socktest: setsockopt(IPV6_ADDRFORM...)");
    exit(1);
  }

  len = sizeof(int);
  if (getsockopt(s, IPPROTO_IPV6, IPV6_ADDRFORM, &new, &len) < 0) {
    perror("socktest: getsockopt(IPV6_ADDRFORM...)");
    exit(1);
  }

  addrlen = sizeof(addrbuf);
  if (getsockname(s, (struct sockaddr *)addrbuf, &addrlen) < 0) {
    perror("socktest: getsockname");
    exit(1);
  }
  addr = (struct sockaddr *)addrbuf;

  if (new != form) {
    fprintf(stderr, "socktest: address form error: started with %s, requested %s\n", nrl_afnumtoname(old), nrl_afnumtoname(form));
    fprintf(stderr, "socktest: address form error: but ended up with %s!\n", nrl_afnumtoname(new));
  }
}
#endif /* ADDRFORM */

usage(char *myname)
{
  fprintf(stderr, "usage: %s [-t socket_type] [-a address_fam]", myname);
#ifdef ADDRFORM
  fprintf(stderr, " [-f address_form]");
#endif /* ADDRFORM */
#ifdef IPSEC
#ifdef IPSEC_NEWAPI
  fprintf(stderr, " [-S security_spec]");
#else /* IPSEC_NEWAPI */
  fprintf(stderr, " [-A authlevel] [-T transportencrlevel] [-N netencrlevel]");
#endif /* IPSEC_NEWAPI */
#endif /* IPSEC */
  fprintf(stderr," [-s] [host] [port]\n");
  exit(1);
}

main(argc, argv)
     int argc;
     char *argv[];
{
  int i;
  char *p;
  struct addrinfo *ai, *ai2;
  char ch;
  int ahlev = 0, esplev = 0, esptlev = 0;
  char hbuf[32], sbuf[8];
 
  while ((ch = getopt(argc, argv, "S:A:N:T:a:f:t:snp:b:")) != EOF)
    switch(ch) {
    case 'a':
      if ((af = nrl_afnametonum(optarg)) == -1) {
	fprintf(stderr,"socktest: invalid address family: %s\n", optarg);
	exit(1);
      }
      break;
#ifdef ADDRFORM
    case 'f':
      if ((aform = nrl_afnametonum(optarg)) == -1) {
	fprintf(stderr,"socktest: invalid address family: %s\n", optarg);
	exit(1);
      }
      break;
#endif /* ADDRFORM */
#if IPSEC
#if IPSEC_NEWAPI
    case 'S':
      requeststr = optarg;
      break;
#else /* IPSEC_NEWAPI */
    case 'A':
      ahlev = atoi(optarg);
      break;
    case 'T':
      /* Transport mode encryption */
      esplev = atoi(optarg);
      break;
    case 'N':
      /* Network(tunnel) mode encryption */
      esptlev = atoi(optarg);
      break;
#endif /* IPSEC_NEWAPI */
#endif /* IPSEC */
    case 's':
      server = 1;
      break;
    case 't':
      if ((type = nrl_nametonum(socktypes, optarg)) == -1) {
	fprintf(stderr,"socktest: invalid socket type: %s\n", optarg);
	exit(1);
      }
      break;
    case 'n':
      nflag = 1;
      break;
    case 'b':
      if (!(datablocklen = atoi(optarg))) {
        fprintf(stderr, "socktest: %s: invalid buffer size\n", optarg);
        exit(1);
      }
      break;
    case 'p':
      if (!(patternlen = atoi(optarg))) {
        fprintf(stderr, "socktest: %s: invalid test pattern start\n", optarg);
        exit(1);
      }
      break;
    default:
      usage(argv[0]);
    }

  if (optind < argc)
    hostname = argv[optind++];

  if (!hostname && !server)
    hostname = "localhost";

  if (optind < argc)
    portname = argv[optind++];

  if (!portname)
    portname = "7777";

  {
  struct addrinfo req;

  memset(&req, 0, sizeof(struct addrinfo));

  req.ai_family = af;
  req.ai_socktype = type;
  if (server)
    req.ai_flags |= AI_PASSIVE;

  if (i = getaddrinfo(hostname, portname, &req, &ai)) {
    fprintf(stderr, "socktest: getaddrinfo: %s.%s: %s\n", hostname, portname, nrl_gai_strerror(i));
    exit(1);
  }
  }

#ifdef FASTCTO
  signal(SIGALRM, &timeout_handler);
#endif /* FASTCTO */

  {
    for (ai2 = ai, s = -1; ai; ai = ai->ai_next) {
      if (nrl_getnameinfo(ai->ai_addr, ai->ai_addrlen, hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), 0)) {
	printf("socktest: getnameinfo() failed!\n");
	continue;
      }

      printf("Trying %s.%s...\n", hbuf, sbuf);

      if ((s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol)) < 0) {
	perror("socket");
	continue;
      }

#if IPSEC
#if IPSEC_NEWAPI
  if (requeststr) {
    void *request;
    int requestlen;

    if (parsesec(requeststr, &request, &requestlen)) {
      fprintf(stderr, "parsesec() failed!\n");
      exit(1);
    }

    if (setsockopt(s, SOL_SOCKET, SO_SECURITY, request, requestlen) < 0) {
      perror("setsockopt SO_SECURITY");
      exit(1);
    }
  }
#else /* IPSEC_NEWAPI */
  {
    int val, len;
    if (setsockopt(s, SOL_SOCKET, SO_SECURITY_AUTHENTICATION, &ahlev,
		 len = sizeof(int)) < 0) {
      perror("setsockopt (auth)");
      exit(1);
    }
    if (getsockopt(s, SOL_SOCKET, SO_SECURITY_AUTHENTICATION, &val,
		 &len) < 0) {
      perror("getsockopt (auth)");
      exit(1);
    }
    if (val != ahlev) {
      fprintf(stderr, "Requested auth level %d and got %d!\n", ahlev, val);
    }
    if (setsockopt(s, SOL_SOCKET, SO_SECURITY_ENCRYPTION_TRANSPORT, &esplev,
		 len = sizeof(int)) < 0) {
      perror("setsockopt (espt)");
      exit(1);
    }
    if (getsockopt(s, SOL_SOCKET, SO_SECURITY_ENCRYPTION_TRANSPORT, &val,
		 &len) < 0) {
      perror("getsockopt (espt)");
      exit(1);
    }
    if (val != esplev) {
      fprintf(stderr, "Requested esp-transport level %d and got %d!\n", esplev, val);
    }
    if (setsockopt(s, SOL_SOCKET, SO_SECURITY_ENCRYPTION_NETWORK, &esptlev,
		 len = sizeof(int)) < 0) {
      perror("setsockopt (espn)");
      exit(1);
    }
    if (getsockopt(s, SOL_SOCKET, SO_SECURITY_ENCRYPTION_NETWORK, &val,
		 &len) < 0) {
      perror("getsockopt (espn)");
      exit(1);
    }
    if (val != esptlev) {
      fprintf(stderr, "Requested esp-tunnel level %d and got %d!\n", esptlev, val);
    }
  }
#endif /* IPSEC_NEWAPI */
#endif /* IPSEC */


      if (ai->ai_flags & AI_PASSIVE) {
	if (bind(s, ai->ai_addr, ai->ai_addrlen) < 0) {
	  perror("bind");
	  close(s);
	  s = -1;
	  continue;
	}

	if ((listen(s, 1) < 0) && (errno != EOPNOTSUPP)) {
	  perror("listen");
	  close(s);
	  s = -1;
	  continue;
	}
      } else {
#if FASTCTO
	if (setjmp(timeout_env)) {
	  fprintf(stderr, "socktest: Connection timed out\n");
	  continue;
	}
	
	alarm(FASTCTO);
#endif /* FASTCTO */

	if (connect(s, ai->ai_addr, ai->ai_addrlen) < 0) {
#ifdef FASTCTO
	  alarm(0);
#endif /* FASTCTO */
	  perror("connect");
	  close(s);
	  s = -1;
	  continue;
	}
#ifdef FASTCTO
	alarm(0);
#endif /* FASTCTO */
      }
      break;
    }
    if (s < 0)
      exit(1);
  }
  
  if (ai->ai_protocol == IPPROTO_TCP) {
    int val = 1;
    if (setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val)) < 0) {
      perror("setsockopt (TCP_NODELAY)");
      exit(1);
    }
  }


#ifdef ADDRFORM
  if (aform && (aform != af))
	addrform(aform);
#endif /* ADDRFORM */

  if (nrl_getnameinfo(ai->ai_addr, ai->ai_addrlen, hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), nflag)) {
    fprintf(stderr, "getnameinfo() failed!\n");
    exit(1);
  }

  freeaddrinfo(ai2);
  
  sprintf(mynamestr, "%s [%s.%s]:\n\t", server ? "Server" : "Client", hbuf, sbuf);

  if (patternlen > datablocklen) {
    fprintf(stderr, "socktest: invalid pattern length\n");
    exit(1);
  };

  if (!(datablock = malloc(datablocklen + 1))) {
    fprintf(stderr, "socktest: can't allocate %d bytes buffer space\n", datablocklen);
    exit(1);
  };

  say("Ready");

  if (server)
    goto servercode;

  if (patternlen)
    goto patterncode;

  while(!feof(stdin)) {
    printf("send: ");
    fgets(datablock, datablocklen, stdin);
    if (p = strchr(datablock, '\n'))
      *p = 0;

    datasize = strlen(datablock);

    if ((write(s, datablock, datasize)) < 0) { 
      perror("write");
      exit(1);
    };

    say("Sent '%s'", datablock);

    fflush(stdout);

    if ((datasize = read(s, datablock, datablocklen)) < 0) {
      perror("read");
      exit(1);
    };

    datablock[datasize] = 0;
    say("Got  '%s'", datablock);

    fflush(stdout);
  }
  exit(0);

patterncode:
  while(patternlen <= datablocklen) {
    for (i = 0; i < patternlen; i++)
      datablock[i] = '0' + (i % 10);

    if ((write(s, datablock, patternlen)) < 0) { 
      perror("write");
      exit(1);
    };

    say("Sent %5d bytes", patternlen);

    if ((datasize = read(s, datablock, datablocklen)) < 0) {
      perror("read");
      exit(1);
    };

    say("Got  %5d bytes", datasize);

    i = 0;

    if (datasize != patternlen) {
      fprintf(stderr, "socktest: Length mismatch (%d != %d)\n", patternlen, datasize);
      i = patternlen;
    }

    while(i < patternlen) {
      if (datablock[i] != '0' + (i % 10)) {
        fprintf(stderr, "socktest: Data mismatch at byte %d\n", i);
        i = patternlen;
      }
      i++;
    }

    patternlen++;
  }
  exit(0);

servercode:
  s2 = -1;

  while(1) {
    if (!datasize) {
      if (s2 >= 0) {
	say("Closing connection to %s.%s", hbuf, sbuf);
	close(s2);
      }

      addrlen = sizeof(addrbuf);
      if ((s2 = accept(s, (struct sockaddr *)addrbuf, &addrlen)) >= 0) {
	if (nrl_getnameinfo((struct sockaddr *)addrbuf, addrlen, hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), nflag)) {
	  fprintf(stderr, "getnameinfo() failed!\n");
	  exit(1);
	}

	say("Accepted connection from %s.%s", hbuf, sbuf);

	memcpy(addrbuf2, addrbuf, sizeof(addrbuf));
      } else {
	if ((errno == EOPNOTSUPP) || (errno == EINVAL)) {
	  s2 = s;
        } else {
	  perror("accept");
	  exit(1);
	}
      }
    }

    addrlen = sizeof(addrbuf);
    if ((datasize = recvfrom(s2, datablock, datablocklen - 1, 0,
			     (struct sockaddr *)addrbuf, &addrlen)) < 0) {
      perror("recvfrom");
      exit(1);
    }

    if (datasize && memcmp(addrbuf, addrbuf2, sizeof(addrbuf))) {
      memcpy(addrbuf2, addrbuf, sizeof(addrbuf));

      if (nrl_getnameinfo((struct sockaddr *)addrbuf, addrlen, hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), nflag)) {
	fprintf(stderr, "getnameinfo() failed!\n");
	exit(1);
      }
    }

    if (datasize) {
      datablock[datasize] = 0;

      if (patternlen)
        say("Got  %5d bytes from %s.%s", datasize, hbuf, sbuf);
      else
        say("Got  '%s' from %s.%s", datablock, hbuf, sbuf);
      fflush(stdout);
      if (sendto(s2, datablock, datasize, 0, (struct sockaddr *)addrbuf,
		 addrlen) < 0) {
	perror("sendto"); 
	exit(1);
      }
    }
  }
}
