/*
 * Copyright (c) 2004 Nikos Mavroyanopoulos <nmav@gnutls.org>
 * Based on Peter 'Luna' Runestig's <peter@runestig.com> original code.
 *
 * This file is part of FEG GNU inetutils.
 *
 * FEG GNU inetutils is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * FEG GNU inetutils is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA
 *
 */


#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#ifdef TLS

#if 0
static char copyright[] =
    "@(#) Copyright (c) Nikos Mavroyanopoulos 2004 <nmav@gnutls.org>.\n";
#endif

#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <stdlib.h>
#include <errno.h>
#include <ctype.h>
#ifdef HAVE_NETDB_H
#include <netdb.h>
#endif				/* HAVE_NETDB_H */
#include <sys/param.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/poll.h>
#include <sys/stat.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#include <gnutls/openpgp.h>

#include <dirent.h>

#ifdef __STDC__
# include <stdarg.h>
#endif

#include "tlsutil.h"
#include "tlsnetrc.h"
#include "ftp_var.h"

typedef struct {
	gnutls_session session;
	int infd;
	int outfd;
} socket_st;

socket_st data_conn = { NULL, -1};
socket_st ctrl_conn = { NULL, -1};

inline gnutls_session SOCK_TO_TLS_SESS( int s) 
{
    if ( s == data_conn.outfd) return data_conn.session;
    else if ( s == data_conn.infd) return data_conn.session;
    else if ( s == ctrl_conn.infd) return ctrl_conn.session;
    else if ( s == ctrl_conn.outfd) return ctrl_conn.session;
    else return NULL;
}

inline socket_st* SOCK_TO_SOCKET_ST(int s)
{				\
    if ( s == data_conn.outfd) return &data_conn;
    else if (s == data_conn.infd) return &data_conn;
    else if (s == ctrl_conn.infd) return &ctrl_conn;
    else if (s == ctrl_conn.outfd) return  &ctrl_conn;
    else return NULL;
}

#define CLEAN_FD(x) (x).infd = (x).outfd = -1

#define FPUTC_BUFFERSIZE	1024

void tls_cleanup(void);
static void tls_fatal_close( socket_st* sock, gnutls_alert_description alert);
void tls_fputc_fflush(int fd);

static unsigned char fputc_buffer[FPUTC_BUFFERSIZE];
static int fputc_buflen = 0;
static int x509rc_override = 0;
static char *tls_key_file = NULL;
static char *tls_cert_file = NULL;
static char *tls_pgp_key_file = NULL;
static char *tls_pgp_cert_file = NULL;
static char *tls_ca_file = NULL;
static char *tls_pgp_ring_file = NULL;
static char *tls_crl_file = NULL;
char *tls_hostname = NULL;	/* hostname used by user to connect */
static gnutls_certificate_credentials cert_cred = NULL;
static gnutls_anon_client_credentials anon_cred = NULL;
static gnutls_srp_client_credentials  srp_cred = NULL;

static void save_cert( const gnutls_datum* raw_cert, int cert_type);

extern char* home;

char* srp_username = NULL;

int tls_on_ctrl = 1;
int tls_no_verify = 0;

#define MAX_PRIO_ELEMENTS 16
static int protocol_priority[MAX_PRIO_ELEMENTS] =
    { GNUTLS_TLS1_1, GNUTLS_TLS1, GNUTLS_SSL3, 0 };
static int kx_priority[MAX_PRIO_ELEMENTS] =
    { GNUTLS_KX_SRP_DSS, GNUTLS_KX_SRP_RSA,
      GNUTLS_KX_DHE_DSS, GNUTLS_KX_DHE_RSA, GNUTLS_KX_RSA,
      GNUTLS_KX_SRP,
      0 };
static int cipher_priority[MAX_PRIO_ELEMENTS] =
    { GNUTLS_CIPHER_AES_256_CBC, GNUTLS_CIPHER_AES_128_CBC,
      GNUTLS_CIPHER_3DES_CBC, GNUTLS_CIPHER_ARCFOUR_128, 0
    };
static int comp_priority[MAX_PRIO_ELEMENTS] =
    { GNUTLS_COMP_NULL, 0 };
static int mac_priority[MAX_PRIO_ELEMENTS] =
    { GNUTLS_MAC_SHA, GNUTLS_MAC_MD5, 0 };
static int cert_type_priority[MAX_PRIO_ELEMENTS] =
    { GNUTLS_CRT_X509, GNUTLS_CRT_OPENPGP, 0 };


int tls_active(int s)
{
	if (SOCK_TO_TLS_SESS(s))
		return 1;
	else
		return 0;
}

#define SU(x) (x==NULL?"Unknown":x)

static char *tls_get_cipher(gnutls_session s)
{
	static char r[220];

	if (s) {
		snprintf(r, sizeof(r), "%s cipher %s/%s (%d bits)",
			 SU(gnutls_protocol_get_name
			    (gnutls_protocol_get_version(s))),
			 SU(gnutls_kx_get_name(gnutls_kx_get(s))),
			 SU(gnutls_cipher_get_name
			    (gnutls_cipher_get(s))),
			 gnutls_cipher_get_key_size(gnutls_cipher_get
						    (s)) * 8);
	} else
		snprintf(r, sizeof(r), "clear");
	return r;
}

static char *tls_get_comp(gnutls_session s)
{
	static char r[220];

	if (s) {
		snprintf(r, sizeof(r), "%s",
			 SU(gnutls_compression_get_name
			    (gnutls_compression_get(s))));
	} else
		snprintf(r, sizeof(r), "none");
	return r;
}


char *tls_get_cipher_info_string(int fd)
{
	return tls_get_cipher(SOCK_TO_TLS_SESS(fd));
}

