/*
 * Copyright 2013 Google Inc.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 */
/*
 * Author: ncardwell@google.com (Neal Cardwell)
 *
 * TCP connection handling for remote on-the-wire testing using a real NIC.
 */

#include "wire_conn.h"

#include <errno.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <stdlib.h>
#include <unistd.h>

#include "logging.h"
#include "tcp.h"

/* Cap the max message we're willing to read, so remote side can't OOM us. */
#define MAX_MESSAGE_BYTES (10*1000*1000)

struct wire_conn *wire_conn_new(void)
{
	DEBUGP("wire_conn_new\n");
	struct wire_conn *wire_conn = calloc(1, sizeof(struct wire_conn));
	wire_conn->fd = -1;

	return wire_conn;
}

void wire_conn_free(struct wire_conn *conn)
{
	if (conn->fd != -1)
		close(conn->fd);
	free(conn->in.buf);
	memset(conn, 0, sizeof(*conn));  /* paranoia: catch bugs */
	free(conn);
}

/* Create the TCP socket. */
static void create_tcp_socket(struct wire_conn *conn)
{
	assert(conn->fd == -1);
	conn->fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
	if (conn->fd < 0)
		die_perror("socket");
}

/* Set default TCP socket options for decent performance. */
static void set_default_tcp_options(struct wire_conn *conn)
{
	int val;

	DEBUGP("set_default_tcp_options fd %d\n", conn->fd);

	/* Disable Nagle algorithm so packets go out ASAP regardless of size. */
	val = 1;
	if (setsockopt(conn->fd, SOL_TCP, TCP_NODELAY, &val, sizeof(val)) < 0)
		die_perror("setsockopt TCP_NODELAY");

	/* Set receive buffer to allow high throughput. */
	val = 128*1024;
	if (setsockopt(conn->fd, SOL_SOCKET, SO_RCVBUF, &val,
		       sizeof(val)) < 0) {
		die_perror("setsockopt SO_RCVBUF");
	}

	/* Set send buffer to allow high throughput and avoid blocking. */
	val = 128*1024;
	if (setsockopt(conn->fd, SOL_SOCKET, SO_SNDBUF, &val,
		       sizeof(val)) < 0) {
		die_perror("setsockopt SO_SNDBUF");
	}
}

void wire_conn_connect(struct wire_conn *conn,
		       const struct ip_address *ip, u16 port)
{
	DEBUGP("wire_conn_connect\n");
	struct sockaddr_storage sa;
	socklen_t length = 0;

	create_tcp_socket(conn);
	set_default_tcp_options(conn);

	/* Do a blocking connect to the server. */
	ip_to_sockaddr(ip, port, (struct sockaddr *)&sa, &length);
	if (connect(conn->fd, (struct sockaddr *)&sa, length) < 0) {
		char ip_string[ADDR_STR_LEN];
		die("error connecting to wire server at %s:%d: %s\n",
		    ip_to_string(ip, ip_string), port, strerror(errno));
	}
}

void wire_conn_bind_listen(struct wire_conn *listen_conn, u16 port)
{
	DEBUGP("wire_conn_bind_listen\n");
	struct sockaddr_in sa_v4;
	int val;

	create_tcp_socket(listen_conn);

	val = 1;
	if (setsockopt(listen_conn->fd, SOL_SOCKET, SO_REUSEADDR,
		       &val, sizeof(val)) < 0) {
		die_perror("setsockopt SO_REUSEADDR");
	}

	memset(&sa_v4, 0, sizeof(sa_v4));
#if defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) || defined(__APPLE__)
	sa_v4.sin_len = sizeof(sa_v4);
#endif
	sa_v4.sin_family = AF_INET;
	sa_v4.sin_port = htons(port);
	sa_v4.sin_addr.s_addr = INADDR_ANY;

	if (bind(listen_conn->fd, (struct sockaddr *)&sa_v4,
		 sizeof(sa_v4)) < 0) {
		die_perror("bind");
	}

	if (listen(listen_conn->fd, 100) < 0)
		die_perror("listen");
}

void wire_conn_accept(struct wire_conn *listen_conn,
		      struct wire_conn **accepted_conn)
{
	int fd = -1;

	DEBUGP("wire_conn_accept\n");

	fd = accept(listen_conn->fd, NULL, NULL);
	if (fd < 0)
		die_perror("accept");

	DEBUGP("accepted fd %d\n", fd);

	*accepted_conn = wire_conn_new();
	(*accepted_conn)->fd = fd;

	set_default_tcp_options(*accepted_conn);
}

/* Do blocking writes until all bytes are written.  Given our large
 * socket buffer size and typically small write sizes, in practice all
 * the writes should complete in one call.
 */
static int write_bytes(struct wire_conn *conn,
		       const void *buf, int buf_len)
{
	while (buf_len > 0) {
		int bytes_written = write(conn->fd, buf, buf_len);
		if (bytes_written < 0) {
			if (errno == EINTR || errno == EAGAIN) {
				continue;
			} else {
				perror("TCP socket write");
				return STATUS_ERR;
			}
		}
		assert(bytes_written <= buf_len);
		buf_len -= bytes_written;
		buf += bytes_written;
	}
	return STATUS_OK;
}

int wire_conn_write(struct wire_conn *conn,
		    enum wire_op_t op,
		    const void *buf, int buf_len)
{
	DEBUGP("wire_conn_write -> op: %s\n",
	       wire_op_to_string(op));
	struct wire_header header;

	header.length	= htonl(sizeof(header) + buf_len);
	header.op	= htonl(op);

	if (write_bytes(conn, &header, sizeof(header)))
		return STATUS_ERR;

	if (write_bytes(conn, buf, buf_len))
		return STATUS_ERR;

	return STATUS_OK;
}

/* Do blocking reads until we've read the given number of bytes. */
static int read_bytes(struct wire_conn *conn,
		      void *buf, int buf_len)
{
	while (buf_len > 0) {
		int bytes_read = read(conn->fd, buf, buf_len);
		if (bytes_read < 0) {
			if (errno == EINTR || errno == EAGAIN) {
				continue;
			} else {
				perror("TCP socket read");
				return STATUS_ERR;
			}
		} else if (bytes_read == 0) {
			fprintf(stderr, "remote side closed connection\n");
			return STATUS_ERR;
		}
		assert(bytes_read <= buf_len);
		buf_len -= bytes_read;
		buf += bytes_read;
	}
	return STATUS_OK;
}

int wire_conn_read(struct wire_conn *conn,
		   enum wire_op_t *op,
		   void **buf, int *buf_len)
{
	DEBUGP("wire_conn_read\n");

	struct wire_header header;

	if (read_bytes(conn, &header, sizeof(header)))
		return STATUS_ERR;

	*op = ntohl(header.op);

	DEBUGP("wire_conn_read -> op: %s\n", wire_op_to_string(*op));

	*buf_len = ntohl(header.length) - sizeof(header);
	if ((*buf_len < 0) || (*buf_len > MAX_MESSAGE_BYTES)) {
		fprintf(stderr, "invalid length %d from remote wire conn\n",
			*buf_len);
		return STATUS_ERR;
	}

	if (conn->in.buf_space < *buf_len) {
		free(conn->in.buf);
		conn->in.buf_space = 2 * *buf_len;
		conn->in.buf = malloc(conn->in.buf_space);
	}

	*buf = conn->in.buf;

	if (read_bytes(conn, *buf, *buf_len))
		return STATUS_ERR;

	return STATUS_OK;
}