From b747c6cf9ba594f207c0d52b0ed572a875ee034b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Manuel=20P=C3=A9gouri=C3=A9-Gonnard?=
 <manuel.pegourie-gonnard@arm.com>
Date: Sun, 12 Aug 2018 13:28:53 +0200
Subject: [PATCH] Add basic first tests for MTU setting

For now, just check that it causes us to fragment. More tests are coming in
follow-up commits to ensure we respect the exact value set, including when
renegotiating.
---
 library/ssl_tls.c          |  3 ++
 programs/ssl/ssl_client2.c | 15 +++++++-
 programs/ssl/ssl_server2.c | 15 +++++++-
 tests/ssl-opt.sh           | 76 ++++++++++++++++++++++++++++++++++++--
 4 files changed, 103 insertions(+), 6 deletions(-)

diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index ea46d85b3..b05d2883a 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -2905,6 +2905,9 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl )
             const size_t frag_len = rem_len > max_hs_fragment_len
                                   ? max_hs_fragment_len : rem_len;
 
+            if( frag_off == 0 && frag_len != hs_len )
+                MBEDTLS_SSL_DEBUG_MSG( 2, ( "fragmenting handshake message" ) );
+
             /* Messages are stored with handshake headers as if not fragmented,
              * copy beginning of headers then fill fragmentation fields.
              * Handshake headers: type(1) len(3) seq(2) f_off(3) f_len(3) */
diff --git a/programs/ssl/ssl_client2.c b/programs/ssl/ssl_client2.c
index 0dd9e3f7b..7cdc53a54 100644
--- a/programs/ssl/ssl_client2.c
+++ b/programs/ssl/ssl_client2.c
@@ -106,6 +106,7 @@ int main( void )
 #define DFL_TRANSPORT           MBEDTLS_SSL_TRANSPORT_STREAM
 #define DFL_HS_TO_MIN           0
 #define DFL_HS_TO_MAX           0
+#define DFL_DTLS_MTU            -1
 #define DFL_FALLBACK            -1
 #define DFL_EXTENDED_MS         -1
 #define DFL_ETM                 -1
@@ -198,7 +199,8 @@ int main( void )
 #define USAGE_DTLS \
     "    dtls=%%d             default: 0 (TLS)\n"                           \
     "    hs_timeout=%%d-%%d    default: (library default: 1000-60000)\n"    \
-    "                        range of DTLS handshake timeouts in millisecs\n"
+    "                        range of DTLS handshake timeouts in millisecs\n" \
+    "    mtu=%%d              default: (library default: unlimited)\n"
 #else
 #define USAGE_DTLS ""
 #endif
@@ -345,6 +347,7 @@ struct options
     int transport;              /* TLS or DTLS?                             */
     uint32_t hs_to_min;         /* Initial value of DTLS handshake timer    */
     uint32_t hs_to_max;         /* Max value of DTLS handshake timer        */
+    int dtls_mtu;               /* UDP Maximum tranport unit for DTLS       */
     int fallback;               /* is this a fallback connection?           */
     int extended_ms;            /* negotiate extended master secret?        */
     int etm;                    /* negotiate encrypt then mac?              */
@@ -617,6 +620,7 @@ int main( int argc, char *argv[] )
     opt.transport           = DFL_TRANSPORT;
     opt.hs_to_min           = DFL_HS_TO_MIN;
     opt.hs_to_max           = DFL_HS_TO_MAX;
+    opt.dtls_mtu            = DFL_DTLS_MTU;
     opt.fallback            = DFL_FALLBACK;
     opt.extended_ms         = DFL_EXTENDED_MS;
     opt.etm                 = DFL_ETM;
@@ -927,6 +931,12 @@ int main( int argc, char *argv[] )
             if( opt.hs_to_min == 0 || opt.hs_to_max < opt.hs_to_min )
                 goto usage;
         }
