/*****************************************************************************\
* Copyright (c) 2003-2004 Pelle Johansson.                                    *
* All rights reserved.                                                        *
*                                                                             *
* This file is part of the moftpd package. Use and distribution of            *
* this software is governed by the terms in the file LICENCE, which           *
* should have come with this package.                                         *
\*****************************************************************************/

/* $moftpd: accounter.c 1242 2004-12-07 22:29:09Z morth $ */

#include "system.h"

#include "accounter.h"

#include "main.h"
#include "utf8fs/memory.h"
#include "events.h"

int accSock = -1;
int accServer;
const char *accPath = VARDIR "/run/" PACKAGE_NAME ".acct";

static void *accinfos, *accservers;

accserver_t *find_acc_server (const char *name)
{
  accserver_t *serv;
  
  if (accservers)
  {
    for (serv = pchild (accservers, NULL); serv; serv = pchild (accservers, serv))
    {
      if (!strcmp (name, serv->name))
	return serv;
    }
  }
  else
  {
    accservers = proot ();
    if (!accservers)
      return NULL;
  }
  
  serv = palloc (sizeof (accserver_t), accservers, NULL);
  if (!serv)
    return NULL;
  
  serv->name = pstring (name, serv);
  serv->numConnects = 0;
  serv->users = NULL;
  
  return serv;
}

accuser_t *find_acc_user (accserver_t *serv, const char *name)
{
  accuser_t *user;
  
  for (user = serv->users; user; user = user->next)
  {
    if (!strcmp (name, user->name))
      return user;
  }
  
  user = palloc (sizeof (accuser_t), serv, NULL);
  if (!user)
    return NULL;
  
  user->name = pstring (name, user);
  user->numLogins = 0;
  user->next = serv->users;
  serv->users = user;
  
  return user;
}

int start_accounter (void)
{
  if (accSock >= 0)
    close (accSock);
  
  if (accPath)
  {
    struct sockaddr_un addr;
    struct stat st;
    
    super_privs (0);
    
    if (accServer == getpid ())
      unlink (accPath);
    
    accSock = socket (PF_UNIX, SOCK_STREAM, 0);
    if (accSock < 0)
      return -1;
    
    addr.sun_family = AF_UNIX;
#ifdef HAVE_STRUCT_SOCKADDR_IN_SIN_LEN
    addr.sun_len = sizeof (addr);
#endif
    strncpy (addr.sun_path, accPath, sizeof (addr.sun_path) - 1);
    addr.sun_path[sizeof (addr.sun_path) - 1] = 0;
    
    if (!stat (accPath, &st) && st.st_uid == geteuid () &&
	  !connect (accSock, (struct sockaddr*)&addr, sizeof (addr)))
    {
      // There's already a server running.
      close (accSock);
      accSock = -1;
      accServer = 0;
      return 0;
    }
    
    // No server running. Try to start a new one.
    unlink (accPath);
    if (!bind (accSock, (struct sockaddr*)&addr, sizeof (addr)))
    {
      chmod (accPath, 0600);
      
      if (listen (accSock, 128))
      {
	close (accSock);
	accSock = -1;
	return -1;
      }
      
      accServer = getpid ();
      add_read_fd (accSock, accounter_accepter, NULL);
      return 0;
    }
  }
  
  // If we get here, we don't need to do anything. Accounting will be
  // provided for children to us.
  accServer = -1;
  accSock = -1;
  return 0;
}

void close_accounter (int level)
{
  accounter_t *acc;
  
  if (accSock >= 0)
  {
    remove_read_fd (accSock);
    close (accSock);
    accSock = -1;
  }
  if (!level)
    return;
  if (accinfos)
  {
    for (acc = pchild (accinfos, NULL); acc; acc = pchild (accinfos, acc))
    {
      if (acc->sock >= 0)
      {
	// Dont call remove_read_fd () since events_init() will be called.
	close (acc->sock);
      }
    }
    pfree (accinfos, NULL);
    accinfos = NULL;
  }
  if (level >= 2 && accServer == getpid ())
    unlink (accPath);
  accServer = 0;
}

