/*
 * RSA signature key generation
 * Copyright (C) 1999, 2000  Henry Spencer.
 * 
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2 of the License, or (at your
 * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
 * 
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * for more details.
 *
 * RCSID $Id: rsasigkey.c,v 1.4 2000/01/06 22:01:57 henry Exp $
 */

#include <stdio.h>
#include <time.h>
#include <limits.h>
#include <errno.h>
#include <string.h>
#include <assert.h>
#include <getopt.h>
#include <freeswan.h>
#include "gmp.h"

#ifndef DEVICE
#define	DEVICE	"/dev/random"
#endif
#ifndef MAXBITS
#define	MAXBITS	4096
#endif

char usage[] = "Usage: rsasigkey [--verbose] [--random device] nbits";
struct option opts[] = {
	"verbose",	0,	NULL,	'v',
	"random",	1,	NULL,	'r',
	"rounds",	1,	NULL,	'p',
	"help",		0,	NULL,	'h',
	"version",	0,	NULL,	'V',
	0,		0,	NULL,	0,
};
int verbose = 0;		/* narrate the action? */
char *device = DEVICE;		/* where to get randomness */
int nrounds = 30;		/* rounds of prime checking; 25 is good */

char me[] = "ipsec rsasigkey";	/* for messages */

/* forwards */
void rsasigkey(int nbits);
void initprime(mpz_t var, int nbits, int eval);
void initrandom(mpz_t var, int nbits);
void getrandom(size_t nbytes, char *buf);
char *hexout(mpz_t var);
void report(char *msg);

/*
 - main - mostly argument parsing
 */
main(argc, argv)
int argc;
char *argv[];
{
	int opt;
	extern int optind;
	extern char *optarg;
	int errflg = 0;
	int nbits;

	while ((opt = getopt_long(argc, argv, "", opts, NULL)) != EOF)
		switch (opt) {
		case 'v':	/* verbose description */
			verbose = 1;
			break;
		case 'r':	/* nonstandard /dev/random */
			device = optarg;
			break;
		case 'p':	/* number of prime-check rounds */
			nrounds = atoi(optarg);
			if (nrounds <= 0) {
				fprintf(stderr, "%s: rounds must be > 0\n", me);
				exit(2);
			}
			break;
		case 'h':	/* help */
			printf("%s\n", usage);
			exit(0);
			break;
		case 'V':	/* version */
			printf("1\n");
			exit(0);
			break;
		case '?':
		default:
			errflg = 1;
			break;
		}
	if (errflg || optind != argc-1) {
		fprintf(stderr, "%s\n", usage);
		exit(2);
	}

	nbits = atoi(argv[optind]);
	if (nbits <= 0) {
		fprintf(stderr, "%s: invalid bit count (%d)\n", me, nbits);
		exit(1);
	}
	if (nbits > MAXBITS) {
		fprintf(stderr, "%s: overlarge bit count (max %d)\n", me,
								MAXBITS);
		exit(1);
	}
	if (nbits % (CHAR_BIT*2) != 0) {
		fprintf(stderr, "%s: bit count (%d) not multiple of %d\n", me,
						nbits, (int)CHAR_BIT*2);
		exit(1);
	}

	rsasigkey(nbits);
	exit(0);
}

/*
 - rsasigkey - generate an RSA signature key
 * We take e to be 3, without discussion.  That would not be wise if these
 * keys were to be used for encryption, but for signatures there are some
 * real speed advantages.
 */
void
rsasigkey(nbits)
int nbits;
{
	mpz_t p;
	mpz_t q;
	mpz_t n;
	mpz_t e;
#	define	E	3
	mpz_t d;
	mpz_t m;			/* internal modulus, (p-1)*(q-1) */
	mpz_t t;			/* temporary */
	mpz_t exp1;
	mpz_t exp2;
	mpz_t coeff;
	char *hexp;
	int success;
	time_t now;

	/* the easy stuff */
	report("computing primes and modulus...");
	initprime(p, nbits/2, E);
	initprime(q, nbits/2, E);
	mpz_init(t);
	if (mpz_cmp(p, q) < 0) {	/* p to be the larger of the primes */
		report("swapping primes so p is the larger");
		mpz_set(t, p);
		mpz_set(p, q);
		mpz_set(q, t);
	}
	mpz_init(n);
	mpz_mul(n, p, q);		/* n = p*q */
	mpz_init_set_ui(e, E);

	/* internal modulus */
	report("computing (p-1)*(q-1)...");
	mpz_init_set(m, p);
	mpz_sub_ui(m, m, 1);
	mpz_set(t, q);
	mpz_sub_ui(t, t, 1);
	mpz_mul(m, m, t);		/* m = (p-1)*(q-1) */
	mpz_gcd(t, m, e);
	assert(mpz_cmp_ui(t, 1) == 0);	/* m and e relatively prime */

	/* decryption key */
	report("computing d...");
	mpz_init(d);
	success = mpz_invert(d, e, m);
	assert(success);		/* e has an inverse mod m */
	if (mpz_cmp_ui(d, 0) < 0)
		mpz_add(d, d, m);
	assert(mpz_cmp(d, m) < 0);

	/* the speedup hacks */
	report("computing exp1, exp1, coeff...");
	mpz_init(exp1);
	mpz_sub_ui(t, p, 1);
	mpz_mod(exp1, d, t);		/* exp1 = d mod p-1 */
	mpz_init(exp2);
	mpz_sub_ui(t, q, 1);
	mpz_mod(exp2, d, t);		/* exp2 = d mod q-1 */
	mpz_init(coeff);
	mpz_invert(coeff, q, p);	/* coeff = q^-1 mod p */
	if (mpz_cmp_ui(coeff, 0) < 0)
		mpz_add(coeff, coeff, p);
	assert(mpz_cmp(coeff, p) < 0);

	/* and the output */
	report("output...\n");		/* deliberate extra newline */
	now = time((time_t *)NULL);
	printf("\t# %d bits, %s", nbits, ctime(&now));	/* ctime provides \n */
	printf("\t# for signatures only, UNSAFE FOR ENCRYPTION\n");
	hexp = hexout(n);
	printf("\t#pubkey=0x01%02x%s\n", E, hexp+2);	/* RFC2537ish format */
	printf("\tModulus: %s\n", hexp);
	printf("\tPublicExponent: %s\n", hexout(e));
	printf("\t# everything after this point is secret\n");
	printf("\tPrivateExponent: %s\n", hexout(d));
	printf("\tPrime1: %s\n", hexout(p));
	printf("\tPrime2: %s\n", hexout(q));
	printf("\tExponent1: %s\n", hexout(exp1));
	printf("\tExponent2: %s\n", hexout(exp2));
	printf("\tCoefficient: %s\n", hexout(coeff));
}

