/*****************************************************************************\
* Copyright (c) 2002 Pelle Johansson.                                         *
* All rights reserved.                                                        *
*                                                                             *
* This file is part of the moftpd package. Use and distribution of            *
* this software is governed by the terms in the file LICENCE, which           *
* should have come with this package.                                         *
\*****************************************************************************/

/* $moftpd: server.c 1264 2005-04-06 13:32:27Z morth $ */

#include "system.h"

#include "server.h"

#include "main.h"
#include "connection.h"
#include "utf8fs/memory.h"
#include "defaults.h"
#include "accounter.h"
#include "events.h"

static void *servers;
static int *serverSockets, numServerSockets;

extern int forkConnections, debug, accClientSock, selfFailedFork;
extern int fakeChroot, doFakeChroot;

static void free_server (void *ptr)
{
  server_t *serv = ptr;
  
#ifdef USE_TLS
  if (serv->tlsCert)
    tls_free_cert (serv->tlsCert);
  if (serv->tlsKey)
    tls_free_key (serv->tlsKey);
#endif
#ifdef USE_SQL
  sql_disconnect (&serv->sql);
#endif
}

server_t *new_server(const char *name, server_t *parent)
{
  server_t *res;
  
  if(!servers)
  {
    servers = proot();
    if(!servers)
      return NULL;
  }
  
  res = palloc(sizeof(server_t), servers, free_server);
  
  if(!res)
    return NULL;
  
  res->name = pstring(name, res);
  if(!res->name)
  {
    pfree(res, servers);
    return NULL;
  }
  
  if(parent)
  {
    res->numPorts = parent->numPorts;
    if(res->numPorts)
    {
      res->ports = palloc(sizeof(int) * res->numPorts, res, NULL);
      if(!res->ports)
      {
	pfree(res, servers);
	return NULL;
      }
      memcpy(res->ports, parent->ports, sizeof(int) * res->numPorts);
    }
    res->passIfInvalid = parent->passIfInvalid;
    res->allowForeign = parent->allowForeign;
    res->allowOutOfRange = parent->allowOutOfRange;
    res->allowLowPorts = parent->allowLowPorts;
    res->minPasvPort = parent->minPasvPort;
    res->maxPasvPort = parent->maxPasvPort;
    res->users = pattach(parent->users, res);
    res->access = pattach(parent->access, res);
    res->defHardLink = parent->defHardLink;
    res->chroot = pattach(parent->chroot, res);
    res->allowLogin = parent->allowLogin;
    res->passwordNeeded = parent->passwordNeeded;
    res->loginFailedMsg = pattach(parent->loginFailedMsg, res);
    res->userInvalidMsg = pattach(parent->userInvalidMsg, res);
    res->passRequestMsg = pattach(parent->passRequestMsg, res);
    res->anonPassMsg = pattach(parent->anonPassMsg, res);
    res->welcomeMsg = pattach(parent->welcomeMsg, res);
    res->dirMsgFile = pattach(parent->dirMsgFile, res);
    res->maxIdle = parent->maxIdle;
    res->sleepOnFail = parent->sleepOnFail;
    res->maxLoginAttempts = parent->maxLoginAttempts;
    res->allowSecLogin = parent->allowSecLogin;
    res->numAliases = parent->numAliases;
    res->maxConnects = parent->maxConnects;
    if(res->numAliases)
    {
      res->aliases = palloc(sizeof(*res->aliases) * res->numAliases, res, NULL);
      if(!res->aliases)
      {
	pfree(res, servers);
	return NULL;
      }
      memcpy(res->aliases, parent->aliases, sizeof(*res->aliases) * res->numAliases);
    }
#ifdef HAVE_LIBPAM
    res->pam_service = pattach (parent->pam_service, res);
#endif
#ifdef USE_SQL
    res->sql.type = parent->sql.type;
    res->sqlHost = pattach (parent->sqlHost, res);
    res->sqlUser = pattach (parent->sqlUser, res);
    res->sqlDB = pattach (parent->sqlDB, res);
    res->sqlPass = pattach (parent->sqlPass, res);
    res->sqlCert = pattach (parent->sqlCert, res);
    res->sqlKey = pattach (parent->sqlKey, res);
    res->sqlUserQuery = pattach (parent->sqlUserQuery, res);
    res->sqlDirQuery = pattach (parent->sqlDirQuery, res);
    res->sqlConnectQuery = pattach (parent->sqlConnectQuery, res);
#endif
  }
  else
  {
    res->allowForeign = defAllowForeign;
    res->allowOutOfRange = defAllowOutOfRange;
    res->allowLowPorts = defAllowLowPorts;
    res->allowUnbound = defAllowUnbound;
    res->minPasvPort = defMinPasvPort;
    res->maxPasvPort = defMaxPasvPort;
    res->passIfInvalid = defPassIfInvalid;
    res->allowLogin = defAllowLogin;
    res->allowSecLogin = defAllowSecLogin;
    res->passwordNeeded = defPasswordNeeded;
    res->maxIdle = defMaxIdle;
    res->sleepOnFail = defSleepOnFail;
    res->maxLoginAttempts = defMaxLoginAttempts;
    res->tlsOptions = defTLSOptions;
#ifdef HAVE_LIBPAM
    res->pam_service = pstring (PACKAGE_NAME, res);
#endif
  }
  
  return res;
}