int tls_optarg(char *optarg)
{
	char *p;

	if ((p = strchr(optarg, '='))) {
		*p++ = 0;
		if (!strcmp(optarg, "cert")) {
			tls_cert_file = strdup(p);
			x509rc_override = 1;
		} else if (!strcmp(optarg, "key"))
			tls_key_file = strdup(p);
		else if (!strcmp(optarg, "pgpkey"))
			tls_pgp_key_file = strdup(p);
		else if (!strcmp(optarg, "pgpring"))
			tls_pgp_ring_file = strdup(p);
		else if (!strcmp(optarg, "pgpcert"))
			tls_pgp_cert_file = strdup(p);
		else if (!strcmp(optarg, "ca"))
			tls_ca_file = strdup(p);
		else if (!strcmp(optarg, "crl"))
			tls_crl_file = strdup(p);
		else
			return 1;
	} else if (!strcmp(optarg, "noprotect")) {
		tls_on_ctrl = 0;
		tls_on_data = 0;
		require_tls = 0;
	} else if (!strcmp(optarg, "private"))
		tls_on_data = 1;
	else if (!strcmp(optarg, "requiretls"))
		require_tls = 1;
	else if (!strcmp(optarg, "compress")) {
		comp_priority[0] = GNUTLS_COMP_DEFLATE;
		comp_priority[1] = GNUTLS_COMP_NULL;
		comp_priority[2] = 0;
	} else if (!strcmp(optarg, "nosrp")) {
		kx_priority[0] = GNUTLS_KX_DHE_DSS;
		kx_priority[1] = GNUTLS_KX_DHE_RSA;
		kx_priority[2] = GNUTLS_KX_RSA;
		kx_priority[3] = 0;
	} else if (!strcmp(optarg, "certsok"))
		tls_no_verify = 1;
	else
		return 1;

	return 0;
}

static const char *tls_get_auth_type( gnutls_session session)
{
int cred_type, ret;
	
	cred_type = gnutls_auth_get_type(session);

	if (cred_type == GNUTLS_CRD_ANON)
		return "Anonymous";
	else if (cred_type == GNUTLS_CRD_SRP)
		return "Anonymous SRP";
	else if (cred_type == GNUTLS_CRD_CERTIFICATE) {
		ret = gnutls_certificate_type_get(session);
		if (ret == GNUTLS_CRT_OPENPGP)
			return "OpenPGP certificate";
		else if (ret == GNUTLS_CRT_X509)
			return "X.509 certificate";
		else return "Unknown certificate";
	} else return "Unknown authentication type.";

}

static const char *tls_get_subject_info(gnutls_session session, 
	const char** issuer, const char** fpr)
{
	const gnutls_datum *raw_cert_list;
	unsigned int raw_cert_list_length;
	static char name[256];
	static char issuer_name[256];
	static char fingerprint[256];
	unsigned char tmp_fingerprint[64];
	int ret, cred_type, i;
	size_t name_size, fpr_size;

	cred_type = gnutls_auth_get_type(session);

	*issuer = "None";
	*fpr = "None";

	if (cred_type == GNUTLS_CRD_ANON || cred_type == GNUTLS_CRD_SRP)
		return "Anonymous";
	else if (cred_type == GNUTLS_CRD_CERTIFICATE) {
		raw_cert_list =
		    gnutls_certificate_get_peers(session,
						 &raw_cert_list_length);
		if (raw_cert_list == NULL) {
			return NULL;
		}

		ret = gnutls_certificate_type_get(session);
		if (ret == GNUTLS_CRT_X509) {
			gnutls_x509_crt cert = NULL;
			/* Import the certificate to the x509_crt format.
			 */
			if (gnutls_x509_crt_init(&cert) < 0) {
				return NULL;
			}

			if (gnutls_x509_crt_import
			    (cert, &raw_cert_list[0],
			     GNUTLS_X509_FMT_DER) < 0) {
				gnutls_x509_crt_deinit(cert);
				return NULL;
			}

			name[0] = 0;

			name_size = sizeof(name);
			gnutls_x509_crt_get_dn(cert, name, &name_size);

			name_size = sizeof(issuer_name);
			gnutls_x509_crt_get_issuer_dn(cert, issuer_name, &name_size);
			
			*issuer = issuer_name;
			
			fpr_size = sizeof(tmp_fingerprint);
			gnutls_x509_crt_get_fingerprint(cert, 
				GNUTLS_DIG_MD5, tmp_fingerprint, &fpr_size);

			gnutls_x509_crt_deinit(cert);
		} else if (ret == GNUTLS_CRT_OPENPGP) {
			gnutls_openpgp_key cert = NULL;
			/* Import the certificate to the openpgp_key format.
			 */
			if (gnutls_openpgp_key_init(&cert) < 0) {
				return NULL;
			}

			if (gnutls_openpgp_key_import
			    (cert, &raw_cert_list[0],
			     GNUTLS_OPENPGP_FMT_RAW) < 0) {
				gnutls_openpgp_key_deinit(cert);
				return NULL;
			}

			name[0] = 0;
			name_size = sizeof(name);
			gnutls_openpgp_key_get_name(cert, 0, name,
						    &name_size);

			fpr_size = sizeof(tmp_fingerprint);
			gnutls_openpgp_key_get_fingerprint(cert, tmp_fingerprint,
						    &fpr_size);

			gnutls_openpgp_key_deinit(cert);
		} else
			return NULL;
	} else {
		warnx("Unknown credentials type: %d\n",
			cred_type);
		return NULL;	/* unknown cred type. */
	}

	if (fpr_size * 3 + 1 >= sizeof(fingerprint))
		return NULL;

	for (i = 0; i < fpr_size; i++) {
		sprintf(&(fingerprint[i * 3]), "%02X%s", tmp_fingerprint[i], (i==fpr_size-1)?"":":");
	}
	fingerprint[sizeof(fingerprint) - 1] = '\0';

	*fpr = fingerprint;

	return name;
}

static char read_char(void)
{
	char inl[10];

	fgets(inl, sizeof(inl), stdin);

	return *inl;
}