/*
 - initprime - initialize an mpz_t to a random prime of specified size
 * Incrementing by 2 rather than 1 would be tidier, but that means having
 * to ensure that the first value is odd, and that's too much trouble; the
 * prime checker efficiently rejects even numbers anyway.  Efficiency tweak:
 * we reject primes that are 1 higher than a multiple of e, since they will
 * make the internal modulus not relatively prime to e.
 */
void
initprime(var, nbits, eval)
mpz_t var;
int nbits;			/* known to be a multiple of CHAR_BIT */
int eval;			/* value of e; 0 means don't bother */
{
	unsigned long tries;

	initrandom(var, nbits);

	report("looking for a prime starting there");
	tries = 1;
	while (!( mpz_probab_prime_p(var, nrounds) &&
				(eval == 0 || mpz_fdiv_ui(var, eval) != 1) )) {
		mpz_add_ui(var, var, 1);
		tries++;
	}

	if (verbose)		/* /2 because don't count the even numbers! */
		fprintf(stderr, "found it after %lu tries\n", tries/2);
}

/*
 - initrandom - initialize an mpz_t to a specified number of random bits
 * Going via hex is a bit strange, but it's the best route GMP gives us.
 */
void
initrandom(var, nbits)
mpz_t var;
int nbits;			/* known to be a multiple of CHAR_BIT */
{
	size_t nbytes = (size_t)(nbits / CHAR_BIT);
	static char bitbuf[MAXBITS/CHAR_BIT];
	static char hexbuf[2 + MAXBITS/4 + 1];
	size_t hsize = sizeof(hexbuf);

	assert(nbytes <= sizeof(bitbuf));
	getrandom(nbytes, bitbuf);
	if (bytestoa(bitbuf, nbytes, 'x', hexbuf, hsize) > hsize) {
		fprintf(stderr, "%s: can't-happen buffer overflow\n", me);
		exit(1);
	}
	if (mpz_init_set_str(var, hexbuf, 0) < 0) {
		fprintf(stderr, "%s: can't-happen hex conversion error\n", me);
		exit(1);
	}
}

/*
 - getrandom - get some random bytes from /dev/random (or wherever)
 */
void
getrandom(nbytes, buf)
size_t nbytes;
char *buf;			/* known to be big enough */
{
	size_t ndone;
	int dev;
	size_t got;

	dev = open(device, 0);
	if (dev < 0) {
		fprintf(stderr, "%s: could not open %s (%s)\n", me,
						device, strerror(errno));
		exit(1);
	}

	ndone = 0;
	if (verbose)
		fprintf(stderr, "getting %d random bytes from %s\n", nbytes,
								device);
	while (ndone < nbytes) {
		got = read(dev, buf + ndone, nbytes - ndone);
		if (got < 0) {
			fprintf(stderr, "%s: read error on %s (%s)\n", me,
						device, strerror(errno));
			exit(1);
		}
		if (got == 0) {
			fprintf(stderr, "%s: eof on %s!?!\n", me, device);
			exit(1);
		}
		ndone += got;
	}

	close(dev);
}

/*
 - hexout - prepare hex output, guaranteeing even number of digits
 * (The current FreeS/WAN conversion routines want an even digit count,
 * but mpz_get_str doesn't promise one.)
 */
char *				/* pointer to static buffer (ick) */
hexout(var)
mpz_t var;
{
	static char hexbuf[3 + MAXBITS/4 + 1];
	char *hexp;

	mpz_get_str(hexbuf+3, 16, var);
	if (strlen(hexbuf+3)%2 == 0)	/* even number of hex digits */
		hexp = hexbuf+1;
	else {				/* odd, must pad */
		hexp = hexbuf;
		hexp[2] = '0';
	}
	hexp[0] = '0';
	hexp[1] = 'x';

	return hexp;
}

/*
 - report - report progress, if indicated
 */
void
report(msg)
char *msg;
{
	if (!verbose)
		return;
	fprintf(stderr, "%s\n", msg);
}
