#include	<stdio.h>
#include	<stdlib.h>
#include	<string.h>
#include	<alloc.h>
#include	<sys/time.h>
#include	<bios.h>

#include	"smbhdr.h"
#include	"smbpd.h"
#include	"netio.h"
#include	"lpt.h"
#include	"util.h"

/* key definitions */
#define	CF1	0x5e				/* CTRL-F1 */
#define	CF2	0x5f
#define	CF3	0x60

#define			is_net_addr(a)	(((a) & ~sin_mask) == 0)
#define			is_on_net(a,n)	(((a) & sin_mask) == (n))

char			my_host_name[65] = "";
char			my_work_group[17] = "workgroup";
unsigned long		my_ip_addr_n;	/* net order version of my_ip_addr */
time_t			my_currenttime;
struct tm		*my_ts;		/* -> localtime struct */
time_t			my_tzsec;	/* add to localtime to get GMT */
char			*loghost;
int			check_subnet = 1;
int			nallow, ndeny;
longword		allow[MAXALLOW], deny[MAXDENY];
int			nconn = 0;
struct smbconn		*conn[MAXCONN];
unsigned char		*insmbreq;
static udp_Socket	log;
static char		logbuf[256];
static udp_Socket	namein;
static udp_Socket	nameout;

extern longword		my_ip_addr;
extern longword		sin_mask;

void check_key(void)
{
	int		key, i;

	if (((key = bioskey(0)) & 0xff) != 0)
		return;			/* ASCII key */
	key = (key >> 8) & 0xff;
	if (key < CF1 || key > CF3)
		return;
	key -= CF1;		/* which printer? */
	for (i = 0; i < nconn; ++i)
	{
		if (conn[i]->printer == key && conn[i]->state != WANT_CONN)
		{
			if (reinit)
			{
				(void)printf("%s Aborting job and reinitialising %s\n",
					ptime(), lpt[key].name);
				(void)biosprint(1, 0, key);
				sleep(2);		/* let printer settle */
			}
			else
			{
				(void)printf("%s Aborting job on %s\n",
					ptime(), lpt[key].name);
			}
			conn[i]->state = ABORT;		/* abort job */
			return;
		}
	}
}

void make_list(char *adjective, char *s,
	longword table[], int *nentries, int maxtable)
{
	char		*p;
	longword	ip;
	int		i;
	char		buffer[64];

	for ( ; *s != '\0'; s = p + 1)
	{
		if ((p = strchr(s, ',')) == 0)
			p = s + strlen(s) - 1;
		else
			*p = '\0';		/* mark end */
		if ((ip = resolve(s)) == (longword)0)
			continue;
		if (*nentries >= maxtable)
			continue;		/* should print warning */
		table[(*nentries)++] = ip;
	}
	for (i = 0; i < *nentries; ++i)
		(void)printf("%s ", Inet_ntoa(table[i]));
	if (*nentries > 0)
		(void)printf("%s access\n", adjective);
}

static int same_subnet(longword client)
{
	/* assumes that bit operations can be done on longword */
	return (((client ^ my_ip_addr) & sin_mask) == 0);
}

static int in_list(longword ip, longword table[], int entries, int maxtable, int answer)
{
	int		i;

	for (i = 0; i < entries && i < maxtable; ++i)
		if ((is_net_addr(table[i]) && is_on_net(ip, table[i]))
			|| ip == table[i])
			return (answer);
	return (!answer);
}

static int check_access(longword ip)
{
	if (check_subnet && same_subnet(ip))
		return (1);
	if (nallow > 0)
		return (in_list(ip, allow, nallow, MAXALLOW, 1));
	if (ndeny > 0)
		return (in_list(ip, deny, ndeny, MAXDENY, 0));
	/* otherwise denied */
	return (0);
}

