/*****************************************************************************\
* Copyright (c) 2002 Pelle Johansson.                                         *
* All rights reserved.                                                        *
*                                                                             *
* This file is part of the moftpd package. Use and distribution of            *
* this software is governed by the terms in the file LICENCE, which           *
* should have come with this package.                                         *
\*****************************************************************************/

/* $moftpd: sql.c 1251 2005-03-06 22:24:29Z morth $ */

#include "system.h"

#include "sql.h"

#ifdef USE_SQL

extern char *sslCertsPath;

int sql_connect (sql_t *sql, const char *host, const char *user, const char *db,
      const char *pass, const char *cert, const char *key)
{
  const char *pp;
  char rhost[256];
  int port;
  
  switch (sql->type)
  {
#ifdef HAVE_MYSQL
  case sqlMYSQL:
    if (sql->u.mysql.h)
      return 0;
    errno = 0;
    if (!(sql->u.mysql.h = mysql_init (NULL)))
      return -1;
#if defined(HAVE_MYSQL_SSL_SET)
    if (cert)
      mysql_ssl_set (sql->u.mysql.h, key, cert, NULL, sslCertsPath, NULL);
#endif
    pp = NULL;
    if (host[0] == '[')
      pp = strchr (host, ']');
    if (!pp)
      pp = host;
    pp = strchr (pp, ':');
    if (pp && pp - host < 256)
    {
      strncpy (rhost, host, pp - host);
      rhost[pp - host] = 0;
    }
    else
    {
      strncpy (rhost, host, 255);
      rhost[255] = 0;
    }
    port = 0;
    if (pp)
    {
      struct servent *serv = getservbyname (++pp, "tcp");
      if (serv)
	port = ntohs (serv->s_port);
      else
	port = strtoul (pp, NULL, 0);
    }
    if (!mysql_real_connect (sql->u.mysql.h, rhost, user, pass, db, port, NULL, 0))
      return -1;
    return 0;
#endif
  default:
    errno = EINVAL;
    return -1;
  }
}

void sql_disconnect (sql_t *sql)
{
  switch (sql->type)
  {
#ifdef HAVE_MYSQL
  case sqlMYSQL:
    if (sql->u.mysql.r)
    {
      mysql_free_result (sql->u.mysql.r);
      sql->u.mysql.r = NULL;
    }
    if (sql->u.mysql.h)
    {
      mysql_close (sql->u.mysql.h);
      sql->u.mysql.h = NULL;
    }
    break;
#endif
  }
}

char *sql_quote (const char *str)
{
  static char buf[4097];
  char *bp = buf;
  const char *sp = str;
  char ch;
  
  *bp++ = '\'';
  while (*sp)
  {
    switch ((ch = *sp++))
    {
    case '\n':
      if (1)
	ch = 'n';
      else
      {
      case '\r':
	ch = 'r';
      }
      /* fallthrough */
    case '\'':
    case '"':
    case '\\':
    case '%':
    case '_':
      if (bp - buf >= 4094)
      {
	errno = EINVAL;
	return NULL;
      }
      *bp++ = '\\';
      *bp++ = ch;
      break;
    default:
      if (bp - buf >= 4095)
      {
	errno = EINVAL;
	return NULL;
      }
      *bp++ = ch;
      break;
    }
  }
  *bp++ = '\'';
  *bp = 0;
  return buf;
}

