Connection tracking and filtering in TCP streams

Paul Rusty Russell Paul.Russell@rustcorp.com.au
Wed, 29 Sep 1999 14:30:14 +0930


In message <E11Vw3T-0004sH-00@taurus.cus.cam.ac.uk> you write:
> If I was doing this with IP chains, the obvious way to do it would be
> to REDIRECT the connection to a transparent proxy, and to use the
> bind-to-a-foreign-address hack to make the connection on the other
> side. I suspect you'll disagree, but I think the model is quite nice.
> It works well with connections with are low bandwidth and where
> latency is not an issue.

The *model* is quite nice, and in face you can use the NAT code to do
exactly the same thing.  The implementation sucked.

> > Reading generically is possible.
> 
> How does this interact with the case of the client sending a
> packet twice, but with different data?

Well, you have to keep a sliding map of data in the current window
anyway, and a pointer to the last match you're looking for.  It's like
this:

1) Checksum packet.  If bad TCP checksum, drop it.
2) If it's outside current TCP window, drop it.
3) If packet's data doesn't match what's already in the map, or drop it.
4) Place packet in map.
5) If there are `holes' in the map before this packet, drop it.
6) Search from last match pointer to first hole: for each match,
   flag it (return, call matchfn, whatever your library does) and
   update match pointer.

When you see a reply the other way which acks data, you can shift the
window map along.

Performance will suck, of course.  But note that this implementation
*almost* handles the writable case, because it keeps a map of future
packets at #4 (even though it dropped them at #5), so if we're looking
for "Rusty's bad poetry" and packet 1 ends with "Rusty's bad", we can
add rule 7: if match goes over end of the current packet, drop the
packet.

This gives:

1 (...Rusty's bad) => DROP
2 (poetry...)      => MATCH (record replacement), do it and ACCEPT
1 (...Rusty's bad) => MATCH do replacement and ACCEPT

You have to keep track of all the replacements being done within the
current window, and even this assumes that the TCP window is > 1, so
you'll get the next packet.

> > ICK.  Connection splicing is something that is THEORETICALLY possible,
> > but once again, what if the TCP options are incompatible?  
> 
> You could create a similar effect by keep two separate connections,
> with separate buffers, and send data from one connection to the other
> without ever touching user-space. That design would avoid the IP
> options/window size issues but would have similar efficiency to a
> spliced connection.

But you should be able to write a kernel module to do that now; if
khttpd exists, then this is hardly any more unholy.  You'll still be
doing a context switch though, just not to userspace.

BTW, you've wrong about the performance: doing a double copy on every
packet is *way* worse than tweaking the SYN, ACK and checksum.  Coming
from Australia, I find it hard to believe `need for speed' arguments,
anyway 8-).

> The reason why I'm strongly in favour of such a design is that I'm
> _much_ more confident that it protects the machine inside a firewall
> from external clients playing games with unusual or invalid packets.

Yes.  Transparent proxy everything.  Performance simply ISN'T THAT BAD.

> > You sound like the man to write the userspace library... think, you'll
> > be top of the scoreboard (7 for rsh extention, 7+7 bonus for the
> > library, 1 for the doco, 1 for a patch in correct form with Changelog
> > = 23 points + any bugfixes along the way).
> 
> I think you undervalue documentation :-).

True.  But I have to set my priorities somewhere, and at the moment
it's not on the doco.

> > #include "netfilter_pjblib.h"
> [snip]
> 
> I'll keep reading your code until I can understand the design -
> there's a lot of concepts to get one's head around. FWIW, so far I've
> been quite impressed.

Have you looked at the netfilter-hacking-HOWTO?  If that is
insufficient, please advise.

Rusty.
PS. Dug up old unifinished libpktmatch.c code.  Enjoy:
================
#include <netinet/ip.h>
#include <netinet/tcp.h>

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "libpktmatch.h"

/* I offer the following code as proof that mangling data inside TCP
   packets is a unholy thing to do. */