void quit_all_servers (void)
{
  pfree (servers, NULL);
  servers = NULL;
  
  close_server_sockets ();
}

void close_server_sockets (void)
{
  int i;
  
  for (i = 0; i < numServerSockets; i++)
  {
    remove_read_fd (serverSockets[i]);
    close (serverSockets[i]);
  }
  pfree (serverSockets, NULL);
  serverSockets = NULL;
  numServerSockets = 0;
}

void add_server_port(server_t *serv, int port)
{
  if(serv->numPorts++)
    serv->ports = prealloc(serv->ports, serv->numPorts * sizeof(*serv->ports));
  else
    serv->ports = palloc(sizeof(*serv->ports), serv, NULL);
  serv->ports[serv->numPorts - 1] = port;
}

void add_server_binding(server_t *serv, const char *mask)
{
  if(serv->numBindings++)
    serv->bindings = prealloc(serv->bindings, serv->numBindings *
	  sizeof(*serv->bindings));
  else
    serv->bindings = palloc(sizeof(*serv->bindings), serv, NULL);
  serv->bindings[serv->numBindings - 1] = pstring (mask, serv->bindings);
}

void add_server_range (server_t *serv, struct sockaddr *addr, struct sockaddr *mask)
{
  int sz, i;
  char *ap, *ep;
  
  if (serv->numRanges++)
    serv->ranges = prealloc (serv->ranges, serv->numRanges *
	  sizeof (*serv->ranges));
  else
    serv->ranges = palloc (sizeof (*serv->ranges), serv, NULL);
  serv->ranges[serv->numRanges - 1].addr = *(struct sockaddr_storage*)addr;
  serv->ranges[serv->numRanges - 1].mask = *(struct sockaddr_storage*)mask;
  switch (mask->sa_family)
  {
  case AF_INET:
    ap = (char*)&((struct sockaddr_in*)mask)->sin_addr;
    ep = ap + 4;
    break;
  case AF_INET6:
    ap = (char*)&((struct sockaddr_in6*)mask)->sin6_addr;
    ep = ap + 16;
    break;
  default:
    ap = ep = NULL;
    break;
  }
  
  for (sz = 0; ap < ep; ap++)
  {
    for (i = 0; i < 8; i++)
    {
      if (*ap & (1 << i))
	sz++;
    }
  }
  serv->ranges[serv->numRanges - 1].maskSize = sz;
}

void server_add_alias(server_t *serv, const char *alias, const char *user)
{
  int i;
  
  for(i = 0; i < serv->numAliases; i++)
  {
    if(!strcmp(serv->aliases[i].alias, alias))
    {
      pfree(serv->aliases[i].user, serv->aliases);
      serv->aliases[i].user = pstring(user, serv->aliases);
      return;
    }
  }
  
  if(serv->numAliases++)
    serv->aliases = prealloc(serv->aliases, serv->numAliases *
	  sizeof(*serv->aliases));
  else
    serv->aliases = palloc(sizeof(*serv->aliases), serv, NULL);
  serv->aliases[serv->numAliases - 1].alias = pstring(alias, serv->aliases);
  serv->aliases[serv->numAliases - 1].user = pstring(user, serv->aliases);
}