static int srp_username_cb(gnutls_session session,
                 unsigned int times, char **username, char **password) 
{
const char* user;
char* pass;

	/* Only ask for password if SRP has been negotiated.
	 */
	if (times==0) return -1;

	printf("[Negotiating TLS-SRP. Please supply your credentials.]\r\n");

	user = get_username( tls_hostname);
	if (user == NULL)
		return -1;
	pass = get_password(tls_hostname, 1);
	if (pass == NULL)
		return -1;

	if (srp_username==NULL)
		srp_username = strdup(user);

	*username = gnutls_strdup(user);
	*password = gnutls_strdup(pass);
	delete_password( pass);

	return 0;
}

static void tls_log_func(int level, const char *str)
{
	warnx("TLS: %s", str);
}

int tls_init(void)
{
	int ret;
	tlsparams_st st;

	ret = gnutls_global_init();
	if (ret < 0) {
		warnx("gnutls_global_init(): %s", gnutls_strerror(ret));
		return 1;
	}

	ret = gnutls_global_init_extra();
	if (ret < 0) {
		warnx("gnutls_global_init(): %s", gnutls_strerror(ret));
		return 1;
	}

	if (verbose > 3) {
		gnutls_global_set_log_function(tls_log_func);
		gnutls_global_set_log_level(2);
	}

	st.cafile = tls_ca_file;
	st.crlfile = tls_crl_file;
	st.cert = tls_cert_file;
	st.key = tls_key_file;
	st.pgpcert = tls_pgp_cert_file;
	st.pgpkey = tls_pgp_key_file;
	st.pgpring = tls_pgp_ring_file;
	st.require_tls = require_tls;
	st.private = tls_on_data;

	ret = tlsparams(tls_hostname, &st);
	if (ret < 0) {
		return 1;
	}
	

	if ((ret =
	     gnutls_certificate_allocate_credentials(&cert_cred)) < 0) {
		warnx("gnutls_certificate_allocate_credentials(): %s",
			gnutls_strerror(ret));
		return 1;
	}
	
	if ((ret =
	     gnutls_anon_allocate_client_credentials(&anon_cred)) < 0) {
		warnx("gnutls_anon_allocate_client_credentials(): %s",
			gnutls_strerror(ret));
		return 1;
	}

	if ((ret =
	     gnutls_srp_allocate_client_credentials(&srp_cred)) < 0) {
		warnx("gnutls_srp_allocate_client_credentials(): %s",
			gnutls_strerror(ret));
		return 1;
	}
	
	gnutls_srp_set_client_credentials_function(srp_cred,
		srp_username_cb);

	/* set up the CApath if defined */
	ret = 0;
	if (st.cafile) {
		ret =
		    gnutls_certificate_set_x509_trust_file(cert_cred,
							   st.cafile,
							   GNUTLS_X509_FMT_PEM);
		if (ret < 0) {
			warnx("gnutls_certificate_set_x509_trust_file(): %s",
				gnutls_strerror(ret));
		}
	}
	if (verbose > 1) {
		printf("Loaded %d trusted CAs\n", ret);
	}

	/* set up the CRL file if defined */
	ret = 0;
	if (st.crlfile) {
		ret =
		    gnutls_certificate_set_x509_crl_file(cert_cred,
							 st.crlfile,
							 GNUTLS_X509_FMT_PEM);
		if (ret < 0) {
			warnx("gnutls_certificate_set_x509_crl_file(): %s",
				gnutls_strerror(ret));
		}
	}
	if (verbose > 1) {
		printf("Loaded %d trusted CRLs\n", ret);
	}

	if (st.cert) {
		char *key_file = tls_key_file;
		if (!key_file)
			key_file = tls_cert_file;

		ret = gnutls_certificate_set_x509_key_file(cert_cred,
							   st.cert,
							   st.key,
							   GNUTLS_X509_FMT_PEM);
		if (ret < 0) {
			warnx("gnutls_certificate_set_x509_key_file(\"%s\") %s\r\n",
				st.cert, gnutls_strerror(ret));
			return 1;
		}
	}

	if (st.pgpcert) {
		char *pgp_key_file = tls_pgp_key_file;
		if (!pgp_key_file)
			pgp_key_file = tls_pgp_cert_file;

		ret = gnutls_certificate_set_openpgp_key_file(cert_cred,
							   st.pgpcert,
							   st.pgpkey);
		if (ret < 0) {
			warnx("gnutls_certificate_set_openpgp_key_file(\"%s\") %s\r\n",
				st.pgpcert, gnutls_strerror(ret));
			return 1;
		}
	}

	if (st.pgpring) {
		ret = gnutls_certificate_set_openpgp_keyring_file(cert_cred,
							   st.pgpring);
		if (ret < 0) {
			warnx("gnutls_certificate_set_openpgp_keyring_file(\"%s\") %s\r\n",
				st.pgpring, gnutls_strerror(ret));
			return 1;
		}
	}
	
	tls_on_data = st.private;
	require_tls = st.require_tls;
	
	if (tls_on_data == -1)
		tls_on_data = 0;
	if (require_tls == -1)
		require_tls = 0;

	return 0;
}


static int tls_session_init(gnutls_session * s, int infd, int outfd)
{
	int ret;

	ret = gnutls_init(s, GNUTLS_CLIENT);
	if (ret < 0) {
		warnx("gnutls_init() %s\r\n",
			gnutls_strerror(ret));
		return ret;
	}

	gnutls_cipher_set_priority(*s, cipher_priority);
	gnutls_kx_set_priority(*s, kx_priority);
	gnutls_mac_set_priority(*s, mac_priority);
	gnutls_protocol_set_priority(*s, protocol_priority);
	gnutls_compression_set_priority(*s, comp_priority);
	gnutls_certificate_type_set_priority(*s, cert_type_priority);
	
	gnutls_transport_set_ptr2(*s, (gnutls_transport_ptr) infd,
		(gnutls_transport_ptr) outfd);
	gnutls_credentials_set(*s, GNUTLS_CRD_CERTIFICATE, cert_cred);
	gnutls_credentials_set(*s, GNUTLS_CRD_ANON, anon_cred);
	gnutls_credentials_set(*s, GNUTLS_CRD_SRP, srp_cred);
	return 0;
}

