From 2fddd3765ea998bb9f40b52dc1baaf843b9889bf Mon Sep 17 00:00:00 2001
From: Hanno Becker <hanno.becker@arm.com>
Date: Wed, 10 Jul 2019 14:37:41 +0100
Subject: [PATCH] Check same-port-reconnect from client outside of record hdr
 parsing

Previously, `ssl_handle_possible_reconnect()` was part of
`ssl_parse_record_header()`, which was required to return a non-zero error
code to indicate a record which should not be further processed because it
was invalid, unexpected, duplicate, .... In this case, some error codes
would lead to some actions to be taken, e.g. `MBEDTLS_ERR_SSL_EARLY_MESSAGE`
to potential buffering of the record, but eventually, the record would be
dropped regardless of the precise value of the error code. The error code
`MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED` returned from
`ssl_handle_possible_reconnect()` did not receive any special treatment and
lead to silent dopping of the record - in particular, it was never returned
to the user.

In the new logic this commit introduces, `ssl_handle_possible_reconnect()` is
part of `ssl_check_client_reconnect()` which is triggered _after_
`ssl_parse_record_header()` found an unexpected record, which is already in
the code-path eventually dropping the record; we want to leave this code-path
only if a valid cookie has been found and we want to reset, but do nothing
otherwise. That's why `ssl_handle_possible_reconnect()` now returns `0` unless
a valid cookie has been found or a fatal error occurred.
---
 library/ssl_tls.c | 101 ++++++++++++++++++++++++++--------------------
 1 file changed, 57 insertions(+), 44 deletions(-)

diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index e5881da74..204fa43e4 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -4691,9 +4691,6 @@ static int ssl_check_dtls_clihlo_cookie(
     size_t sid_len, cookie_len;
     unsigned char *p;
 
-    if( f_cookie_write == NULL || f_cookie_check == NULL )
-        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
-
     /*
      * Structure of ClientHello with record and handshake headers,
      * and expected values. We don't need to check a lot, more checks will be
@@ -4819,6 +4816,14 @@ static int ssl_handle_possible_reconnect( mbedtls_ssl_context *ssl )
     int ret;
     size_t len;
 
+    if( ssl->conf->f_cookie_write == NULL ||
+        ssl->conf->f_cookie_check == NULL )
+    {
+        /* If we can't use cookies to verify reachability of the peer,
+         * drop the record. */
+        return( 0 );
+    }
+
     ret = ssl_check_dtls_clihlo_cookie(
             ssl->conf->f_cookie_write,
             ssl->conf->f_cookie_check,
@@ -4835,8 +4840,7 @@ static int ssl_handle_possible_reconnect( mbedtls_ssl_context *ssl )
          * If the error is permanent we'll catch it later,
          * if it's not, then hopefully it'll work next time. */
         (void) ssl->f_send( ssl->p_bio, ssl->out_buf, len );
-
-        return( MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED );
+        ret = 0;
     }
 
     if( ret == 0 )
@@ -4991,49 +4995,22 @@ static int ssl_parse_record_header( mbedtls_ssl_context *ssl )
     {
         unsigned int rec_epoch = ( ssl->in_ctr[0] << 8 ) | ssl->in_ctr[1];
 
-        /* Check epoch (and sequence number) with DTLS */
-        if( rec_epoch != ssl->in_epoch )
+        if( rec_epoch == (unsigned) ssl->in_epoch + 1 )
+        {
+            /* Consider buffering the record. */
+            MBEDTLS_SSL_DEBUG_MSG( 2, ( "Consider record for buffering" ) );
+            return( MBEDTLS_ERR_SSL_EARLY_MESSAGE );
+        }
+        else if( rec_epoch != ssl->in_epoch )
         {
             MBEDTLS_SSL_DEBUG_MSG( 1, ( "record from another epoch: "
                                         "expected %d, received %d",
                                         ssl->in_epoch, rec_epoch ) );
-
-#if defined(MBEDTLS_SSL_DTLS_CLIENT_PORT_REUSE) && defined(MBEDTLS_SSL_SRV_C)
-            /*
-             * Check for an epoch 0 ClientHello. We can't use in_msg here to
-             * access the first byte of record content (handshake type), as we
-             * have an active transform (possibly iv_len != 0), so use the
-             * fact that the record header len is 13 instead.
-             */
-            if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER &&
-                ssl->state == MBEDTLS_SSL_HANDSHAKE_OVER &&
-                rec_epoch == 0 &&
-                ssl->in_msgtype == MBEDTLS_SSL_MSG_HANDSHAKE &&
-                ssl->in_left > 13 &&
-                ssl->in_buf[13] == MBEDTLS_SSL_HS_CLIENT_HELLO )
-            {
-                MBEDTLS_SSL_DEBUG_MSG( 1, ( "possible client reconnect "
-                                            "from the same port" ) );
-                return( ssl_handle_possible_reconnect( ssl ) );
-            }
-            else
-#endif /* MBEDTLS_SSL_DTLS_CLIENT_PORT_REUSE && MBEDTLS_SSL_SRV_C */
-            {
-                /* Consider buffering the record. */
-                if( rec_epoch == (unsigned int) ssl->in_epoch + 1 )
-                {
-                    MBEDTLS_SSL_DEBUG_MSG( 2, ( "Consider record for buffering" ) );
-                    return( MBEDTLS_ERR_SSL_EARLY_MESSAGE );
-                }
-
-                return( MBEDTLS_ERR_SSL_UNEXPECTED_RECORD );
-            }
+            return( MBEDTLS_ERR_SSL_UNEXPECTED_RECORD );
         }