const char *server_expand_alias(const server_t *serv, const char *name)
{
  int i;
  
  for(i = 0; i < serv->numAliases; i++)
  {
    if(!strcmp(serv->aliases[i].alias, name))
      return serv->aliases[i].user;
  }
  
  return name;
}

static void *visited_servers;

server_t *find_server (const struct sockaddr *addr, const struct sockaddr *raddr, const server_t *last)
{
  char host[NI_MAXHOST];
  int port;
  server_t *res;
  struct addrinfo hints = {}, *aires, *aicurr;
  int i, addrlen;
  
  switch(addr->sa_family)
  {
  case AF_INET:
    port = ntohs (((struct sockaddr_in*)addr)->sin_port);
    addrlen = sizeof (struct sockaddr_in);
    break;
  case AF_INET6:
    port = ntohs (((struct sockaddr_in6*)addr)->sin6_port);
    addrlen = sizeof (struct sockaddr_in6);
    break;
  default:
    return NULL;
  }
  
  visited_servers = proot ();
  
  /* First look for numeric host to avoid dns lookup timeout. */
  if (!getnameinfo (addr, addrlen, host, sizeof(host), NULL, 0,
	NI_NUMERICHOST))
  {
    res = find_named_server (host, port, raddr, last);
    if (res == (server_t*)-1)
      last = NULL;
    else if (res)
    {
      pfree (visited_servers, NULL);
      visited_servers = NULL;
      return res;
    }
  }
  
  /* Next lookup hostname. */
  if (!getnameinfo (addr, addrlen, host, sizeof(host), NULL, 0, NI_NAMEREQD))
  {
    res = find_named_server (host, port, raddr, last);
    if (res == (server_t*)-1)
      last = NULL;
    else if (res)
    {
      pfree (visited_servers, NULL);
      visited_servers = NULL;
      return res;
    }
  }
  
  /* Next lookup configured server names. */
  hints.ai_socktype = SOCK_STREAM;
  for(res = pchild(servers, NULL); res; res = pchild(servers, res))
  {
    if (plinked (res, visited_servers))
      continue;
    
    for(i = 0; i < res->numPorts; i++)
      if(res->ports[i] == port)
	break;
    if(i == res->numPorts)
      continue;
    
    if(strchr(res->name, '*') || strchr(res->name, '?'))
      continue;
    
    aires = NULL;
    if(!getaddrinfo(res->name, NULL, &hints, &aires) && aires)
    {
      for(aicurr = aires; aicurr; aicurr = aicurr->ai_next)
      {
	if(same_addr(addr, aicurr->ai_addr, 0))
	  break;
      }
      freeaddrinfo(aires);
      if(aicurr)
      {
	if (last)
	{
	  if (res == last)
	    last = NULL;
	  pattach (res, visited_servers);
	}
	else
	{
	  pfree (visited_servers, NULL);
	  visited_servers = NULL;
	  return res;
	}
      }
    }
  }
  
  /* Next lookup configured servers bindings. */
  for(res = pchild(servers, NULL); res; res = pchild(servers, res))
  {
    if (plinked (res, visited_servers))
      continue;
    
    for(i = 0; i < res->numPorts; i++)
      if(res->ports[i] == port)
	break;
    if(i == res->numPorts)
      continue;
    
    for(i = 0; i < res->numBindings; i++)
    {
      if(strchr(res->bindings[i], '*') || strchr(res->bindings[i], '?'))
	continue;
      
      aires = NULL;
      if(!getaddrinfo(res->bindings[i], NULL, &hints, &aires) && aires)
      {
	for(aicurr = aires; aicurr; aicurr = aicurr->ai_next)
	{
	  if(same_addr(addr, aicurr->ai_addr, 0))
	    break;
	}
	freeaddrinfo(aires);
	if(aicurr)
	{
	  if (last)
	  {
	    if (res == last)
	      last = NULL;
	    pattach (res, visited_servers);
	  }
	  else
	  {
	    pfree (visited_servers, NULL);
	    visited_servers = NULL;
	    return res;
	  }
	}
      }
    }
  }
  
  /* Last find first server on port that allows nonbound. */
  for(res = pchild(servers, NULL); res; res = pchild(servers, res))
  {
    if (plinked (res, visited_servers))
      continue;
    
    if(!res->allowUnbound)
      continue;
    
    for(i = 0; i < res->numPorts; i++)
      if(res->ports[i] == port)
	break;
    if(i < res->numPorts)
    {
      if (last)
      {
	if (res == last)
	  last = NULL;
	pattach (res, visited_servers);
      }
      else
      {
	pfree (visited_servers, NULL);
	visited_servers = NULL;
	return res;
      }
    }
  }
  
  /* No server found. */
  pfree (visited_servers, NULL);
  visited_servers = NULL;
  return NULL;
}

