From 21bc223da9f23005a800bfb6eda8e1ab03f731bf Mon Sep 17 00:00:00 2001
From: Michael Tuexen <tuexen@fh-muenster.de>
Date: Sat, 16 May 2015 14:24:57 +0200
Subject: [PATCH] Inject appropriate HEARTBEAT-ACK chunks in response HEARTBEAT
 chunks.

When outbound HEARTBEAT chunks are observed, inject the appropriate
HEARTBEAT-ACK chunk when scripted.
While there, plug a memory leak and rename a variable for consistency.
This fixes https://github.com/nplab/packetdrill/issues/11
---
 gtests/net/packetdrill/run_packet.c  | 170 ++++++++++++++++++++-------
 gtests/net/packetdrill/sctp_packet.c |   1 +
 gtests/net/packetdrill/socket.c      |   9 +-
 gtests/net/packetdrill/socket.h      |   6 +-
 4 files changed, 139 insertions(+), 47 deletions(-)

diff --git a/gtests/net/packetdrill/run_packet.c b/gtests/net/packetdrill/run_packet.c
index 4278dcb3..4a3fe2f4 100644
--- a/gtests/net/packetdrill/run_packet.c
+++ b/gtests/net/packetdrill/run_packet.c
@@ -1997,12 +1997,14 @@ static int do_outbound_script_packet(
 	struct sctp_chunk *chunk;
 	struct sctp_init_ack_chunk *init_ack;
 	struct sctp_cookie_echo_chunk *cookie_echo;
+	struct sctp_heartbeat_chunk *heartbeat;
+	struct sctp_heartbeat_ack_chunk *heartbeat_ack;
 	struct sctp_parameter *parameter;
 	struct sctp_state_cookie_parameter *state_cookie;
 	int result = STATUS_ERR;		/* return value */
 	struct packet *live_packet = NULL;
 	u16 cookie_length, chunk_length, parameter_length, parameters_length;
-	u16 padding_length;
+	u16 value_length, padding_length;
 
 	DEBUGP("do_outbound_script_packet\n");
 	if ((packet->icmpv4 != NULL) || (packet->icmpv6 != NULL)) {
@@ -2037,17 +2039,22 @@ static int do_outbound_script_packet(
 	if (sniff_outbound_live_packet(state, socket, &live_packet, error))
 		goto out;
 
-	if (socket->state == SOCKET_PASSIVE_PACKET_RECEIVED) {
-		if (packet->tcp && packet->tcp->syn && packet->tcp->ack) {
+	if (packet->tcp) {
+		if ((socket->state == SOCKET_PASSIVE_PACKET_RECEIVED) &&
+		    packet->tcp->syn && packet->tcp->ack) {
 			socket->state = SOCKET_PASSIVE_SYNACK_SENT;
 			socket->live.local_isn = ntohl(live_packet->tcp->seq);
 			DEBUGP("SYNACK live.local_isn: %u\n",
 			       socket->live.local_isn);
 		}
-		if (live_packet->sctp) {
-			chunk = sctp_chunks_begin(live_packet, &chunk_iter, error);
-			if ((*error == NULL) &&
-			    (chunk != NULL) &&
+	}
+	if (live_packet->sctp) {
+		for (chunk = sctp_chunks_begin(live_packet, &chunk_iter, error);
+		     chunk != NULL;
+		     chunk = sctp_chunks_next(&chunk_iter, error)) {
+			if (*error != NULL)
+				goto out;
+			if ((socket->state == SOCKET_PASSIVE_PACKET_RECEIVED) &&
 			    (chunk->type == SCTP_INIT_ACK_CHUNK_TYPE)) {
 				chunk_length = ntohs(chunk->length);
 				if (chunk_length < sizeof(struct sctp_init_ack_chunk)) {
@@ -2083,8 +2090,19 @@ static int do_outbound_script_packet(
 						cookie_echo->length = htons(chunk_length);
 						memcpy(cookie_echo->cookie, state_cookie->cookie, cookie_length);
 						memset(cookie_echo->cookie + cookie_length, 0, padding_length);
-						socket->prepared_state_cookie = cookie_echo;
-						socket->prepared_state_cookie_length = chunk_length + padding_length;
+						if (socket->prepared_cookie_echo != NULL) {
+							 /* paranoia to help catch bugs */
+							memset(socket->prepared_cookie_echo,
+							       0,
+							       socket->prepared_cookie_echo_length);
+							free(socket->prepared_cookie_echo);
+							socket->prepared_cookie_echo = NULL;
+							socket->prepared_cookie_echo_length = 0;
+						}
+						socket->prepared_cookie_echo = cookie_echo;
+						socket->prepared_cookie_echo_length = chunk_length + padding_length;
+						DEBUGP("COOKIE_ECHO of length %u prepeared\n",
+						       chunk_length);
 						break;
 					}
 				}
@@ -2096,6 +2114,38 @@ static int do_outbound_script_packet(
 				DEBUGP("INIT_ACK: live.local_initial_tsn: %u\n",
 				       socket->live.local_initial_tsn);
 			}
+			if (chunk->type == SCTP_HEARTBEAT_CHUNK_TYPE) {
+				heartbeat = (struct sctp_heartbeat_chunk *)chunk;
+				chunk_length = ntohs(heartbeat->length);
+				if (chunk_length < sizeof(struct sctp_heartbeat_chunk)) {
+					asprintf(error, "HEARTBEAT chunk too short (length=%u)", chunk_length);
+					goto out;
+				}
+				value_length = chunk_length - sizeof(struct sctp_heartbeat_chunk);
+				padding_length = chunk_length % 4;
+				if (padding_length > 0) {
+					padding_length = 4 - padding_length;
+				}
+				heartbeat_ack = (struct sctp_heartbeat_ack_chunk *)malloc(chunk_length + padding_length);
+				heartbeat_ack->type = SCTP_HEARTBEAT_ACK_CHUNK_TYPE;
+				heartbeat_ack->flags = 0;
+				heartbeat_ack->length = htons(chunk_length);
+				memcpy(heartbeat_ack->value, heartbeat->value, value_length);
+				memset(heartbeat_ack->value + value_length, 0, padding_length);
+				if (socket->prepared_heartbeat_ack != NULL) {
+					 /* paranoia to help catch bugs */
+					memset(socket->prepared_heartbeat_ack,
+					       0,
+					       socket->prepared_heartbeat_ack_length);
+					free(socket->prepared_heartbeat_ack);
+					socket->prepared_heartbeat_ack = NULL;
+					socket->prepared_heartbeat_ack_length = 0;
+				}
+				socket->prepared_heartbeat_ack = heartbeat_ack;
+				socket->prepared_heartbeat_ack_length = chunk_length + padding_length;
+				DEBUGP("HEARTBEAT-ACK of length %u prepeared\n",
+				       chunk_length);
+			}
 		}
 	}
 
@@ -2142,8 +2192,7 @@ static int do_inbound_script_packet(
 	struct sctp_init_ack_chunk *init_ack;
 	struct sctp_chunk_list_item *item;
 	int result = STATUS_ERR;	/* return value */
-	u16 offset;
-	bool cooie_echo_replaced = false;
+	u16 offset = 0, temp_offset;
 	u16 i;
 
 	DEBUGP("do_inbound_script_packet\n");
@@ -2161,36 +2210,40 @@ static int do_inbound_script_packet(
 		}
 	}
 	if (packet->sctp) {
-		item = packet->chunk_list->first;
-		if ((socket->state == SOCKET_ACTIVE_INIT_SENT) &&
-		    (item != NULL) &&
-		    (item->chunk->type == SCTP_INIT_ACK_CHUNK_TYPE)) {
-			init_ack = (struct sctp_init_ack_chunk *)item->chunk;
-			DEBUGP("Moving socket in SOCKET_ACTIVE_INIT_ACK_RECEIVED\n");
-			socket->state = SOCKET_ACTIVE_INIT_ACK_RECEIVED;
-			socket->script.remote_initiate_tag = ntohl(init_ack->initiate_tag);
-			socket->script.remote_initial_tsn = ntohl(init_ack->initial_tsn);
-			socket->live.remote_initiate_tag = ntohl(init_ack->initiate_tag);
-			socket->live.remote_initial_tsn = ntohl(init_ack->initial_tsn);
-			DEBUGP("remote_initiate_tag %d, remote_initial_tsn %d\n", ntohl(init_ack->initiate_tag), ntohl(init_ack->initial_tsn));
-		}
-		if (socket->state == SOCKET_PASSIVE_INIT_ACK_SENT) {
-			for (; item != NULL; item = item->next) {
-				if (item->chunk->type == SCTP_COOKIE_ECHO_CHUNK_TYPE) {
-					offset = socket->prepared_state_cookie_length - item->length;
-					assert(packet->ip_bytes + offset <= packet->buffer_bytes);
-					memmove((u8 *)item->chunk + item->length + offset,
-					        (u8 *)item->chunk + item->length,
-					        packet_end(packet) - ((u8 *)item->chunk + item->length));
-					memcpy(item->chunk, socket->prepared_state_cookie, socket->prepared_state_cookie_length);
-					item->length = socket->prepared_state_cookie_length;
-					packet->buffer_bytes += offset;
-					packet->ip_bytes += offset;
+		for (item = packet->chunk_list->first;
+		     item != NULL;
+		     item = item->next) {
+			switch (item->chunk->type) {
+			case SCTP_INIT_ACK_CHUNK_TYPE:
+				if (socket->state == SOCKET_ACTIVE_INIT_SENT) {
+					init_ack = (struct sctp_init_ack_chunk *)item->chunk;
+					DEBUGP("Moving socket in SOCKET_ACTIVE_INIT_ACK_RECEIVED\n");
+					socket->state = SOCKET_ACTIVE_INIT_ACK_RECEIVED;
+					socket->script.remote_initiate_tag = ntohl(init_ack->initiate_tag);
+					socket->script.remote_initial_tsn = ntohl(init_ack->initial_tsn);
+					socket->live.remote_initiate_tag = ntohl(init_ack->initiate_tag);
+					socket->live.remote_initial_tsn = ntohl(init_ack->initial_tsn);
+					DEBUGP("remote_initiate_tag %d, remote_initial_tsn %d\n", ntohl(init_ack->initiate_tag), ntohl(init_ack->initial_tsn));
+				}
+				break;
+			case SCTP_COOKIE_ECHO_CHUNK_TYPE:
+				if (socket->state == SOCKET_PASSIVE_INIT_ACK_SENT) {
+					temp_offset = socket->prepared_cookie_echo_length - item->length;
+					assert(packet->ip_bytes + temp_offset <= packet->buffer_bytes);
+					memmove((u8 *)item->chunk + item->length + temp_offset,
+						(u8 *)item->chunk + item->length,
+						packet_end(packet) - ((u8 *)item->chunk + item->length));
+					memcpy(item->chunk,
+					       socket->prepared_cookie_echo,
+					       socket->prepared_cookie_echo_length);
+					item->length = socket->prepared_cookie_echo_length;
+					packet->buffer_bytes += temp_offset;
+					packet->ip_bytes += temp_offset;
 					if (packet->ipv4) {
-						packet->ipv4->tot_len = htons(ntohs(packet->ipv4->tot_len) + offset);
+						packet->ipv4->tot_len = htons(ntohs(packet->ipv4->tot_len) + temp_offset);
 					}
 					if (packet->ipv6) {
-						packet->ipv6->payload_len = htons(ntohs(packet->ipv6->payload_len) + offset);
+						packet->ipv6->payload_len = htons(ntohs(packet->ipv6->payload_len) + temp_offset);
 					}
 					for (i = 0; i < PACKET_MAX_HEADERS; i++) {
 						if ((packet->ipv4 != NULL && packet->headers[i].h.ipv4 == packet->ipv4) ||
@@ -2199,15 +2252,44 @@ static int do_inbound_script_packet(
 						}
 					}
 					assert(packet->headers[i + 1].type == HEADER_SCTP);
-					packet->headers[i].total_bytes += offset;
-					packet->headers[i + 1].total_bytes += offset;
-					cooie_echo_replaced = true;
+					packet->headers[i].total_bytes += temp_offset;
+					packet->headers[i + 1].total_bytes += temp_offset;
 					socket->state = SOCKET_PASSIVE_COOKIE_ECHO_RECEIVED;
-				} else {
-					if (cooie_echo_replaced) {
-						item->chunk = (struct sctp_chunk *)((char *)item->chunk + offset);
+					offset += temp_offset;
+				}
+				break;
+			case SCTP_HEARTBEAT_ACK_CHUNK_TYPE:
+				temp_offset = socket->prepared_heartbeat_ack_length - item->length;
+				assert(packet->ip_bytes + temp_offset <= packet->buffer_bytes);
+				memmove((u8 *)item->chunk + item->length + temp_offset,
+					(u8 *)item->chunk + item->length,
+					packet_end(packet) - ((u8 *)item->chunk + item->length));
+				memcpy(item->chunk,
+				       socket->prepared_heartbeat_ack,
+				       socket->prepared_heartbeat_ack_length);
+				item->length = socket->prepared_heartbeat_ack_length;
+				packet->buffer_bytes += temp_offset;
+				packet->ip_bytes += temp_offset;
+				if (packet->ipv4) {
+					packet->ipv4->tot_len = htons(ntohs(packet->ipv4->tot_len) + temp_offset);
+				}
+				if (packet->ipv6) {
+					packet->ipv6->payload_len = htons(ntohs(packet->ipv6->payload_len) + temp_offset);
+				}
+				for (i = 0; i < PACKET_MAX_HEADERS; i++) {
+					if ((packet->ipv4 != NULL && packet->headers[i].h.ipv4 == packet->ipv4) ||
+					    (packet->ipv6 != NULL && packet->headers[i].h.ipv6 == packet->ipv6)) {
+						break;
 					}
 				}
+				assert(packet->headers[i + 1].type == HEADER_SCTP);
+				packet->headers[i].total_bytes += temp_offset;
+				packet->headers[i + 1].total_bytes += temp_offset;
+				offset += temp_offset;
+				break;
+			default:
+				item->chunk = (struct sctp_chunk *)((char *)item->chunk + offset);
+				break;
 			}
 		}
 	}
diff --git a/gtests/net/packetdrill/sctp_packet.c b/gtests/net/packetdrill/sctp_packet.c
index 8b7817bb..21cbcfc8 100644
--- a/gtests/net/packetdrill/sctp_packet.c
+++ b/gtests/net/packetdrill/sctp_packet.c
@@ -794,6 +794,7 @@ new_sctp_packet(int address_family,
 			case SCTP_HEARTBEAT_CHUNK_TYPE:
 				break;
 			case SCTP_HEARTBEAT_ACK_CHUNK_TYPE:
+				overbook = true;
 				break;
 			case SCTP_ABORT_CHUNK_TYPE:
 				if (item->flags & FLAG_CHUNK_LENGTH_NOCHECK) {
diff --git a/gtests/net/packetdrill/socket.c b/gtests/net/packetdrill/socket.c
index 2f8f9206..87f713d3 100644
--- a/gtests/net/packetdrill/socket.c
+++ b/gtests/net/packetdrill/socket.c
@@ -40,6 +40,13 @@ struct socket *socket_new(struct state *state)
 void socket_free(struct socket *socket)
 {
 	hash_map_free(socket->ts_val_map);
-	memset(socket, 0, sizeof(*socket));  /* paranoia to help catch bugs */
+	 /* paranoia to help catch bugs */
+	memset(socket->prepared_cookie_echo, 0, socket->prepared_cookie_echo_length);
+	free(socket->prepared_cookie_echo);
+	 /* paranoia to help catch bugs */
+	memset(socket->prepared_heartbeat_ack, 0, socket->prepared_heartbeat_ack_length);
+	free(socket->prepared_heartbeat_ack);
+	 /* paranoia to help catch bugs */
+	memset(socket, 0, sizeof(*socket));
 	free(socket);
 }
diff --git a/gtests/net/packetdrill/socket.h b/gtests/net/packetdrill/socket.h
index f38515f8..96c6f6bf 100644
--- a/gtests/net/packetdrill/socket.h
+++ b/gtests/net/packetdrill/socket.h
@@ -116,8 +116,10 @@ struct socket {
 	struct tcp last_injected_tcp_header;
 	u32 last_injected_tcp_payload_len;
 
-	struct sctp_cookie_echo_chunk *prepared_state_cookie;
-	u16 prepared_state_cookie_length;
+	struct sctp_cookie_echo_chunk *prepared_cookie_echo;
+	u16 prepared_cookie_echo_length;
+	struct sctp_heartbeat_ack_chunk *prepared_heartbeat_ack;
+	u16 prepared_heartbeat_ack_length;
 
 	struct socket *next;	/* next in linked list of sockets */
 };
-- 
GitLab