int accounter_add (int sock)
{
  accounter_t *acc = NULL, *a, *na;
  time_t t;
  
  if (!accinfos)
    accinfos = proot ();
  if (accinfos)
    acc = palloc (sizeof (*acc), accinfos, NULL);
  if (!acc)
  {
    write (sock, "ERROR\n", 6);
    close (sock);
    return -1;
  }
  acc->sock = sock;
  acc->id = 0;
  time (&t);
  for (a = pchild (accinfos, NULL); a; a = na)
  {
    na = pchild (accinfos, a);
    if (a->removed && a->removed <= t)
    {
      acc->id = a->id;
      pfree (a, accinfos);
      break;
    }
    if (a != acc && a->id >= acc->id)
      acc->id = a->id + 1;
  }
  return add_read_fd (sock, accounter_client_handler, acc);
}

int connect_accounter (void)
{
  int sock;
  struct sockaddr_un addr;
  int tries = 20;
  
  super_privs (0);
  
  while (tries--)
  {
    if (accPath)
    {
      struct stat st;
      
      sock = socket (PF_UNIX, SOCK_STREAM, 0);
      if (sock < 0)
	return -1;
      
      addr.sun_family = AF_UNIX;
#ifdef HAVE_STRUCT_SOCKADDR_IN_SIN_LEN
      addr.sun_len = sizeof (addr);
#endif
      strncpy (addr.sun_path, accPath, sizeof (addr.sun_path) - 1);
      addr.sun_path[sizeof (addr.sun_path) - 1] = 0;
      
      if (!stat (accPath, &st) && st.st_uid == geteuid () &&
	    !connect (sock, (struct sockaddr*)&addr, sizeof (addr)))
	return sock;
      
      close (sock);
    }
    
    if (!accServer && start_accounter ())
      return -1;
    
    if (accServer)
    {
      int pair[2];
      
      if (socketpair (PF_UNIX, SOCK_STREAM, 0, pair))
	return -1;
      accounter_add (pair[0]);
      return pair[1];
    }
  }
  errno = EINVAL;
  return -1;
}

int accounter_accepter (int sock, void *user, int urgent)
{
  int csock = accept (sock, NULL, NULL);
  if (csock < 0)
    return 0;
  
  accounter_add (csock);
  return 0;
}