server_t *find_named_server (const char *host, int port, const struct sockaddr *raddr, const server_t *last)
{
  server_t *res, *best = NULL;
  int i, bestMask = -1;
  int wl = (last != NULL);
  
  /* First check versus names. */
  for(res = pchild(servers, NULL); res; res = pchild(servers, res))
  {
    if (visited_servers && plinked (res, visited_servers))
      continue;
    
    if (port)
    {
      for(i = 0; i < res->numPorts; i++)
	if(res->ports[i] == port)
	  break;
      if(i == res->numPorts)
	continue;
    }
    
    if (strcasecmp(host, res->name))
      continue;
    
    if (!raddr)
      return res;
    
    if (res->numRanges)
    {
      for (i = 0; i < res->numRanges; i++)
      {
	if (check_range (raddr, (struct sockaddr*)&res->ranges[i].addr,
		  (struct sockaddr*)&res->ranges[i].mask))
	  break;
      }
      if (i == res->numRanges)
	continue;
      i = res->ranges[i].maskSize;
    }
    else
      i = 0;
    if (i <= bestMask)
      continue;
    
    if (last)
    {
      if (res == last)
	last = NULL;
      if (visited_servers)
	pattach (res, visited_servers);
    }
    else
    {
      best = res;
      bestMask = i;
    }
    break; // Only one server with a specific name.
  }
  
  if (!port)
    return best;
  
  /* Then check versus bindings. */
  for(res = pchild(servers, NULL); res; res = pchild(servers, res))
  {
    if (visited_servers && plinked (res, visited_servers))
      continue;
    
    for(i = 0; i < res->numPorts; i++)
      if(res->ports[i] == port)
	break;
    if(i == res->numPorts)
      continue;
    
    for(i = 0; i < res->numBindings; i++)
    {
      if(match_pattern(res->bindings[i], host))
	break;
    }
    if (i == res->numBindings)
      continue;
    
    if (!raddr)
      return res;
    
    if (res->numRanges)
    {
      for (i = 0; i < res->numRanges; i++)
      {
	if (check_range (raddr, (struct sockaddr*)&res->ranges[i].addr,
		  (struct sockaddr*)&res->ranges[i].mask))
	  break;
      }
      if (i == res->numRanges)
	continue;
      i = res->ranges[i].maskSize;
    }
    else
      i = 0;
    if (i <= bestMask)
      continue;
    
    if (last)
    {
      if (res == last)
	last = NULL;
      if (visited_servers)
	pattach (res, visited_servers);
    }
    else
    {
      best = res;
      bestMask = i;
    }
  }
  
  if (best)
    return best;
  if (wl && !last)
    return (void*)-1;
  return NULL;
}

