From 3a4fe104a06b73fd19c550546e7c65a59ff2afe3 Mon Sep 17 00:00:00 2001
From: Guus Sliepen <guus@tinc-vpn.org>
Date: Sun, 18 Mar 2012 16:42:02 +0100
Subject: [PATCH] Add datagram mode to the SPTPS protocol.

* Everything is identical except the headers of the records.
* Instead of sending explicit message length and having an implicit sequence
  number, datagram mode has an implicit message length and an explicit sequence
  number.
* The sequence number is used to set the most significant bytes of the counter.
---
 src/protocol_auth.c |   2 +-
 src/sptps.c         | 124 +++++++++++++++++++++++++++++++++++++++++---
 src/sptps.h         |   3 +-
 src/sptps_test.c    |  13 +++--
 4 files changed, 130 insertions(+), 12 deletions(-)

diff --git a/src/protocol_auth.c b/src/protocol_auth.c
index d3eef21e..3bf18b21 100644
--- a/src/protocol_auth.c
+++ b/src/protocol_auth.c
@@ -144,7 +144,7 @@ bool id_h(connection_t *c, char *request) {
 		else
 			snprintf(label, sizeof label, "tinc TCP key expansion %s %s", c->name, myself->name);
 
-		return sptps_start(&c->sptps, c, c->outgoing, myself->connection->ecdsa, c->ecdsa, label, sizeof label, send_meta_sptps, receive_meta_sptps);
+		return sptps_start(&c->sptps, c, c->outgoing, false, myself->connection->ecdsa, c->ecdsa, label, sizeof label, send_meta_sptps, receive_meta_sptps);
 	} else {
 		return send_metakey(c);
 	}
diff --git a/src/sptps.c b/src/sptps.c
index 395c92fc..fa1594db 100644
--- a/src/sptps.c
+++ b/src/sptps.c
@@ -54,8 +54,41 @@ static bool error(sptps_t *s, int s_errno, const char *msg) {
 	return false;
 }
 
