From 4d1122bd6cbbba3a0b19215596ae3ae18f5f56f0 Mon Sep 17 00:00:00 2001
From: Neal Cardwell <ncardwell@google.com>
Date: Fri, 25 Oct 2013 16:04:54 -0400
Subject: [PATCH] net-test: packetdrill encap support: main encapsulation
 engine

The main engine for handling encapsualted packets in packetdrill,
including:

- parsing scripts with encapsulation
- parsing packets with encapsulation
- injecting packets with encapsulation

Encapsulated packets look like the following, which has an IP header,
GRE header, and then regular IP/TCP header:

+0 > ipv4 1.1.1.1 > 2.2.2.2: gre: . 1:1001(1000) ack 1

See other patches in the series for more detailed examples.


Change-Id: Ia8c7580e268f5b0347bd9d3fdc73b096f751a79d
---
 gtests/net/packetdrill/Makefile.common   |   2 +-
 gtests/net/packetdrill/icmp_packet.c     |  28 ++--
 gtests/net/packetdrill/ip.h              |   5 +
 gtests/net/packetdrill/ip_packet.c       |  70 +++++++-
 gtests/net/packetdrill/ip_packet.h       |  17 +-
 gtests/net/packetdrill/lexer.l           |   6 +-
 gtests/net/packetdrill/packet.c          | 201 ++++++++++++++++++++---
 gtests/net/packetdrill/packet.h          |  74 +++++++--
 gtests/net/packetdrill/packet_checksum.c |  10 +-
 gtests/net/packetdrill/packet_parser.c   | 147 +++++++++++++----
 gtests/net/packetdrill/parser.y          |  81 ++++++---
 gtests/net/packetdrill/run_packet.c      |   2 +-
 gtests/net/packetdrill/tcp_packet.c      |  12 +-
 gtests/net/packetdrill/udp_packet.c      |  13 +-
 14 files changed, 535 insertions(+), 133 deletions(-)

diff --git a/gtests/net/packetdrill/Makefile.common b/gtests/net/packetdrill/Makefile.common
index a61e86e5..1f2d45ab 100644
--- a/gtests/net/packetdrill/Makefile.common
+++ b/gtests/net/packetdrill/Makefile.common
@@ -18,7 +18,7 @@ packetdrill-lib := checksum.o code.o config.o hash.o hash_map.o ip_address.o \
          symbols_freebsd.o \
          symbols_openbsd.o \
          symbols_netbsd.o \
-         icmp_packet.o ip_packet.o tcp_packet.o udp_packet.o \
+         gre_packet.o icmp_packet.o ip_packet.o tcp_packet.o udp_packet.o \
          run.o run_command.o run_packet.o run_system_call.o \
          script.o socket.o system.o \
          tcp_options.o tcp_options_iterator.o tcp_options_to_string.o \
diff --git a/gtests/net/packetdrill/icmp_packet.c b/gtests/net/packetdrill/icmp_packet.c
index 84c1d806..0aec9f00 100644
--- a/gtests/net/packetdrill/icmp_packet.c
+++ b/gtests/net/packetdrill/icmp_packet.c
@@ -189,19 +189,27 @@ static int set_icmpv6_header(struct icmpv6 *icmpv6,
 }
 
 /* Populate ICMP header fields. */