int create_server_sockets(void)
{
  server_t *serv;
  int *allPorts = NULL;
  int numAllPorts = 0, maxAllPorts = 0;
  int isAll, i, j, sock;
  struct sockaddr_in addr4 = {};
  struct sockaddr_in6 addr6 = {};
  char *binding, portBuf[8];
  struct addrinfo hints = { AI_PASSIVE, 0, SOCK_STREAM }, *res, *curr;
  int maxSocks = 1;
  
  if (numServerSockets)
    maxSocks = numServerSockets;
  
  for(serv = pchild(servers, NULL); serv; serv = pchild(servers, serv))
  {
    isAll = 0;
    
    if(serv->allowUnbound)
      isAll = 1;
    else for(i = 0; i < serv->numBindings; i++)
    {
      if(strchr(serv->bindings[i], '*') || strchr(serv->bindings[i], '?'))
      {
	/* Wildcards, have to listen on all. */
	isAll = 1;
	break;
      }
    }
    
    if (isAll)
    {
      for (i = 0; i < serv->numPorts; i++)
      {
	for (j = 0; j < numAllPorts; j++)
	{
	  if (allPorts[j] == serv->ports[i])
	    break;
	}
	if (j < numAllPorts)
	  continue;
	if (numAllPorts >= maxAllPorts)
	{
	  if (maxAllPorts)
	  {
	    maxAllPorts *= 2;
	    allPorts = trealloc (allPorts, maxAllPorts * sizeof (*allPorts));
	  }
	  else
	  {
	    maxAllPorts = 10; /* Start high, only temporary. */
	    allPorts = talloc (maxAllPorts * sizeof (*allPorts));
	  }
	  if (!allPorts)
	  {
	    syslog (LOG_ERR, "talloc: %m");
	    return -1;
	  }
	}
	allPorts[numAllPorts++] = serv->ports[i];
      }
    }
    else
    {
      for (i = 0; i < serv->numPorts; i++)
      {
	sprintf (portBuf, "%d", serv->ports[i]);
	
	for (j = -1; j < serv->numBindings; j++)
	{
	  if (j == -1)
	  {
	    binding = serv->name;
	    if (strchr (binding, '*') || strchr (binding, '?'))
	      continue;
	  }
	  else
	    binding = serv->bindings[j];
	  
	  if ((sock = getaddrinfo (binding, portBuf, &hints, &res)) || !res)
	  {
	    syslog (LOG_DEBUG, "getaddrinfo (%s, %s): %s", binding, portBuf,
		  gai_strerror (sock));
	    continue;
	  }
	  for (curr = res; curr; curr = curr->ai_next)
	  {
	    sock = socket (curr->ai_family, curr->ai_socktype,
		  curr->ai_protocol);
	    if (sock < 0)
	      syslog (LOG_DEBUG, "socket (%d): %m", curr->ai_family);
	    else
	    {
	      isAll = -1;
	      setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &isAll, sizeof(isAll));
#ifdef IPV6_V6ONLY
	      // Only IPv6 please. This will fail on IPv4 naturally.
	      setsockopt (sock, IPPROTO_IPV6, IPV6_V6ONLY, &isAll, sizeof (isAll));
#endif
	      
	      if (bind (sock, curr->ai_addr, curr->ai_addrlen))
	      {
		syslog (LOG_DEBUG, "bind (%s, %s): %m", binding, portBuf);
		close (sock);
	      }
	      else if (listen (sock, 128))
	      {
		syslog (LOG_ERR, "listen (%s, %s): %m", binding, portBuf);
		close (sock);
	      }
	      else
	      {
		 // Used serv as user before, but might get wrong one then.
		add_read_fd (sock, connection_accepter, NULL);
		if (!serverSockets)
		{
		  serverSockets = palloc (sizeof (int), NULL, NULL);
		  maxSocks = 1;
		  numServerSockets = 0;
		}
		else if (numServerSockets == maxSocks)
		{
		  maxSocks *= 2;
		  serverSockets = prealloc (serverSockets, sizeof (int) * maxSocks);
		}
		if (serverSockets)
		  serverSockets[numServerSockets++] = sock;
	      }
	    }
	  }
	  freeaddrinfo (res);
	}
      }
    }
  }
  
  for (i = 0; i < numAllPorts; i++)
  {
    sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
    if(sock < 0)
      syslog (LOG_DEBUG, "socket (PF_INET): %m");
    else
    {
      j = -1;
      setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &j, sizeof(j));
      
      addr4.sin_family = AF_INET;
#ifdef HAVE_STRUCT_SOCKADDR_IN_SIN_LEN
      addr4.sin_len = sizeof(addr4);
#endif
      addr4.sin_port = htons (allPorts[i]);
      
      if(bind(sock, (struct sockaddr*)&addr4, sizeof(addr4)))
      {
	syslog (LOG_ERR, "bind (PF_INET, %d): %m", allPorts[i]);
	close(sock);
      }
      else if(listen(sock, 128))
      {
	syslog (LOG_ERR, "listen (PF_INET, %d): %m", allPorts[i]);
	close(sock);
      }
      else
      {
	add_read_fd(sock, connection_accepter, NULL);
	if (!serverSockets)
	{
	  serverSockets = palloc (sizeof (int), NULL, NULL);
	  maxSocks = 1;
	  numServerSockets = 0;
	}
	else if (numServerSockets == maxSocks)
	{
	  maxSocks *= 2;
	  serverSockets = prealloc (serverSockets, sizeof (int) * maxSocks);
	}
	if (serverSockets)
	  serverSockets[numServerSockets++] = sock;
      }
    }
    
    sock = socket(PF_INET6, SOCK_STREAM, IPPROTO_TCP);
    if(sock < 0)
      syslog (LOG_DEBUG, "socket (PF_INET6): %m");
    else
    {
      j = -1;
      setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &j, sizeof(j));
