/* 
 * Copyright 2011 Stefan Lankes, Chair for Operating Systems,
 *                               RWTH Aachen University
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * This file is part of MetalSVM. 
 */

#include <metalsvm/stddef.h>
#include <metalsvm/stdio.h>
#include <metalsvm/time.h>
#include <metalsvm/tasks.h>
#include <metalsvm/syscall.h>
#include <metalsvm/errno.h>

/*
 * This implements a netio server.
 * The client sends a command word (4 bytes) then a data length word (4 bytes).
 * If the command is "receive", the server is to consume "data length" bytes into
 * a circular buffer until the first byte is non-zero, then it is to consume
 * another command/data pair.
 * If the command is "send", the server is to send "data length" bytes from a circular
 * buffer with the first byte being zero, until "some time" (6 seconds in the
 * current netio131.zip download) has passed and then send one final buffer with
 * the first byte being non-zero. Then it is to consume another command/data pair.
 */

/* See http://www.nwlab.net/art/netio/netio.html to get the netio tool */

#ifdef CONFIG_LWIP
#include <lwip/sockets.h>

typedef struct
{
	uint32_t cmd;
	uint32_t data;
} CONTROL;

#define CMD_QUIT  0
#define CMD_C2S   1
#define CMD_S2C   2
#define CMD_RES   3

#define CTLSIZE sizeof(CONTROL)
#define DEFAULTPORT 0x494F /* "IO" */
#define TMAXSIZE 65536

//static int tSizes[] = {1024, 2048, 4096, 8192, 16384, 32767};
//static size_t ntSizes = sizeof(tSizes) / sizeof(int);
static int nPort = DEFAULTPORT;
static const int sobufsize = 131072;
static struct in_addr addr_local;

static int send_data(int socket, void *buffer, size_t size, int flags)
{
	int rc = send(socket, buffer, size, flags);

	if (rc < 0)
	{
		kprintf("send failed: %d\n", rc);
		return -1;
	}

	if (rc != size)
		return 1;

	return 0;
}

static int recv_data(int socket, void *buffer, size_t size, int flags)
{
	size_t rc = recv(socket, buffer, size, flags);

	if (rc < 0) {
		kprintf("recv failed: %d\n", rc);
		return -1;
	}

	if (rc != size)
		return 1;

	return 0;
}

static char *InitBuffer(size_t nSize)
{
	char *cBuffer = kmalloc(nSize); 

	memset(cBuffer, 0xFF, nSize); 
	cBuffer[0] = 0;

	return cBuffer;
}

static char *PacketSize(int nSize)
{
	static char szBuffer[64];

	if ((nSize % 1024) == 0 || (nSize % 1024) == 1023)
		ksprintf(szBuffer, "%2dk", (nSize + 512) / 1024);
	else
		ksprintf(szBuffer, "%d", nSize);

	return szBuffer;
}