+        else if( strcmp( p, "mtu" ) == 0 )
+        {
+            opt.dtls_mtu = atoi( q );
+            if( opt.dtls_mtu < 0 )
+                goto usage;
+        }
         else if( strcmp( p, "recsplit" ) == 0 )
         {
             opt.recsplit = atoi( q );
@@ -1327,6 +1337,9 @@ int main( int argc, char *argv[] )
     if( opt.hs_to_min != DFL_HS_TO_MIN || opt.hs_to_max != DFL_HS_TO_MAX )
         mbedtls_ssl_conf_handshake_timeout( &conf, opt.hs_to_min,
                                             opt.hs_to_max );
+
+    if( opt.dtls_mtu != DFL_DTLS_MTU )
+        mbedtls_ssl_conf_mtu( &conf, opt.dtls_mtu );
 #endif /* MBEDTLS_SSL_PROTO_DTLS */
 
 #if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH)
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 7654a6446..484f84fdd 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -150,6 +150,7 @@ int main( void )
 #define DFL_ANTI_REPLAY         -1
 #define DFL_HS_TO_MIN           0
 #define DFL_HS_TO_MAX           0
+#define DFL_DTLS_MTU            -1
 #define DFL_BADMAC_LIMIT        -1
 #define DFL_EXTENDED_MS         -1
 #define DFL_ETM                 -1
@@ -297,7 +298,8 @@ int main( void )
 #define USAGE_DTLS \
     "    dtls=%%d             default: 0 (TLS)\n"                           \
     "    hs_timeout=%%d-%%d    default: (library default: 1000-60000)\n"    \
-    "                        range of DTLS handshake timeouts in millisecs\n"
+    "                        range of DTLS handshake timeouts in millisecs\n" \
+    "    mtu=%%d              default: (library default: unlimited)\n"
 #else
 #define USAGE_DTLS ""
 #endif
@@ -470,6 +472,7 @@ struct options
     int anti_replay;            /* Use anti-replay for DTLS? -1 for default */
     uint32_t hs_to_min;         /* Initial value of DTLS handshake timer    */
     uint32_t hs_to_max;         /* Max value of DTLS handshake timer        */
+    int dtls_mtu;               /* UDP Maximum tranport unit for DTLS       */
     int badmac_limit;           /* Limit of records with bad MAC            */
 } opt;
 
@@ -1338,6 +1341,7 @@ int main( int argc, char *argv[] )
     opt.anti_replay         = DFL_ANTI_REPLAY;
     opt.hs_to_min           = DFL_HS_TO_MIN;
     opt.hs_to_max           = DFL_HS_TO_MAX;
+    opt.dtls_mtu            = DFL_DTLS_MTU;
     opt.badmac_limit        = DFL_BADMAC_LIMIT;
     opt.extended_ms         = DFL_EXTENDED_MS;
     opt.etm                 = DFL_ETM;
@@ -1684,6 +1688,12 @@ int main( int argc, char *argv[] )
             if( opt.hs_to_min == 0 || opt.hs_to_max < opt.hs_to_min )
                 goto usage;
         }