#ifdef IPV6_V6ONLY
      // Only IPv6 please.
      setsockopt (sock, IPPROTO_IPV6, IPV6_V6ONLY, &j, sizeof (j));
#endif
      
      addr6.sin6_family = AF_INET6;
#ifdef HAVE_STRUCT_SOCKADDR_IN_SIN_LEN
      addr6.sin6_len = sizeof(addr6);
#endif
      addr6.sin6_port = htons (allPorts[i]);
      
      if(bind(sock, (struct sockaddr*)&addr6, sizeof(addr6)))
      {
	syslog (LOG_ERR, "bind (PF_INET6, %d): %m", allPorts[i]);
	close(sock);
      }
      else if(listen(sock, 128))
      {
	syslog (LOG_ERR, "listen (PF_INET6, %d): %m", allPorts[i]);
	close(sock);
      }
      else
      {
	add_read_fd(sock, connection_accepter, NULL);
	if (!serverSockets)
	{
	  serverSockets = palloc (sizeof (int), NULL, NULL);
	  maxSocks = 1;
	  numServerSockets = 0;
	}
	else if (numServerSockets == maxSocks)
	{
	  maxSocks *= 2;
	  serverSockets = prealloc (serverSockets, sizeof (int) * maxSocks);
	}
	if (serverSockets)
	  serverSockets[numServerSockets++] = sock;
      }
    }
  }
  
  return 0;
}

int connection_accepter(int sock, void *user, int urgent)
{
  int csock = accept(sock, NULL, NULL);
  
  if(csock >= 0)
  {
    if (forkConnections)
    {
      int accSock = connect_accounter ();
      int accCSock = connect_accounter ();
      const char *err;
      
      switch(fork())
      {
      case -1:
        // Error
	if (selfFailedFork)
	{
          syslog (LOG_WARNING, "fork: %m");
	  new_connection (csock, user, accSock);
        }
        else
        {
          err = strerror (errno);
          syslog (LOG_ERR, "fork: %m");
          write (csock, "421 ", 4);
          write (csock, err, strlen (err));
          write (csock, "\r\n", 2);
          close (csock);
          close (accSock);
        }
	close (accCSock);
	return 0;
      case 0:
        // Child
	break;
      default:
        // Parent
	close(csock);
	close (accSock);
	close (accCSock);
	return 0;
      }
      fakeChroot = doFakeChroot;
      close_server_sockets ();
      close_accounter (1);
      closelog ();
      openlog (getprogname (), LOG_PID | LOG_NDELAY | (debug? LOG_PERROR : 0),
	    LOG_FTP);
      events_init ();
      if (accClientSock != -1)
	close (accClientSock);
      accClientSock = accCSock;
      if (accClientSock >= 0)
      {
	add_read_fd (accClientSock, accounter_master_reply, NULL);
	accounter (accClientSock, "SET PID %d\n", (int)getpid ());
      }
      new_connection (csock, user, accSock);
      quit_all_servers ();
      return 1;
    }
    else
      new_connection (csock, user, -1);
  }
  else
    syslog (LOG_ERR, "accept: %m");
  return 0;
}


syntax highlighted by Code2HTML, v. 0.9.1