inline
static int show_hostname_warning(char *s1, char *s2)
{
	char inp;

	warnx("Hostname (\"%s\") and server's certificate (\"%s\") don't match.",
		s1, s2);
}


#ifndef INADDR_NONE
#define INADDR_NONE -1
#endif				/* !INADDR_NONE */

/* Checks if the certificates in session1 matches the certificate
 * in session2. Returns 0 on match and non zero otherwise.
 *
 * This is to check if the data connection certificate is the
 * same as in the control connection.
 */
static int check_cert_match(gnutls_session s1, gnutls_session s2)
{
	const gnutls_datum *cert1, *cert2;
	unsigned int cert1_length, cert2_length;
	int type1, type2;

	if (tls_no_verify != 0)
		return 0;
	
	if (s1 == NULL || s2 == NULL) 
		return 1;

	type1 = gnutls_auth_get_type(s1);
	type2 = gnutls_auth_get_type(s2);
	if (type1 == type2 && type1 != GNUTLS_CRD_CERTIFICATE)
		return 0; /* ok even if we have no certificates */

	cert1 =
	    gnutls_certificate_get_peers(s1, &cert1_length);
	if (cert1 == NULL)
		return 1;

	cert2 =
	    gnutls_certificate_get_peers(s2, &cert2_length);
	if (cert2 == NULL)
		return 1;
	
	if (cert1->size != cert2->size || cert2_length != cert1_length)
		return 1;
		
	if (memcmp(cert1->data, cert2->data, cert1->size) != 0)
		return 1;
	
	return 0; /* the certificates are the same. */
}


/* returns 0 if hostname and server's cert matches, else 1 
 */
static int check_server_x509_cert(const gnutls_datum* raw_cert_list)
{
	struct in_addr ia;
	gnutls_x509_crt cert;
	int ret = 0, i, err;
	char buf[256];
	size_t buf_size;

	/* Import the certificate to the x509_crt format.
	 */
	if (gnutls_x509_crt_init(&cert) < 0) {
		return 1;
	}

	if (gnutls_x509_crt_import
	    (cert, &raw_cert_list[0], GNUTLS_X509_FMT_DER) < 0) {
		err = 1;
		goto finish;
	}

	/* first we check if `tls_hostname' is in fact an ip address */
	if ((ia.s_addr = inet_addr(tls_hostname)) != INADDR_NONE) {

		/* First, check the subjectAltName X509v3 extensions, as is proper, for
		 * the IP address and FQDN.  If enough people clamor for backward
		 * compatibility, I'll amend this to check commonName later.  Otherwise,
		 * for now, only look in the extensions.
		 */
		for (i = 0; !(ret < 0); i++) {
			buf_size = sizeof(buf);

			ret =
			    gnutls_x509_crt_get_subject_alt_name(cert, i,
								 buf,
								 &buf_size,
								 NULL);

			if (ret == GNUTLS_SAN_IPADDRESS) {
				if (strcasecmp(buf, tls_hostname) == 0) {
					err = 0;
					goto finish;
				}
			}
		}

		show_hostname_warning(tls_hostname, "NO IP IN CERT");
		goto finish;
	}

	if (gnutls_x509_crt_check_hostname(cert, tls_hostname)) {
		err = 0;
		goto finish;
	}

	/* otherwise ask or fail.
	 */
	buf_size = sizeof(buf);
	if (gnutls_x509_crt_get_dn_by_oid
	    (cert, GNUTLS_OID_X520_COMMON_NAME, 0, 0, buf,
	     &buf_size) >= 0) {
		if (strcasecmp(buf, tls_hostname) == 0) {
			return 0;
		}
		show_hostname_warning(tls_hostname, buf);
		goto finish;
	}

	err = 1;

      finish:
	gnutls_x509_crt_deinit(cert);
	return err;
}

/* returns 0 if hostname and server's cert matches, else 1 
 */
static int check_server_pgp_cert(const gnutls_datum* raw_cert)
{
	gnutls_openpgp_key key;
	int err;
	char buf[256];
	size_t buf_size;

	/* Import the certificate to the x509_crt format.
	 */
	if (gnutls_openpgp_key_init(&key) < 0) {
		return 1;
	}

	if (gnutls_openpgp_key_import
	    (key, raw_cert, GNUTLS_OPENPGP_FMT_RAW) < 0) {
		err = 1;
		goto finish;
	}

	if (gnutls_openpgp_key_check_hostname(key, tls_hostname)) {
		err = 0;
		goto finish;
	}

	/* otherwise ask or fail.
	 */
	buf_size = sizeof(buf);
	if (gnutls_openpgp_key_get_name
	    (key, 0, buf, &buf_size) >= 0) {
		if (strcasecmp(buf, tls_hostname) == 0) {
			return 0;
		}
		err = show_hostname_warning(tls_hostname, buf);
		goto finish;
	}

	err = 1;

      finish:
	gnutls_openpgp_key_deinit(key);
	return err;
}


/* Create the file string the contains the
 * certificate's file.
 */
inline
static char* create_home_cert_file( int cert_type)
{
char* x;

	if (home == NULL || tls_hostname == NULL) 
		return NULL;

	x = malloc(strlen(home)+strlen(tls_hostname)+32);
	if (x==NULL)
		return NULL;
	
	strcpy( x, home);
	strcat( x, "/.certs/");
	strcat( x, tls_hostname);
	if (cert_type==GNUTLS_CRT_X509)
		strcat( x, ".pem");
	else
		strcat( x, ".pgp.asc");

	return x;
}

/* Create the file string the contains the
 * certificate's directory.
 */
inline
static char* create_home_cert_dir( void)
{
char* x;

	if (home == NULL) 
		return NULL;

	x = malloc(strlen(home)+32);
	if (x==NULL)
		return NULL;
	
	strcpy( x, home);
	strcat( x, "/.certs/");

	return x;
}

/* Loads the requested file into a memory buffer (locally allocated)
 * If null_term is set the memory buffer will be null terminated.
 */
