diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c index d9a29dcd0..7b6efb1cc 100644 --- a/library/ssl_tls12_server.c +++ b/library/ssl_tls12_server.c @@ -4115,10 +4115,6 @@ static int ssl_parse_client_key_exchange( mbedtls_ssl_context *ssl ) MBEDTLS_PUT_UINT16_BE( zlen, psm, 0 ); psm += zlen_size + zlen; - /* opaque psk<0..2^16-1>; */ - if( psm_end - psm < 2 ) - return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - const unsigned char *psk = NULL; size_t psk_len = 0; @@ -4130,13 +4126,14 @@ static int ssl_parse_client_key_exchange( mbedtls_ssl_context *ssl ) */ return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); + /* opaque psk<0..2^16-1>; */ + if( (size_t)( psm_end - psm ) < ( 2 + psk_len ) ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + /* Write the PSK length as uint16 */ MBEDTLS_PUT_UINT16_BE( psk_len, psm, 0 ); psm += 2; - if( psm_end < psm || (size_t)( psm_end - psm ) < psk_len ) - return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - /* Write the PSK itself */ memcpy( psm, psk, psk_len ); psm += psk_len;