/*****************************************************************************\
* Copyright (c) 2004 Pelle Johansson.                                         *
* All rights reserved.                                                        *
*                                                                             *
* This file is part of the moftpd package. Use and distribution of            *
* this software is governed by the terms in the file LICENCE, which           *
* should have come with this package.                                         *
\*****************************************************************************/

/* $moftpd: tls_openssl.c 1186 2004-10-25 14:27:26Z morth $ */

#include "system.h"

#include "tls.h"

#include "utf8fs/memory.h"
#include "utf8fs/file.h"

static int gnutls_inited = 0;

extern char *sslCertsPath;

const char *tls_get_cert_dir (void)
{
  return "/etc/certs";
}

tls_t tls_open (int fd, int options, tlscert_t cert, tlskey_t key)
{
  tls_t res;
  gnutls_session session;
  gnutls_certificate_credentials creds;
  
  if (fd < 0 || !cert || !key)
    return NULL;
  
  if (!gnutls_inited)
  {
    if (gnutls_global_init ())
      return NULL;
    gnutls_inited = 1;
  }
  
  if (gnutls_init (&session, GNUTLS_SERVER))
    return NULL;
  if (gnutls_certificate_allocate_credentials (&creds))
  {
    gnutls_deinit (session);
    return NULL;
  }
  
  gnutls_certificate_set_x509_key (creds, &cert, 1, key);
  
  if (options & tlsVerifyClient)
  {
    DIR *dir;
    struct dirent *ent;
    gnutls_certificate_server_set_request (session, GNUTLS_CERT_REQUEST);
    char path[4097], *fp;
    
    dir = opendir (sslCertsPath);
    if (dir)
    {
      strcpy (path, sslCertsPath);
      fp = path + strlen (path);
      *fp++ = '/';
      while ((ent = readdir (dir)))
      {
#ifdef HAVE_STRUCT_DIRENT_D_TYPE
        if (ent->d_type == DT_DIR)
          continue;
#endif
        if (fp - path + ent->d_namlen >= sizeof (path))
          continue;
        strcpy (fp, ent->d_name);
        gnutls_certificate_set_x509_trust_file (creds, path, GNUTLS_X509_FMT_PEM);
      }
    }
  }
  
  gnutls_set_default_priority (session);
  gnutls_credentials_set (session, GNUTLS_CRD_CERTIFICATE, creds);
  
  res = palloc (sizeof (*res), NULL, NULL);
  if (!res)
  {
    gnutls_deinit (session);
    gnutls_certificate_free_credentials (creds);
    return NULL;
  }
  res->session = session;
  res->creds = creds;
  res->options = options;
  
  gnutls_transport_set_ptr (res->session, (gnutls_transport_ptr)fd);
  
  return res;
}

void tls_start (tls_t tls)
{
  /* noop */
}

int tls_stop (tls_t tls)
{
  int res = gnutls_bye (tls->session, GNUTLS_SHUT_RDWR);
  
  if (!res)
    return 1;
  if (res == GNUTLS_E_AGAIN || res == GNUTLS_E_INTERRUPTED)
    return 0;
  return res;
}

void tls_free (tls_t tls)
{
  gnutls_deinit (tls->session);
  gnutls_certificate_free_credentials (tls->creds);
  pfree (tls, NULL);
}

int tls_accept (tls_t tls)
{
  int res = gnutls_handshake (tls->session);
  
  if (!res)
  {
    if (tls->options & tlsVerifyClient)
    {
      unsigned int status;
      
      if ((res = gnutls_certificate_verify_peers2 (tls->session, &status)) || status)
      {
        if (!res)
          res = GNUTLS_E_CERTIFICATE_ERROR;
        return res;
      }
    }
    return 1;
  }
  if (res == GNUTLS_E_AGAIN || res == GNUTLS_E_INTERRUPTED)
    return 0;
  return res;
}

ssize_t tls_read (tls_t tls, void *buf, size_t maxlen)
{
  return gnutls_record_recv (tls->session, buf, maxlen);
}

