diff --git a/include/psa/crypto_struct.h b/include/psa/crypto_struct.h index 47012fdd0..a1182c48d 100644 --- a/include/psa/crypto_struct.h +++ b/include/psa/crypto_struct.h @@ -75,6 +75,8 @@ extern "C" { #include "mbedtls/cmac.h" #include "mbedtls/gcm.h" +#include "mbedtls/ccm.h" +#include "mbedtls/chachapoly.h" /* Include the context definition for the compiled-in drivers for the primitive * algorithms. */ @@ -153,17 +155,27 @@ struct psa_aead_operation_s { psa_algorithm_t alg; unsigned int key_set : 1; - unsigned int iv_set : 1; - uint8_t iv_size; - uint8_t block_size; + unsigned int nonce_set : 1; + + uint8_t tag_length; + union { unsigned dummy; /* Enable easier initializing of the union. */ - mbedtls_cipher_context_t cipher; +#if defined(MBEDTLS_PSA_BUILTIN_ALG_CCM) + mbedtls_ccm_context ccm; +#endif /* MBEDTLS_PSA_BUILTIN_ALG_CCM */ +#if defined(MBEDTLS_PSA_BUILTIN_ALG_GCM) + mbedtls_gcm_context gcm; +#endif /* MBEDTLS_PSA_BUILTIN_ALG_GCM */ +#if defined(MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305) + mbedtls_chachapoly_context chachapoly; +#endif /* MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305 */ + } ctx; }; -#define PSA_AEAD_OPERATION_INIT {0, 0, 0, 0, 0, {0}} +#define PSA_AEAD_OPERATION_INIT {0, 0, 0, 0, {0}} static inline struct psa_aead_operation_s psa_aead_operation_init( void ) { const struct psa_aead_operation_s v = PSA_AEAD_OPERATION_INIT; diff --git a/library/psa_crypto_aead.c b/library/psa_crypto_aead.c index 356679c38..07c52d433 100644 --- a/library/psa_crypto_aead.c +++ b/library/psa_crypto_aead.c @@ -30,30 +30,10 @@ #include "mbedtls/cipher.h" #include "mbedtls/gcm.h" -typedef struct -{ - union - { - unsigned dummy; /* Make the union non-empty even with no supported algorithms. */ -#if defined(MBEDTLS_PSA_BUILTIN_ALG_CCM) - mbedtls_ccm_context ccm; -#endif /* MBEDTLS_PSA_BUILTIN_ALG_CCM */ -#if defined(MBEDTLS_PSA_BUILTIN_ALG_GCM) - mbedtls_gcm_context gcm; -#endif /* MBEDTLS_PSA_BUILTIN_ALG_GCM */ -#if defined(MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305) - mbedtls_chachapoly_context chachapoly; -#endif /* MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305 */ - } ctx; - psa_algorithm_t core_alg; - uint8_t tag_length; -} aead_operation_t; -#define AEAD_OPERATION_INIT {{0}, 0, 0} - -static void psa_aead_abort_internal( aead_operation_t *operation ) +static void psa_aead_abort_internal( psa_aead_operation_t *operation ) { - switch( operation->core_alg ) + switch( operation->alg ) { #if defined(MBEDTLS_PSA_BUILTIN_ALG_CCM) case PSA_ALG_CCM: @@ -74,7 +54,7 @@ static void psa_aead_abort_internal( aead_operation_t *operation ) } static psa_status_t psa_aead_setup( - aead_operation_t *operation, + psa_aead_operation_t *operation, const psa_key_attributes_t *attributes, const uint8_t *key_buffer, psa_algorithm_t alg ) @@ -97,7 +77,7 @@ static psa_status_t psa_aead_setup( { #if defined(MBEDTLS_PSA_BUILTIN_ALG_CCM) case PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 0 ): - operation->core_alg = PSA_ALG_CCM; + operation->alg = PSA_ALG_CCM; full_tag_length = 16; /* CCM allows the following tag lengths: 4, 6, 8, 10, 12, 14, 16. * The call to mbedtls_ccm_encrypt_and_tag or @@ -116,7 +96,7 @@ static psa_status_t psa_aead_setup( #if defined(MBEDTLS_PSA_BUILTIN_ALG_GCM) case PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_GCM, 0 ): - operation->core_alg = PSA_ALG_GCM; + operation->alg = PSA_ALG_GCM; full_tag_length = 16; /* GCM allows the following tag lengths: 4, 8, 12, 13, 14, 15, 16. * The call to mbedtls_gcm_crypt_and_tag or @@ -135,7 +115,7 @@ static psa_status_t psa_aead_setup( #if defined(MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305) case PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CHACHA20_POLY1305, 0 ): - operation->core_alg = PSA_ALG_CHACHA20_POLY1305; + operation->alg = PSA_ALG_CHACHA20_POLY1305; full_tag_length = 16; /* We only support the default tag length. */ if( alg != PSA_ALG_CHACHA20_POLY1305 ) @@ -176,7 +156,7 @@ psa_status_t mbedtls_psa_aead_encrypt( uint8_t *ciphertext, size_t ciphertext_size, size_t *ciphertext_length ) { psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; - aead_operation_t operation = AEAD_OPERATION_INIT; + psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT; uint8_t *tag; (void) key_buffer_size; @@ -194,7 +174,7 @@ psa_status_t mbedtls_psa_aead_encrypt( tag = ciphertext + plaintext_length; #if defined(MBEDTLS_PSA_BUILTIN_ALG_CCM) - if( operation.core_alg == PSA_ALG_CCM ) + if( operation.alg == PSA_ALG_CCM ) { status = mbedtls_to_psa_error( mbedtls_ccm_encrypt_and_tag( &operation.ctx.ccm, @@ -208,7 +188,7 @@ psa_status_t mbedtls_psa_aead_encrypt( else #endif /* MBEDTLS_PSA_BUILTIN_ALG_CCM */ #if defined(MBEDTLS_PSA_BUILTIN_ALG_GCM) - if( operation.core_alg == PSA_ALG_GCM ) + if( operation.alg == PSA_ALG_GCM ) { status = mbedtls_to_psa_error( mbedtls_gcm_crypt_and_tag( &operation.ctx.gcm, @@ -222,7 +202,7 @@ psa_status_t mbedtls_psa_aead_encrypt( else #endif /* MBEDTLS_PSA_BUILTIN_ALG_GCM */ #if defined(MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305) - if( operation.core_alg == PSA_ALG_CHACHA20_POLY1305 ) + if( operation.alg == PSA_ALG_CHACHA20_POLY1305 ) { if( nonce_length != 12 || operation.tag_length != 16 ) { @@ -286,7 +266,7 @@ psa_status_t mbedtls_psa_aead_decrypt( uint8_t *plaintext, size_t plaintext_size, size_t *plaintext_length ) { psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; - aead_operation_t operation = AEAD_OPERATION_INIT; + psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT; const uint8_t *tag = NULL; (void) key_buffer_size; @@ -301,7 +281,7 @@ psa_status_t mbedtls_psa_aead_decrypt( goto exit; #if defined(MBEDTLS_PSA_BUILTIN_ALG_CCM) - if( operation.core_alg == PSA_ALG_CCM ) + if( operation.alg == PSA_ALG_CCM ) { status = mbedtls_to_psa_error( mbedtls_ccm_auth_decrypt( &operation.ctx.ccm, @@ -315,7 +295,7 @@ psa_status_t mbedtls_psa_aead_decrypt( else #endif /* MBEDTLS_PSA_BUILTIN_ALG_CCM */ #if defined(MBEDTLS_PSA_BUILTIN_ALG_GCM) - if( operation.core_alg == PSA_ALG_GCM ) + if( operation.alg == PSA_ALG_GCM ) { status = mbedtls_to_psa_error( mbedtls_gcm_auth_decrypt( &operation.ctx.gcm, @@ -329,7 +309,7 @@ psa_status_t mbedtls_psa_aead_decrypt( else #endif /* MBEDTLS_PSA_BUILTIN_ALG_GCM */ #if defined(MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305) - if( operation.core_alg == PSA_ALG_CHACHA20_POLY1305 ) + if( operation.alg == PSA_ALG_CHACHA20_POLY1305 ) { if( nonce_length != 12 || operation.tag_length != 16 ) {