/*****************************************************************************\
* Copyright (c) 2003-2004 Pelle Johansson. *
* All rights reserved. *
* *
* This file is part of the moftpd package. Use and distribution of *
* this software is governed by the terms in the file LICENCE, which *
* should have come with this package. *
\*****************************************************************************/
/* $moftpd: accounter.c 1242 2004-12-07 22:29:09Z morth $ */
#include "system.h"
#include "accounter.h"
#include "main.h"
#include "utf8fs/memory.h"
#include "events.h"
int accSock = -1;
int accServer;
const char *accPath = VARDIR "/run/" PACKAGE_NAME ".acct";
static void *accinfos, *accservers;
accserver_t *find_acc_server (const char *name)
{
accserver_t *serv;
if (accservers)
{
for (serv = pchild (accservers, NULL); serv; serv = pchild (accservers, serv))
{
if (!strcmp (name, serv->name))
return serv;
}
}
else
{
accservers = proot ();
if (!accservers)
return NULL;
}
serv = palloc (sizeof (accserver_t), accservers, NULL);
if (!serv)
return NULL;
serv->name = pstring (name, serv);
serv->numConnects = 0;
serv->users = NULL;
return serv;
}
accuser_t *find_acc_user (accserver_t *serv, const char *name)
{
accuser_t *user;
for (user = serv->users; user; user = user->next)
{
if (!strcmp (name, user->name))
return user;
}
user = palloc (sizeof (accuser_t), serv, NULL);
if (!user)
return NULL;
user->name = pstring (name, user);
user->numLogins = 0;
user->next = serv->users;
serv->users = user;
return user;
}
int start_accounter (void)
{
if (accSock >= 0)
close (accSock);
if (accPath)
{
struct sockaddr_un addr;
struct stat st;
super_privs (0);
if (accServer == getpid ())
unlink (accPath);
accSock = socket (PF_UNIX, SOCK_STREAM, 0);
if (accSock < 0)
return -1;
addr.sun_family = AF_UNIX;
#ifdef HAVE_STRUCT_SOCKADDR_IN_SIN_LEN
addr.sun_len = sizeof (addr);
#endif
strncpy (addr.sun_path, accPath, sizeof (addr.sun_path) - 1);
addr.sun_path[sizeof (addr.sun_path) - 1] = 0;
if (!stat (accPath, &st) && st.st_uid == geteuid () &&
!connect (accSock, (struct sockaddr*)&addr, sizeof (addr)))
{
// There's already a server running.
close (accSock);
accSock = -1;
accServer = 0;
return 0;
}
// No server running. Try to start a new one.
unlink (accPath);
if (!bind (accSock, (struct sockaddr*)&addr, sizeof (addr)))
{
chmod (accPath, 0600);
if (listen (accSock, 128))
{
close (accSock);
accSock = -1;
return -1;
}
accServer = getpid ();
add_read_fd (accSock, accounter_accepter, NULL);
return 0;
}
}
// If we get here, we don't need to do anything. Accounting will be
// provided for children to us.
accServer = -1;
accSock = -1;
return 0;
}
void close_accounter (int level)
{
accounter_t *acc;
if (accSock >= 0)
{
remove_read_fd (accSock);
close (accSock);
accSock = -1;
}
if (!level)
return;
if (accinfos)
{
for (acc = pchild (accinfos, NULL); acc; acc = pchild (accinfos, acc))
{
if (acc->sock >= 0)
{
// Dont call remove_read_fd () since events_init() will be called.
close (acc->sock);
}
}
pfree (accinfos, NULL);
accinfos = NULL;
}
if (level >= 2 && accServer == getpid ())
unlink (accPath);
accServer = 0;
}
int accounter_add (int sock)
{
accounter_t *acc = NULL, *a, *na;
time_t t;
if (!accinfos)
accinfos = proot ();
if (accinfos)
acc = palloc (sizeof (*acc), accinfos, NULL);
if (!acc)
{
write (sock, "ERROR\n", 6);
close (sock);
return -1;
}
acc->sock = sock;
acc->id = 0;
time (&t);
for (a = pchild (accinfos, NULL); a; a = na)
{
na = pchild (accinfos, a);
if (a->removed && a->removed <= t)
{
acc->id = a->id;
pfree (a, accinfos);
break;
}
if (a != acc && a->id >= acc->id)
acc->id = a->id + 1;
}
return add_read_fd (sock, accounter_client_handler, acc);
}
int connect_accounter (void)
{
int sock;
struct sockaddr_un addr;
int tries = 20;
super_privs (0);
while (tries--)
{
if (accPath)
{
struct stat st;
sock = socket (PF_UNIX, SOCK_STREAM, 0);
if (sock < 0)
return -1;
addr.sun_family = AF_UNIX;
#ifdef HAVE_STRUCT_SOCKADDR_IN_SIN_LEN
addr.sun_len = sizeof (addr);
#endif
strncpy (addr.sun_path, accPath, sizeof (addr.sun_path) - 1);
addr.sun_path[sizeof (addr.sun_path) - 1] = 0;
if (!stat (accPath, &st) && st.st_uid == geteuid () &&
!connect (sock, (struct sockaddr*)&addr, sizeof (addr)))
return sock;
close (sock);
}
if (!accServer && start_accounter ())
return -1;
if (accServer)
{
int pair[2];
if (socketpair (PF_UNIX, SOCK_STREAM, 0, pair))
return -1;
accounter_add (pair[0]);
return pair[1];
}
}
errno = EINVAL;
return -1;
}
int accounter_accepter (int sock, void *user, int urgent)
{
int csock = accept (sock, NULL, NULL);
if (csock < 0)
return 0;
accounter_add (csock);
return 0;
}
int accounter_client_handler (int sock, void *user, int urgent)
{
int l, max;
char buf[4097], *bp, *nbp, msg[4097];
accounter_t *acc = user, *a, *na;
long long sz;
time_t t;
l = read (sock, buf, sizeof (buf) - 1);
if (l <= 0)
{
if (!l || errno != EINTR)
{
if (acc)
{
if (acc->server && acc->server->numConnects)
acc->server->numConnects--;
if (acc->user && acc->user->numLogins)
acc->user->numLogins--;
time (&acc->removed);
acc->removed += 30; // 30 s to reuse id.
acc->sock = -1;
}
remove_read_fd (sock);
close (sock);
}
return 0;
}
buf[l] = 0;
for (bp = buf; (nbp = strchr (bp, '\n')); bp = nbp)
{
#if 0
if (nbp > bp && *(nbp - 1) == '\r')
*(nbp - 1) = 0;
#endif
*nbp++ = 0;
if (!strlen (bp))
continue;
max = 0;
msg[0] = 0;
if (!strcmp (bp, "LIST"))
{
time (&t);
for (a = pchild (accinfos, NULL); a; a = na)
{
na = pchild (accinfos, a);
if (a->removed)
{
if (a->removed <= t)
pfree (a, accinfos);
}
else if (a->server)
{
snprintf (msg, 4096, "%d: %d %s %s %s %s %s %s\n",
a->id,
a->pid,
a->rhost,
a->server->name,
a->user? a->user->name : "-",
a->acct[0]? a->acct : "-",
a->email[0]? a->email : "-",
a->work);
write (sock, msg, strlen (msg));
}
}
write (sock, "END\n", 4);
}
else if (sscanf (bp, "CONNECT %255s %255s %d", acc->rhost, msg, &max) >= 2)
{
accserver_t *serv;
msg[255] = 0;
serv = find_acc_server (msg);
if (!serv)
write (sock, "ERROR\n", 6);
else if (max && serv->numConnects >= max)
write (sock, "DENY\n", 5);
else
{
if (acc->server)
{
if (acc->server->numConnects)
acc->server->numConnects--;
pfree (acc->server, acc);
}
acc->server = pattach (serv, acc);
acc->server->numConnects++;
sprintf (msg, "ALLOW %d\n", acc->id);
write (sock, msg, strlen (msg));
}
}
else if (sscanf (bp, "LOGIN %255s %255s %255s %d", msg, acc->acct, acc->email, &max) >= 1)
{
if (!strcmp (acc->acct, "-"))
acc->acct[0] = 0;
if (!strcmp (acc->email, "-"))
acc->email[0] = 0;
if (!acc->server)
write (sock, "NOSERVER\n", 9);
else
{
acc->user = pattach (find_acc_user (acc->server, msg), acc);
if (!acc->user)
write (sock, "ERROR\n", 6);
else if (max && acc->user->numLogins >= max)
{
write (sock, "DENY\n", 5);
pfree (acc->user, acc);
acc->user = NULL;
}
else
{
acc->user->numLogins++;
write (sock, "ALLOW\n", 6);
}
}
}
else if (!strcmp (bp, "LOGOUT"))
{
if (acc->user)
{
if (acc->user->numLogins)
acc->user->numLogins--;
pfree (acc->user, acc);
acc->user = NULL;
}
acc->acct[0] = 0;
acc->email[0] = 0;
}
else if (sscanf (bp, "SENDING %lld %[^\n]", &sz, msg) == 2)
{
snprintf (acc->work, 4096, "RETR: %s", msg);
write (sock, "ALLOW\n", 6);
}
else if (sscanf (bp, "SENT %lld %*s", &sz) == 2)
{
acc->sent += sz;
acc->work[0] = 0;
}
else if (sscanf (bp, "GETTING %[^\n]", msg) == 1)
{
snprintf (acc->work, 4096, "STOR: %s", msg);
write (sock, "ALLOW\n", 6);
}
else if (sscanf (bp, "GOT %*d %*s") == 2)
acc->work[0] = 0;
else if (sscanf (bp, "SET %s %[^\n]", msg, bp) == 2)
{
if (!strcmp (msg, "PID"))
{
acc->pid = atoi (bp);
if (acc->pid < 0)
acc->pid = 0;
}
}
else if (sscanf (bp, "QUERY %s", msg) == 1)
{
if (!strcmp (msg, "VERSION"))
{
strcpy (msg, "VERSION " PACKAGE_VERSION "\r\n");
write (sock, msg, strlen (msg));
}
else
write (sock, "UNKNOWN\n", 9);
}
else if (sscanf (bp, "MSG %d %[^\n]", &l, msg) == 2)
{
for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
{
if (a->id == l)
break;
}
if (a && !a->removed)
{
l = strlen (msg);
if (l > 4096 - 9)
l = 4096 - 9;
memmove (msg + 8, msg, l);
strcpy (msg, "MSG 200");
msg[7] = ' ';
l += 8;
msg[l++] = '\n';
send (a->sock, msg, l, 0);
if (a->pid)
{
super_privs (0);
kill (a->pid, SIGURG);
}
write (sock, "OK\n", 3);
}
else
write (sock, "INVALID\n", 8);
}
else if (sscanf (bp, "MSGALL %[^\n]", msg) == 1)
{
l = strlen (msg);
if (l > 4096 - 8)
l = 4096 - 8;
memmove (msg + 8, msg, l);
l += 8;
msg[l++] = '\n';
strcpy (msg, "MSG 200");
msg[7] = ' ';
for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
{
if (!a->removed)
{
send (a->sock, msg, l, 0);
if (a->pid)
{
super_privs (0);
kill (a->pid, SIGURG);
}
}
}
write (sock, "OK\n", 3);
}
else if (sscanf (bp, "ABORT %d %[^\n]", &l, msg) >= 1)
{
for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
{
if (a->id == l)
break;
}
if (a && !a->removed)
{
if (msg [0])
{
l = strlen (msg);
memmove (msg + 6, msg, l);
strcpy (msg, "ABORT");
msg[5] = ' ';
l += 6;
msg[l++] = '\n';
send (a->sock, msg, l, 0);
}
else
send (a->sock, "ABORT\n", 6, 0);
if (a->pid)
{
super_privs (0);
kill (a->pid, SIGURG);
}
write (sock, "OK\n", 3);
}
else
write (sock, "INVALID\n", 8);
}
else if (sscanf (bp, "DISCONNECT %d %[^\n]", &l, msg) >= 1)
{
for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
{
if (a->id == l)
break;
}
if (a && !a->removed)
{
if (msg [0])
{
l = strlen (msg);
memmove (msg + 11, msg, l);
strcpy (msg, "DISCONNECT");
msg[10] = ' ';
l += 11;
msg[l++] = '\n';
send (a->sock, msg, l, 0);
}
else
send (a->sock, "DISCONNECT\n", 11, 0);
if (a->pid)
{
super_privs (0);
kill (a->pid, SIGURG);
}
write (sock, "OK\n", 3);
}
else
write (sock, "INVALID\n", 8);
}
else if (sscanf (bp, "DISCABORT %d %[^\n]", &l, msg) >= 1)
{
for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
{
if (a->id == l)
break;
}
if (a && !a->removed)
{
send (a->sock, "ABORT\n", 6, 0);
if (msg [0])
{
l = strlen (msg);
memmove (msg + 11, msg, l);
strcpy (msg, "DISCONNECT");
msg[10] = ' ';
l += 11;
msg[l++] = '\n';
send (a->sock, msg, l, 0);
}
else
send (a->sock, "DISCONNECT\n", 11, 0);
if (a->pid)
{
super_privs (0);
kill (a->pid, SIGURG);
}
write (sock, "OK\n", 3);
}
else
write (sock, "INVALID\n", 8);
}
else if (!strcmp (bp, "RELOAD"))
{
for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
{
if (!a->removed && !a->server)
{
send (a->sock, "RELOAD\n", 7, 0);
if (a->pid)
{
super_privs (0);
kill (a->pid, SIGURG);
}
}
}
}
else
write (sock, "UNKNOWN\n", 9);
}
return 0;
}
int accounter (int sock, const char *format, ...)
{
va_list ap;
char *str, *sp;
int res;
#ifdef HAVE_VASPRINTF
int doFree;
#endif
va_start (ap, format);
#ifdef HAVE_VASPRINTF
vasprintf (&str, format, ap);
doFree = 1;
#else
str = talloc (4097);
if (str)
vsnprintf (str, 4096, format, ap);
#endif
va_end (ap);
if (!str)
{
str = tstring (format);
#ifdef HAVE_VASPRINTF
doFree = 0;
#endif
}
while ((sp = strchr (str, '\n')) && *(sp + 1))
*sp = '?';
res = write (sock, str, strlen (str));
#ifdef HAVE_VASPRINTF
if (doFree)
free (str);
#endif
if (res < 0)
return -1;
return 0;
}
syntax highlighted by Code2HTML, v. 0.9.1