diff --git a/src/meta.c b/src/meta.c index 1c29fe9c..0849d3cd 100644 --- a/src/meta.c +++ b/src/meta.c @@ -159,8 +159,14 @@ bool receive_meta(connection_t *c) { } do { - if(c->protocol_minor >= 2) - return sptps_receive_data(&c->sptps, bufp, inlen); + if(c->protocol_minor >= 2) { + int len = sptps_receive_data(&c->sptps, bufp, inlen); + if(!len) + return false; + bufp += len; + inlen -= len; + continue; + } if(!c->status.decryptin) { endp = memchr(bufp, '\n', inlen); diff --git a/src/sptps.c b/src/sptps.c index 4a9683f2..e5946ed6 100644 --- a/src/sptps.c +++ b/src/sptps.c @@ -495,90 +495,92 @@ static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len } // Receive incoming data. Check if it contains a complete record, if so, handle it. -bool sptps_receive_data(sptps_t *s, const void *data, size_t len) { +size_t sptps_receive_data(sptps_t *s, const void *data, size_t len) { + size_t total_read = 0; + if(!s->state) return error(s, EIO, "Invalid session state zero"); if(s->datagram) - return sptps_receive_data_datagram(s, data, len); + return sptps_receive_data_datagram(s, data, len) ? len : false; - while(len) { - // First read the 2 length bytes. - if(s->buflen < 2) { - size_t toread = 2 - s->buflen; - if(toread > len) - toread = len; - - memcpy(s->inbuf + s->buflen, data, toread); - - s->buflen += toread; - len -= toread; - data += toread; - - // Exit early if we don't have the full length. - if(s->buflen < 2) - return true; - - // Get the length bytes - - memcpy(&s->reclen, s->inbuf, 2); - s->reclen = ntohs(s->reclen); - - // If we have the length bytes, ensure our buffer can hold the whole request. - s->inbuf = realloc(s->inbuf, s->reclen + 19UL); - if(!s->inbuf) - return error(s, errno, strerror(errno)); - - // Exit early if we have no more data to process. - if(!len) - return true; - } - - // Read up to the end of the record. - size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->buflen; + // First read the 2 length bytes. + if(s->buflen < 2) { + size_t toread = 2 - s->buflen; if(toread > len) toread = len; memcpy(s->inbuf + s->buflen, data, toread); + + total_read += toread; s->buflen += toread; len -= toread; data += toread; - // If we don't have a whole record, exit. - if(s->buflen < s->reclen + (s->instate ? 19UL : 3UL)) - return true; + // Exit early if we don't have the full length. + if(s->buflen < 2) + return total_read; - // Update sequence number. + // Get the length bytes - uint32_t seqno = s->inseqno++; + memcpy(&s->reclen, s->inbuf, 2); + s->reclen = ntohs(s->reclen); - // Check HMAC and decrypt. - if(s->instate) { - if(!chacha_poly1305_decrypt(s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL)) - return error(s, EINVAL, "Failed to decrypt and verify record"); - } + // If we have the length bytes, ensure our buffer can hold the whole request. + s->inbuf = realloc(s->inbuf, s->reclen + 19UL); + if(!s->inbuf) + return error(s, errno, strerror(errno)); - // Append a NULL byte for safety. - s->inbuf[s->reclen + 3UL] = 0; - - uint8_t type = s->inbuf[2]; - - if(type < SPTPS_HANDSHAKE) { - if(!s->instate) - return error(s, EIO, "Application record received before handshake finished"); - if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen)) - return false; - } else if(type == SPTPS_HANDSHAKE) { - if(!receive_handshake(s, s->inbuf + 3, s->reclen)) - return false; - } else { - return error(s, EIO, "Invalid record type %d", type); - } - - s->buflen = 0; + // Exit early if we have no more data to process. + if(!len) + return total_read; } - return true; + // Read up to the end of the record. + size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->buflen; + if(toread > len) + toread = len; + + memcpy(s->inbuf + s->buflen, data, toread); + total_read += toread; + s->buflen += toread; + len -= toread; + data += toread; + + // If we don't have a whole record, exit. + if(s->buflen < s->reclen + (s->instate ? 19UL : 3UL)) + return total_read; + + // Update sequence number. + + uint32_t seqno = s->inseqno++; + + // Check HMAC and decrypt. + if(s->instate) { + if(!chacha_poly1305_decrypt(s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL)) + return error(s, EINVAL, "Failed to decrypt and verify record"); + } + + // Append a NULL byte for safety. + s->inbuf[s->reclen + 3UL] = 0; + + uint8_t type = s->inbuf[2]; + + if(type < SPTPS_HANDSHAKE) { + if(!s->instate) + return error(s, EIO, "Application record received before handshake finished"); + if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen)) + return false; + } else if(type == SPTPS_HANDSHAKE) { + if(!receive_handshake(s, s->inbuf + 3, s->reclen)) + return false; + } else { + return error(s, EIO, "Invalid record type %d", type); + } + + s->buflen = 0; + + return total_read; } // Start a SPTPS session. diff --git a/src/sptps.h b/src/sptps.h index a2633bd1..75a95651 100644 --- a/src/sptps.h +++ b/src/sptps.h @@ -88,7 +88,7 @@ extern void (*sptps_log)(sptps_t *s, int s_errno, const char *format, va_list ap extern bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t *mykey, ecdsa_t *hiskey, const void *label, size_t labellen, send_data_t send_data, receive_record_t receive_record); extern bool sptps_stop(sptps_t *s); extern bool sptps_send_record(sptps_t *s, uint8_t type, const void *data, uint16_t len); -extern bool sptps_receive_data(sptps_t *s, const void *data, size_t len); +extern size_t sptps_receive_data(sptps_t *s, const void *data, size_t len); extern bool sptps_force_kex(sptps_t *s); extern bool sptps_verify_datagram(sptps_t *s, const void *data, size_t len);