static gnutls_datum load_file(const char *file, int null_term)
{
	FILE *fp;
	gnutls_datum ret = { NULL, 0 };
	struct stat stat_st;
	size_t tot_size, data_read;
	const gnutls_datum null = { NULL, 0 };

	if (file == NULL)
		return null;

	fp = fopen(file, "r");
	if (fp == NULL) {
		return null;
	}

	fstat(fileno(fp), &stat_st);

	tot_size = stat_st.st_size;
	if (tot_size == 0) {
		goto error;
	}

	ret.data = malloc(tot_size + 1);
	if (ret.data == NULL) {
		goto error;
	}

	data_read = fread(ret.data, 1, tot_size, fp);

	ret.size = data_read;

	fclose(fp);

	if (null_term) {
		ret.data[ret.size] = 0;
	}

	return ret;

      error:
	if (fp != NULL)
		fclose(fp);
	free(ret.data);
	return null;
}

static void unload_file(gnutls_datum * mem_file)
{
	free(mem_file->data);
	mem_file->data = NULL;
	mem_file->size = 0;
}


/* Checks if the given certificate exists in the
 * ~/.certs/ directory and it is the same.
 *
 * Returns 0 if they do not match and non-zero otherwise.
 */
static int check_file_cert( const gnutls_datum* raw_cert, int cert_type)
{
char* cert_file = create_home_cert_file( cert_type);
gnutls_datum mem_file;
unsigned char* buf;
size_t buf_size;

	if (cert_file == NULL)
		return 0;

	mem_file = load_file(cert_file, 1);
	free( cert_file);

	if (mem_file.data == NULL)
		return 0;

	if (cert_type==GNUTLS_CRT_X509) {
		gnutls_x509_crt cert;

		if (gnutls_x509_crt_init(&cert) < 0)
			return 0;

		if (gnutls_x509_crt_import
		    (cert, &mem_file, GNUTLS_X509_FMT_PEM) < 0) {
			gnutls_x509_crt_deinit(cert);
			return 0;
		}
		unload_file(&mem_file);

		buf = malloc( raw_cert->size);
		buf_size = raw_cert->size;

		if (gnutls_x509_crt_export
		    (cert, GNUTLS_X509_FMT_DER, buf, &buf_size) < 0) {
			gnutls_x509_crt_deinit(cert);
			free(buf);
			return 0;
		}
		gnutls_x509_crt_deinit( cert);
		
		if ( raw_cert->size == buf_size &&
			memcmp( buf, raw_cert->data, raw_cert->size)==0) {
			free(buf);
			return 1; /* match! */
		}
		free(buf);
	
	} else if (cert_type == GNUTLS_CRT_OPENPGP) {
		gnutls_openpgp_key cert;
		unsigned char fpr[20];
		size_t fpr_size = sizeof(fpr);

		/* Here we compare fingerprints.
		 */
		if (gnutls_openpgp_key_init(&cert) < 0)
			return 0;

		if (gnutls_openpgp_key_import
		    (cert, &mem_file, GNUTLS_OPENPGP_FMT_BASE64) < 0) {
			gnutls_openpgp_key_deinit(cert);
			return 0;
		}
		unload_file(&mem_file);

		buf = malloc( 20);
		buf_size = 20;

		if (gnutls_openpgp_key_get_fingerprint
		    (cert, buf, &buf_size) < 0) {
			gnutls_openpgp_key_deinit(cert);
			free(buf);
			return 0;
		}
		gnutls_openpgp_key_deinit( cert);


		/* read the second certificate
		 */
		if (gnutls_openpgp_key_init(&cert) < 0) {
			free(buf);
			return 0;
		}

		if (gnutls_openpgp_key_import
		    (cert, raw_cert, GNUTLS_OPENPGP_FMT_RAW) < 0) {
			gnutls_openpgp_key_deinit(cert);
			free(buf);
			return 0;
		}

		if (gnutls_openpgp_key_get_fingerprint
		    (cert, fpr, &fpr_size) < 0) {
			gnutls_openpgp_key_deinit(cert);
			free(buf);
			return 0;
		}
		gnutls_openpgp_key_deinit( cert);

		
		if (buf_size == 20 && memcmp( buf, fpr, buf_size)==0) {
			free(buf);
			return 1; /* match! */
		}
		free(buf);
	}
	
	unload_file(&mem_file);
	return 0;
}

/* Saves the given certificate exists in the
 * ~/.certs/ directory.
 *
 */
static void save_cert( const gnutls_datum* raw_cert, int cert_type)
{
char* cert_file = NULL, *cert_dir = NULL;
char* buf = NULL;
size_t buf_size;
FILE* fd = NULL;

	cert_dir = create_home_cert_dir();
	if (cert_dir) {
		mkdir( cert_dir, S_IWUSR|S_IRUSR|S_IXUSR);
		free( cert_dir);
	}
	
	cert_file = create_home_cert_file( cert_type);
	if (cert_file == NULL)
		return;

	fd = fopen( cert_file, "w");
	if (fd == NULL) {
		warn("Could not open: %s", cert_file);
		free(cert_file);
		goto fail;
	}
	free(cert_file);

	if (cert_type==GNUTLS_CRT_X509) {
		gnutls_x509_crt cert;

		if (gnutls_x509_crt_init(&cert) < 0)
			goto fail;

		if (gnutls_x509_crt_import
		    (cert, raw_cert, GNUTLS_X509_FMT_DER) < 0) {
			gnutls_x509_crt_deinit(cert);
			goto fail;
		}
		
		buf_size = 0;
		gnutls_x509_crt_export(cert, GNUTLS_X509_FMT_PEM, NULL, &buf_size);

		buf = malloc( buf_size);

		if (gnutls_x509_crt_export
		    (cert, GNUTLS_X509_FMT_PEM, buf, &buf_size) < 0) {
			gnutls_x509_crt_deinit(cert);
			free(buf);
			goto fail;
		}
		
		gnutls_x509_crt_deinit( cert);
		
		fwrite( buf, 1, buf_size, fd);
		
		free(buf);
	
	} else if (cert_type == GNUTLS_CRT_OPENPGP) {
		gnutls_openpgp_key cert;

		if (gnutls_openpgp_key_init(&cert) < 0)
			goto fail;

		if (gnutls_openpgp_key_import
		    (cert, raw_cert, GNUTLS_OPENPGP_FMT_RAW) < 0) {
			gnutls_openpgp_key_deinit(cert);
			goto fail;
		}

		buf_size = 0;
		gnutls_openpgp_key_export(cert, GNUTLS_OPENPGP_FMT_BASE64, NULL, &buf_size);

		buf = malloc( buf_size);

		if (gnutls_openpgp_key_export
		    (cert, GNUTLS_OPENPGP_FMT_BASE64, buf, &buf_size) < 0) {
		    	gnutls_openpgp_key_deinit(cert);
			free(buf);
			goto fail;
		}
		
		gnutls_openpgp_key_deinit( cert);
		
		fwrite( buf, 1, buf_size, fd);
		
		free(buf);

	}

	fclose(fd);
	
	return;

fail:
	free( cert_dir);
	free( cert_file);
	free( buf);
	if (fd) fclose(fd);
}


