/*****************************************************************************\
* Copyright (c) 2002 Pelle Johansson. *
* All rights reserved. *
* *
* This file is part of the moftpd package. Use and distribution of *
* this software is governed by the terms in the file LICENCE, which *
* should have come with this package. *
\*****************************************************************************/
/* $moftpd: sql.c 1251 2005-03-06 22:24:29Z morth $ */
#include "system.h"
#include "sql.h"
#ifdef USE_SQL
extern char *sslCertsPath;
int sql_connect (sql_t *sql, const char *host, const char *user, const char *db,
const char *pass, const char *cert, const char *key)
{
const char *pp;
char rhost[256];
int port;
switch (sql->type)
{
#ifdef HAVE_MYSQL
case sqlMYSQL:
if (sql->u.mysql.h)
return 0;
errno = 0;
if (!(sql->u.mysql.h = mysql_init (NULL)))
return -1;
#if defined(HAVE_MYSQL_SSL_SET)
if (cert)
mysql_ssl_set (sql->u.mysql.h, key, cert, NULL, sslCertsPath, NULL);
#endif
pp = NULL;
if (host[0] == '[')
pp = strchr (host, ']');
if (!pp)
pp = host;
pp = strchr (pp, ':');
if (pp && pp - host < 256)
{
strncpy (rhost, host, pp - host);
rhost[pp - host] = 0;
}
else
{
strncpy (rhost, host, 255);
rhost[255] = 0;
}
port = 0;
if (pp)
{
struct servent *serv = getservbyname (++pp, "tcp");
if (serv)
port = ntohs (serv->s_port);
else
port = strtoul (pp, NULL, 0);
}
if (!mysql_real_connect (sql->u.mysql.h, rhost, user, pass, db, port, NULL, 0))
return -1;
return 0;
#endif
default:
errno = EINVAL;
return -1;
}
}
void sql_disconnect (sql_t *sql)
{
switch (sql->type)
{
#ifdef HAVE_MYSQL
case sqlMYSQL:
if (sql->u.mysql.r)
{
mysql_free_result (sql->u.mysql.r);
sql->u.mysql.r = NULL;
}
if (sql->u.mysql.h)
{
mysql_close (sql->u.mysql.h);
sql->u.mysql.h = NULL;
}
break;
#endif
}
}
char *sql_quote (const char *str)
{
static char buf[4097];
char *bp = buf;
const char *sp = str;
char ch;
*bp++ = '\'';
while (*sp)
{
switch ((ch = *sp++))
{
case '\n':
if (1)
ch = 'n';
else
{
case '\r':
ch = 'r';
}
/* fallthrough */
case '\'':
case '"':
case '\\':
case '%':
case '_':
if (bp - buf >= 4094)
{
errno = EINVAL;
return NULL;
}
*bp++ = '\\';
*bp++ = ch;
break;
default:
if (bp - buf >= 4095)
{
errno = EINVAL;
return NULL;
}
*bp++ = ch;
break;
}
}
*bp++ = '\'';
*bp = 0;
return buf;
}
int sql_query (sql_t *sql, const char *query, int nargs, const sql_arg_t *args)
{
const char *qp = query;
char rquery[4097], *rp = rquery;
char quoting = 0, ch;
int i;
while (*qp)
{
switch ((ch = *qp++))
{
case '%':
if (!quoting)
{
for (i = 0; i < nargs; i++)
{
if (args[i].ch == *qp)
break;
}
if (i < nargs && args[i].str)
{
qp++;
if ((rp - rquery) + strlen (args[i].str) >= sizeof (rquery))
{
errno = EINVAL;
return -1;
}
#ifdef HAVE_STPCPY
rp = stpcpy (rp, args[i].str);
#else
strcpy (rp, args[i].str);
rp += strlen (rp);
#endif
break;
}
}
if (rp - rquery == 4096)
{
errno = EINVAL;
return -1;
}
*rp++ = '%';
break;
case '\'':
case '"':
if (rp - rquery == 4096)
{
errno = EINVAL;
return -1;
}
*rp++ = ch;
if (quoting == ch)
quoting = 0;
else if (!quoting)
quoting = ch;
break;
case '\\':
if (rp - rquery >= 4095)
{
errno = EINVAL;
return -1;
}
if (quoting || *qp != '?')
*rp++ = '\\';
if (*qp)
*rp++ = *qp++;
break;
default:
if (rp - rquery == 4096)
{
errno = EINVAL;
return -1;
}
*rp++ = ch;
break;
}
}
if (*qp)
{
if (rp - rquery + strlen (qp) >= sizeof (rquery))
{
errno = EINVAL;
return -1;
}
strcpy (rp, qp);
}
else
*rp = 0;
syslog (LOG_DEBUG, "SQL query: %s.", rquery);
switch (sql->type)
{
#ifdef HAVE_MYSQL
case sqlMYSQL:
errno = 0;
if (mysql_query (sql->u.mysql.h, rquery))
return -1;
if (sql->u.mysql.r)
mysql_free_result (sql->u.mysql.r);
sql->u.mysql.r = mysql_store_result (sql->u.mysql.h);
sql->u.mysql.ccol = 0;
sql->u.mysql.crow = -1;
if (sql->u.mysql.r)
return mysql_num_rows (sql->u.mysql.r);
if (mysql_errno (sql->u.mysql.h))
return -1;
return mysql_affected_rows (sql->u.mysql.h);
#endif
default:
errno = EINVAL;
return -1;
}
}
const char *sql_fetch_cell (sql_t *sql, int row, int col)
{
if (col < 0 || row < -1)
return NULL;
switch (sql->type)
{
#ifdef HAVE_MYSQL
case sqlMYSQL:
if (row == -1)
{
MYSQL_FIELD *f;
if (col != sql->u.mysql.ccol)
{
mysql_field_seek (sql->u.mysql.r, col);
sql->u.mysql.ccol = col + 1;
}
else
sql->u.mysql.ccol++;
f = mysql_fetch_field (sql->u.mysql.r);
if (!f)
return NULL;
return f->name;
}
if (col >= mysql_num_fields (sql->u.mysql.r))
return NULL;
if (row != sql->u.mysql.crow)
{
if (row != sql->u.mysql.crow + 1)
mysql_data_seek (sql->u.mysql.r, row);
sql->u.mysql.row = mysql_fetch_row (sql->u.mysql.r);
if (!sql->u.mysql.row)
{
sql->u.mysql.crow = -2;
return NULL;
}
sql->u.mysql.crow = row;
}
return sql->u.mysql.row[col];
#endif
default:
errno = EINVAL;
return NULL;
}
}
void sql_free_result (sql_t *sql)
{
switch (sql->type)
{
#ifdef HAVE_MYSQL
case sqlMYSQL:
mysql_free_result (sql->u.mysql.r);
sql->u.mysql.r = NULL;
break;
#endif
}
}
const char *sql_error (sql_t *sql)
{
if (errno)
return strerror (errno);
switch (sql->type)
{
#ifdef HAVE_MYSQL
case sqlMYSQL:
return mysql_error (sql->u.mysql.h);
#endif
default:
return "Unknown sql type.";
}
}
#endif /*USE_SQL*/
syntax highlighted by Code2HTML, v. 0.9.1