diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c index bf49af21b..bbc853579 100644 --- a/library/ssl_tls13_server.c +++ b/library/ssl_tls13_server.c @@ -106,7 +106,7 @@ MBEDTLS_CHECK_RETURN_CRITICAL static int ssl_tls13_offered_psks_check_identity_match( mbedtls_ssl_context *ssl, const unsigned char *identity, - uint16_t identity_len ) + size_t identity_len ) { /* Check identity with external configured function */ if( ssl->conf->f_psk != NULL ) @@ -121,7 +121,8 @@ static int ssl_tls13_offered_psks_check_identity_match( MBEDTLS_SSL_DEBUG_BUF( 5, "identity", identity, identity_len ); /* Check identity with pre-configured psk */ - if( identity_len == ssl->conf->psk_identity_len && + if( ssl->conf->psk_identity != NULL && + identity_len == ssl->conf->psk_identity_len && mbedtls_ct_memcmp( ssl->conf->psk_identity, identity, identity_len ) == 0 ) { @@ -134,7 +135,7 @@ static int ssl_tls13_offered_psks_check_identity_match( MBEDTLS_CHECK_RETURN_CRITICAL static int ssl_tls13_get_psk( mbedtls_ssl_context *ssl, - const unsigned char **psk, + unsigned char **psk, size_t *psk_len ) { #if defined(MBEDTLS_USE_PSA_CRYPTO) @@ -150,7 +151,7 @@ static int ssl_tls13_get_psk( mbedtls_ssl_context *ssl, return( psa_ssl_status_to_mbedtls( status ) ); } - *psk_len = PSA_BITS_TO_BYTES(psa_get_key_bits( &key_attributes ) ); + *psk_len = PSA_BITS_TO_BYTES( psa_get_key_bits( &key_attributes ) ); *psk = mbedtls_calloc( 1, *psk_len ); if( *psk == NULL ) { @@ -174,22 +175,32 @@ static int ssl_tls13_get_psk( mbedtls_ssl_context *ssl, MBEDTLS_CHECK_RETURN_CRITICAL static int ssl_tls13_offered_psks_check_binder_match( mbedtls_ssl_context *ssl, const unsigned char *binder, - uint16_t binder_len ) + size_t binder_len ) { int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; int psk_type; - mbedtls_md_type_t md_alg = - binder_len == 32 ? MBEDTLS_MD_SHA256 : MBEDTLS_MD_SHA384 ; - psa_algorithm_t psa_md_alg = mbedtls_psa_translate_md( md_alg ); + mbedtls_md_type_t md_alg; + psa_algorithm_t psa_md_alg; unsigned char transcript[PSA_HASH_MAX_SIZE]; size_t transcript_len; - const unsigned char *psk; + unsigned char *psk; size_t psk_len; unsigned char server_computed_binder[PSA_HASH_MAX_SIZE]; psk_type = MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL; - + switch( binder_len ) + { + case 32: + md_alg = MBEDTLS_MD_SHA256; + break; + case 48: + md_alg = MBEDTLS_MD_SHA384; + break; + default: + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); + } + psa_md_alg = mbedtls_psa_translate_md( md_alg ); /* Get current state of handshake transcript. */ ret = mbedtls_ssl_get_handshake_transcript( ssl, md_alg, transcript, sizeof( transcript ), @@ -215,7 +226,7 @@ static int ssl_tls13_offered_psks_check_binder_match( mbedtls_ssl_context *ssl, } MBEDTLS_SSL_DEBUG_BUF( 3, "psk binder ( computed ): ", - server_computed_binder, binder_len ); + server_computed_binder, transcript_len ); MBEDTLS_SSL_DEBUG_BUF( 3, "psk binder ( received ): ", binder, binder_len ); if( mbedtls_ct_memcmp( server_computed_binder, binder, binder_len ) == 0 ) @@ -262,7 +273,7 @@ static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl, int matched_identity = -1; int identity_id = -1; - MBEDTLS_SSL_DEBUG_BUF( 3, "pre_shared_key extesion", buf, end - buf ); + MBEDTLS_SSL_DEBUG_BUF( 3, "pre_shared_key extension", buf, end - buf ); /* identities_len 2 bytes * identities_data >= 7 bytes