void show_stats(struct smbconn *client)
{
	time_t		elapsed;

	if (client->printer < 0)
		return;
	elapsed = time(0) - client->starttime;
	(void)printf("%s %s: %ld bytes %ld seconds", ptime(),
		lpt[client->printer].name, client->joblen, elapsed);
	if (elapsed > 0)
		(void)printf(" %ld bytes/second", client->joblen / elapsed);
	(void)printf("\n");
}

char *Inet_ntoa(unsigned long ipaddr)
{
	static char	dotted_name[32];

	return (inet_ntoa(dotted_name, ipaddr));
}

static int have_bytes(tcp_Socket *socket, int to_read)
{
	return (sock_rbused(socket) < to_read ? 0 : to_read);
}

static int read_smb_len(tcp_Socket *socket, unsigned char *buffer)
{
	return (sock_fastread(socket, buffer, 4), 4);
}

static int read_smb_rest(tcp_Socket *socket, unsigned char *buffer, long length)
{
	unsigned char		*b;
	int			l, n;

	/* read until request completely satisfied or some read
	   error occured
	   don't try to read more than maxshort bytes each time */
	for (b = buffer; length > 0; )
	{
		l = length > 32767 ? 32767 : length;
		if ((n = sock_read(socket, b, l)) <= 0)
			break;
		b += n;
		length -= n;
	}
	return (length <= 0);
}

void init_net(void)
{
	int			i;
	struct smbconn		*client;
	struct tm		*ts;

	callback_init();
	dbuginit();
	sock_init();
	my_ip_addr_n = htonl(my_ip_addr);
	gethostname(my_host_name, sizeof(my_host_name)-1);
#ifdef	DEBUG
	printf("Hostname %s workgroup %s address %s\n",
		my_host_name, my_work_group, Inet_ntoa(my_ip_addr));
#endif
	my_currenttime = time(0);
	ts = localtime(&my_currenttime);
	my_tzsec = timezone - (ts->tm_isdst > 0 ? 3600 : 0);
	for (i = 0; i < MAXCONN; ++i)
	{
		client = conn[i] = malloc(sizeof(struct smbconn));
		if (client == 0)
			break;
		memset(client, 0, sizeof(struct smbconn));
		client->connid = i;
		client->state = INIT;
	}
	nconn = i;
	if ((insmbreq = farmalloc((unsigned long)MAXBUFFER)) == 0)
	{
		printf("Out of heap memory\n");
		exit(1);
	}
	/* open a broadcast listener on NSPORT */
	udp_open(&namein, NSPORT, 0L, 0, 0);
}

void init_log(void)
{
	longword		logip;

	if (loghost == 0 ||
		(logip = resolve(loghost)) == (longword)0 ||
		udp_open(&log, 0, logip, LOGPORT, 0) == 0)
	{
		loghost = 0;
		return;
	}
	(void)sprintf(logbuf, LOG_TAG PROGRAM " " VERSION
		", %d printer(s)\n", nlpt);
	(void)sock_fastwrite(&log, logbuf, strlen(logbuf));
}

