diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index 45ea26ae2..173b6d58b 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -264,6 +264,57 @@ #define MBEDTLS_TLS_EXT_SUPPORTED_POINT_FORMATS_PRESENT (1 << 0) #define MBEDTLS_TLS_EXT_ECJPAKE_KKPP_OK (1 << 1) +/* + * Helpers for code specific to TLS or DTLS. + * + * Goals for these helpers: + * - generate minimal code, eg don't test if mode is DTLS in a DTLS-only build + * - make the flow clear to the compiler, ie that in dual-mode builds, + * when there are two branchs, exactly one of them is taken + * - preserve readability + * + * There are three macros: + * - MBEDTLS_SSL_TRANSPORT_IS_TLS( transport ) + * - MBEDTLS_SSL_TRANSPORT_IS_DTLS( transport ) + * - MBEDTLS_SSL_TRANSPORT_ELSE + * + * The first two are macros rather than static inline functions because some + * compilers (eg arm-none-eabi-gcc 5.4.1 20160919) don't propagate constants + * well enough for us with static inline functions. + * + * Usage 1 (can replace DTLS with TLS): + * #if defined(MBEDTLS_SSL_PROTO_DTLS) + * if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( transport ) ) + * // DTLS-specific code + * #endif + * + * Usage 2 (can swap DTLS and TLS); + * #if defined(MBEDTLS_SSL_PROTO_DTLS) + * if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( transport ) ) + * // DTLS-specific code + * MBEDTLS_SSL_TRANSPORT_ELSE + * #endif + * #if defined(MBEDTLS_SSL_PROTO_TLS) + * // TLS-specific code + * #endif + */ +#if defined(MBEDTLS_SSL_PROTO_DTLS) && defined(MBEDTLS_SSL_PROTO_TLS) /* both */ +#define MBEDTLS_SSL_TRANSPORT__BOTH /* shorcut for future tests */ +#define MBEDTLS_SSL_TRANSPORT_IS_TLS( transport ) \ + ( (transport) == MBEDTLS_SSL_TRANSPORT_STREAM ) +#define MBEDTLS_SSL_TRANSPORT_IS_DTLS( transport ) \ + ( (transport) == MBEDTLS_SSL_TRANSPORT_DATAGRAM ) +#define MBEDTLS_SSL_TRANSPORT_ELSE else +#elif defined(MBEDTLS_SSL_PROTO_DTLS) /* DTLS only */ +#define MBEDTLS_SSL_TRANSPORT_IS_TLS( transport ) 0 +#define MBEDTLS_SSL_TRANSPORT_IS_DTLS( transport ) 1 +#define MBEDTLS_SSL_TRANSPORT_ELSE /* empty: no other branch */ +#else /* TLS only */ +#define MBEDTLS_SSL_TRANSPORT_IS_TLS( transport ) 1 +#define MBEDTLS_SSL_TRANSPORT_IS_DTLS( transport ) 0 +#define MBEDTLS_SSL_TRANSPORT_ELSE /* empty: no other branch */ +#endif /* TLS and/or DTLS */ + #ifdef __cplusplus extern "C" { #endif @@ -905,12 +956,14 @@ static inline size_t mbedtls_ssl_out_hdr_len( const mbedtls_ssl_context *ssl ) static inline size_t mbedtls_ssl_hs_hdr_len( const mbedtls_ssl_context *ssl ) { -#if defined(MBEDTLS_SSL_PROTO_DTLS) - if( ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM ) - return( 12 ); -#else +#if !defined(MBEDTLS_SSL_PROTO__BOTH) ((void) ssl); #endif + +#if defined(MBEDTLS_SSL_PROTO_DTLS) + if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) ) + return( 12 ); +#endif return( 4 ); } diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 9afcc96e8..4cfef3f57 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -3023,7 +3023,7 @@ int mbedtls_ssl_fetch_input( mbedtls_ssl_context *ssl, size_t nb_want ) } #if defined(MBEDTLS_SSL_PROTO_DTLS) - if( ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM ) + if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) ) { uint32_t timeout; @@ -3164,8 +3164,9 @@ int mbedtls_ssl_fetch_input( mbedtls_ssl_context *ssl, size_t nb_want ) ssl->in_left = ret; } - else -#endif + MBEDTLS_SSL_TRANSPORT_ELSE +#endif /* MBEDTLS_SSL_PROTO_DTLS */ +#if defined(MBEDTLS_SSL_PROTO_TLS) { MBEDTLS_SSL_DEBUG_MSG( 2, ( "in_left: %d, nb_want: %d", ssl->in_left, nb_want ) ); @@ -3212,6 +3213,7 @@ int mbedtls_ssl_fetch_input( mbedtls_ssl_context *ssl, size_t nb_want ) ssl->in_left += ret; } } +#endif /* MBEDTLS_SSL_PROTO_TLS */ MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= fetch input" ) );