-static int set_packet_icmp_header(struct packet *packet, void *icmp_header,
-				  int address_family,
+static int set_packet_icmp_header(struct packet *packet, void *icmp,
+				  int address_family, int icmp_bytes,
 				  u8 type, u8 code, s64 mtu, char **error)
 {
+	struct header *icmp_header = NULL;
+
 	if (address_family == AF_INET) {
-		struct icmpv4 *icmpv4 = (struct icmpv4 *) icmp_header;
+		struct icmpv4 *icmpv4 = (struct icmpv4 *) icmp;
 		packet->icmpv4 = icmpv4;
 		assert(packet->icmpv6 == NULL);
+		icmp_header = packet_append_header(packet, HEADER_ICMPV4,
+						   sizeof(*icmpv4));
+		icmp_header->total_bytes = icmp_bytes;
 		return set_icmpv4_header(icmpv4, type, code, mtu, error);
 	} else if (address_family == AF_INET6) {
-		struct icmpv6 *icmpv6 = (struct icmpv6 *) icmp_header;
+		struct icmpv6 *icmpv6 = (struct icmpv6 *) icmp;
 		packet->icmpv6 = icmpv6;
 		assert(packet->icmpv4 == NULL);
+		icmp_header = packet_append_header(packet, HEADER_ICMPV6,
+						   sizeof(*icmpv6));
+		icmp_header->total_bytes = icmp_bytes;
 		return set_icmpv6_header(icmpv6, type, code, mtu, error);
 	} else {
 		assert(!"bad ip_version in config");
@@ -287,7 +295,7 @@ struct packet *new_icmp_packet(int address_family,
 	 * header and the first 8 bytes after that (which will
 	 * typically have the port info needed to demux the message).
 	 */
-	const int ip_fixed_bytes = ip_header_len(address_family);
+	const int ip_fixed_bytes = ip_header_min_len(address_family);
 	const int ip_option_bytes = 0;
 	const int ip_header_bytes = ip_fixed_bytes + ip_option_bytes;
 	const int echoed_bytes = ip_fixed_bytes + ICMP_ECHO_BYTES;
@@ -313,7 +321,6 @@ struct packet *new_icmp_packet(int address_family,
 	/* Allocate and zero out a packet object of the desired size */
 	packet = packet_new(ip_bytes);
 	memset(packet->buffer, 0, ip_bytes);
-	packet->ip_bytes = ip_bytes;
 
 	packet->direction = direction;
 	packet->flags = 0;
@@ -321,13 +328,13 @@ struct packet *new_icmp_packet(int address_family,
 
 	/* Set IP header fields */
 	const enum ip_ecn_t ecn = ECN_NONE;
-	set_packet_ip_header(packet, address_family, ip_bytes, direction, ecn,
+	set_packet_ip_header(packet, address_family, ip_bytes, ecn,
 			     icmp_protocol(address_family));
 
 	/* Find the start of the ICMP header and then populate common fields. */
-	void *icmp_header = packet_start(packet) + ip_header_bytes;
+	void *icmp_header = ip_start(packet) + ip_header_bytes;
 	if (set_packet_icmp_header(packet, icmp_header, address_family,
-				   type, code, mtu, error))
+				   icmp_bytes, type, code, mtu, error))
 		goto error_out;
 
 	/* All ICMP message types currently supported by this tool
@@ -343,12 +350,13 @@ struct packet *new_icmp_packet(int address_family,
 				     layer4_header_len(protocol) +
 				     payload_bytes);
 	set_ip_header(echoed_ip, address_family, echoed_ip_bytes,
-		      reverse_direction(direction), ecn, protocol);
+		      ecn, protocol);
 	if (protocol == IPPROTO_TCP) {
 		u32 *seq = packet_echoed_tcp_seq(packet);
 		*seq = htonl(tcp_start_sequence);
 	}
 
+	packet->ip_bytes = ip_bytes;
 	return packet;
 
 error_out:
diff --git a/gtests/net/packetdrill/ip.h b/gtests/net/packetdrill/ip.h
index 650b704a..c2bd926a 100644
--- a/gtests/net/packetdrill/ip.h
+++ b/gtests/net/packetdrill/ip.h
@@ -89,6 +89,11 @@ static inline u8 ipv4_ecn_bits(const struct ipv4 *ipv4)
 	return ipv4->tos & IP_ECN_MASK;
 }
 
+static inline int ipv4_header_len(const struct ipv4 *ipv4)
+{
+	return ipv4->ihl * sizeof(u32);
+}
+
 /* IP fragmentation bit flags */
 #define IP_RF		0x8000	/* reserved fragment flag */
 #define IP_DF		0x4000	/* don't fragment flag */
diff --git a/gtests/net/packetdrill/ip_packet.c b/gtests/net/packetdrill/ip_packet.c
index 3af76912..7c0aad16 100644
--- a/gtests/net/packetdrill/ip_packet.c
+++ b/gtests/net/packetdrill/ip_packet.c
@@ -24,6 +24,7 @@
 
 #include "ip_packet.h"
 
+#include "checksum.h"
 #include "ip.h"
 #include "ipv6.h"
 
@@ -43,7 +44,6 @@ static u8 ip_ecn_bits(enum ip_ecn_t ecn)
 /* Fill in IPv4 header fields. */
 static void set_ipv4_header(struct ipv4 *ipv4,
 			    u16 ip_bytes,
-			    enum direction_t direction,
 			    enum ip_ecn_t ecn, u8 protocol)
 {
 	ipv4->version = 4;
@@ -64,7 +64,6 @@ static void set_ipv4_header(struct ipv4 *ipv4,
 /* Fill in IPv6 header fields. */
 static void set_ipv6_header(struct ipv6 *ipv6,
 			    u16 ip_bytes,
-			    enum direction_t direction,
 			    enum ip_ecn_t ecn, u8 protocol)
 {
 	ipv6->version = 6;
@@ -85,13 +84,12 @@ static void set_ipv6_header(struct ipv6 *ipv6,
 void set_ip_header(void *ip_header,
 		   int address_family,
 		   u16 ip_bytes,
-		   enum direction_t direction,
 		   enum ip_ecn_t ecn, u8 protocol)
 {
 	if (address_family == AF_INET)
-		set_ipv4_header(ip_header, ip_bytes, direction, ecn, protocol);
+		set_ipv4_header(ip_header, ip_bytes, ecn, protocol);
 	else if (address_family == AF_INET6)
-		set_ipv6_header(ip_header, ip_bytes, direction, ecn, protocol);
+		set_ipv6_header(ip_header, ip_bytes, ecn, protocol);
 	else
 		assert(!"bad ip_version in config");
 }
@@ -99,20 +97,76 @@ void set_ip_header(void *ip_header,
 void set_packet_ip_header(struct packet *packet,
 			  int address_family,
 			  u16 ip_bytes,
-			  enum direction_t direction,
 			  enum ip_ecn_t ecn, u8 protocol)
 {
+	struct header *ip_header = NULL;
+
 	if (address_family == AF_INET) {
 		struct ipv4 *ipv4 = (struct ipv4 *) packet->buffer;
 		packet->ipv4 = ipv4;
 		assert(packet->ipv6 == NULL);
-		set_ipv4_header(ipv4, ip_bytes, direction, ecn, protocol);
+		ip_header = packet_append_header(packet, HEADER_IPV4,
+						 sizeof(*ipv4));
+		ip_header->total_bytes = ip_bytes;
+		set_ipv4_header(ipv4, ip_bytes, ecn, protocol);
 	} else if (address_family == AF_INET6) {
 		struct ipv6 *ipv6 = (struct ipv6 *) packet->buffer;
 		packet->ipv6 = ipv6;
 		assert(packet->ipv4 == NULL);
-		set_ipv6_header(ipv6, ip_bytes, direction, ecn, protocol);
+		ip_header = packet_append_header(packet, HEADER_IPV6,
+						 sizeof(*ipv6));
+		ip_header->total_bytes = ip_bytes;
+		set_ipv6_header(ipv6, ip_bytes, ecn, protocol);
 	} else {
 		assert(!"bad ip_version in config");
 	}
 }
+
+int ipv4_header_append(struct packet *packet,
+		       const char *ip_src,
+		       const char *ip_dst,
+		       char **error)
+{
+	struct header *header = NULL;
+	const int ipv4_bytes = sizeof(struct ipv4);
+	struct ipv4 *ipv4 = NULL;
+
+	header = packet_append_header(packet, HEADER_IPV4, ipv4_bytes);
+	if (header == NULL) {
+		asprintf(error, "too many headers");
+		return STATUS_ERR;
+	}
+
+	ipv4 = header->h.ipv4;
+	set_ip_header(ipv4, AF_INET, 0, ECN_NONE, 0);
+
+	if (inet_pton(AF_INET, ip_src, &ipv4->src_ip) != 1) {
+		asprintf(error, "bad IPv4 src address: '%s'\n", ip_src);
+		return STATUS_ERR;
+	}
+
+	if (inet_pton(AF_INET, ip_dst, &ipv4->dst_ip) != 1) {
+		asprintf(error, "bad IPv4 dst address: '%s'\n", ip_dst);
+		return STATUS_ERR;
+	}
+
+	return STATUS_OK;
+}
+
+int ipv4_header_finish(struct packet *packet,
+		       struct header *header, struct header *next_inner)
+{
+	struct ipv4 *ipv4 = header->h.ipv4;
+	int ip_bytes = sizeof(struct ipv4) + next_inner->total_bytes;
+
+	ipv4->tot_len = htons(ip_bytes);
+	ipv4->protocol = header_type_info(next_inner->type)->ip_proto;
+
+	/* Fill in IPv4 header checksum. */
+	ipv4->check = 0;
+	ipv4->check = ipv4_checksum(ipv4, ipv4->ihl * sizeof(u32));
+
+	header->total_bytes = ip_bytes;
+
+	return STATUS_OK;
+}
diff --git a/gtests/net/packetdrill/ip_packet.h b/gtests/net/packetdrill/ip_packet.h
index 7f3e268e..ae39e578 100644
--- a/gtests/net/packetdrill/ip_packet.h
+++ b/gtests/net/packetdrill/ip_packet.h
@@ -33,14 +33,27 @@
 extern void set_ip_header(void *ip_header,
 			  int address_family,
 			  u16 ip_bytes,
-			  enum direction_t direction,
 			  enum ip_ecn_t ecn, u8 protocol);
 
 /* Set the packet's IP header pointer and then populate the IP header fields. */
 extern void set_packet_ip_header(struct packet *packet,
 				 int address_family,
 				 u16 ip_bytes,
-				 enum direction_t direction,
 				 enum ip_ecn_t ecn, u8 protocol);
 
+/* Append an IPv4 header to the end of the given packet and fill in
+ * src/dst.  On success, return STATUS_OK; on error return STATUS_ERR
+ * and fill in a malloc-allocated error message in *error.
+ */
+extern int ipv4_header_append(struct packet *packet,
+			      const char *ip_src,
+			      const char *ip_dst,
+			      char **error);
+
+/* Finalize the IPV4 header by filling in all necessary fields that
+ * were not filled in at parse time.
+ */
+extern int ipv4_header_finish(struct packet *packet,
+			      struct header *header, struct header *next_inner);
+
 #endif /* __IP_PACKET_H__ */
diff --git a/gtests/net/packetdrill/lexer.l b/gtests/net/packetdrill/lexer.l
index b4ad49ee..759f9bd0 100644
--- a/gtests/net/packetdrill/lexer.l
+++ b/gtests/net/packetdrill/lexer.l
@@ -134,7 +134,7 @@ code		\%\{(([^}])|(\}[^\%]))*\}\%
 /* A regular experssion for an IP address
  * TODO(ncardwell): IPv6
  */
-ip_addr		[0-9]+[.][0-9]+[.][0-9]+[.][0-9]+
+ipv4_addr		[0-9]+[.][0-9]+[.][0-9]+[.][0-9]+
 
 %%
 sa_family		return SA_FAMILY;
@@ -149,8 +149,10 @@ revents			return REVENTS;
 onoff			return ONOFF;
 linger			return LINGER;
 htons			return _HTONS_;
+ipv4			return IPV4;
 icmp			return ICMP;
 udp			return UDP;
+gre			return GRE;
 inet_addr		return INET_ADDR;
 ack			return ACK;
 eol			return EOL;
@@ -183,5 +185,5 @@ ce			return CE;
 {cpp_comment}		/* ignore C++-style comment */;
 {c_comment}		/* ignore C-style comment */;
 {code}			yylval.string = code(yytext);   return CODE;
-{ip_addr}		yylval.string = strdup(yytext); return IP_ADDR;
+{ipv4_addr}		yylval.string = strdup(yytext); return IPV4_ADDR;
 %%
diff --git a/gtests/net/packetdrill/packet.c b/gtests/net/packetdrill/packet.c
index d608da34..8f814650 100644
--- a/gtests/net/packetdrill/packet.c
+++ b/gtests/net/packetdrill/packet.c
@@ -28,6 +28,22 @@
 #include <assert.h>
 #include <stdlib.h>
 #include <string.h>
+#include "ethernet.h"
+#include "gre_packet.h"
+#include "ip_packet.h"
+
+
+/* Info for all types of header we support. */
+struct header_type_info header_types[HEADER_NUM_TYPES] = {
+	{ "NONE",   0,			0,		NULL },
+	{ "IPV4",   IPPROTO_IPIP,	ETHERTYPE_IP,	ipv4_header_finish },
+	{ "IPV6",   IPPROTO_IPV6,	ETHERTYPE_IPV6, NULL },
+	{ "GRE",    IPPROTO_GRE,	0,		gre_header_finish },
+	{ "TCP",    IPPROTO_TCP,	0,		NULL },
+	{ "UDP",    IPPROTO_UDP,	0,		NULL },
+	{ "ICMPV4", IPPROTO_ICMP,	0,		NULL },
+	{ "ICMPV6", IPPROTO_ICMPV6,	0,		NULL },
+};
 
 struct packet *packet_new(u32 buffer_bytes)
 {
@@ -44,40 +60,173 @@ void packet_free(struct packet *packet)
 	free(packet);
 }
 
-struct packet *packet_copy(struct packet *old_packet)
+int packet_header_count(const struct packet *packet)
+{
+	int i;
+
+	for (i = 0; i < ARRAY_SIZE(packet->headers); ++i) {
+		if (packet->headers[i].type == HEADER_NONE)
+			break;
+	}
+	return i;
+}
+
+/* Copy any header info from old_packet to new_packet. */
+static void packet_copy_headers(struct packet *new_packet,
+				struct packet *old_packet,
+				int bytes_headroom)
+{
+	int i;
+	u8 *base = new_packet->buffer + bytes_headroom;
+
+	for (i = 0; i < ARRAY_SIZE(old_packet->headers); ++i) {
+		struct header *old_header = &old_packet->headers[i];
+		struct header *new_header = &new_packet->headers[i];
+		int offset = 0;
+
+		if (old_header->type == HEADER_NONE)
+			break;
+		offset = old_header->h.ptr - old_packet->buffer;
+		new_header->h.ptr		= base + offset;
+		new_header->header_bytes	= old_header->header_bytes;
+		new_header->total_bytes		= old_header->total_bytes;
+		new_header->type		= old_header->type;
+	}
+}
+
+struct header *packet_append_header(struct packet *packet,
+				    enum header_t header_type,
+				    int header_bytes)
 {
-	int offset;
+	struct header *header = NULL;
+	int num_headers = packet_header_count(packet);
 
+	assert(num_headers <= PACKET_MAX_HEADERS);
+	if (num_headers == PACKET_MAX_HEADERS)
+		return NULL;
+
+	header = &packet->headers[num_headers];
+
+	if (packet->ip_bytes + header_bytes > packet->buffer_bytes)
+		return NULL;
+	header->h.ptr = packet->buffer + packet->ip_bytes;
+	packet->ip_bytes += header_bytes;
+
+	header->type = header_type;
+	header->header_bytes = header_bytes;
+	header->total_bytes = 0;
+	return header;
+}
+
+/* Map a pointer to a packet offset from an old base to a new base. */
+static void *offset_ptr(u8 *old_base, u8* new_base, void *old_ptr)
+{
+	u8 *old = (u8*)old_ptr;
+
+	return (old == NULL) ? NULL : (new_base + (old - old_base));
+}
+
+/* Make a copy of the given old packet, but in the new copy reserve the
+ * given number of bytes of headroom at the start of the packet->buffer.
+ * This empty headroom can later be filled with outer packet headers.
+ * A slow but simple model.
+ */
+static struct packet *packet_copy_with_headroom(struct packet *old_packet,
+						int bytes_headroom)
+{
 	/* Allocate a new packet and copy link layer header and IP datagram. */
-	const u32 bytes_used = packet_end(old_packet) - old_packet->buffer;
-	struct packet *packet = packet_new(bytes_used);
-	memcpy(packet->buffer, old_packet->buffer, bytes_used);
+	const int bytes_used = packet_end(old_packet) - old_packet->buffer;
+	assert(bytes_used >= 0);
+	assert(bytes_used <= 128*1024);
+	struct packet *packet = packet_new(bytes_headroom + bytes_used);
+	u8 *old_base = old_packet->buffer;
+	u8 *new_base = packet->buffer + bytes_headroom;
+
+	memcpy(new_base, old_base, bytes_used);
+
+	packet->ip_bytes	= old_packet->ip_bytes;
+	packet->direction	= old_packet->direction;
+	packet->time_usecs	= old_packet->time_usecs;
+	packet->flags		= old_packet->flags;
+	packet->ecn		= old_packet->ecn;
 
-	packet->ip_bytes = old_packet->ip_bytes;
+	packet_copy_headers(packet, old_packet, bytes_headroom);
 
 	/* Set up layer 3 header pointer. */
-	if (old_packet->ipv4 != NULL) {
-		offset = (u8 *) old_packet->ipv4 - old_packet->buffer;
-		packet->ipv4 = (struct ipv4 *) (packet->buffer + offset);
-	} else if (old_packet->ipv6 != NULL) {
-		offset = (u8 *) old_packet->ipv6 - old_packet->buffer;
-		packet->ipv6 = (struct ipv6 *) (packet->buffer + offset);
-	}
+	packet->ipv4	= offset_ptr(old_base, new_base, old_packet->ipv4);
+	packet->ipv6	= offset_ptr(old_base, new_base, old_packet->ipv6);
+	packet->tcp	= offset_ptr(old_base, new_base, old_packet->tcp);
+	packet->udp	= offset_ptr(old_base, new_base, old_packet->udp);
+	packet->icmpv4	= offset_ptr(old_base, new_base, old_packet->icmpv4);
+	packet->icmpv6	= offset_ptr(old_base, new_base, old_packet->icmpv6);
+
+	packet->tcp_ts_val	= offset_ptr(old_base, new_base,
+					     old_packet->tcp_ts_val);
+	packet->tcp_ts_ecr	= offset_ptr(old_base, new_base,
+					     old_packet->tcp_ts_ecr);
+
+	return packet;
+}
+
+struct packet *packet_copy(struct packet *old_packet)
+{
+	return packet_copy_with_headroom(old_packet, 0);
+}
+
+/* Finalize all the headers once we know what's inside inner layers. */
+static void packet_finish_encapsulation_headers(struct packet *packet)
+{
+	int i;
+	struct header *header = NULL, *next = NULL;
 
-	/* Set up layer 4 header pointer. */
-	if (old_packet->tcp != NULL) {
-		offset = (u8 *)old_packet->tcp - old_packet->buffer;
-		packet->tcp = (struct tcp *)(packet->buffer + offset);
-	} else if (old_packet->udp != NULL) {
-		offset = (u8 *)old_packet->udp - old_packet->buffer;
-		packet->udp = (struct udp *)(packet->buffer + offset);
-	} else if (old_packet->icmpv4 != NULL) {
-		offset = (u8 *)old_packet->icmpv4 - old_packet->buffer;
-		packet->icmpv4 = (struct icmpv4 *)(packet->buffer + offset);
-	} else if (old_packet->icmpv6 != NULL) {
-		offset = (u8 *)old_packet->icmpv6 - old_packet->buffer;
-		packet->icmpv6 = (struct icmpv6 *)(packet->buffer + offset);
+	/* Proceed from inner to outer. */
+	for (i = ARRAY_SIZE(packet->headers) - 1; i >= 0; --i, next = header) {
+		struct header_type_info *type_info = NULL;
+
+		header = &packet->headers[i];
+		if (header->type == HEADER_NONE)
+			continue;
+
+		type_info = header_type_info(header->type);
+		if (type_info->finish != NULL)
+			type_info->finish(packet, header, next);
 	}
+}
+
+struct packet *packet_encapsulate(struct packet *outer, struct packet *inner)
+{
+	struct packet *packet = NULL;
+	const int outer_headers = packet_header_count(outer);
+	const int inner_headers = packet_header_count(inner);
+
+	assert(outer_headers + inner_headers <= PACKET_MAX_HEADERS);
+
+	/* Copy the inner packet bits and header metadata. */
+	packet = packet_copy_with_headroom(inner, outer->ip_bytes);
+
+	/* Copy over the bits in the outer headers. */
+	memcpy(packet->buffer, outer->buffer, outer->ip_bytes);
+
+	/* Move the inner header metadata to make room for the outer. */
+	memmove(packet->headers + outer_headers, packet->headers + 0,
+		inner_headers * sizeof(struct header));
+
+	/* Copy over the metadata about the outer headers. */
+	packet_copy_headers(packet, outer, 0);
+
+	assert(packet_header_count(packet) == outer_headers + inner_headers);
+
+	packet_finish_encapsulation_headers(packet);
+
+	packet->ip_bytes = outer->ip_bytes + inner->ip_bytes;
 
 	return packet;
 }
+
+struct header_type_info *header_type_info(enum header_t header_type)
+{
+	assert(header_type > HEADER_NONE);
+	assert(header_type < HEADER_NUM_TYPES);
+	assert(ARRAY_SIZE(header_types) == HEADER_NUM_TYPES);
+	return &header_types[header_type];
+}
diff --git a/gtests/net/packetdrill/packet.h b/gtests/net/packetdrill/packet.h
index f6f0313d..3ecb94f4 100644
--- a/gtests/net/packetdrill/packet.h
+++ b/gtests/net/packetdrill/packet.h
@@ -30,6 +30,8 @@
 
 #include <assert.h>
 #include <sys/time.h>
+#include "gre.h"
+#include "header.h"
 #include "icmp.h"
 #include "icmpv6.h"
 #include "ip.h"
@@ -52,6 +54,12 @@
  */
 static const int PACKET_READ_BYTES = 64 * 1024;
 
+/* Maximum number of headers. */
+#define PACKET_MAX_HEADERS	6
+
+/* Maximum number of bytes of headers. */
+#define PACKET_MAX_HEADER_BYTES	256
+
 /* TCP/UDP/IPv4 packet, including IPv4 header, TCP/UDP header, and data. There
  * may also be a link layer header between the 'buffer' and 'ip'
  * pointers, but we typically ignore that. The 'buffer_bytes' field
@@ -61,12 +69,20 @@ static const int PACKET_READ_BYTES = 64 * 1024;
 struct packet {
 	u8 *buffer;		/* data buffer: full contents of packet */
 	u32 buffer_bytes;	/* bytes of space in data buffer */
-	u32 ip_bytes;		/* bytes on wire: IP, TCP/UDP header, payload */
+	u32 ip_bytes;		/* bytes on wire: outermost IP hdrs/payload */
 	enum direction_t direction;	/* direction packet is traveling */
 
-	/* The following pointers point into the 'buffer' area.
-	 * Each pointer may be NULL if there is no header of that
-	 * type present in the packet.
+	/* Metadata about all the headers in the packet, including all
+	 * layers of encapsulation, from outer to inner, starting from
+	 * the outermost IP header at headers[0].
+	 */
+	struct header headers[PACKET_MAX_HEADERS];
+
+	/* The following pointers point into the 'buffer' area. Each
+	 * pointer may be NULL if there is no header of that type
+	 * present in the packet. In each case these are pointers to
+	 * the innermost header of that kind, since that is where most
+	 * of the interesting TCP/UDP/IP action is.
 	 */
 
 	/* Layer 3 */
@@ -100,6 +116,33 @@ extern void packet_free(struct packet *packet);
 /* Create a packet that is a copy of the contents of the given packet. */
 extern struct packet *packet_copy(struct packet *old_packet);
 
+/* Return the number of headers in the given packet. */
+extern int packet_header_count(const struct packet *packet);
+
+/* Attempt to append a new header to the given packet. Return a
+ * pointer to the new header metadata, or NULL if we can't add the
+ * header.
+ */
+extern struct header *packet_append_header(struct packet *packet,
+					   enum header_t header_type,
+					   int header_bytes);
+
+/* Return a newly-allocated packet that is a copy of the given inner packet
+ * but with the given outer packet prepended.
+ */
+extern struct packet *packet_encapsulate(struct packet *outer,
+					 struct packet *inner);
+
+/* Encapsulate a packet and free the original outer and inner packets. */
+static inline struct packet *packet_encapsulate_and_free(struct packet *outer,
+							 struct packet *inner)
+{
+	struct packet *packet = packet_encapsulate(outer, inner);
+	packet_free(outer);
+	packet_free(inner);
+	return packet;
+}
+
 /* Return the direction in which the given packet is traveling. */
 static inline enum direction_t packet_direction(const struct packet *packet)
 {
@@ -118,8 +161,16 @@ static inline int packet_address_family(const struct packet *packet)
 	return AF_UNSPEC;
 }
 
-/* Return a pointer to the first byte of the IP header. */
+/* Return a pointer to the first byte of the outermost IP header. */
 static inline u8 *packet_start(struct packet *packet)
+{
+	u8 *start = packet->headers[0].h.ptr;
+	assert(start != NULL);
+	return start;
+}
+
+/* Return a pointer to the first byte of the innermost IP header. */
+static inline u8 *ip_start(struct packet *packet)
 {
 	if (packet->ipv4 != NULL)
 		return (u8 *)packet->ipv4;
@@ -133,7 +184,7 @@ static inline u8 *packet_start(struct packet *packet)
 /* Return the length in bytes of the IP header for packets of the
  * given address family, assuming no IP options.
  */
-static inline int ip_header_len(int address_family)
+static inline int ip_header_min_len(int address_family)
 {
 	if (address_family == AF_INET)
 		return sizeof(struct ipv4);
@@ -143,17 +194,6 @@ static inline int ip_header_len(int address_family)
 		assert(!"bad ip_version in config");
 }
 
-/* Return the length of the IP header (includes options for IPv4 case). */
-static inline int packet_ip_header_len(struct packet *packet)
-{
-	if (packet->ipv4 != NULL)
-		return packet->ipv4->ihl * sizeof(u32);
-	if (packet->ipv6 != NULL)
-		return sizeof(*packet->ipv6);
-	assert(!"bad address family");
-	return 0;
-}
-
 /* Return the layer4 protocol of the packet. */
 static inline int packet_ip_protocol(const struct packet *packet)
 {
diff --git a/gtests/net/packetdrill/packet_checksum.c b/gtests/net/packetdrill/packet_checksum.c
index 24550e6c..d5164b34 100644
--- a/gtests/net/packetdrill/packet_checksum.c
+++ b/gtests/net/packetdrill/packet_checksum.c
@@ -37,11 +37,11 @@ static void checksum_ipv4_packet(struct packet *packet)
 
 	/* Fill in IPv4 header checksum. */
 	ipv4->check = 0;
-	ipv4->check = ipv4_checksum(ipv4, packet_ip_header_len(packet));
-	assert(packet->ip_bytes == ntohs(ipv4->tot_len));
+	ipv4->check = ipv4_checksum(ipv4, ipv4_header_len(ipv4));
+	assert(packet->ip_bytes >= ntohs(ipv4->tot_len));
 
 	/* Find the length of layer 4 header, options, and payload. */
-	const int l4_bytes = packet->ip_bytes - packet_ip_header_len(packet);
+	const int l4_bytes = ntohs(ipv4->tot_len) - ipv4_header_len(ipv4);
 	assert(l4_bytes > 0);
 
 	/* Fill in IPv4-based layer 4 checksum. */
@@ -72,10 +72,10 @@ static void checksum_ipv6_packet(struct packet *packet)
 
 	/* IPv6 has no header checksum. */
 	/* For now we do not support IPv6 extension headers. */
-	assert(packet->ip_bytes == sizeof(*ipv6) + ntohs(ipv6->payload_len));
+	assert(packet->ip_bytes >= sizeof(*ipv6) + ntohs(ipv6->payload_len));
 
 	/* Find the length of layer 4 header, options, and payload. */
-	const int l4_bytes = packet->ip_bytes - packet_ip_header_len(packet);
+	const int l4_bytes = ntohs(ipv6->payload_len);
 	assert(l4_bytes > 0);
 
 	/* Fill in IPv6-based layer 4 checksum. */
diff --git a/gtests/net/packetdrill/packet_parser.c b/gtests/net/packetdrill/packet_parser.c
index 6e7f7fbb..02b7ae33 100644
--- a/gtests/net/packetdrill/packet_parser.c
+++ b/gtests/net/packetdrill/packet_parser.c
@@ -37,6 +37,7 @@
 
 #include "checksum.h"
 #include "ethernet.h"
+#include "gre.h"
 #include "ip.h"
 #include "ip_address.h"
 #include "logging.h"
@@ -47,9 +48,12 @@ static int parse_ipv4(struct packet *packet, u8 *header_start, u8 *packet_end,
 		      char **error);
 static int parse_ipv6(struct packet *packet, u8 *header_start, u8 *packet_end,
 		      char **error);
+static int parse_layer3_packet_by_proto(struct packet *packet,
+					u16 proto, u8 *header_start,
+					u8 *packet_end, char **error);
 static int parse_layer4(struct packet *packet, u8 *header_start,
 			int layer4_protocol, int layer4_bytes,
-			u8 *packet_end, char **error);
+			u8 *packet_end, bool *is_inner, char **error);
 
 static int parse_layer2_packet(struct packet *packet, int in_bytes,
 				       char **error)
@@ -67,7 +71,20 @@ static int parse_layer2_packet(struct packet *packet, int in_bytes,
 	ether = (struct ether_header *)p;
 	p += sizeof(*ether);
 
-	if (ntohs(ether->ether_type) == ETHERTYPE_IP) {
+	return parse_layer3_packet_by_proto(packet, ntohs(ether->ether_type),
+					    p, packet_end, error);
+
+error_out:
+	return PACKET_BAD;
+}
+
+static int parse_layer3_packet_by_proto(struct packet *packet,
+					u16 proto, u8 *header_start,
+					u8 *packet_end, char **error)
+{
+	u8 *p = header_start;
+
+	if (proto == ETHERTYPE_IP) {
 		struct ipv4 *ip = NULL;
 
 		/* Examine IPv4 header. */
@@ -86,7 +103,7 @@ static int parse_layer2_packet(struct packet *packet, int in_bytes,
 			asprintf(error, "Bad IP version for ETHERTYPE_IP");
 			goto error_out;
 		}
-	} else if (ntohs(ether->ether_type) == ETHERTYPE_IPV6) {
+	} else if (proto == ETHERTYPE_IPV6) {
 		struct ipv6 *ip = NULL;
 
 		/* Examine IPv6 header. */
@@ -146,7 +163,7 @@ int parse_packet(struct packet *packet, int in_bytes,
 	assert(in_bytes <= packet->buffer_bytes);
 	char *message = NULL;		/* human-readable error summary */
 	char *hex = NULL;		/* hex dump of bad packet */
-	enum packet_parse_result_t result;
+	enum packet_parse_result_t result = PACKET_BAD;
 
 	if (layer == PACKET_LAYER_2_ETHERNET)
 		result = parse_layer2_packet(packet, in_bytes, error);
@@ -175,13 +192,16 @@ int parse_packet(struct packet *packet, int in_bytes,
 static int parse_ipv4(struct packet *packet, u8 *header_start, u8 *packet_end,
 		      char **error)
 {
+	struct header *ip_header = NULL;
 	u8 *p = header_start;
+	const bool is_outer = (packet->ip_bytes == 0);
+	bool is_inner = false;
+	enum packet_parse_result_t result = PACKET_BAD;
+	struct ipv4 *ipv4 = (struct ipv4 *) (p);
 
-	packet->ipv4 = (struct ipv4 *) (p);
-
-	const int ip_header_bytes = packet_ip_header_len(packet);
+	const int ip_header_bytes = ipv4_header_len(ipv4);
 	assert(ip_header_bytes >= 0);
-	if (ip_header_bytes < sizeof(*packet->ipv4)) {
+	if (ip_header_bytes < sizeof(*ipv4)) {
 		asprintf(error, "IP header too short");
 		goto error_out;
 	}
@@ -189,8 +209,8 @@ static int parse_ipv4(struct packet *packet, u8 *header_start, u8 *packet_end,
 		asprintf(error, "Full IP header overflows packet");
 		goto error_out;
 	}
-	const int ip_total_bytes = ntohs(packet->ipv4->tot_len);
-	packet->ip_bytes = ip_total_bytes;
+	const int ip_total_bytes = ntohs(ipv4->tot_len);
+
 	if (p + ip_total_bytes > packet_end) {
 		asprintf(error, "IP payload overflows packet");
 		goto error_out;
@@ -199,20 +219,27 @@ static int parse_ipv4(struct packet *packet, u8 *header_start, u8 *packet_end,
 		asprintf(error, "IP header bigger than datagram");
 		goto error_out;
 	}
-	if (ntohs(packet->ipv4->frag_off) & IP_MF) {	/* more fragments? */
+	if (ntohs(ipv4->frag_off) & IP_MF) {	/* more fragments? */
 		asprintf(error, "More fragments remaining");
 		goto error_out;
 	}
-	if (ntohs(packet->ipv4->frag_off) & IP_OFFMASK) {  /* fragment offset */
+	if (ntohs(ipv4->frag_off) & IP_OFFMASK) {  /* fragment offset */
 		asprintf(error, "Non-zero fragment offset");
 		goto error_out;
 	}
-	const u16 checksum = ipv4_checksum(packet->ipv4, ip_header_bytes);
+	const u16 checksum = ipv4_checksum(ipv4, ip_header_bytes);
 	if (checksum != 0) {
 		asprintf(error, "Bad IP checksum");
 		goto error_out;
 	}
 
+	ip_header = packet_append_header(packet, HEADER_IPV4, ip_header_bytes);
+	if (ip_header == NULL) {
+		asprintf(error, "Too many nested headers at IPv4 header");
+		goto error_out;
+	}
+	ip_header->total_bytes = ip_total_bytes;
+
 	/* Move on to the header inside. */
 	p += ip_header_bytes;
 	assert(p <= packet_end);
@@ -221,17 +248,26 @@ static int parse_ipv4(struct packet *packet, u8 *header_start, u8 *packet_end,
 		char src_string[ADDR_STR_LEN];
 		char dst_string[ADDR_STR_LEN];
 		struct ip_address src_ip, dst_ip;
-		ip_from_ipv4(&packet->ipv4->src_ip, &src_ip);
-		ip_from_ipv4(&packet->ipv4->dst_ip, &dst_ip);
+		ip_from_ipv4(&ipv4->src_ip, &src_ip);
+		ip_from_ipv4(&ipv4->dst_ip, &dst_ip);
 		DEBUGP("src IP: %s\n", ip_to_string(&src_ip, src_string));
 		DEBUGP("dst IP: %s\n", ip_to_string(&dst_ip, dst_string));
 	}
 
 	/* Examine the L4 header. */
 	const int layer4_bytes = ip_total_bytes - ip_header_bytes;
-	const int layer4_protocol = packet->ipv4->protocol;
-	return parse_layer4(packet, p, layer4_protocol, layer4_bytes,
-			    packet_end, error);
+	const int layer4_protocol = ipv4->protocol;
+	result = parse_layer4(packet, p, layer4_protocol, layer4_bytes,
+			      packet_end, &is_inner, error);
+
+	/* If this is the innermost IP header then this is the primary. */
+	if (is_inner)
+		packet->ipv4 = ipv4;
+	/* If this is the outermost IP header then this is the packet length. */
+	if (is_outer)
+		packet->ip_bytes = ip_total_bytes;
+
+	return result;
 
 error_out:
 	return PACKET_BAD;
@@ -245,14 +281,15 @@ error_out:
 static int parse_ipv6(struct packet *packet, u8 *header_start, u8 *packet_end,
 		      char **error)
 {
+	struct header *ip_header = NULL;
 	u8 *p = header_start;
-
-	packet->ipv6 = (struct ipv6 *) (p);
+	const bool is_outer = (packet->ip_bytes == 0);
+	bool is_inner = false;
+	struct ipv6 *ipv6 = (struct ipv6 *) (p);
+	enum packet_parse_result_t result = PACKET_BAD;
 
 	/* Check that header fits in sniffed packet. */
-	const int ip_header_bytes = packet_ip_header_len(packet);
-	assert(ip_header_bytes >= 0);
-	assert(ip_header_bytes == sizeof(*packet->ipv6));
+	const int ip_header_bytes = sizeof(*ipv6);
 	if (p + ip_header_bytes > packet_end) {
 		asprintf(error, "IPv6 header overflows packet");
 		goto error_out;
@@ -260,14 +297,21 @@ static int parse_ipv6(struct packet *packet, u8 *header_start, u8 *packet_end,
 
 	/* Check that payload fits in sniffed packet. */
 	const int ip_total_bytes = (ip_header_bytes +
-				    ntohs(packet->ipv6->payload_len));
-	packet->ip_bytes = ip_total_bytes;
+				    ntohs(ipv6->payload_len));
+
 	if (p + ip_total_bytes > packet_end) {
 		asprintf(error, "IPv6 payload overflows packet");
 		goto error_out;
 	}
 	assert(ip_header_bytes <= ip_total_bytes);
 
+	ip_header = packet_append_header(packet, HEADER_IPV6, ip_header_bytes);
+	if (ip_header == NULL) {
+		asprintf(error, "Too many nested headers at IPv6 header");
+		goto error_out;
+	}
+	ip_header->total_bytes = ip_total_bytes;
+
 	/* Move on to the header inside. */
 	p += ip_header_bytes;
 	assert(p <= packet_end);
@@ -276,17 +320,26 @@ static int parse_ipv6(struct packet *packet, u8 *header_start, u8 *packet_end,
 		char src_string[ADDR_STR_LEN];
 		char dst_string[ADDR_STR_LEN];
 		struct ip_address src_ip, dst_ip;
-		ip_from_ipv6(&packet->ipv6->src_ip, &src_ip);
-		ip_from_ipv6(&packet->ipv6->dst_ip, &dst_ip);
+		ip_from_ipv6(&ipv6->src_ip, &src_ip);
+		ip_from_ipv6(&ipv6->dst_ip, &dst_ip);
 		DEBUGP("src IP: %s\n", ip_to_string(&src_ip, src_string));
 		DEBUGP("dst IP: %s\n", ip_to_string(&dst_ip, dst_string));
 	}
 
 	/* Examine the L4 header. */
 	const int layer4_bytes = ip_total_bytes - ip_header_bytes;
-	const int layer4_protocol = packet->ipv6->next_header;
-	return parse_layer4(packet, p, layer4_protocol, layer4_bytes,
-			    packet_end, error);
+	const int layer4_protocol = ipv6->next_header;
+	result = parse_layer4(packet, p, layer4_protocol, layer4_bytes,
+			      packet_end, &is_inner, error);
+
+	/* If this is the innermost IP header then this is the primary. */
+	if (is_inner)
+		packet->ipv6 = ipv6;
+	/* If this is the outermost IP header then this is the packet length. */
+	if (is_outer)
+		packet->ip_bytes = ip_total_bytes;
+
+	return result;
 
 error_out:
 	return PACKET_BAD;
@@ -296,6 +349,7 @@ error_out:
 static int parse_tcp(struct packet *packet, u8 *layer4_start, int layer4_bytes,
 		     u8 *packet_end, char **error)
 {
+	struct header *tcp_header = NULL;
 	u8 *p = layer4_start;
 
 	assert(layer4_bytes >= 0);
@@ -314,6 +368,13 @@ static int parse_tcp(struct packet *packet, u8 *layer4_start, int layer4_bytes,
 		goto error_out;
 	}
 
+	tcp_header = packet_append_header(packet, HEADER_TCP, tcp_header_len);
+	if (tcp_header == NULL) {
+		asprintf(error, "Too many nested headers at TCP header");
+		goto error_out;
+	}
+	tcp_header->total_bytes = layer4_bytes;
+
 	p += layer4_bytes;
 	assert(p <= packet_end);
 
@@ -329,6 +390,7 @@ error_out:
 static int parse_udp(struct packet *packet, u8 *layer4_start, int layer4_bytes,
 		     u8 *packet_end, char **error)
 {
+	struct header *udp_header = NULL;
 	u8 *p = layer4_start;
 
 	assert(layer4_bytes >= 0);
@@ -338,7 +400,8 @@ static int parse_udp(struct packet *packet, u8 *layer4_start, int layer4_bytes,
 	}
 	packet->udp = (struct udp *) p;
 	const int udp_len = ntohs(packet->udp->len);
-	if (udp_len < sizeof(struct udp)) {
+	const int udp_header_len = sizeof(struct udp);
+	if (udp_len < udp_header_len) {
 		asprintf(error, "UDP datagram length too small for UDP header");
 		goto error_out;
 	}
@@ -351,6 +414,13 @@ static int parse_udp(struct packet *packet, u8 *layer4_start, int layer4_bytes,
 		goto error_out;
 	}
 
+	udp_header = packet_append_header(packet, HEADER_UDP, udp_header_len);
+	if (udp_header == NULL) {
+		asprintf(error, "Too many nested headers at UDP header");
+		goto error_out;
+	}
+	udp_header->total_bytes = layer4_bytes;
+
 	p += layer4_bytes;
 	assert(p <= packet_end);
 
@@ -364,13 +434,22 @@ error_out:
 
 static int parse_layer4(struct packet *packet, u8 *layer4_start,
 			int layer4_protocol, int layer4_bytes,
-			u8 *packet_end, char **error)
+			u8 *packet_end, bool *is_inner, char **error)
 {
-	if (layer4_protocol == IPPROTO_TCP)
+	if (layer4_protocol == IPPROTO_TCP) {
+		*is_inner = true;	/* found inner-most layer 4 */
 		return parse_tcp(packet, layer4_start, layer4_bytes, packet_end,
 				 error);
-	else if (layer4_protocol == IPPROTO_UDP)
+	} else if (layer4_protocol == IPPROTO_UDP) {
+		*is_inner = true;	/* found inner-most layer 4 */
 		return parse_udp(packet, layer4_start, layer4_bytes, packet_end,
 				 error);
+	} else if (layer4_protocol == IPPROTO_IPIP) {
+		*is_inner = false;
+		return parse_ipv4(packet, layer4_start, packet_end, error);
+	} else if (layer4_protocol == IPPROTO_IPV6) {
+		*is_inner = false;
+		return parse_ipv6(packet, layer4_start, packet_end, error);
+	}
 	return PACKET_UNKNOWN_L4;
 }
diff --git a/gtests/net/packetdrill/parser.y b/gtests/net/packetdrill/parser.y
index 0b3c5867..636e2e01 100644
--- a/gtests/net/packetdrill/parser.y
+++ b/gtests/net/packetdrill/parser.y
@@ -90,7 +90,9 @@
 #include <sys/types.h>
 #include <sys/stat.h>
 #include <unistd.h>
+#include "gre_packet.h"
 #include "ip.h"
+#include "ip_packet.h"
 #include "icmp_packet.h"
 #include "logging.h"
 #include "tcp_packet.h"
@@ -471,11 +473,11 @@ static struct tcp_option *new_tcp_fast_open_option(const char *cookie_string,
 %token <reserved> ACK ECR EOL MSS NOP SACK SACKOK TIMESTAMP VAL WIN WSCALE PRO
 %token <reserved> FAST_OPEN
 %token <reserved> ECT0 ECT1 CE ECT01 NO_ECN
-%token <reserved> ICMP UDP MTU
+%token <reserved> IPV4 ICMP UDP GRE MTU
 %token <reserved> OPTION
 %token <floating> FLOAT
 %token <integer> INTEGER HEX_INTEGER
-%token <string> WORD STRING BACK_QUOTED CODE IP_ADDR
+%token <string> WORD STRING BACK_QUOTED CODE IPV4_ADDR
 %type <direction> direction
 %type <ip_ecn> opt_ip_info
 %type <ip_ecn> ip_ecn
@@ -483,6 +485,7 @@ static struct tcp_option *new_tcp_fast_open_option(const char *cookie_string,
 %type <event> event events event_time action
 %type <time_usecs> time opt_end_time
 %type <packet> packet_spec tcp_packet_spec udp_packet_spec icmp_packet_spec
+%type <packet> packet_prefix
 %type <syscall> syscall_spec
 %type <command> command_spec
 %type <code> code_spec
@@ -546,7 +549,8 @@ option_value
 : INTEGER	{ $$ = strdup(yytext); }
 | WORD		{ $$ = $1; }
 | STRING	{ $$ = $1; }
-| IP_ADDR	{ $$ = $1; }
+| IPV4_ADDR	{ $$ = $1; }
+| IPV4		{ $$ = strdup("ipv4"); }
 ;
 
 opt_init_command
@@ -662,58 +666,97 @@ packet_spec
 ;
 
 tcp_packet_spec
-: direction opt_ip_info flags seq opt_ack opt_window opt_tcp_options {
+: packet_prefix opt_ip_info flags seq opt_ack opt_window opt_tcp_options {
 	char *error = NULL;
+	struct packet *outer = $1, *inner = NULL;
+	enum direction_t direction = outer->direction;
 
-	if (($7 == NULL) && ($1 != DIRECTION_OUTBOUND)) {
+	if (($7 == NULL) && (direction != DIRECTION_OUTBOUND)) {
 		yylineno = @7.first_line;
 		semantic_error("<...> for TCP options can only be used with "
 			       "outbound packets");
 	}
 
-	$$ = new_tcp_packet(in_config->wire_protocol,
-			    $1, $2, $3, $4.start_sequence, $4.payload_bytes,
-	                    $5, $6, $7, &error);
+	inner = new_tcp_packet(in_config->wire_protocol,
+			       direction, $2, $3,
+			       $4.start_sequence, $4.payload_bytes,
+			       $5, $6, $7, &error);
 	free($3);
 	free($7);
-	if ($$ == NULL) {
+	if (inner == NULL) {
 		assert(error != NULL);
 		semantic_error(error);
 		free(error);
 	}
+
+	$$ = packet_encapsulate_and_free(outer, inner);
 }
 ;
 
 udp_packet_spec
-: direction UDP '(' INTEGER ')' {
+: packet_prefix UDP '(' INTEGER ')' {
 	char *error = NULL;
+	struct packet *outer = $1, *inner = NULL;
+	enum direction_t direction = outer->direction;
+
 	if (!is_valid_u16($4)) {
 		semantic_error("UDP payload size out of range");
 	}
 
-	$$ = new_udp_packet(in_config->wire_protocol, $1, $4, &error);
-	if ($$ == NULL) {
+	inner = new_udp_packet(in_config->wire_protocol, direction, $4, &error);
+	if (inner == NULL) {
 		assert(error != NULL);
 		semantic_error(error);
 		free(error);
 	}
+
+	$$ = packet_encapsulate_and_free(outer, inner);
 }
 ;
 
-////////////////////
 icmp_packet_spec
-: direction opt_icmp_echoed ICMP icmp_type opt_icmp_code opt_icmp_mtu {
+: packet_prefix opt_icmp_echoed ICMP icmp_type opt_icmp_code opt_icmp_mtu {
 	char *error = NULL;
-	$$ = new_icmp_packet(in_config->wire_protocol,
-			     $1, $4, $5,
-			     $2.protocol, $2.start_sequence, $2.payload_bytes,
-			     $6, &error);
+	struct packet *outer = $1, *inner = NULL;
+	enum direction_t direction = outer->direction;
+
+	inner = new_icmp_packet(in_config->wire_protocol, direction, $4, $5,
+				$2.protocol, $2.start_sequence,
+				$2.payload_bytes, $6, &error);
 	free($4);
 	free($5);
-	if ($$ == NULL) {
+	if (inner == NULL) {
 		semantic_error(error);
 		free(error);
 	}
+
+	$$ = packet_encapsulate_and_free(outer, inner);
+}
+;
+
+
+packet_prefix
+: direction {
+	$$ = packet_new(PACKET_MAX_HEADER_BYTES);
+	$$->direction = $1;
+}
+| packet_prefix IPV4 IPV4_ADDR '>' IPV4_ADDR ':' {
+	char *error = NULL;
+	struct packet *packet = $1;
+	char *ip_src = $3;
+	char *ip_dst = $5;
+	if (ipv4_header_append(packet, ip_src, ip_dst, &error))
+		semantic_error(error);
+	free(ip_src);
+	free(ip_dst);
+	$$ = packet;
+}
+| packet_prefix GRE ':' {
+	char *error = NULL;
+	struct packet *packet = $1;
+	if (gre_header_append(packet, &error))
+		semantic_error(error);
+	$$ = packet;
 }
 ;
 
diff --git a/gtests/net/packetdrill/run_packet.c b/gtests/net/packetdrill/run_packet.c
index 205da99b..0932db41 100644
--- a/gtests/net/packetdrill/run_packet.c
+++ b/gtests/net/packetdrill/run_packet.c
@@ -616,7 +616,7 @@ static int verify_outbound_live_checksums(struct packet *live_packet,
 	/* Verify IP header checksum. */
 	if ((live_packet->ipv4 != NULL) &&
 	    ipv4_checksum(live_packet->ipv4,
-			  packet_ip_header_len(live_packet))) {
+			  ipv4_header_len(live_packet->ipv4))) {
 		asprintf(error, "bad outbound IP checksum");
 		return STATUS_ERR;
 	}
diff --git a/gtests/net/packetdrill/tcp_packet.c b/gtests/net/packetdrill/tcp_packet.c
index 51d95997..03eeda1a 100644
--- a/gtests/net/packetdrill/tcp_packet.c
+++ b/gtests/net/packetdrill/tcp_packet.c
@@ -62,10 +62,11 @@ struct packet *new_tcp_packet(int address_family,
 			       char **error)
 {
 	struct packet *packet = NULL;  /* the newly-allocated result packet */
+	struct header *tcp_header = NULL;  /* the TCP header info */
 	/* Calculate lengths in bytes of all sections of the packet */
 	const int ip_option_bytes = 0;
 	const int tcp_option_bytes = tcp_options ? tcp_options->length : 0;
-	const int ip_header_bytes = (ip_header_len(address_family) +
+	const int ip_header_bytes = (ip_header_min_len(address_family) +
 				     ip_option_bytes);
 	const int tcp_header_bytes = sizeof(struct tcp) + tcp_option_bytes;
 	const int ip_bytes =
@@ -104,18 +105,20 @@ struct packet *new_tcp_packet(int address_family,
 	/* Allocate and zero out a packet object of the desired size */
 	packet = packet_new(ip_bytes);
 	memset(packet->buffer, 0, ip_bytes);
-	packet->ip_bytes = ip_bytes;
 
 	packet->direction = direction;
 	packet->flags = 0;
 	packet->ecn = ecn;
 
 	/* Set IP header fields */
-	set_packet_ip_header(packet, address_family, ip_bytes, direction, ecn,
+	set_packet_ip_header(packet, address_family, ip_bytes, ecn,
 			     IPPROTO_TCP);
 
+	tcp_header = packet_append_header(packet, HEADER_TCP, tcp_header_bytes);
+	tcp_header->total_bytes = tcp_header_bytes + tcp_payload_bytes;
+
 	/* Find the start of TCP sections of the packet */
-	packet->tcp = (struct tcp *) (packet_start(packet) + ip_header_bytes);
+	packet->tcp = (struct tcp *) (ip_start(packet) + ip_header_bytes);
 	u8 *tcp_option_start = (u8 *) (packet->tcp + 1);
 
 	/* Set TCP header fields */
@@ -154,5 +157,6 @@ struct packet *new_tcp_packet(int address_family,
 		       tcp_options->length);
 	}
 
+	packet->ip_bytes = ip_bytes;
 	return packet;
 }
diff --git a/gtests/net/packetdrill/udp_packet.c b/gtests/net/packetdrill/udp_packet.c
index 72283568..3efd408d 100644
--- a/gtests/net/packetdrill/udp_packet.c
+++ b/gtests/net/packetdrill/udp_packet.c
@@ -33,9 +33,10 @@ struct packet *new_udp_packet(int address_family,
 			       char **error)
 {
 	struct packet *packet = NULL;  /* the newly-allocated result packet */
+	struct header *udp_header = NULL;  /* the UDP header info */
 	/* Calculate lengths in bytes of all sections of the packet */
 	const int ip_option_bytes = 0;
-	const int ip_header_bytes = (ip_header_len(address_family) +
+	const int ip_header_bytes = (ip_header_min_len(address_family) +
 				     ip_option_bytes);
 	const int udp_header_bytes = sizeof(struct udp);
 	const int ip_bytes =
@@ -59,18 +60,21 @@ struct packet *new_udp_packet(int address_family,
 	/* Allocate and zero out a packet object of the desired size */
 	packet = packet_new(ip_bytes);
 	memset(packet->buffer, 0, ip_bytes);
-	packet->ip_bytes = ip_bytes;
 
 	packet->direction = direction;
 	packet->flags = 0;
 	packet->ecn = ECN_NONE;
 
 	/* Set IP header fields */
-	set_packet_ip_header(packet, address_family, ip_bytes, direction,
+	set_packet_ip_header(packet, address_family, ip_bytes,
 			     packet->ecn, IPPROTO_UDP);
 
+	udp_header = packet_append_header(packet, HEADER_UDP,
+					  sizeof(struct udp));
+	udp_header->total_bytes = udp_header_bytes + udp_payload_bytes;
+
 	/* Find the start of UDP section of the packet */
-	packet->udp = (struct udp *) (packet_start(packet) + ip_header_bytes);
+	packet->udp = (struct udp *) (ip_start(packet) + ip_header_bytes);
 
 	/* Set UDP header fields */
 	packet->udp->src_port	= htons(0);
@@ -78,5 +82,6 @@ struct packet *new_udp_packet(int address_family,
 	packet->udp->len	= htons(udp_header_bytes + udp_payload_bytes);
 	packet->udp->check	= 0;
 
+	packet->ip_bytes = ip_bytes;
 	return packet;
 }
-- 
GitLab