diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index 97c00e256..80da3ac1a 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -1203,6 +1203,9 @@ void mbedtls_ssl_send_flight_completed( mbedtls_ssl_context *ssl ); void mbedtls_ssl_recv_flight_completed( mbedtls_ssl_context *ssl ); int mbedtls_ssl_resend( mbedtls_ssl_context *ssl ); int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ); +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) +void mbedtls_ssl_immediate_flight_done( mbedtls_ssl_context *ssl ); +#endif #endif /* Visible for testing purposes only */ diff --git a/library/ssl_cli.c b/library/ssl_cli.c index 08d5a7117..7f69b6242 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -1141,11 +1141,17 @@ static int ssl_write_client_hello( mbedtls_ssl_context *ssl ) } #if defined(MBEDTLS_SSL_PROTO_DTLS) - if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) && - ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) ) { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); - return( ret ); +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) + mbedtls_ssl_immediate_flight_done( ssl ); +#else + if( ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); + return( ret ); + } +#endif } #endif /* MBEDTLS_SSL_PROTO_DTLS */ diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 389a24e48..ce92f98dc 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -2743,11 +2743,17 @@ static int ssl_write_hello_verify_request( mbedtls_ssl_context *ssl ) } #if defined(MBEDTLS_SSL_PROTO_DTLS) - if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) && - ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) ) { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); - return( ret ); +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) + mbedtls_ssl_immediate_flight_done( ssl ); +#else + if( ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); + return( ret ); + } +#endif } #endif /* MBEDTLS_SSL_PROTO_DTLS */ @@ -3802,11 +3808,17 @@ static int ssl_write_server_hello_done( mbedtls_ssl_context *ssl ) } #if defined(MBEDTLS_SSL_PROTO_DTLS) - if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) && - ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) ) { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); - return( ret ); +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) + mbedtls_ssl_immediate_flight_done( ssl ); +#else + if( ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); + return( ret ); + } +#endif } #endif /* MBEDTLS_SSL_PROTO_DTLS */ diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 038e581d5..f20faf92a 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -4360,6 +4360,131 @@ int mbedtls_ssl_flush_output( mbedtls_ssl_context *ssl ) * Functions to handle the DTLS retransmission state machine */ #if defined(MBEDTLS_SSL_PROTO_DTLS) +static int ssl_swap_epochs( mbedtls_ssl_context *ssl ); + +static int mbedtls_ssl_flight_transmit_msg( mbedtls_ssl_context *ssl, mbedtls_ssl_flight_item *msg ) +{ + size_t max_frag_len; + int ret = MBEDTLS_ERR_PLATFORM_FAULT_DETECTED; + int const is_retransmitting = + ( ssl->handshake->retransmit_state == MBEDTLS_SSL_RETRANS_SENDING ); + int const is_finished = + ( msg->type == MBEDTLS_SSL_MSG_HANDSHAKE && + msg->p[0] == MBEDTLS_SSL_HS_FINISHED ); + + uint8_t const force_flush = ssl->disable_datagram_packing == 1 ? + SSL_FORCE_FLUSH : SSL_DONT_FORCE_FLUSH; + + /* Swap epochs before sending Finished: we can't do it after + * sending ChangeCipherSpec, in case write returns WANT_READ. + * Must be done before copying, may change out_msg pointer */ + if( is_retransmitting && is_finished && ssl->handshake->cur_msg_p == ( msg->p + 12 ) ) + { + MBEDTLS_SSL_DEBUG_MSG( 2, ( "swap epochs to send finished message" ) ); + if( ( ret = ssl_swap_epochs( ssl ) ) != 0 ) + return( ret ); + } + + ret = ssl_get_remaining_payload_in_datagram( ssl ); + if( ret < 0 ) + return( ret ); + max_frag_len = (size_t) ret; + + /* CCS is copied as is, while HS messages may need fragmentation */ + if( msg->type == MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC ) + { + if( max_frag_len == 0 ) + { + if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) + return( ret ); + + return( 0 ); + } + + mbedtls_platform_memcpy( ssl->out_msg, msg->p, msg->len ); + ssl->out_msglen = msg->len; + ssl->out_msgtype = msg->type; + + /* Update position inside current message */ + ssl->handshake->cur_msg_p += msg->len; + } + else + { + const unsigned char * const p = ssl->handshake->cur_msg_p; + const size_t hs_len = msg->len - 12; + const size_t frag_off = p - ( msg->p + 12 ); + const size_t rem_len = hs_len - frag_off; + size_t cur_hs_frag_len, max_hs_frag_len; + + if( ( max_frag_len < 12 ) || ( max_frag_len == 12 && hs_len != 0 ) ) + { + if( is_finished && is_retransmitting ) + { + if( ( ret = ssl_swap_epochs( ssl ) ) != 0 ) + return( ret ); + } + + if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) + return( ret ); + + return( 0 ); + } + max_hs_frag_len = max_frag_len - 12; + + cur_hs_frag_len = rem_len > max_hs_frag_len ? + max_hs_frag_len : rem_len; + + if( frag_off == 0 && cur_hs_frag_len != hs_len ) + { + MBEDTLS_SSL_DEBUG_MSG( 2, ( "fragmenting handshake message (%u > %u)", + (unsigned) cur_hs_frag_len, + (unsigned) max_hs_frag_len ) ); + } + + /* Messages are stored with handshake headers as if not fragmented, + * copy beginning of headers then fill fragmentation fields. + * Handshake headers: type(1) len(3) seq(2) f_off(3) f_len(3) */ + mbedtls_platform_memcpy( ssl->out_msg, msg->p, 6 ); + + (void)mbedtls_platform_put_uint24_be( &ssl->out_msg[6], frag_off ); + (void)mbedtls_platform_put_uint24_be( &ssl->out_msg[9], + cur_hs_frag_len ); + + MBEDTLS_SSL_DEBUG_BUF( 3, "handshake header", ssl->out_msg, 12 ); + + /* Copy the handshake message content and set records fields */ + mbedtls_platform_memcpy( ssl->out_msg + 12, p, cur_hs_frag_len ); + ssl->out_msglen = cur_hs_frag_len + 12; + ssl->out_msgtype = msg->type; + + /* Update position inside current message */ + ssl->handshake->cur_msg_p += cur_hs_frag_len; + } + + /* If done with the current message move to the next one if any */ + if( ssl->handshake->cur_msg_p >= msg->p + msg->len ) + { + if( msg->next != NULL ) + { + ssl->handshake->cur_msg = msg->next; + ssl->handshake->cur_msg_p = msg->next->p + 12; + } + else + { + ssl->handshake->cur_msg = NULL; + ssl->handshake->cur_msg_p = NULL; + } + } + + /* Actually send the message out */ + if( ( ret = mbedtls_ssl_write_record( ssl, force_flush ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret ); + return( ret ); + } + return( ret ); +} + /* * Append current handshake message to current outgoing flight */ @@ -4402,6 +4527,21 @@ static int ssl_flight_append( mbedtls_ssl_context *ssl ) cur->next = msg; } +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) + ssl->handshake->cur_msg = msg; + ssl->handshake->cur_msg_p = msg->p + 12; + { + int ret = MBEDTLS_ERR_PLATFORM_FAULT_DETECTED; + while( ssl->handshake->cur_msg != NULL ) + { + if( ( ret = mbedtls_ssl_flight_transmit_msg( ssl, ssl->handshake->cur_msg ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit_msg", ret ); + return( ret ); + } + } + } +#endif MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= ssl_flight_append" ) ); return( 0 ); } @@ -4491,6 +4631,24 @@ int mbedtls_ssl_resend( mbedtls_ssl_context *ssl ) return( ret ); } +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) +void mbedtls_ssl_immediate_flight_done( mbedtls_ssl_context *ssl ) +{ + MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> mbedtls_ssl_immediate_flight_done" ) ); + + /* Update state and set timer */ + if( ssl->state == MBEDTLS_SSL_HANDSHAKE_OVER ) + ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_FINISHED; + else + { + ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_WAITING; + ssl_set_timer( ssl, ssl->handshake->retransmit_timeout ); + } + + MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= mbedtls_ssl_immediate_flight_done" ) ); +} +#endif + /* * Transmit or retransmit the current flight of messages. * @@ -4507,138 +4665,19 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ) { MBEDTLS_SSL_DEBUG_MSG( 2, ( "initialise flight transmission" ) ); -#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) - ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_SENDING; - - return( 0 ); -#else - ssl->handshake->cur_msg = ssl->handshake->flight; ssl->handshake->cur_msg_p = ssl->handshake->flight->p + 12; if( ( ret = ssl_swap_epochs( ssl ) ) != 0 ) return( ret ); ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_SENDING; -#endif /* MBEDTLS_IMMEDIATE_TRANSMISSION */ } while( ssl->handshake->cur_msg != NULL ) { - size_t max_frag_len; - const mbedtls_ssl_flight_item * const cur = ssl->handshake->cur_msg; - - int const is_finished = - ( cur->type == MBEDTLS_SSL_MSG_HANDSHAKE && - cur->p[0] == MBEDTLS_SSL_HS_FINISHED ); - - uint8_t const force_flush = ssl->disable_datagram_packing == 1 ? - SSL_FORCE_FLUSH : SSL_DONT_FORCE_FLUSH; - - /* Swap epochs before sending Finished: we can't do it after - * sending ChangeCipherSpec, in case write returns WANT_READ. - * Must be done before copying, may change out_msg pointer */ - if( is_finished && ssl->handshake->cur_msg_p == ( cur->p + 12 ) ) + if( ( ret = mbedtls_ssl_flight_transmit_msg( ssl, ssl->handshake->cur_msg ) ) != 0 ) { - MBEDTLS_SSL_DEBUG_MSG( 2, ( "swap epochs to send finished message" ) ); - if( ( ret = ssl_swap_epochs( ssl ) ) != 0 ) - return( ret ); - } - - ret = ssl_get_remaining_payload_in_datagram( ssl ); - if( ret < 0 ) - return( ret ); - max_frag_len = (size_t) ret; - - /* CCS is copied as is, while HS messages may need fragmentation */ - if( cur->type == MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC ) - { - if( max_frag_len == 0 ) - { - if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) - return( ret ); - - continue; - } - - mbedtls_platform_memcpy( ssl->out_msg, cur->p, cur->len ); - ssl->out_msglen = cur->len; - ssl->out_msgtype = cur->type; - - /* Update position inside current message */ - ssl->handshake->cur_msg_p += cur->len; - } - else - { - const unsigned char * const p = ssl->handshake->cur_msg_p; - const size_t hs_len = cur->len - 12; - const size_t frag_off = p - ( cur->p + 12 ); - const size_t rem_len = hs_len - frag_off; - size_t cur_hs_frag_len, max_hs_frag_len; - - if( ( max_frag_len < 12 ) || ( max_frag_len == 12 && hs_len != 0 ) ) - { - if( is_finished ) - { - if( ( ret = ssl_swap_epochs( ssl ) ) != 0 ) - return( ret ); - } - - if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) - return( ret ); - - continue; - } - max_hs_frag_len = max_frag_len - 12; - - cur_hs_frag_len = rem_len > max_hs_frag_len ? - max_hs_frag_len : rem_len; - - if( frag_off == 0 && cur_hs_frag_len != hs_len ) - { - MBEDTLS_SSL_DEBUG_MSG( 2, ( "fragmenting handshake message (%u > %u)", - (unsigned) cur_hs_frag_len, - (unsigned) max_hs_frag_len ) ); - } - - /* Messages are stored with handshake headers as if not fragmented, - * copy beginning of headers then fill fragmentation fields. - * Handshake headers: type(1) len(3) seq(2) f_off(3) f_len(3) */ - mbedtls_platform_memcpy( ssl->out_msg, cur->p, 6 ); - - (void)mbedtls_platform_put_uint24_be( &ssl->out_msg[6], frag_off ); - (void)mbedtls_platform_put_uint24_be( &ssl->out_msg[9], - cur_hs_frag_len ); - - MBEDTLS_SSL_DEBUG_BUF( 3, "handshake header", ssl->out_msg, 12 ); - - /* Copy the handshake message content and set records fields */ - mbedtls_platform_memcpy( ssl->out_msg + 12, p, cur_hs_frag_len ); - ssl->out_msglen = cur_hs_frag_len + 12; - ssl->out_msgtype = cur->type; - - /* Update position inside current message */ - ssl->handshake->cur_msg_p += cur_hs_frag_len; - } - - /* If done with the current message move to the next one if any */ - if( ssl->handshake->cur_msg_p >= cur->p + cur->len ) - { - if( cur->next != NULL ) - { - ssl->handshake->cur_msg = cur->next; - ssl->handshake->cur_msg_p = cur->next->p + 12; - } - else - { - ssl->handshake->cur_msg = NULL; - ssl->handshake->cur_msg_p = NULL; - } - } - - /* Actually send the message out */ - if( ( ret = mbedtls_ssl_write_record( ssl, force_flush ) ) != 0 ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret ); + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit_msg", ret ); return( ret ); } } @@ -4657,7 +4696,7 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= mbedtls_ssl_flight_transmit" ) ); - return( 0 ); + return( ret ); } /* @@ -4868,14 +4907,6 @@ int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl ) ! ( ssl->out_msgtype == MBEDTLS_SSL_MSG_HANDSHAKE && hs_type == MBEDTLS_SSL_HS_HELLO_REQUEST ) ) { -#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) - if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "ssl_write_record", ret ); - return( ret ); - } -#endif /* MBEDTLS_IMMEDIATE_TRANSMISSION */ - if( ( ret = ssl_flight_append( ssl ) ) != 0 ) { MBEDTLS_SSL_DEBUG_RET( 1, "ssl_flight_append", ret ); @@ -8707,13 +8738,19 @@ int mbedtls_ssl_write_finished( mbedtls_ssl_context *ssl ) } #if defined(MBEDTLS_SSL_PROTO_DTLS) - if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) && - ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) ) { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); - return( ret ); - } +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) + mbedtls_ssl_immediate_flight_done( ssl ); +#else + if( ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); + return( ret ); + } #endif + } +#endif /* MBEDTLS_SSL_PROTO_DTLS */ MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= write finished" ) );