diff --git a/include/mbedtls/aes.h b/include/mbedtls/aes.h
index 11edc0fab..197d4db10 100644
--- a/include/mbedtls/aes.h
+++ b/include/mbedtls/aes.h
@@ -197,8 +197,10 @@ int mbedtls_aes_setkey_dec( mbedtls_aes_context *ctx, const unsigned char *key,
  *                 sets the encryption key.
  *
  * \param ctx      The AES XTS context to which the key should be bound.
+ *                 It must be initialized.
  * \param key      The encryption key. This is comprised of the XTS key1
  *                 concatenated with the XTS key2.
+ *                 This must be a readable buffer of size \p keybits bits.
  * \param keybits  The size of \p key passed in bits. Valid options are:
  *                 <ul><li>256 bits (each of key1 and key2 is a 128-bit key)</li>
  *                 <li>512 bits (each of key1 and key2 is a 256-bit key)</li></ul>
@@ -215,8 +217,10 @@ int mbedtls_aes_xts_setkey_enc( mbedtls_aes_xts_context *ctx,
  *                 sets the decryption key.
  *
  * \param ctx      The AES XTS context to which the key should be bound.
+ *                 It must be initialized.
  * \param key      The decryption key. This is comprised of the XTS key1
  *                 concatenated with the XTS key2.
+ *                 This must be a readable buffer of size \p keybits bits.
  * \param keybits  The size of \p key passed in bits. Valid options are:
  *                 <ul><li>256 bits (each of key1 and key2 is a 128-bit key)</li>
  *                 <li>512 bits (each of key1 and key2 is a 256-bit key)</li></ul>
diff --git a/library/aes.c b/library/aes.c
index cc1e5ceb4..4d9a56a5c 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -771,6 +771,9 @@ int mbedtls_aes_xts_setkey_enc( mbedtls_aes_xts_context *ctx,
     const unsigned char *key1, *key2;
     unsigned int key1bits, key2bits;
 
+    AES_VALIDATE_RET( ctx != NULL );
+    AES_VALIDATE_RET( key != NULL );
+
     ret = mbedtls_aes_xts_decode_keys( key, keybits, &key1, &key1bits,
                                        &key2, &key2bits );
     if( ret != 0 )
@@ -793,6 +796,9 @@ int mbedtls_aes_xts_setkey_dec( mbedtls_aes_xts_context *ctx,
     const unsigned char *key1, *key2;
     unsigned int key1bits, key2bits;
 
+    AES_VALIDATE_RET( ctx != NULL );
+    AES_VALIDATE_RET( key != NULL );
+
     ret = mbedtls_aes_xts_decode_keys( key, keybits, &key1, &key1bits,
                                        &key2, &key2bits );
     if( ret != 0 )
diff --git a/tests/suites/test_suite_aes.function b/tests/suites/test_suite_aes.function
index 131565060..576e5be08 100644
--- a/tests/suites/test_suite_aes.function
+++ b/tests/suites/test_suite_aes.function
@@ -215,7 +215,7 @@ void aes_crypt_xts_size( int size, int retval )
 void aes_crypt_xts_keysize( int size, int retval )
 {
     mbedtls_aes_xts_context ctx;
-    const unsigned char *key = NULL;
+    const unsigned char key[] = { 0x01, 0x02, 0x03, 0x04, 0x05, 0x06 };
     size_t key_len = size;
 
     mbedtls_aes_xts_init( &ctx );
@@ -374,26 +374,38 @@ exit:
 /* BEGIN_CASE depends_on:MBEDTLS_CHECK_PARAMS:!MBEDTLS_PARAM_FAILED_ALT */
 void aes_invalid_param( )
 {
-    mbedtls_aes_context dummy_ctx;
+    mbedtls_aes_context aes_ctx;
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+    mbedtls_aes_xts_context xts_ctx;
+#endif
     const unsigned char key[] = { 0x01, 0x02, 0x03, 0x04, 0x05, 0x06 };
 
     TEST_INVALID_PARAM( mbedtls_aes_init( NULL ) );
-
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
     TEST_INVALID_PARAM( mbedtls_aes_xts_init( NULL ) );
+#endif
 
-    /* mbedtls_aes_setkey_enc() */
     TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
                             mbedtls_aes_setkey_enc( NULL, key, 128 ) );
-
     TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
-                            mbedtls_aes_setkey_enc( &dummy_ctx, NULL, 128 ) );
+                            mbedtls_aes_setkey_enc( &aes_ctx, NULL, 128 ) );
 
-    /* mbedtls_aes_setkey_dec() */
     TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
                             mbedtls_aes_setkey_dec( NULL, key, 128 ) );
+    TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+                            mbedtls_aes_setkey_dec( &aes_ctx, NULL, 128 ) );
+
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+    TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+                            mbedtls_aes_xts_setkey_enc( NULL, key, 128 ) );
+    TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+                            mbedtls_aes_xts_setkey_enc( &xts_ctx, NULL, 128 ) );
 
     TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
-                            mbedtls_aes_setkey_dec( &dummy_ctx, NULL, 128 ) );
+                            mbedtls_aes_xts_setkey_dec( NULL, key, 128 ) );
+    TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+                            mbedtls_aes_xts_setkey_dec( &xts_ctx, NULL, 128 ) );
+#endif
 }
 /* END_CASE */