+        else if( strcmp( p, "mtu" ) == 0 )
+        {
+            opt.dtls_mtu = atoi( q );
+            if( opt.dtls_mtu < 0 )
+                goto usage;
+        }
         else if( strcmp( p, "sni" ) == 0 )
         {
             opt.sni = q;
@@ -2155,6 +2165,9 @@ int main( int argc, char *argv[] )
 #if defined(MBEDTLS_SSL_PROTO_DTLS)
     if( opt.hs_to_min != DFL_HS_TO_MIN || opt.hs_to_max != DFL_HS_TO_MAX )
         mbedtls_ssl_conf_handshake_timeout( &conf, opt.hs_to_min, opt.hs_to_max );
+
+    if( opt.dtls_mtu != DFL_DTLS_MTU )
+        mbedtls_ssl_conf_mtu( &conf, opt.dtls_mtu );
 #endif /* MBEDTLS_SSL_PROTO_DTLS */
 
 #if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH)
diff --git a/tests/ssl-opt.sh b/tests/ssl-opt.sh
index 0cf288f12..3d61ac3a4 100755
--- a/tests/ssl-opt.sh
+++ b/tests/ssl-opt.sh
@@ -4911,7 +4911,7 @@ requires_config_enabled MBEDTLS_SSL_PROTO_DTLS
 requires_config_enabled MBEDTLS_RSA_C
 requires_config_enabled MBEDTLS_ECDSA_C
 requires_config_enabled MBEDTLS_SSL_MAX_FRAGMENT_LENGTH
-run_test    "DTLS fragmenting: server only" \
+run_test    "DTLS fragmenting: server only (max_frag_len)" \
             "$P_SRV dtls=1 debug_level=2 auth_mode=required \
              crt_file=data_files/server7_int-ca.crt \
              key_file=data_files/server7.key \
@@ -4929,7 +4929,7 @@ requires_config_enabled MBEDTLS_SSL_PROTO_DTLS
 requires_config_enabled MBEDTLS_RSA_C
 requires_config_enabled MBEDTLS_ECDSA_C
 requires_config_enabled MBEDTLS_SSL_MAX_FRAGMENT_LENGTH
-run_test    "DTLS fragmenting: server only (more)" \
+run_test    "DTLS fragmenting: server only (more) (max_frag_len)" \
             "$P_SRV dtls=1 debug_level=2 auth_mode=required \
              crt_file=data_files/server7_int-ca.crt \
              key_file=data_files/server7.key \
@@ -4947,7 +4947,7 @@ requires_config_enabled MBEDTLS_SSL_PROTO_DTLS
 requires_config_enabled MBEDTLS_RSA_C
 requires_config_enabled MBEDTLS_ECDSA_C
 requires_config_enabled MBEDTLS_SSL_MAX_FRAGMENT_LENGTH
-run_test    "DTLS fragmenting: client-initiated, server only" \
+run_test    "DTLS fragmenting: client-initiated, server only (max_frag_len)" \
             "$P_SRV dtls=1 debug_level=2 auth_mode=none \
              crt_file=data_files/server7_int-ca.crt \
              key_file=data_files/server7.key \
@@ -4965,7 +4965,7 @@ requires_config_enabled MBEDTLS_SSL_PROTO_DTLS
 requires_config_enabled MBEDTLS_RSA_C
 requires_config_enabled MBEDTLS_ECDSA_C
 requires_config_enabled MBEDTLS_SSL_MAX_FRAGMENT_LENGTH
-run_test    "DTLS fragmenting: client-initiated, both" \
+run_test    "DTLS fragmenting: client-initiated, both (max_frag_len)" \
             "$P_SRV dtls=1 debug_level=2 auth_mode=required \
              crt_file=data_files/server7_int-ca.crt \
              key_file=data_files/server7.key \
@@ -4979,6 +4979,74 @@ run_test    "DTLS fragmenting: client-initiated, both" \
             -c "found fragmented DTLS handshake message" \
             -C "error"
 
+requires_config_enabled MBEDTLS_SSL_PROTO_DTLS
+requires_config_enabled MBEDTLS_RSA_C
+requires_config_enabled MBEDTLS_ECDSA_C
+run_test    "DTLS fragmenting: none (for reference) (MTU)" \
+            "$P_SRV dtls=1 debug_level=2 auth_mode=required \
+             crt_file=data_files/server7_int-ca.crt \
+             key_file=data_files/server7.key \
+             mtu=2048" \
+            "$P_CLI dtls=1 debug_level=2 \
+             crt_file=data_files/server8_int-ca2.crt \
+             key_file=data_files/server8.key \
+             mtu=2048" \
+            0 \
+            -S "found fragmented DTLS handshake message" \
+            -C "found fragmented DTLS handshake message" \
+            -C "error"
+
+requires_config_enabled MBEDTLS_SSL_PROTO_DTLS
+requires_config_enabled MBEDTLS_RSA_C
+requires_config_enabled MBEDTLS_ECDSA_C
+run_test    "DTLS fragmenting: client (MTU)" \
+            "$P_SRV dtls=1 debug_level=2 auth_mode=required \
+             crt_file=data_files/server7_int-ca.crt \
+             key_file=data_files/server7.key \
+             mtu=2048" \
+            "$P_CLI dtls=1 debug_level=2 \
+             crt_file=data_files/server8_int-ca2.crt \
+             key_file=data_files/server8.key \
+             mtu=512" \
+            0 \
+            -s "found fragmented DTLS handshake message" \
+            -C "found fragmented DTLS handshake message" \
+            -C "error"
+
+requires_config_enabled MBEDTLS_SSL_PROTO_DTLS
+requires_config_enabled MBEDTLS_RSA_C
+requires_config_enabled MBEDTLS_ECDSA_C
+run_test    "DTLS fragmenting: server (MTU)" \
+            "$P_SRV dtls=1 debug_level=2 auth_mode=required \
+             crt_file=data_files/server7_int-ca.crt \
+             key_file=data_files/server7.key \
+             mtu=512" \
+            "$P_CLI dtls=1 debug_level=2 \
+             crt_file=data_files/server8_int-ca2.crt \
+             key_file=data_files/server8.key \
+             mtu=2048" \
+            0 \
+            -S "found fragmented DTLS handshake message" \
+            -c "found fragmented DTLS handshake message" \
+            -C "error"
+
+requires_config_enabled MBEDTLS_SSL_PROTO_DTLS
+requires_config_enabled MBEDTLS_RSA_C
+requires_config_enabled MBEDTLS_ECDSA_C
+run_test    "DTLS fragmenting: both (MTU)" \
+            "$P_SRV dtls=1 debug_level=2 auth_mode=required \
+             crt_file=data_files/server7_int-ca.crt \
+             key_file=data_files/server7.key \
+             mtu=512" \
+            "$P_CLI dtls=1 debug_level=2 \
+             crt_file=data_files/server8_int-ca2.crt \
+             key_file=data_files/server8.key \
+             mtu=512" \
+            0 \
+            -s "found fragmented DTLS handshake message" \
+            -c "found fragmented DTLS handshake message" \
+            -C "error"
+
 # Tests for specific things with "unreliable" UDP connection
 
 not_with_valgrind # spurious resend due to timeout