/* returns 0 if hostname and server's cert matches, else 1 
 */
static int check_server_cert(gnutls_session s)
{
	const gnutls_datum *raw_cert_list;
	unsigned int raw_cert_list_length;
	int status, verified = 1, type, ret;
	char inp;

	if (tls_no_verify != 0)
		return 0;

	type = gnutls_certificate_type_get(s);

	raw_cert_list =
	    gnutls_certificate_get_peers(s, &raw_cert_list_length);
	if (raw_cert_list == NULL) {
		return 1;
	}
	
	/* Fistly check if the certificate exists in the trusted certificates
	 * database (actually a directory).
	 */
	if (check_file_cert( &raw_cert_list[0], type)!=0)
		return 0; /* it's ok */



	type = gnutls_auth_get_type(s);
	if (type != GNUTLS_CRD_CERTIFICATE)
		return 0; /* ok even if we have no certificates */

	status = gnutls_certificate_verify_peers(s);
	if (status < 0)
		verified = 0;

	if (verified) {
		if (status & GNUTLS_CERT_SIGNER_NOT_FOUND) {
			warnx("Server's certificate signer is not trusted.");
			verified = 0;
		}

		if (status & GNUTLS_CERT_SIGNER_NOT_CA) {
			warnx("Server's certificate signer is not a CA.");
			verified = 0;
		}

		if (status & GNUTLS_CERT_REVOKED) {
			warnx("Server's certificate has been revoked.");
			verified = 0;
		}

		if (status & GNUTLS_CERT_INVALID) {
			if (verified)
				warnx("Server's certificate signature is invalid.");
			verified = 0;
		}
	}

	if (type == GNUTLS_CRT_X509)
		ret = check_server_x509_cert( raw_cert_list);
	else if (type == GNUTLS_CRT_OPENPGP)
		ret = check_server_pgp_cert( raw_cert_list);
	else return 1;

	if (!verified || ret) {
		warnx("Errors while verifying the server's certificate chain, continue? (Y/N) ");
		inp = read_char();
		
		if (!(inp == 'y' || inp == 'Y'))
			quit(0,0);
		
		/* Otherwise save the certificate for future use.
		 */
		save_cert( &raw_cert_list[0], type);
	}

	return 0;
}

static
void handle_alert( gnutls_session s, int ret)
{
	int fatal = 0;
	const char *alert_name;

	if (ret != GNUTLS_E_FATAL_ALERT_RECEIVED &&
		ret != GNUTLS_E_WARNING_ALERT_RECEIVED)
		return;

	alert_name =
		    SU(gnutls_alert_get_name
		       (gnutls_alert_get(s)));

	if (ret == GNUTLS_E_FATAL_ALERT_RECEIVED)
		fatal = 1;

	warnx("Server sent %s TLS alert[%d]: %s",
		fatal ? "fatal" : "warning", gnutls_alert_get(s),
		alert_name);
}

static unsigned char last_session_data[5*1024];
static size_t last_session_data_size = 0;

/* Perform a TLS handshake on the given socket.
 * If the data != 0 then a data connection is assumed.
 */
int tls_handshake( socket_st* conn, int data)
{
int ret;
const char *subject, *issuer, *fpr;
char inp;

	if (last_session_data_size > 0)
		gnutls_session_set_data( conn->session, last_session_data, last_session_data_size);

	do {
		ret = gnutls_handshake(conn->session);
	} while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);

	if (ret >= 0) {
	
		/* save data
		 */
		last_session_data_size = sizeof( last_session_data);
		if (gnutls_session_get_data( conn->session, last_session_data, &last_session_data_size) < 0)
			last_session_data_size = 0;
	
		if (data == 0) {
			printf("[Server authentication type: %s]\r\n", tls_get_auth_type(conn->session));
		
			if ((subject = tls_get_subject_info(conn->session, &issuer, &fpr))) {
				printf( "[Subject: %s]\r\n", subject);
				printf( "[Issuer:  %s]\r\n", issuer);
				printf( "[Fingerprint:  %s]\r\n", fpr);
			} else {
				warnx("Server didn't provide a certificate, continue? (Y/N) ");
				inp = read_char();
				if (!(inp == 'y' || inp == 'Y'))
					return 4;
			}
		
			printf( "[%s]\r\n",
				tls_get_cipher(conn->session));
			printf( "Compression: %s\r\n",
				tls_get_comp(conn->session));

			if (check_server_cert(conn->session)) {
				/* the host name on the command line didn't match with the server's
				 * cert, and the user didn't ansver `Y' to the question.
			 	*/
				tls_fatal_close( conn, GNUTLS_A_BAD_CERTIFICATE);
				return 5;
			}
		} else { /* DATA connection. */
			/* Check if the certificate matches the ctrl connection.
			 */
			if (check_cert_match(conn->session, ctrl_conn.session)) {
				warnx("ERROR: The server sent a different certificate in data connection");
				/* the host name on the command line didn't match with the server's
				 * cert, and the user didn't ansver `Y' to the question.
			 	*/
				tls_fatal_close( conn, GNUTLS_A_BAD_CERTIFICATE);
				return 5;
			}
		}
		
		return 0;
	} else {		/* TLS connection failed */
		warnx("TLS error: %s", gnutls_strerror(ret));
		handle_alert( conn->session, ret);		
		tls_fatal_close( conn, GNUTLS_A_HANDSHAKE_FAILURE);
		return 6;
	}

}