int sql_query (sql_t *sql, const char *query, int nargs, const sql_arg_t *args)
{
  const char *qp = query;
  char rquery[4097], *rp = rquery;
  char quoting = 0, ch;
  int i;
  
  while (*qp)
  {
    switch ((ch = *qp++))
    {
    case '%':
      if (!quoting)
      {
	for (i = 0; i < nargs; i++)
	{
	  if (args[i].ch == *qp)
	    break;
	}
	if (i < nargs && args[i].str)
	{
	  qp++;
	  if ((rp - rquery) + strlen (args[i].str) >= sizeof (rquery))
	  {
	    errno = EINVAL;
	    return -1;
	  }
#ifdef HAVE_STPCPY
	  rp = stpcpy (rp, args[i].str);
#else
	  strcpy (rp, args[i].str);
	  rp += strlen (rp);
#endif
	  break;
	}
      }
      if (rp - rquery == 4096)
      {
	errno = EINVAL;
	return -1;
      }
      *rp++ = '%';
      break;
    case '\'':
    case '"':
      if (rp - rquery == 4096)
      {
	errno = EINVAL;
	return -1;
      }
      *rp++ = ch;
      if (quoting == ch)
	quoting = 0;
      else if (!quoting)
	quoting = ch;
      break;
    case '\\':
      if (rp - rquery >= 4095)
      {
	errno = EINVAL;
	return -1;
      }
      if (quoting || *qp != '?')
	*rp++ = '\\';
      if (*qp)
	*rp++ = *qp++;
      break;
    default:
      if (rp - rquery == 4096)
      {
	errno = EINVAL;
	return -1;
      }
      *rp++ = ch;
      break;
    }
  }
  
  if (*qp)
  {
    if (rp - rquery + strlen (qp) >= sizeof (rquery))
    {
      errno = EINVAL;
      return -1;
    }
    strcpy (rp, qp);
  }
  else
    *rp = 0;
  
  syslog (LOG_DEBUG, "SQL query: %s.", rquery);
  switch (sql->type)
  {
#ifdef HAVE_MYSQL
  case sqlMYSQL:
    errno = 0;
    if (mysql_query (sql->u.mysql.h, rquery))
      return -1;
    if (sql->u.mysql.r)
      mysql_free_result (sql->u.mysql.r);
    sql->u.mysql.r = mysql_store_result (sql->u.mysql.h);
    sql->u.mysql.ccol = 0;
    sql->u.mysql.crow = -1;
    if (sql->u.mysql.r)
      return mysql_num_rows (sql->u.mysql.r);
    if (mysql_errno (sql->u.mysql.h))
      return -1;
    return mysql_affected_rows (sql->u.mysql.h);
#endif
  default:
    errno = EINVAL;
    return -1;
  }
}

const char *sql_fetch_cell (sql_t *sql, int row, int col)
{
  if (col < 0 || row < -1)
    return NULL;
  switch (sql->type)
  {
#ifdef HAVE_MYSQL
  case sqlMYSQL:
    if (row == -1)
    {
      MYSQL_FIELD *f;
      
      if (col != sql->u.mysql.ccol)
      {
	mysql_field_seek (sql->u.mysql.r, col);
	sql->u.mysql.ccol = col + 1;
      }
      else
	sql->u.mysql.ccol++;
      f = mysql_fetch_field (sql->u.mysql.r);
      if (!f)
	return NULL;
      return f->name;
    }
    if (col >= mysql_num_fields (sql->u.mysql.r))
      return NULL;
    if (row != sql->u.mysql.crow)
    {
      if (row != sql->u.mysql.crow + 1)
	mysql_data_seek (sql->u.mysql.r, row);
      sql->u.mysql.row = mysql_fetch_row (sql->u.mysql.r);
      if (!sql->u.mysql.row)
      {
	sql->u.mysql.crow = -2;
	return NULL;
      }
      sql->u.mysql.crow = row;
    }
    return sql->u.mysql.row[col];
#endif
  default:
    errno = EINVAL;
    return NULL;
  }
}

void sql_free_result (sql_t *sql)
{
  switch (sql->type)
  {
#ifdef HAVE_MYSQL
  case sqlMYSQL:
    mysql_free_result (sql->u.mysql.r);
    sql->u.mysql.r = NULL;
    break;
#endif
  }
}

const char *sql_error (sql_t *sql)
{
  if (errno)
    return strerror (errno);
  switch (sql->type)
  {
#ifdef HAVE_MYSQL
  case sqlMYSQL:
    return mysql_error (sql->u.mysql.h);
#endif
  default:
    return "Unknown sql type.";
  }
}

#endif /*USE_SQL*/


syntax highlighted by Code2HTML, v. 0.9.1