/* Shrinking packets is probably OK.  Growing them is a pain.  We're
 * not supposed to make them larger than MSS (that's the RULE).  We're
 * not supposed to make them larger than smallest MTU if DF is set
 * (ie. they're doing Path MTU discovery); imagine their confusion if
 * an ICMP Fragment Needed packet comes back telling them that their
 * packet WASN'T bigger than MTU...
 *
 * So I don't grow packets, and hope the retransmit code can handle
 * it.
 */

#ifdef DEBUG
#define DEBUGP(format, args...)  \
      fprintf (stderr, format , ## args)
#else
#define DEBUGP(format, args...)
#endif

/* Linked list of modifications. */
struct pkm_seq_modifications
{
	struct pkm_seq_modifications *next;

	/* Cumulative seq number offset up to this point. */
	u_int32_t seq_off;

	/* This is the exclusive sequence range in the original packet. */
	u_int32_t modification_start;
	u_int32_t modification_end;

	/* This is the length of the replacement. */
	size_t replacement_length;
	/* Replacement hangs off end. */
	unsigned char replacement[0];
};

enum initstate
{
	PKM_UNINITIALIZED,
	PKM_OUT_INITIALIZED,
	PKM_REPLY_INITIALIZED,
	PKM_INITIALIZED
};

struct pkm_internal 
{
	enum initstate state;

	u_int32_t seq_start_of_data;

	size_t maximum_record;
	size_t match_offset;
	unsigned char *data;
	unsigned char *mask;

	u_int32_t seq_off;
	struct pkm_seq_modifications *mods;
};

#define NEW(type, contents...) 				\
({ 							\
	typeof (type *) _n = malloc(sizeof(type)); 	\
	if (_n) *_n = ( (type) ## contents);		\
	_n;						\
})

static u_int16_t 
csum_partial(void *buffer, unsigned int len, u_int16_t prevsum)
{
	u_int32_t sum = 0;
	u_int16_t *ptr = buffer;

	while (len > 1)  {
		sum += *ptr++;
		len -= 2;
	}
	if (len) {
		union {
			u_int8_t byte;
			u_int16_t wyde;
		} odd;
		odd.wyde = 0;
		odd.byte = *((u_int8_t *)ptr);
		sum += odd.wyde;
	}
	sum = (sum >> 16) + (sum & 0xFFFF);
	sum += prevsum;
	return (sum + (sum >> 16));
}

#define seq_after(a,b)		((long)(b) - (long)(a) < 0)
#define seq_before(a,b)		seq_after(b,a)

#define seq_after_eq(a,b)	((long)(a) - (long)(b) >= 0)
#define seq_before_eq(a,b)	seq_after_eq(b,a)

/* FIXME: Compare data changes. --RR */
static int
record_packet(const unsigned char *data,
	      size_t data_length,
	      u_int32_t seq,
	      struct pkm_internal *info)
{
	/* Slot it in: we always record unmodified data stream. */

	/* Outside window?  Drop it. */
	if (seq_before(seq, info->seq_start_of_data)
	    || seq_after_eq(seq + data_length,
			    info->seq_start_of_data + info->maximum_record)) {
		fprintf(stderr, "TCP %u (len %u) out of range %u-%u.\n",
			seq, data_length,
			info->seq_start_of_data,
			info->seq_start_of_data + info->maximum_record);
		return 0;
	}

	memcpy(info->data + (seq - info->seq_start_of_data),
	       data, data_length);
	memset(info->mask + (seq - info->seq_start_of_data), 0xFF, 
	       data_length);

	return 1;
}

static unsigned char *
does_match(const unsigned char *data, 
	   const unsigned char *mask, 
	   const struct pkm_pattern *pat,
	   size_t length,
	   size_t *match_len,
	   int *is_complete)
{
	size_t i;
	unsigned char *next;

	if (!pat) {
		*match_len = 0;
		*is_complete = 1;
		return (unsigned char *)data;
	}
	for (i = 0; i < pat->minimum_repeats; i++) {
		size_t j;

		if (mask[i]) {
			for (j = 0; j < pat->num_possibilities; j++) {
				if (data[i] == pat->possibilities[j])
					break;
			}
			if (j == pat->num_possibilities) {
				*is_complete = 0;
				return NULL;
			}
		}
	}

	/* It matched the minimum. Each character (up to maximum) can be
	   this one or the next */
	for (; i < pat->maximum_repeats; i++) {
		if (mask[i]) {
			size_t j;

			for (j = 0; j < pat->num_possibilities; j++) {
				if (data[i] == pat->possibilities[j])
					break;
			}
			if (j == pat->num_possibilities)
				break;
		}
	}

	/* We're out.  Try next pattern. */
	next = does_match(data + i, mask + i, pat->next,
			  length - i, match_len, is_complete);
	if (next) {
		*match_len += i;
		next -= i;
	}
	return next;
}

/* Returns pointer to first match.  Fills in match_len, and is_complete */
static const char *
find_first_match(const unsigned char *data, 
		 const unsigned char *mask, 
		 const struct pkm_pattern *pattern,
		 size_t length,
		 size_t *match_len,
		 int *is_complete)
{
	size_t i;
	const char *match;

	for (i = 0; i < length; i++) {
		if ((match = does_match(data + i, mask + i, pattern, 
					length - i, match_len, is_complete))
		    != NULL) {
			/* We have a match; if it claims completeness,
                           check for holes. */
			if (*is_complete
			    && memchr(mask + i, 0, *match_len))
				*is_complete = 0;
			return data + i;
		}
	}

	*is_complete = 0;
	return NULL;
}

static void
move_buffer_down(struct pkm_internal *info, size_t offcut)
{
	memmove(info->data + offcut,
		info->data,
		info->maximum_record - offcut);

	memmove(info->mask + offcut,
		info->mask,
		info->maximum_record - offcut);

	memset(info->mask + info->maximum_record - offcut,
	       0, offcut);

	info->seq_start_of_data += offcut;
}

void recalc_tcp_checksum(struct iphdr *iph)
{
	struct tcphdr *tcph = (struct tcphdr *)((char *)iph + iph->ihl * 4);
	struct {
		u_int32_t srcip, dstip;
		u_int8_t mbz, proto;
		u_int16_t tcplen;
	} pseudo_header = { iph->saddr, iph->daddr, 0, 
			    IPPROTO_TCP, 
			    ntohs(iph->tot_len) - iph->ihl * 4 };

	tcph->check = 0;
	tcph->check = csum_partial(tcph, ntohs(iph->tot_len) - iph->ihl * 4, 
				   csum_partial(&pseudo_header,
						sizeof(pseudo_header),
						0));
}

void recalc_ip_checksum(struct iphdr *iph)
{
	iph->check = 0;
	iph->check = csum_partial(iph, iph->ihl * 4, 0);
}

static size_t 
remove_option(struct iphdr *iph, unsigned char *opt, size_t len)
{
	size_t packet_length = ntohs(iph->tot_len) - len;

	iph->tot_len = htons(packet_length - len);

	memmove(opt, opt+len, packet_length - len - iph->ihl * 4);
	recalc_tcp_checksum(iph);
	recalc_ip_checksum(iph);
	return packet_length;
}

/* We simply take the window scale sack permitted options out of the
   first SYN. */
/* Returns new packet size. */
static size_t parse_syn_options(struct iphdr *iph, 
				size_t packet_length)
{
	struct tcphdr *tcph = (struct tcphdr *)((char *)iph + iph->ihl * 4);
	unsigned char *opt = (unsigned char *)(tcph + 1);
	unsigned int i;

	for (i = 0; i < tcph->doff; ) {
		switch (opt[i]) {
		case TCPOPT_EOL:
		case TCPOPT_NOP:
			i++;
			break;

		case TCPOPT_SACK_PERMITTED:
			if (opt[i+1] != TCPOLEN_SACK_PERMITTED) {
				fprintf(stderr, "Bad SACK_PERM len: %u\n",
					opt[i+1]);
				return 0;
			}
			packet_length = remove_option(iph, opt+i, 
						      TCPOLEN_SACK_PERMITTED);
			break;

		case TCPOPT_WINDOW:
			if (opt[i+1] != TCPOLEN_WINDOW) {
				fprintf(stderr, "Bad WINDOW len: %u\n",
					opt[i+1]);
				return 0;
			}
			packet_length = remove_option(iph, opt+i, 
						      TCPOLEN_WINDOW);
			break;

		default:
			i += opt[i+1] ?: 1;
		}
	}
	return packet_length;
}

static void record_ack(struct iphdr *iph, 
		       struct tcphdr *tcph, 
		       struct pkm_internal *info)
{
	struct pkm_seq_modifications **i;
	u_int32_t seq_off = info->seq_off;

	if (seq_before(ntohl(tcph->ack_seq), info->seq_start_of_data)
	    || seq_after_eq(ntohl(tcph->ack_seq), 
			    info->seq_start_of_data + info->maximum_record)) {
		fprintf(stderr, "Ack %u out of range %u-%u.\n",
			ntohl(tcph->ack_seq), 
			info->seq_start_of_data,
			info->seq_start_of_data + info->maximum_record);
		return;
	}

	/* Move packet map. */
	move_buffer_down(info, ntohl(tcph->ack_seq) - info->seq_start_of_data);

	for (i = &info->mods; *i; i = &(*i)->next) {
		seq_off = (*i)->seq_off;

		/* Replacements preceeding this ack, we can drop. */
		if (seq_after(ntohl(tcph->ack_seq), (*i)->modification_end)) {
			struct pkm_seq_modifications *freeme = *i;
			*i = freeme->next;
			info->seq_off = freeme->seq_off;
			free(freeme);
		}
		if (seq_before(ntohl(tcph->ack_seq), (*i)->modification_start))
			break;
	}

	tcph->ack_seq = htonl(ntohl(tcph->ack_seq) + seq_off);
	recalc_tcp_checksum(iph);
}

static void
record_replacement(u_int32_t match_offset,
		   size_t match_length,
		   struct pkm_internal *info,
		   unsigned char *(*replace)(unsigned char *pattern_start,
					     size_t pattern_length,
					     size_t *replacement_length,
					     void *info),
		   void *replacement_info)
{
	struct pkm_seq_modifications **i;
	u_int32_t seq_off = info->seq_off;
	size_t replacement_length;
	unsigned char *replacement;

	for (i = &info->mods; *i; i = &(*i)->next) {
		seq_off = (*i)->seq_off 
			+ (*i)->replacement_length
			- ((*i)->modification_end - (*i)->modification_start);
	}

	replacement = replace(info->data + match_offset, 
			      match_length,
			      &replacement_length,
			      replacement_info);
	*i = malloc(sizeof(struct pkm_seq_modifications) + replacement_length);
	**i = ((struct pkm_seq_modifications) 
	       { NULL, seq_off, info->seq_start_of_data + match_offset,
		 info->seq_start_of_data + match_offset + match_length,
		 replacement_length, { } });
	memcpy((*i)->replacement, replacement, replacement_length);
}

#define seq_between(seq, min, max) \
(seq_after_eq((seq), (min)) && seq_before((seq), (max)))

#define MAX(a,b) ((a)>(b)?(a):(b))

/* Iterate through, alter packet */
static void
do_replacements(struct iphdr *iph, struct tcphdr *tcph, 
		struct pkm_internal *info)
{
	u_int32_t seq_off = info->seq_off;
	struct pkm_seq_modifications *i;
	u_int32_t seq = ntohl(tcph->seq);
	u_int32_t len = ntohs(iph->tot_len) - iph->ihl * 4 - tcph->doff * 4;

	for (i = info->mods; i; i = i->next) {
		/* If this mod preceeds our packet, count seq offset. */
		if (seq_before(seq, i->modification_start))
			seq_off = i->seq_off;

		/* Do we overlap this mod? */
		if (seq_between(seq, i->modification_start,
				i->modification_end)
		    || seq_between(seq + len, i->modification_start,
				   i->modification_end)) {
			u_int32_t pos;
			size_t match_len, repl_len;

			/* Note: we never grow packet, only shrink */
			pos = i->modification_start - seq;
			match_len = i->modification_end-i->modification_start;
			repl_len = i->replacement_length;

			/* Match past end of packet? */
			if (pos + match_len > len)
				match_len = len - pos;

			/* Replace past end of packet? */
			if (pos + repl_len > len) 
				repl_len = len - pos;

			/* Move trailing parts of packet. */
			memmove((char *)tcph + tcph->doff*4 + pos + repl_len,
				(char *)tcph + tcph->doff*4 + pos + match_len,
				len - pos - MAX(repl_len, match_len));

			/* Shrink packet if neccessary. */
			if (repl_len < match_len)
				len -= (match_len - repl_len);
		}
	}

	tcph->seq = htonl(seq + seq_off);
	iph->tot_len = htons(len + iph->ihl * 4 + tcph->doff * 4);

	recalc_ip_checksum(iph);
	recalc_tcp_checksum(iph);
}

/* The main loop: returns 0 on fail, 1 on successful completion. */
/* pattern: data pattern to look for.
   get_packet: function to get packet, returns -1 on error or pkt len.
   get_packet_info: last arg to get_packet.
   replacement: function to get replacement for given pattern; returns
	pointer to buffer and sets replacement_length.
   replacement_info: last arg to replacement.
   reinject_packet: function to let packet continue, returns 1 on success.
   reinject_info: last arg to reinject_packet.
*/
int pkm_loop(const struct pkm_pattern *pattern,
	     int (*get_packet)(struct iphdr *pkt, void *info, int *dir,
			       int *local),
	     void *get_packet_info,
	     unsigned char *(*replacement)(unsigned char *pattern_start,
					   size_t pattern_length,
					   size_t *replacement_length,
					   void *info),
	     void *replacement_info,
	     int (*reinject_packet)(struct iphdr *pkt, size_t len, void *info),
	     void *reinject_info,
	     int (*discard_packet)(struct iphdr *pkt, size_t len, void *info),
	     void *discard_info)
{
	/* FIXME: Check pattern here. --RR */
	int packet_length;
	unsigned char packet[65536];
	struct iphdr *iph = (struct iphdr *)packet;
	struct pkm_internal info;
	int dir, local;

	info.state = PKM_UNINITIALIZED;

	while ((packet_length = get_packet(iph, get_packet_info, &dir, &local))
	       > 0) {
		int is_complete;
		const unsigned char *match;
		size_t match_length;
		struct tcphdr *tcph;

		/* Check packet: is it malformed? */
		if (packet_length < sizeof(struct iphdr)
		    || packet_length < iph->ihl * 4 + sizeof(struct tcphdr)) {
			DEBUGP("Packet too short (%u): discarding\n",
			       packet_length);
			goto discard_packet;
		}

		tcph = (struct tcphdr *)(packet + iph->ihl * 4);
		if (packet_length < iph->ihl * 4 + tcph->doff * 4) {
			DEBUGP("TCP Packet too short (%u): discarding\n",
			       packet_length);
			goto discard_packet;
		}

		/* IP Checksum must be OK. */
		if (!local) {
			DEBUGP("IP checksum = %u\n", iph->check);
			if (csum_partial(iph, iph->ihl * 4, 0) == 0xFFFF) {
				/* TCP Checksum must be OK. */
				struct {
					u_int32_t srcip, dstip;
					u_int8_t mbz, proto;
					u_int16_t tcplen;
				} pseudo_header = { iph->saddr, iph->daddr, 0, 
						    IPPROTO_TCP, 
						    htons(packet_length 
							  - iph->ihl * 4)
				};
				if (csum_partial(tcph, 
						 packet_length - iph->ihl * 4, 
						 csum_partial(&pseudo_header,
							      sizeof(pseudo_header),
							      0)) != 0xFFFF) {
					DEBUGP("TCP checksum failed\n");
					goto discard_packet;
				}
			} else {
				DEBUGP("IP checksum failed: (%u) discarding\n",
				       csum_partial(iph, iph->ihl*4, 0));
				goto discard_packet;
			}
		}

		/* Remove options from syns. */
		if (tcph->syn) {
			packet_length = parse_syn_options(iph, 
							  packet_length);
			if (!packet_length) {
				DEBUGP("Parsing SYN opts failed\n");
				goto discard_packet;
			}
		}

		if (info.state != PKM_INITIALIZED) {
			if (!tcph->syn
			    || packet_length != (iph->ihl + tcph->doff) * 4) {
				DEBUGP("Non-SYN packet on uninit conn\n");
				goto discard_packet;
			}

			if (dir) {
				info.seq_start_of_data = ntohl(tcph->seq)+1;
				if (info.state == PKM_REPLY_INITIALIZED)
					info.state = PKM_INITIALIZED;
				else
					info.state = PKM_OUT_INITIALIZED;
			} else {
				/* Reply SYN gives window size. */
				info.maximum_record = ntohs(tcph->window);
				info.match_offset = 0;
				info.data = malloc(info.maximum_record * 2);
				if (!info.data)
					break;
				info.mask = info.data + info.maximum_record;
				info.seq_off = 0;
				info.mods = NULL;
				if (info.state == PKM_OUT_INITIALIZED)
					info.state = PKM_INITIALIZED;
				else
					info.state = PKM_REPLY_INITIALIZED;
			}
			if (!reinject_packet(iph, packet_length,
					     reinject_info))
				break;
			continue;
		}

		/* Do the ack pruning thing. */
		if (!dir) {
			record_ack(iph, tcph, &info);
			if (!reinject_packet(iph, packet_length,
					     reinject_info))
				break;
			continue;
		}

		if (!record_packet((unsigned char *)tcph + tcph->doff * 4,
				   packet_length - (iph->ihl + tcph->doff) * 4,
				   ntohl(tcph->seq),
				   &info)) {
			DEBUGP("Recording packet failed\n");
			goto discard_packet;
		}

		while ((match = find_first_match(info.data + info.match_offset,
						 info.mask + info.match_offset,
						 pattern,
						 info.maximum_record
						 - info.match_offset,
						 &match_length,
						 &is_complete)) != NULL) {
			/* Woohoo.  A match. */
			info.match_offset = match - info.data;

			/* If it's complete, do the match thing. */
			if (is_complete) {
				record_replacement(info.match_offset,
						   match_length,
						   &info,
						   replacement,
						   replacement_info);
				info.match_offset += match_length;
			}
			else
				break;
		}

		/* Packet must preceed match_offset to be passed. */
		if (!seq_before(ntohl(tcph->seq) 
				+ packet_length - (iph->ihl + tcph->doff) * 4,
				info.seq_start_of_data + info.match_offset)) {
			DEBUGP("Packet does not preceed match_offset\n");
			goto discard_packet;
		}

		do_replacements(iph, tcph, &info);
		if (!reinject_packet(iph, packet_length, reinject_info))
			break;

		continue;

		/* FIXME: End loop nicely if we see ACK for FIN, or RST */

		discard_packet:
		discard_packet(iph, packet_length, discard_info);
	}
	return 0;
}
--
Hacking time.