ssize_t tls_write (tls_t tls, const void *buf, size_t len)
{
  return gnutls_record_send (tls->session, buf, len);
}

ssize_t tls_write_vecs (tls_t tls, struct iovec *vecs, int num)
{
  int i, l = 0;
  int res = 0;
  
  for (i = 0; i < num; i++)
  {
    l = tls_write (tls, vecs[i].iov_base, vecs[i].iov_len);
    if (l < 0)
      break;
    res += l;
  }
  if (res)
    return res;
  return l;
}

tlscert_t tls_read_cert (const char *file)
{
  gnutls_x509_crt res;
  gnutls_datum datum;
  size_t sz;
  
  if (!gnutls_inited)
  {
    if (gnutls_global_init ())
      return NULL;
    gnutls_inited = 1;
  }
  
  datum.data = read_file (file, &sz);
  datum.size = sz;
  if (!datum.data)
    return NULL;
  
  if (gnutls_x509_crt_init (&res))
    return NULL;
  
  if (gnutls_x509_crt_import (res, &datum, GNUTLS_X509_FMT_PEM))
  {
    gnutls_x509_crt_deinit (res);
    return NULL;
  }
  return res;
}

tlscert_t tls_get_peer_cert (const tls_t tls)
{
  const gnutls_datum *certs;
  int ncerts = 0;
  gnutls_x509_crt res;
  
  certs = gnutls_certificate_get_peers (tls->session, &ncerts);
  if (!certs || ncerts < 1)
    return NULL;
  
  if (gnutls_x509_crt_init(&res))
    return NULL;
  
  if (gnutls_x509_crt_import (res, certs, GNUTLS_X509_FMT_DER))
  {
    gnutls_x509_crt_deinit (res);
    return NULL;
  }
  return res;  
}

void tls_free_cert (tlscert_t cert)
{
  gnutls_x509_crt_deinit (cert);
}

const char *tls_get_cn (tlscert_t cert)
{
  size_t sz;
  char *buf;
  
  gnutls_x509_crt_get_dn_by_oid (cert, GNUTLS_OID_X520_COMMON_NAME, 0, 0, NULL, &sz);
  if (!sz)
    return NULL;
  buf = talloc (sz);
  if (!buf)
    return NULL;
  if (gnutls_x509_crt_get_dn_by_oid (cert, GNUTLS_OID_X520_COMMON_NAME, 0, 0, buf, &sz))
    return NULL;
  return buf;
}

int tls_compare_certs (const tlscert_t c1, const tlscert_t c2)
{
  char buf1[100], buf2[100];
  size_t sz1 = sizeof (buf1), sz2 = sizeof (buf2);
  
  if (gnutls_x509_crt_get_key_id (c1, 0, buf1, &sz1))
    return -1;
  if (gnutls_x509_crt_get_key_id (c2, 0, buf2, &sz2))
    return 1;
  if (sz1 != sz2)
    return sz1 - sz2;
  return memcmp (buf1, buf2, sz1);
}

tlskey_t tls_read_key (const char *file)
{
  gnutls_x509_privkey res;
  gnutls_datum datum;
  size_t sz;
  
  if (!gnutls_inited)
  {
    if (gnutls_global_init ())
      return NULL;
    gnutls_inited = 1;
  }
  
  datum.data = read_file (file, &sz);
  datum.size = sz;
  if (!datum.data)
    return NULL;
  
  if (gnutls_x509_privkey_init(&res))
    return NULL;
  
  if (gnutls_x509_privkey_import (res, &datum, GNUTLS_X509_FMT_PEM))
  {
    gnutls_x509_privkey_deinit (res);
    return NULL;
  }
  return res;
}

void tls_free_key (tlskey_t key)
{
  gnutls_x509_privkey_deinit (key);
}

const char *tls_error (const tls_t tls, int res)
{
  return gnutls_strerror (res);
}


syntax highlighted by Code2HTML, v. 0.9.1