int tls_connect_ctrl(int in, int out)
{
	int ret;

	if (ctrl_conn.session) {
		printf("Already TLS connected!\r\n");
		return 1;
	}

	ret = tls_session_init(&ctrl_conn.session, in, out);
	if (ret < 0) {
		warnx("tls_session_init() %s\r\n",
			gnutls_strerror(ret));
		return 5;
	}

	ctrl_conn.infd = in;
	ctrl_conn.outfd = out;

	printf("[Starting SSL/TLS negotiation...]\r\n");
	return tls_handshake( &ctrl_conn, 0);
}

int tls_connect_data(int s)
{
	int ret;

	if (data_conn.session) {
		printf("Already TLS connected!\r\n");
		return 1;
	}

	ret = tls_session_init(&data_conn.session, s, s);
	if (ret < 0) {
		warnx("tls_session_init() %s\r\n",
			gnutls_strerror(ret));
		return 2;
	}
	
	data_conn.outfd = s;
	data_conn.infd = s;

	ret = tls_handshake( &data_conn, 1);
	if (ret != 0)
		return ret;
		
	return 0;
}

static void tls_fatal_close( socket_st* sock, gnutls_alert_description alert)
{
	if (sock) {
		gnutls_alert_send(sock->session, GNUTLS_AL_FATAL, alert);
		gnutls_deinit(sock->session);
		close(sock->infd);
		close(sock->outfd);
		sock->session = NULL;
		CLEAN_FD(*sock);
	}

}

static void tls_close_int(socket_st *sock)
{
	if (sock) {
		if (sock->outfd >= 0)
			tls_fputc_fflush(sock->outfd);
	
		if (sock->session) {
			gnutls_bye(sock->session, GNUTLS_SHUT_WR);
			gnutls_deinit(sock->session);
		}
		sock->session = NULL;
		
		if (sock->infd != -1) {
			close(sock->infd);
			close(sock->outfd);
		}
		CLEAN_FD(*sock);
	}
}

int tls_fclose(FILE * stream)
{
	socket_st * s = SOCK_TO_SOCKET_ST(fileno(stream));

	tls_close_int(s);
	return fclose(stream);
}

int tls_close(int fd)
{
	socket_st *s = SOCK_TO_SOCKET_ST(fd);

	tls_close_int(s);
	return close(fd);
}

int tls_shutdown(int s, int how)
{
	/* if s == -1, do a `global' SSL shutdown, else emulate a shutdown(2) */
	if (s == -1) {
		tls_close_int(&data_conn);
		tls_close_int(&ctrl_conn);

		tls_cleanup();

		return 0;
	} else {
		socket_st *sock = SOCK_TO_SOCKET_ST(s);
		tls_close_int(sock);
		return shutdown(s, how);
	}
}

void tls_free_ssls(void)
{
	tls_close_int(&data_conn);
	tls_close_int(&ctrl_conn);
}

void tls_cleanup(void)
{
	if (data_conn.session)
		tls_close_int(&data_conn);

	if (ctrl_conn.session)
		tls_close_int(&ctrl_conn);

	if (cert_cred) {
		gnutls_certificate_free_credentials(cert_cred);
		cert_cred = NULL;
	}

	if (anon_cred) {
		gnutls_anon_free_client_credentials(anon_cred);
		anon_cred = NULL;
	}
	
	if (srp_cred) {
		gnutls_srp_free_client_credentials(srp_cred);
		free( srp_username);
		srp_username = NULL;
		srp_cred = NULL;
	}
	
	gnutls_global_deinit();


}

static void handle_ssl_error(int error, char *where)
{
	switch (error) {
	case 0:
		return;
	default:
		warnx("unhandled TLS error: %s in %s\r\n",
			gnutls_strerror(error), where);
	}

	tls_shutdown(-1, 0);
	tls_cleanup();
	exit(1);
}

static int select_read(int rfd)
/* timeout = 20 seconds */
{
	fd_set rfds;
	struct timeval tv;

	FD_ZERO(&rfds);
	FD_SET(rfd, &rfds);
	tv.tv_sec = 20;
	tv.tv_usec = 0;
	return select(rfd + 1, &rfds, NULL, NULL, &tv);
}

ssize_t tls_read(int fd, void *buf, size_t count)
{
	gnutls_session s = SOCK_TO_TLS_SESS(fd);

      retry:
	if (s) {
		ssize_t c = gnutls_record_recv(s, buf, count);
		if (c < 0) {
			int err = c;
			/* read(2) returns only the generic error number -1 */
			c = -1;
			switch (err) {
			case GNUTLS_E_AGAIN:
			case GNUTLS_E_INTERRUPTED:
				/* GNUTLS needs more data from the wire to finish the current block,
				 * so we wait a little while for it. */
				err = select_read(fd);
				if (err > 0)
					goto retry;
				else if (err == 0)
					/* still missing data after timeout, emulate an EINTR and return. */
					errno = EINTR;
				/* if err < 0, i.e. some error from the select(), everything is already
				 * in place; errno is properly set and this function returns -1. */
				break;
			case GNUTLS_E_PULL_ERROR: /* server had problems */
			case GNUTLS_E_UNEXPECTED_PACKET_LENGTH:
			case GNUTLS_E_INVALID_SESSION:
				/* return an EOF emulation.
				 */
				return 0;
			case GNUTLS_E_REHANDSHAKE:
				/* FIXME: for now
				 */
				gnutls_alert_send( s, GNUTLS_AL_WARNING, GNUTLS_A_NO_RENEGOTIATION);
				errno = EINTR;
				break;
			case GNUTLS_E_WARNING_ALERT_RECEIVED:
			case GNUTLS_E_FATAL_ALERT_RECEIVED:
				
				handle_alert( s, err);

				if (err == GNUTLS_E_FATAL_ALERT_RECEIVED)
					return -1;

				/* Otherwise simulate EINTR and return;
				 */
				errno = EINTR;

				break;
			default:
				handle_ssl_error(err, "tls_read()");
				break;
			}
		}
		return c;
	} else
		return read(fd, buf, count);
}