int accounter_client_handler (int sock, void *user, int urgent)
{
  int l, max;
  char buf[4097], *bp, *nbp, msg[4097];
  accounter_t *acc = user, *a, *na;
  long long sz;
  time_t t;
    
  l = read (sock, buf, sizeof (buf) - 1);
  if (l <= 0)
  {
    if (!l || errno != EINTR)
    {
      if (acc)
      {
	if (acc->server && acc->server->numConnects)
	  acc->server->numConnects--;
	if (acc->user && acc->user->numLogins)
	  acc->user->numLogins--;
	time (&acc->removed);
	acc->removed += 30; // 30 s to reuse id.
        acc->sock = -1;
      }
      remove_read_fd (sock);
      close (sock);
    }
    return 0;
  }
  
  buf[l] = 0;
  for (bp = buf; (nbp = strchr (bp, '\n')); bp = nbp)
  {
#if 0
    if (nbp > bp && *(nbp - 1) == '\r')
      *(nbp - 1) = 0;
#endif
    *nbp++ = 0;
    if (!strlen (bp))
      continue;
    max = 0;
    msg[0] = 0;
    if (!strcmp (bp, "LIST"))
    {
      time (&t);
      for (a = pchild (accinfos, NULL); a; a = na)
      {
	na = pchild (accinfos, a);
	if (a->removed)
	{
	  if (a->removed <= t)
	    pfree (a, accinfos);
	}
	else if (a->server)
	{
	  snprintf (msg, 4096, "%d: %d %s %s %s %s %s %s\n",
		a->id,
		a->pid,
		a->rhost,
		a->server->name,
		a->user? a->user->name : "-",
		a->acct[0]? a->acct : "-",
		a->email[0]? a->email : "-",
		a->work);
	  write (sock, msg, strlen (msg));
	}
      }
      write (sock, "END\n", 4);
    }
    else if (sscanf (bp, "CONNECT %255s %255s %d", acc->rhost, msg, &max) >= 2)
    {
      accserver_t *serv;

      msg[255] = 0;
      serv = find_acc_server (msg);
      if (!serv)
	write (sock, "ERROR\n", 6);
      else if (max && serv->numConnects >= max)
	write (sock, "DENY\n", 5);
      else
      {
	if (acc->server)
	{
	  if (acc->server->numConnects)
	    acc->server->numConnects--;
	  pfree (acc->server, acc);
	}
	acc->server = pattach (serv, acc);
	acc->server->numConnects++;
	sprintf (msg, "ALLOW %d\n", acc->id);
	write (sock, msg, strlen (msg));
      }
    }
    else if (sscanf (bp, "LOGIN %255s %255s %255s %d", msg, acc->acct, acc->email, &max) >= 1)
    {
      if (!strcmp (acc->acct, "-"))
	acc->acct[0] = 0;
      if (!strcmp (acc->email, "-"))
	acc->email[0] = 0;
      if (!acc->server)
	write (sock, "NOSERVER\n", 9);
      else
      {
	acc->user = pattach (find_acc_user (acc->server, msg), acc);
	if (!acc->user)
	  write (sock, "ERROR\n", 6);
	else if (max && acc->user->numLogins >= max)
	{
	  write (sock, "DENY\n", 5);
	  pfree (acc->user, acc);
	  acc->user = NULL;
	}
	else
	{
	  acc->user->numLogins++;
	  write (sock, "ALLOW\n", 6);
	}
      }
    }
    else if (!strcmp (bp, "LOGOUT"))
    {
      if (acc->user)
      {
	if (acc->user->numLogins)
	  acc->user->numLogins--;
	pfree (acc->user, acc);
	acc->user = NULL;
      }
      acc->acct[0] = 0;
      acc->email[0] = 0;
    }
    else if (sscanf (bp, "SENDING %lld %[^\n]", &sz, msg) == 2)
    {
      snprintf (acc->work, 4096, "RETR: %s", msg);
      write (sock, "ALLOW\n", 6);
    }
    else if (sscanf (bp, "SENT %lld %*s", &sz) == 2)
    {
      acc->sent += sz;
      acc->work[0] = 0;
    }
    else if (sscanf (bp, "GETTING %[^\n]", msg) == 1)
    {
      snprintf (acc->work, 4096, "STOR: %s", msg);
      write (sock, "ALLOW\n", 6);
    }
    else if (sscanf (bp, "GOT %*d %*s") == 2)
      acc->work[0] = 0;
    else if (sscanf (bp, "SET %s %[^\n]", msg, bp) == 2)
    {
      if (!strcmp (msg, "PID"))
      {
	acc->pid = atoi (bp);
	if (acc->pid < 0)
	  acc->pid = 0;
      }
    }
    else if (sscanf (bp, "QUERY %s", msg) == 1)
    {
      if (!strcmp (msg, "VERSION"))
      {
	strcpy (msg, "VERSION " PACKAGE_VERSION "\r\n");
	write (sock, msg, strlen (msg));
      }
      else
	write (sock, "UNKNOWN\n", 9);
    }
    else if (sscanf (bp, "MSG %d %[^\n]", &l, msg) == 2)
    {
      for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
      {
	if (a->id == l)
	  break;
      }
      if (a && !a->removed)
      {
	l = strlen (msg);
	if (l > 4096 - 9)
	  l = 4096 - 9;
	memmove (msg + 8, msg, l);
	strcpy (msg, "MSG 200");
	msg[7] = ' ';
	l += 8;
	msg[l++] = '\n';
	send (a->sock, msg, l, 0);
	if (a->pid)
	{
	  super_privs (0);
	  kill (a->pid, SIGURG);
	}
	write (sock, "OK\n", 3);
      }
      else
	write (sock, "INVALID\n", 8);
    }
    else if (sscanf (bp, "MSGALL %[^\n]", msg) == 1)
    {
      l = strlen (msg);
      if (l > 4096 - 8)
	l = 4096 - 8;
      memmove (msg + 8, msg, l);
      l += 8;
      msg[l++] = '\n';
      strcpy (msg, "MSG 200");
      msg[7] = ' ';
      for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
      {
	if (!a->removed)
	{
	  send (a->sock, msg, l, 0);
	  if (a->pid)
	  {
	    super_privs (0);
	    kill (a->pid, SIGURG);
	  }
	}
      }
      write (sock, "OK\n", 3);
    }
    else if (sscanf (bp, "ABORT %d %[^\n]", &l, msg) >= 1)
    {
      for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
      {
	if (a->id == l)
	  break;
      }
      if (a && !a->removed)
      {
	if (msg [0])
	{
	  l = strlen (msg);
	  memmove (msg + 6, msg, l);
	  strcpy (msg, "ABORT");
	  msg[5] = ' ';
	  l += 6;
	  msg[l++] = '\n';
	  send (a->sock, msg, l, 0);
	}
	else
	  send (a->sock, "ABORT\n", 6, 0);
	if (a->pid)
	{
	  super_privs (0);
	  kill (a->pid, SIGURG);
	}
	write (sock, "OK\n", 3);
      }
      else
	write (sock, "INVALID\n", 8);
    }
    else if (sscanf (bp, "DISCONNECT %d %[^\n]", &l, msg) >= 1)
    {
      for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
      {
	if (a->id == l)
	  break;
      }
      if (a && !a->removed)
      {
	if (msg [0])
	{
	  l = strlen (msg);
	  memmove (msg + 11, msg, l);
	  strcpy (msg, "DISCONNECT");
	  msg[10] = ' ';
	  l += 11;
	  msg[l++] = '\n';
	  send (a->sock, msg, l, 0);
	}
	else
	  send (a->sock, "DISCONNECT\n", 11, 0);
	if (a->pid)
	{
	  super_privs (0);
	  kill (a->pid, SIGURG);
	}
	write (sock, "OK\n", 3);
      }
      else
	write (sock, "INVALID\n", 8);
    }
    else if (sscanf (bp, "DISCABORT %d %[^\n]", &l, msg) >= 1)
    {
      for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
      {
	if (a->id == l)
	  break;
      }
      if (a && !a->removed)
      {
	send (a->sock, "ABORT\n", 6, 0);
	if (msg [0])
	{
	  l = strlen (msg);
	  memmove (msg + 11, msg, l);
	  strcpy (msg, "DISCONNECT");
	  msg[10] = ' ';
	  l += 11;
	  msg[l++] = '\n';
	  send (a->sock, msg, l, 0);
	}
	else
	  send (a->sock, "DISCONNECT\n", 11, 0);
	if (a->pid)
	{
	  super_privs (0);
	  kill (a->pid, SIGURG);
	}
	write (sock, "OK\n", 3);
      }
      else
	write (sock, "INVALID\n", 8);
    }
    else if (!strcmp (bp, "RELOAD"))
    {
      for (a = pchild (accinfos, NULL); a; a = pchild (accinfos, a))
      {
	if (!a->removed && !a->server)
	{
	  send (a->sock, "RELOAD\n", 7, 0);
	  if (a->pid)
	  {
	    super_privs (0);
	    kill (a->pid, SIGURG);
	  }
	}
      }
    }
    else
      write (sock, "UNKNOWN\n", 9);
  }
  return 0;
}

int accounter (int sock, const char *format, ...)
{
  va_list ap;
  char *str, *sp;
  int res;
#ifdef HAVE_VASPRINTF
  int doFree;
#endif
  
  va_start (ap, format);
#ifdef HAVE_VASPRINTF
  vasprintf (&str, format, ap);
  doFree = 1;
#else
  str = talloc (4097);
  if (str)
    vsnprintf (str, 4096, format, ap);
#endif
  va_end (ap);
  
  if (!str)
  {
    str = tstring (format);
#ifdef HAVE_VASPRINTF
    doFree = 0;
#endif
  }
  
  while ((sp = strchr (str, '\n')) && *(sp + 1))
    *sp = '?';
  
  res = write (sock, str, strlen (str));
  
#ifdef HAVE_VASPRINTF
  if (doFree)
    free (str);
#endif
  
  if (res < 0)
    return -1;
  return 0;
}


syntax highlighted by Code2HTML, v. 0.9.1