-
 #if defined(MBEDTLS_SSL_DTLS_ANTI_REPLAY)
         /* Replay detection only works for the current epoch */
-        if( rec_epoch == ssl->in_epoch &&
-            mbedtls_ssl_dtls_replay_check( ssl ) != 0 )
+        else if( mbedtls_ssl_dtls_replay_check( ssl ) != 0 )
         {
             MBEDTLS_SSL_DEBUG_MSG( 1, ( "replayed record" ) );
             return( MBEDTLS_ERR_SSL_UNEXPECTED_RECORD );
@@ -5045,6 +5022,34 @@ static int ssl_parse_record_header( mbedtls_ssl_context *ssl )
     return( 0 );
 }
 
+
+#if defined(MBEDTLS_SSL_DTLS_CLIENT_PORT_REUSE) && defined(MBEDTLS_SSL_SRV_C)
+static int ssl_check_client_reconnect( mbedtls_ssl_context *ssl )
+{
+    unsigned int rec_epoch = ( ssl->in_ctr[0] << 8 ) | ssl->in_ctr[1];
+
+    /*
+     * Check for an epoch 0 ClientHello. We can't use in_msg here to
+     * access the first byte of record content (handshake type), as we
+     * have an active transform (possibly iv_len != 0), so use the
+     * fact that the record header len is 13 instead.
+     */
+    if( rec_epoch == 0 &&
+        ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER &&
+        ssl->state == MBEDTLS_SSL_HANDSHAKE_OVER &&
+        ssl->in_msgtype == MBEDTLS_SSL_MSG_HANDSHAKE &&
+        ssl->in_left > 13 &&
+        ssl->in_buf[13] == MBEDTLS_SSL_HS_CLIENT_HELLO )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "possible client reconnect "
+                                    "from the same port" ) );
+        return( ssl_handle_possible_reconnect( ssl ) );
+    }
+
+    return( 0 );
+}
+#endif /* MBEDTLS_SSL_DTLS_CLIENT_PORT_REUSE && MBEDTLS_SSL_SRV_C */
+
 /*
  * If applicable, decrypt (and decompress) record content
  */
@@ -5926,8 +5931,7 @@ static int ssl_get_next_record( mbedtls_ssl_context *ssl )
     if( ( ret = ssl_parse_record_header( ssl ) ) != 0 )
     {
 #if defined(MBEDTLS_SSL_PROTO_DTLS)
-        if( ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM &&
-            ret != MBEDTLS_ERR_SSL_CLIENT_RECONNECT )
+        if( ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM )
         {
             if( ret == MBEDTLS_ERR_SSL_EARLY_MESSAGE )
             {
@@ -5941,6 +5945,12 @@ static int ssl_get_next_record( mbedtls_ssl_context *ssl )
 
             if( ret == MBEDTLS_ERR_SSL_UNEXPECTED_RECORD )
             {
+#if defined(MBEDTLS_SSL_DTLS_CLIENT_PORT_REUSE) && defined(MBEDTLS_SSL_SRV_C)
+                ret = ssl_check_client_reconnect( ssl );
+                if( ret != 0 )
+                    return( ret );
+#endif
+
                 /* Skip unexpected record (but not whole datagram) */
                 ssl->next_record_offset = ssl->in_msglen
                                         + mbedtls_ssl_in_hdr_len( ssl );
@@ -5961,8 +5971,11 @@ static int ssl_get_next_record( mbedtls_ssl_context *ssl )
             /* Get next record */
             return( MBEDTLS_ERR_SSL_CONTINUE_PROCESSING );
         }
+        else
 #endif
-        return( ret );
+        {
+            return( ret );
+        }
     }
 
     /*