int tls_fgetc(FILE * stream)
{
	gnutls_session s = SOCK_TO_TLS_SESS(fileno(stream));

	if (s) {
		unsigned char r;
		int err;
		do
			err = tls_read(fileno(stream), &r, 1);
		while (err < 0 && errno == EINTR);
		if (err == 1)
			return (int) r;
		else
			return EOF;
	} else
		return fgetc(stream);
}

int tls_recv(int fd, void *buf, size_t len, int flags)
{
	gnutls_session s = SOCK_TO_TLS_SESS(fd);

	if (s)
		return (int) tls_read(fd, buf, len);
	else
		return recv(fd, buf, len, flags);
}

ssize_t tls_write(int fd, const void *buf, size_t count)
{
	gnutls_session s = SOCK_TO_TLS_SESS(fd);

	if (s) {
		ssize_t c = gnutls_record_send(s, buf, count);
		if (c < 0) {
			int err = c;
			/* write(2) returns only the generic error number -1 */
			c = -1;
			switch (err) {
			case GNUTLS_E_AGAIN:
			case GNUTLS_E_INTERRUPTED:
				/* simulate an EINTR in case GNUTLS wants to write more */
				errno = EINTR;
				break;
			case GNUTLS_E_PUSH_ERROR:
				errno = EPIPE;
				break;
			default:
				handle_ssl_error(err, "tls_write()");
				break;
			}
		}
		return c;
	} else
		return write(fd, buf, count);
}

int tls_fputs(const char *str, FILE * stream)
{
	gnutls_session s = SOCK_TO_TLS_SESS(fileno(stream));

	if (s) {
		int sent = 0, size, w;
		size = strlen(str);
		do {
			w = tls_write(fileno(stream), str + sent,
				      size - sent);
			if (w > 0)
				sent += w;
			else if (!(w < 0 && errno == EINTR))
				break;	/* other error than EINTR or w == 0 */
		} while (sent != size);
		if (w < 0)
			return EOF;
		else
			return w;
	} else
		return fputs(str, stream);
}

void tls_fputc_fflush(int fd)
{
	if (fputc_buflen > 0) {
		tls_write(fd, fputc_buffer, fputc_buflen);
		fputc_buflen = 0;
	}
	return;
}

int tls_fputc(int c, FILE * stream)
{
	gnutls_session s = SOCK_TO_TLS_SESS(fileno(stream));

	if (s) {
		unsigned char uc = c;
		int err = 1;
		do {
			fputc_buffer[fputc_buflen++] = uc;
			if (fputc_buflen >= FPUTC_BUFFERSIZE) {
				err =
				    tls_write(fileno(stream), fputc_buffer,
					      fputc_buflen);
				if (err >= 0) {
					err = 1;
				}
				fputc_buflen = 0;
			}
		} while (err < 0 && errno == EINTR);
		if (err == 1)
			return (int) uc;
		else
			return EOF;
	} else
		return fputc(c, stream);
}

int tls_send(int fd, const void *msg, size_t len, int flags)
{
	gnutls_session s = SOCK_TO_TLS_SESS(fd);

	if (s)
		return (int) tls_write(fd, msg, len);
	else
		return send(fd, msg, len, flags);
}

int tls_vfprintf(FILE * stream, const char *format, va_list ap)
{
#define SNP_MAXBUF 1024000
	gnutls_session ssl = SOCK_TO_TLS_SESS(fileno(stream));

	if (ssl) {
		/* here I boldly assume that snprintf() and vsnprintf() uses the same
		 * return value convention. if not, what kind of libc is this? ;-)
		 */
		char sbuf[1024] = { 0 }, *buf = sbuf, *lbuf = NULL;
		int sent = 0, size, ret, w;
		ret = vsnprintf(sbuf, sizeof(sbuf), format, ap);

		/* this one returns the number of bytes it wants to write in case of overflow */
		if (ret >= sizeof(sbuf) && ret < SNP_MAXBUF) {
			/* sbuf was too small, use a larger lbuf */
			lbuf = malloc(ret + 1);
			if (lbuf) {
				vsnprintf(lbuf, ret + 1, format, ap);
				buf = lbuf;
			}
		}

		size = strlen(buf);
		do {
			w = tls_write(fileno(stream), buf + sent,
				      size - sent);
			if (w > 0)
				sent += w;
			else if (!(w < 0 && errno == EINTR))
				break;	/* other error than EINTR or w == 0 */
		} while (sent != size);
		if (lbuf)
			free(lbuf);
		return sent;
	} else
		return vfprintf(stream, format, ap);
}

#ifdef __STDC__
int tls_fprintf(FILE * stream, const char *fmt, ...)
#else
int tls_fprintf(stream, fmt, va_alist)
FILE *stream;
char *fmt;
va_dcl
#endif
{
	va_list ap;
#ifdef __STDC__
	va_start(ap, fmt);
#else
	va_start(ap);
#endif
	return tls_vfprintf(stream, fmt, ap);
}

int tls_fflush(FILE * stream)
{
	if (stream == NULL)
		return fflush(NULL);
	if (SOCK_TO_TLS_SESS(fileno(stream)))
		return 0;	/* don't do anything! */
	else
		return fflush(stream);
}

#endif /* TLS */
