From 44c5d58d05a1afbee11903d7c40f84b68f8bb888 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Manuel=20P=C3=A9gouri=C3=A9-Gonnard?=
 <manuel.pegourie-gonnard@arm.com>
Date: Mon, 10 Dec 2018 16:56:14 +0100
Subject: [PATCH] Document AES functions and fix free() functions

---
 include/mbedtls/aes.h                | 18 ++++++++++--------
 library/aes.c                        | 11 +++++++----
 tests/suites/helpers.function        | 27 +++++++++++++++++++++++++++
 tests/suites/test_suite_aes.function |  6 ++++++
 4 files changed, 50 insertions(+), 12 deletions(-)

diff --git a/include/mbedtls/aes.h b/include/mbedtls/aes.h
index cfb20c4fc..da7ab5496 100644
--- a/include/mbedtls/aes.h
+++ b/include/mbedtls/aes.h
@@ -121,14 +121,14 @@ typedef struct mbedtls_aes_xts_context
  *                 It must be the first API called before using
  *                 the context.
  *
- * \param ctx      The AES context to initialize.
+ * \param ctx      The AES context to initialize. Must not be NULL.
  */
 void mbedtls_aes_init( mbedtls_aes_context *ctx );
 
 /**
  * \brief          This function releases and clears the specified AES context.
  *
- * \param ctx      The AES context to clear.
+ * \param ctx      The AES context to clear. If NULL, no action is taken.
  */
 void mbedtls_aes_free( mbedtls_aes_context *ctx );
 
@@ -139,14 +139,14 @@ void mbedtls_aes_free( mbedtls_aes_context *ctx );
  *                 It must be the first API called before using
  *                 the context.
  *
- * \param ctx      The AES XTS context to initialize.
+ * \param ctx      The AES XTS context to initialize. Must not be NULL.
  */
 void mbedtls_aes_xts_init( mbedtls_aes_xts_context *ctx );
 
 /**
  * \brief          This function releases and clears the specified AES XTS context.
  *
- * \param ctx      The AES XTS context to clear.
+ * \param ctx      The AES XTS context to clear. If NULL, no action is taken.
  */
 void mbedtls_aes_xts_free( mbedtls_aes_xts_context *ctx );
 #endif /* MBEDTLS_CIPHER_MODE_XTS */
@@ -154,8 +154,9 @@ void mbedtls_aes_xts_free( mbedtls_aes_xts_context *ctx );
 /**
  * \brief          This function sets the encryption key.
  *
- * \param ctx      The AES context to which the key should be bound.
- * \param key      The encryption key.
+ * \param ctx      The AES context to which the key should be bound. Must not
+ *                 be NULL.
+ * \param key      The encryption key. Must not be NULL.
  * \param keybits  The size of data passed in bits. Valid options are:
  *                 <ul><li>128 bits</li>
  *                 <li>192 bits</li>
@@ -170,8 +171,9 @@ int mbedtls_aes_setkey_enc( mbedtls_aes_context *ctx, const unsigned char *key,
 /**
  * \brief          This function sets the decryption key.
  *
- * \param ctx      The AES context to which the key should be bound.
- * \param key      The decryption key.
+ * \param ctx      The AES context to which the key should be bound. Must not
+ *                 be NULL.
+ * \param key      The decryption key. Must not be NULL.
  * \param keybits  The size of data passed. Valid options are:
  *                 <ul><li>128 bits</li>
  *                 <li>192 bits</li>
diff --git a/library/aes.c b/library/aes.c
index 6ff39d74c..cc1e5ceb4 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -58,7 +58,7 @@
 
 /* Parameter validation macros based on platform_util.h */
 #define AES_VALIDATE_RET( cond )    \
-    MBEDTLS_INTERNAL_VALIDATE_RET( cond, MBEDTLS_ERR_AES_BAD_INPUT_DATA)
+    MBEDTLS_INTERNAL_VALIDATE_RET( cond, MBEDTLS_ERR_AES_BAD_INPUT_DATA )
 #define AES_VALIDATE( cond )        \
     MBEDTLS_INTERNAL_VALIDATE( cond )
 
