/*****************************************************************************\ * Copyright (c) 2002 Pelle Johansson. * * All rights reserved. * * * * This file is part of the moftpd package. Use and distribution of * * this software is governed by the terms in the file LICENCE, which * * should have come with this package. * \*****************************************************************************/ /* $moftpd: sql.c 1251 2005-03-06 22:24:29Z morth $ */ #include "system.h" #include "sql.h" #ifdef USE_SQL extern char *sslCertsPath; int sql_connect (sql_t *sql, const char *host, const char *user, const char *db, const char *pass, const char *cert, const char *key) { const char *pp; char rhost[256]; int port; switch (sql->type) { #ifdef HAVE_MYSQL case sqlMYSQL: if (sql->u.mysql.h) return 0; errno = 0; if (!(sql->u.mysql.h = mysql_init (NULL))) return -1; #if defined(HAVE_MYSQL_SSL_SET) if (cert) mysql_ssl_set (sql->u.mysql.h, key, cert, NULL, sslCertsPath, NULL); #endif pp = NULL; if (host[0] == '[') pp = strchr (host, ']'); if (!pp) pp = host; pp = strchr (pp, ':'); if (pp && pp - host < 256) { strncpy (rhost, host, pp - host); rhost[pp - host] = 0; } else { strncpy (rhost, host, 255); rhost[255] = 0; } port = 0; if (pp) { struct servent *serv = getservbyname (++pp, "tcp"); if (serv) port = ntohs (serv->s_port); else port = strtoul (pp, NULL, 0); } if (!mysql_real_connect (sql->u.mysql.h, rhost, user, pass, db, port, NULL, 0)) return -1; return 0; #endif default: errno = EINVAL; return -1; } } void sql_disconnect (sql_t *sql) { switch (sql->type) { #ifdef HAVE_MYSQL case sqlMYSQL: if (sql->u.mysql.r) { mysql_free_result (sql->u.mysql.r); sql->u.mysql.r = NULL; } if (sql->u.mysql.h) { mysql_close (sql->u.mysql.h); sql->u.mysql.h = NULL; } break; #endif } } char *sql_quote (const char *str) { static char buf[4097]; char *bp = buf; const char *sp = str; char ch; *bp++ = '\''; while (*sp) { switch ((ch = *sp++)) { case '\n': if (1) ch = 'n'; else { case '\r': ch = 'r'; } /* fallthrough */ case '\'': case '"': case '\\': case '%': case '_': if (bp - buf >= 4094) { errno = EINVAL; return NULL; } *bp++ = '\\'; *bp++ = ch; break; default: if (bp - buf >= 4095) { errno = EINVAL; return NULL; } *bp++ = ch; break; } } *bp++ = '\''; *bp = 0; return buf; } int sql_query (sql_t *sql, const char *query, int nargs, const sql_arg_t *args) { const char *qp = query; char rquery[4097], *rp = rquery; char quoting = 0, ch; int i; while (*qp) { switch ((ch = *qp++)) { case '%': if (!quoting) { for (i = 0; i < nargs; i++) { if (args[i].ch == *qp) break; } if (i < nargs && args[i].str) { qp++; if ((rp - rquery) + strlen (args[i].str) >= sizeof (rquery)) { errno = EINVAL; return -1; } #ifdef HAVE_STPCPY rp = stpcpy (rp, args[i].str); #else strcpy (rp, args[i].str); rp += strlen (rp); #endif break; } } if (rp - rquery == 4096) { errno = EINVAL; return -1; } *rp++ = '%'; break; case '\'': case '"': if (rp - rquery == 4096) { errno = EINVAL; return -1; } *rp++ = ch; if (quoting == ch) quoting = 0; else if (!quoting) quoting = ch; break; case '\\': if (rp - rquery >= 4095) { errno = EINVAL; return -1; } if (quoting || *qp != '?') *rp++ = '\\'; if (*qp) *rp++ = *qp++; break; default: if (rp - rquery == 4096) { errno = EINVAL; return -1; } *rp++ = ch; break; } } if (*qp) { if (rp - rquery + strlen (qp) >= sizeof (rquery)) { errno = EINVAL; return -1; } strcpy (rp, qp); } else *rp = 0; syslog (LOG_DEBUG, "SQL query: %s.", rquery); switch (sql->type) { #ifdef HAVE_MYSQL case sqlMYSQL: errno = 0; if (mysql_query (sql->u.mysql.h, rquery)) return -1; if (sql->u.mysql.r) mysql_free_result (sql->u.mysql.r); sql->u.mysql.r = mysql_store_result (sql->u.mysql.h); sql->u.mysql.ccol = 0; sql->u.mysql.crow = -1; if (sql->u.mysql.r) return mysql_num_rows (sql->u.mysql.r); if (mysql_errno (sql->u.mysql.h)) return -1; return mysql_affected_rows (sql->u.mysql.h); #endif default: errno = EINVAL; return -1; } } const char *sql_fetch_cell (sql_t *sql, int row, int col) { if (col < 0 || row < -1) return NULL; switch (sql->type) { #ifdef HAVE_MYSQL case sqlMYSQL: if (row == -1) { MYSQL_FIELD *f; if (col != sql->u.mysql.ccol) { mysql_field_seek (sql->u.mysql.r, col); sql->u.mysql.ccol = col + 1; } else sql->u.mysql.ccol++; f = mysql_fetch_field (sql->u.mysql.r); if (!f) return NULL; return f->name; } if (col >= mysql_num_fields (sql->u.mysql.r)) return NULL; if (row != sql->u.mysql.crow) { if (row != sql->u.mysql.crow + 1) mysql_data_seek (sql->u.mysql.r, row); sql->u.mysql.row = mysql_fetch_row (sql->u.mysql.r); if (!sql->u.mysql.row) { sql->u.mysql.crow = -2; return NULL; } sql->u.mysql.crow = row; } return sql->u.mysql.row[col]; #endif default: errno = EINVAL; return NULL; } } void sql_free_result (sql_t *sql) { switch (sql->type) { #ifdef HAVE_MYSQL case sqlMYSQL: mysql_free_result (sql->u.mysql.r); sql->u.mysql.r = NULL; break; #endif } } const char *sql_error (sql_t *sql) { if (errno) return strerror (errno); switch (sql->type) { #ifdef HAVE_MYSQL case sqlMYSQL: return mysql_error (sql->u.mysql.h); #endif default: return "Unknown sql type."; } } #endif /*USE_SQL*/