+// Send a record (datagram version, accepts all record types, handles encryption and authentication).
+static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
+	char buffer[len + 23UL];
+
+	// Create header with sequence number, length and record type
+	uint32_t seqno = htonl(s->outseqno++);
+	uint16_t netlen = htons(len);
+
+	memcpy(buffer, &netlen, 2);
+	memcpy(buffer + 2, &seqno, 4);
+	buffer[6] = type;
+
+	// Add plaintext (TODO: avoid unnecessary copy)
+	memcpy(buffer + 7, data, len);
+
+	if(s->outstate) {
+		// If first handshake has finished, encrypt and HMAC
+		cipher_set_counter(&s->outcipher, &seqno, sizeof seqno);
+		if(!cipher_counter_xor(&s->outcipher, buffer + 6, len + 1UL, buffer + 6))
+			return false;
+
+		if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
+			return false;
+
+		return s->send_data(s->handle, buffer + 2, len + 21UL);
+	} else {
+		// Otherwise send as plaintext
+		return s->send_data(s->handle, buffer + 2, len + 5UL);
+	}
+}
 // Send a record (private version, accepts all record types, handles encryption and authentication).
 static bool send_record_priv(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
+	if(s->datagram)
+		return send_record_priv_datagram(s, type, data, len);
+
 	char buffer[len + 23UL];
 
 	// Create header with sequence number, length and record type
@@ -102,6 +135,8 @@ static bool send_kex(sptps_t *s) {
 	size_t keylen = ECDH_SIZE;
 
 	// Make room for our KEX message, which we will keep around since send_sig() needs it.
+	if(s->mykex)
+		abort();
 	s->mykex = realloc(s->mykex, 1 + 32 + keylen);
 	if(!s->mykex)
 		return error(s, errno, strerror(errno));
@@ -219,6 +254,8 @@ static bool receive_kex(sptps_t *s, const char *data, uint16_t len) {
 	// Ignore version number for now.
 
 	// Make a copy of the KEX message, send_sig() and receive_sig() need it
+	if(s->hiskex)
+		abort();
 	s->hiskex = realloc(s->hiskex, len);
 	if(!s->hiskex)
 		return error(s, errno, strerror(errno));
@@ -315,7 +352,6 @@ static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) {
 			// If we already sent our secondary public ECDH key, we expect the peer to send his.
 			if(!receive_sig(s, data, len))
 				return false;
-			// s->state = SPTPS_ACK;
 			s->state = SPTPS_ACK;
 			return true;
 		case SPTPS_ACK:
@@ -331,8 +367,79 @@ static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) {
 	}
 }
 
+// Receive incoming data, datagram version.
+static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len) {
+	if(len < (s->instate ? 21 : 5))
+		return error(s, EIO, "Received short packet");
+
+	uint32_t seqno;
+	memcpy(&seqno, data, 4);
+	seqno = ntohl(seqno);
+
+	if(!s->instate) {
+		if(seqno != s->inseqno) {
+			fprintf(stderr, "Received invalid packet seqno: %d != %d\n", seqno, s->inseqno);
+			return error(s, EIO, "Invalid packet seqno");
+		}
+
+		s->inseqno = seqno + 1;
+
+		uint8_t type = data[4];
+
+		if(type != SPTPS_HANDSHAKE)
+			return error(s, EIO, "Application record received before handshake finished");
+
+		return receive_handshake(s, data + 5, len - 5);
+	}
+
+	if(seqno < s->inseqno) {
+		fprintf(stderr, "Received late or replayed packet: %d < %d\n", seqno, s->inseqno);
+		return true;
+	}
+
+	if(seqno > s->inseqno)
+		fprintf(stderr, "Missed %d packets\n", seqno - s->inseqno);
+
+	s->inseqno = seqno + 1;
+
+	uint16_t netlen = htons(len - 21);
+
+	char buffer[len + 23];
+
+	memcpy(buffer, &netlen, 2);
+	memcpy(buffer + 2, data, len);
+
+	memcpy(&seqno, buffer + 2, 4);
+
+	// Check HMAC and decrypt.
+	if(!digest_verify(&s->indigest, buffer, len - 14, buffer + len - 14))
+		return error(s, EIO, "Invalid HMAC");
+
+	cipher_set_counter(&s->incipher, &seqno, sizeof seqno);
+	if(!cipher_counter_xor(&s->incipher, buffer + 6, len - 4, buffer + 6))
+		return false;
+
+	// Append a NULL byte for safety.
+	buffer[len - 14] = 0;
+
+	uint8_t type = buffer[6];
+
+	if(type < SPTPS_HANDSHAKE) {
+		if(!s->instate)
+			return error(s, EIO, "Application record received before handshake finished");
+		if(!s->receive_record(s->handle, type, buffer + 7, len - 21))
+			return false;
+	} else {
+		return error(s, EIO, "Invalid record type");
+	}
+
+	return true;
+}
 // Receive incoming data. Check if it contains a complete record, if so, handle it.
 bool sptps_receive_data(sptps_t *s, const char *data, size_t len) {
+	if(s->datagram)
+		return sptps_receive_data_datagram(s, data, len);
+
 	while(len) {
 		// First read the 2 length bytes.
 		if(s->buflen < 6) {
@@ -422,12 +529,13 @@ bool sptps_receive_data(sptps_t *s, const char *data, size_t len) {
 }
 
 // Start a SPTPS session.
-bool sptps_start(sptps_t *s, void *handle, bool initiator, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) {
+bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) {
 	// Initialise struct sptps
 	memset(s, 0, sizeof *s);
 
 	s->handle = handle;
 	s->initiator = initiator;
+	s->datagram = datagram;
 	s->mykey = mykey;
 	s->hiskey = hiskey;
 
@@ -435,11 +543,13 @@ bool sptps_start(sptps_t *s, void *handle, bool initiator, ecdsa_t mykey, ecdsa_
 	if(!s->label)
 		return error(s, errno, strerror(errno));
 
-	s->inbuf = malloc(7);
-	if(!s->inbuf)
-		return error(s, errno, strerror(errno));
-	s->buflen = 4;
-	memset(s->inbuf, 0, 4);
+	if(!datagram) {
+		s->inbuf = malloc(7);
+		if(!s->inbuf)
+			return error(s, errno, strerror(errno));
+		s->buflen = 4;
+		memset(s->inbuf, 0, 4);
+	}
 
 	memcpy(s->label, label, labellen);
 	s->labellen = labellen;
diff --git a/src/sptps.h b/src/sptps.h
index 065c6a09..3854ec24 100644
--- a/src/sptps.h
+++ b/src/sptps.h
@@ -45,6 +45,7 @@ typedef bool (*receive_record_t)(void *handle, uint8_t type, const char *data, u
 
 typedef struct sptps {
 	bool initiator;
+	bool datagram;
 	int state;
 
 	char *inbuf;
@@ -76,7 +77,7 @@ typedef struct sptps {
 	receive_record_t receive_record;
 } sptps_t;
 
-extern bool sptps_start(sptps_t *s, void *handle, bool initiator, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record);
+extern bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record);
 extern bool sptps_stop(sptps_t *s);
 extern bool sptps_send_record(sptps_t *s, uint8_t type, const char *data, uint16_t len);
 extern bool sptps_receive_data(sptps_t *s, const char *data, size_t len);
diff --git a/src/sptps_test.c b/src/sptps_test.c
index 56dcc886..3ee7ab69 100644
--- a/src/sptps_test.c
+++ b/src/sptps_test.c
@@ -51,9 +51,16 @@ static bool receive_record(void *handle, uint8_t type, const char *data, uint16_
 
 int main(int argc, char *argv[]) {
 	bool initiator = false;
+	bool datagram = false;
 
-	if(argc < 3) {
-		fprintf(stderr, "Usage: %s my_ecdsa_key_file his_ecdsa_key_file [host] port\n", argv[0]);
+	if(argc > 1 && !strcmp(argv[1], "-d")) {
+		datagram = true;
+		argc--;
+		argv++;
+	}
+
+	if(argc < 4) {
+		fprintf(stderr, "Usage: %s [-d] my_ecdsa_key_file his_ecdsa_key_file [host] port\n", argv[0]);
 		return 1;
 	}
 
@@ -123,7 +130,7 @@ int main(int argc, char *argv[]) {
 	fprintf(stderr, "Keys loaded\n");
 
 	sptps_t s;
-	if(!sptps_start(&s, &sock, initiator, mykey, hiskey, "sptps_test", 10, send_data, receive_record))
+	if(!sptps_start(&s, &sock, initiator, datagram, mykey, hiskey, "sptps_test", 10, send_data, receive_record))
 		return 1;
 
 	while(true) {