@@ -541,7 +541,8 @@ void mbedtls_aes_xts_init( mbedtls_aes_xts_context *ctx )
 
 void mbedtls_aes_xts_free( mbedtls_aes_xts_context *ctx )
 {
-    AES_VALIDATE( ctx != NULL );
+    if( ctx == NULL )
+        return;
 
     mbedtls_aes_free( &ctx->crypt );
     mbedtls_aes_free( &ctx->tweak );
@@ -558,7 +559,8 @@ int mbedtls_aes_setkey_enc( mbedtls_aes_context *ctx, const unsigned char *key,
     unsigned int i;
     uint32_t *RK;
 
-    AES_VALIDATE_RET( ctx != NULL && key != NULL );
+    AES_VALIDATE_RET( ctx != NULL );
+    AES_VALIDATE_RET( key != NULL );
 
     switch( keybits )
     {
@@ -676,7 +678,8 @@ int mbedtls_aes_setkey_dec( mbedtls_aes_context *ctx, const unsigned char *key,
     uint32_t *RK;
     uint32_t *SK;
 
-    AES_VALIDATE_RET( ctx != NULL && key != NULL );
+    AES_VALIDATE_RET( ctx != NULL );
+    AES_VALIDATE_RET( key != NULL );
 
     mbedtls_aes_init( &cty );
 
diff --git a/tests/suites/helpers.function b/tests/suites/helpers.function
index 71390ecfe..57bc25913 100644
--- a/tests/suites/helpers.function
+++ b/tests/suites/helpers.function
@@ -173,6 +173,33 @@ typedef enum
         memcpy(param_fail_jmp, jmp_tmp, sizeof(jmp_buf));                   \
     } while( 0 )
 
+/**
+ * \brief   This macro tests the statement passed to it as a test step or
+ *          individual test in a test case. The macro assumes the test will not fail.
+ *
+ *          It assumes the library function under test cannot return a value and
+ *          assumes errors can only be indicated by calls to
+ *          MBEDTLS_PARAM_FAILED().
+ *
+ *          When MBEDTLS_CHECK_PARAMS is enabled, calls to the parameter failure
+ *          callback, MBEDTLS_PARAM_FAILED(), are assumed to indicate the
+ *          expected failure. If MBEDTLS_CHECK_PARAMS is not enabled, no test
+ *          can be made.
+ *
+ *          This macro is intended to test that function that return void
+ *          accept all of the parameter values they're supposed to accept - eg
+ *          that they don't call MBEDTLS_PARAM_FAILED() when a parameter
+ *          that's allowed to be NULL happends to be NULL.
+ *
+ *          Note: for functions that return something other that void,
+ *          checking that they accept all the parameters they're supposed to
+ *          accept is best done by using TEST_ASSERT() and checking the return
+ *          value as well.
+ *
+ * \param   TEST                The test expression to be tested.
+ */
+#define TEST_VALID_PARAM( TEST )                                    \
+    TEST_ASSERT( ( TEST, 1 ) );
 #endif /* MBEDTLS_CHECK_PARAMS && !MBEDTLS_PARAM_FAILED_ALT */
 
 #define assert(a) if( !( a ) )                                      \
diff --git a/tests/suites/test_suite_aes.function b/tests/suites/test_suite_aes.function
index 7dab01b47..f61f71c3e 100644
--- a/tests/suites/test_suite_aes.function
+++ b/tests/suites/test_suite_aes.function
@@ -379,6 +379,8 @@ void aes_invalid_param( )
 
     TEST_INVALID_PARAM( mbedtls_aes_init( NULL ) );
 
+    TEST_INVALID_PARAM( mbedtls_aes_xts_init( NULL ) );
+
     /* mbedtls_aes_setkey_enc() */
     TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
                             mbedtls_aes_setkey_enc( NULL, key, 128 ) );
@@ -393,6 +395,10 @@ void aes_invalid_param( )
     TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
                             mbedtls_aes_setkey_dec( &dummy_ctx, NULL, 128 ) );
 
+    /* These calls accept NULL */
+    TEST_VALID_PARAM( mbedtls_aes_free( NULL ) );
+    TEST_VALID_PARAM( mbedtls_aes_xts_free( NULL ) );
+
 exit:
     return;
 }