static int TCPServer(void* arg)
{
	char *cBuffer;
	CONTROL ctl;
	uint64_t nData;
	struct sockaddr_in sa_server, sa_client;
	int server, client;	
	socklen_t length;
	struct timeval tv;
	fd_set fds;
	int rc;
	int nByte;
	int err;
	uint64_t start, end;
	uint32_t freq = get_cpu_frequency(); /* in MHz */

	if ((cBuffer = InitBuffer(TMAXSIZE)) == NULL) {
    		kprintf("Netio: Not enough memory\n");
		return -EINVAL;
	}

	if ((server = socket(PF_INET, SOCK_STREAM, 0)) < 0) {
		kprintf("socket failed: %d\n", server);
 		kfree(cBuffer, TMAXSIZE);
 		return -1;
	}

	setsockopt(server, SOL_SOCKET, SO_RCVBUF, (char *) &sobufsize, sizeof(sobufsize));
	setsockopt(server, SOL_SOCKET, SO_SNDBUF, (char *) &sobufsize, sizeof(sobufsize));

	sa_server.sin_family = AF_INET;
	sa_server.sin_port = htons(nPort);
	sa_server.sin_addr = addr_local;

	if ((err = bind(server, (struct sockaddr *) &sa_server, sizeof(sa_server))) < 0)
	{
		kprintf("bind failed: %d\n", err);
		closesocket(server);
		kfree(cBuffer, TMAXSIZE);
		return -1;
	}

	if ((err = listen(server, 2)) != 0)
	{
		kprintf("listen failed: %d\n", err);
		closesocket(server);
		kfree(cBuffer, TMAXSIZE);
		return -1;
	}

	for (;;)
	{
		kprintf("TCP server listening.\n");

		FD_ZERO(&fds);
		FD_SET(server, &fds);
		tv.tv_sec  = 3600;
		tv.tv_usec = 0;

		if ((rc = select(FD_SETSIZE, &fds, 0, 0, &tv)) < 0)
		{
			kprintf("select failed: %d\n", rc);
			break;
		}

		if (rc == 0 || FD_ISSET(server, &fds) == 0)
			continue;

		length = sizeof(sa_client);
		if ((client = accept(server, (struct sockaddr *) &sa_client, &length)) == -1)
			continue;

		setsockopt(client, SOL_SOCKET, SO_RCVBUF, (char *) &sobufsize, sizeof(sobufsize));
		setsockopt(client, SOL_SOCKET, SO_SNDBUF, (char *) &sobufsize, sizeof(sobufsize));

		kprintf("TCP connection established ... ");

		for (;;)
		{
			if (recv_data(client, (void *) &ctl, CTLSIZE, 0))
				break;

			ctl.cmd = ntohl(ctl.cmd);
			ctl.data = ntohl(ctl.data);

			if (ctl.cmd == CMD_C2S)
			{
				start = rdtsc();

				kprintf("\nReceiving from client, packet size %s ... ", PacketSize(ctl.data));
				cBuffer[0] = 0;
				nData = 0;
				
				do {
					for (nByte = 0; nByte < ctl.data; )
					{
						rc = recv(client, cBuffer + nByte, ctl.data - nByte, 0);

						if (rc < 0)
						{
							kprintf("recv failed: %d\n", rc);
							break;
						}
            
						if (rc > 0)
							nByte += rc;
					}

					nData += ctl.data;
				} while (cBuffer[0] == 0 && rc > 0);

				end = rdtsc();
				kprintf("Time to receive %llu bytes: %llu nsec (ticks %llu)\n", nData, ((end-start)*1000ULL)/freq, end-start);
			} else if (ctl.cmd == CMD_S2C) {
				start = rdtsc();

				kprintf("\nSending to client, packet size %s ... ", PacketSize(ctl.data));
				cBuffer[0] = 0;
				nData = 0;

				do
				{
					//GenerateRandomData(cBuffer, ctl.data);

					for (nByte = 0; nByte < ctl.data; )
					{
						rc = send(client, cBuffer + nByte, ctl.data - nByte, 0);

						if (rc < 0)
						{
							kprintf("send failed: %d\n", rc);
							break;
						}

						if (rc > 0)
							nByte += rc;
					}

					nData += ctl.data;
					end = rdtsc();
				} while((end-start)/freq < 6000000ULL /* = 6s */);

				cBuffer[0] = 1;

				if (send_data(client, cBuffer, ctl.data, 0))
					break;

				end = rdtsc();
				kprintf("Time to send %llu bytes: %llu nsec (ticks %llu)\n", nData, ((end-start)*1000ULL)/freq, end-start);
			} else /* quit */
				break;
		}

		kprintf("\nDone.\n");

		closesocket(client);

		if (rc < 0)
			break;
	}

	closesocket(server);
	kfree(cBuffer, TMAXSIZE);

	return 0;
}

int netio_init(void)
{
	addr_local.s_addr = INADDR_ANY;

	return create_kernel_task(NULL, TCPServer, NULL, NORMAL_PRIO);
}
#endif