diff --git a/src/sptps.c b/src/sptps.c index 7d6293c0..3fbd8540 100644 --- a/src/sptps.c +++ b/src/sptps.c @@ -370,14 +370,74 @@ static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) { } } +static bool sptps_check_seqno(sptps_t *s, uint32_t seqno, bool update_state) { + // Replay protection using a sliding window of configurable size. + // s->inseqno is expected sequence number + // seqno is received sequence number + // s->late[] is a circular buffer, a 1 bit means a packet has not been received yet + // The circular buffer contains bits for sequence numbers from s->inseqno - s->replaywin * 8 to (but excluding) s->inseqno. + if(s->replaywin) { + if(seqno != s->inseqno) { + if(seqno >= s->inseqno + s->replaywin * 8) { + // Prevent packets that jump far ahead of the queue from causing many others to be dropped. + bool farfuture = s->farfuture < s->replaywin >> 2; + if (update_state) + s->farfuture++; + if(farfuture) + return error(s, EIO, "Packet is %d seqs in the future, dropped (%u)\n", seqno - s->inseqno, s->farfuture); + + // Unless we have seen lots of them, in which case we consider the others lost. + warning(s, "Lost %d packets\n", seqno - s->inseqno); + if (update_state) { + // Mark all packets in the replay window as being late. + memset(s->late, 255, s->replaywin); + } + } else if (seqno < s->inseqno) { + // If the sequence number is farther in the past than the bitmap goes, or if the packet was already received, drop it. + if((s->inseqno >= s->replaywin * 8 && seqno < s->inseqno - s->replaywin * 8) || !(s->late[(seqno / 8) % s->replaywin] & (1 << seqno % 8))) + return error(s, EIO, "Received late or replayed packet, seqno %d, last received %d\n", seqno, s->inseqno); + } else if (update_state) { + // We missed some packets. Mark them in the bitmap as being late. + for(int i = s->inseqno; i < seqno; i++) + s->late[(i / 8) % s->replaywin] |= 1 << i % 8; + } + } + + if (update_state) { + // Mark the current packet as not being late. + s->late[(seqno / 8) % s->replaywin] &= ~(1 << seqno % 8); + s->farfuture = 0; + } + } + + if (update_state) { + if(seqno >= s->inseqno) + s->inseqno = seqno + 1; + + if(!s->inseqno) + s->received = 0; + else + s->received++; + } + + return true; +} + // Check datagram for valid HMAC bool sptps_verify_datagram(sptps_t *s, const char *data, size_t len) { if(!s->instate || len < 21) return error(s, EIO, "Received short packet"); - // TODO: just decrypt without updating the replay window + uint32_t seqno; + memcpy(&seqno, data, 4); + seqno = ntohl(seqno); - return true; + char buffer[len]; + size_t outlen; + if(!chacha_poly1305_decrypt(s->incipher, seqno, data + 4, len - 4, buffer, &outlen)) + return false; + + return sptps_check_seqno(s, seqno, false); } // Receive incoming data, datagram version. @@ -412,45 +472,8 @@ static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len if(!chacha_poly1305_decrypt(s->incipher, seqno, data + 4, len - 4, buffer, &outlen)) return error(s, EIO, "Failed to decrypt and verify packet"); - // Replay protection using a sliding window of configurable size. - // s->inseqno is expected sequence number - // seqno is received sequence number - // s->late[] is a circular buffer, a 1 bit means a packet has not been received yet - // The circular buffer contains bits for sequence numbers from s->inseqno - s->replaywin * 8 to (but excluding) s->inseqno. - if(s->replaywin) { - if(seqno != s->inseqno) { - if(seqno >= s->inseqno + s->replaywin * 8) { - // Prevent packets that jump far ahead of the queue from causing many others to be dropped. - if(s->farfuture++ < s->replaywin >> 2) - return error(s, EIO, "Packet is %d seqs in the future, dropped (%u)\n", seqno - s->inseqno, s->farfuture); - - // Unless we have seen lots of them, in which case we consider the others lost. - warning(s, "Lost %d packets\n", seqno - s->inseqno); - // Mark all packets in the replay window as being late. - memset(s->late, 255, s->replaywin); - } else if (seqno < s->inseqno) { - // If the sequence number is farther in the past than the bitmap goes, or if the packet was already received, drop it. - if((s->inseqno >= s->replaywin * 8 && seqno < s->inseqno - s->replaywin * 8) || !(s->late[(seqno / 8) % s->replaywin] & (1 << seqno % 8))) - return error(s, EIO, "Received late or replayed packet, seqno %d, last received %d\n", seqno, s->inseqno); - } else { - // We missed some packets. Mark them in the bitmap as being late. - for(int i = s->inseqno; i < seqno; i++) - s->late[(i / 8) % s->replaywin] |= 1 << i % 8; - } - } - - // Mark the current packet as not being late. - s->late[(seqno / 8) % s->replaywin] &= ~(1 << seqno % 8); - s->farfuture = 0; - } - - if(seqno >= s->inseqno) - s->inseqno = seqno + 1; - - if(!s->inseqno) - s->received = 0; - else - s->received++; + if(!sptps_check_seqno(s, seqno, true)) + return false; // Append a NULL byte for safety. buffer[len - 20] = 0;