void poll_conn(int i)
{
	struct smbconn		*client;
	struct sockaddr		cliaddr;
	char			*name;
	int			l;

	client = conn[i];
	switch (client->state)
	{
	case INIT:
		tcp_listen(&client->socket, SMBPORT, 0L, 0, 0, 0);
		(void)printf("%s Connection %d listening on TCP port %u\n",
			ptime(), i, SMBPORT);
		client->printer = -1;
		client->largebuf = 0;
		client->state = WANT_CONN;
		break;
	case WANT_CONN:
		tcp_tick(0);
		if (!sock_established(&client->socket))
			break;
		l = sizeof(cliaddr);
		cliaddr.s_ip = 0;
		name = getpeername(&client->socket, &cliaddr, &l) == 0 ?
			Inet_ntoa(cliaddr.s_ip) : "?";
		if (check_access(cliaddr.s_ip))
		{
			(void)printf("%s Connection %d from %s\n",
				ptime(), i, name);
			client->to_read = 4;
			client->state = WANT_LENGTH;
		}
		else
		{
			(void)printf("%s Connection %d from %s refused\n",
				ptime(), i, name);
			client->state = CLOSING;
		}
		break;
	case WANT_LENGTH:
		if (!tcp_tick(&client->socket))
		{
			client->state = CLOSING;
			break;
		}
		if (have_bytes(&client->socket, client->to_read) <= 0)
			break;
		if (read_smb_len(&client->socket, client->smallbuf) < 0)
		{
			client->state = CLOSING;
			break;
		}
		client->to_read = client->smblen = get_smb_length(client->smallbuf);
		if (client->smblen > sizeof(client->smallbuf) - 4)
		{
			if (client->largebuf == 0 ||
				(client->largebuf != 0 &&
				client->smblen > MAXBUFFER - 4))
			{
				printf("%s Buffer not large enough\n", ptime());
				client->state = CLOSING;
				break;
			}
			/* copy the length bytes over */
			memcpy(client->largebuf, client->smallbuf, 4);
			client->sh = (struct smbhdr *)client->largebuf;
		}
		else
			client->sh = (struct smbhdr *)client->smallbuf;
		client->state = WANT_SMB;
		break;
	case WANT_SMB:
		if (!tcp_tick(&client->socket))
		{
			client->state = CLOSING;
			break;
		}
		if (!read_smb_rest(&client->socket, (unsigned char *)client->sh + 4, client->smblen))
		{
			client->state = CLOSING;
			break;
		}
		if (dispatch(client) < 0)
		{
			client->state = CLOSING;
			break;
		}
		if (client->state == WANT_SMB)
		{
			client->to_read = 4;
			client->state = WANT_LENGTH;
		}
		break;
	case PRINTING:
		if (print_data(client))
		{
			client->to_read = 4;
			client->state = WANT_LENGTH;
		}
		if (!tcp_tick(&client->socket))
		{
			client->state = CLOSING;
			break;
		}
		break;
	case CLOSING:
		sock_flush(&client->socket);
		sock_close(&client->socket);
		goto release;
		break;
	case ABORT:
		sock_abort(&client->socket);
release:
		/* free up printer */
		if (0 <= client->printer && client->printer < MAXLPT
			&& lpt[client->printer].avail == BUSY)
		{
			lpt[client->printer].avail = FREE;
			lpt[client->printer].client = 0;
		}
		client->fid = 0;
		client->printer = -1;
		client->largebuf = 0;
		client->state = INIT;
		break;
	}
}

void report_change(char *msg, int printer)
{
	(void)printf("%s %s %s\n", ptime(), lpt[printer].name, msg);
	if (loghost != 0)		/* log to syslogd */
	{
		(void)sprintf(logbuf, LOG_TAG "%s %s\n", lpt[printer].name, msg);
		(void)sock_fastwrite(&log, logbuf, strlen(logbuf));
	}
}

int write_smb(struct smbconn *client, unsigned char *buffer, int length)
{
	return (sock_write(&client->socket, buffer, length), length);
}

void check_ns(void)
{
	tcp_tick(&namein);
	if (sock_dataready(&namein))
		ns_request();
}

int read_ns(unsigned char *buffer, int length, unsigned long *fromaddr)
{
	int			ret, i;
	struct sockaddr		sender;

	ret = sock_rbused(&namein);
	sock_fastread(&namein, buffer, length);
	i = sizeof(sender);
	if (getpeername(&namein, &sender, &i) != 0)
		return (0);
	*fromaddr = sender.s_ip;
	udp_open(&namein, NSPORT, 0L, 0, 0);		/* reprime */
	return (ret);
}

int write_ns(unsigned char *buffer, int length, unsigned long toaddr)
{
	udp_open(&nameout, 0, toaddr, NSPORT, 0);
	(void)sock_fastwrite(&nameout, buffer, length);
	return (length);
}
