From aa9ffc5e98f1c375efb6787e12d6537da2e99327 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Manuel=20P=C3=A9gouri=C3=A9-Gonnard?= <mpg@elzevir.fr>
Date: Tue, 3 Sep 2013 16:19:22 +0200
Subject: [PATCH] Split tag handling out of cipher_finish()

---
 include/polarssl/cipher.h               | 35 +++++++++---
 library/cipher.c                        | 73 ++++++++++++++++---------
 library/pkcs12.c                        |  5 +-
 library/pkcs5.c                         |  5 +-
 programs/aes/crypt_and_hash.c           |  4 +-
 tests/suites/test_suite_cipher.function | 20 +++----
 6 files changed, 88 insertions(+), 54 deletions(-)

diff --git a/include/polarssl/cipher.h b/include/polarssl/cipher.h
index 7dea1e214..dc5a41ca3 100644
--- a/include/polarssl/cipher.h
+++ b/include/polarssl/cipher.h
@@ -519,11 +519,6 @@ int cipher_update( cipher_context_t *ctx, const unsigned char *input, size_t ile
  * \param ctx           Generic cipher context
  * \param output        buffer to write data to. Needs block_size available.
  * \param olen          length of the data written to the output buffer.
- * \param tag           Ignore by non-AEAD ciphers. For AEAD ciphers:
- *                      - on encryption: buffer to write the tag;
- *                      - on decryption: tag to verify.
- *                      May be NULL if tag_len is zero.
- * \param tag_len       Length of the tag to write/check for AEAD ciphers.
  *
  * \returns             0 on success, POLARSSL_ERR_CIPHER_BAD_INPUT_DATA if
  *                      parameter verification fails,
@@ -533,8 +528,34 @@ int cipher_update( cipher_context_t *ctx, const unsigned char *input, size_t ile
  *                      while decrypting or a cipher specific error code.
  */
 int cipher_finish( cipher_context_t *ctx,
-                   unsigned char *output, size_t *olen,
-                   unsigned char *tag, size_t tag_len );
+                   unsigned char *output, size_t *olen );
+
+/**
+ * \brief               Write tag for AEAD ciphers.
+ *                      No effect for other ciphers.
+ *                      Must be called after cipher_finish().
+ *
+ * \param tag           buffer to write the tag
+ * \param tag_len       Length of the tag to write
+ *
+ * \return              0 on success, or a specific error code.
+ */
+int cipher_write_tag( cipher_context_t *ctx,
+                      unsigned char *tag, size_t tag_len );
+
+/**
+ * \brief               Check tag for AEAD ciphers.
+ *                      No effect for other ciphers.
+ *                      Calling time depends on the cipher:
+ *                      for GCM, must be called after cipher_finish().
+ *
+ * \param tag           Buffer holding the tag
+ * \param tag_len       Length of the tag to check
+ *
+ * \return              0 on success, or a specific error code.
+ */
+int cipher_check_tag( cipher_context_t *ctx,
+                      const unsigned char *tag, size_t tag_len );
 
 /**
  * \brief          Checkup routine
diff --git a/library/cipher.c b/library/cipher.c
index f8e2841d2..a90e2dcd9 100644
--- a/library/cipher.c
+++ b/library/cipher.c
@@ -777,8 +777,7 @@ static int get_no_padding( unsigned char *input, size_t input_len,
 }
 
 int cipher_finish( cipher_context_t *ctx,
-                   unsigned char *output, size_t *olen,
-                   unsigned char *tag, size_t tag_len )
+                   unsigned char *output, size_t *olen )
 {
     int ret = 0;
 
@@ -797,10 +796,6 @@ int cipher_finish( cipher_context_t *ctx,
 #if defined(POLARSSL_GCM_C)
     if( POLARSSL_MODE_GCM == ctx->cipher_info->mode )
     {
-        unsigned char check_tag[16];
-        size_t i;
-        int diff;
-
         if( 0 != ( ret = gcm_update( ctx->cipher_ctx,
                         ctx->unprocessed_len, ctx->unprocessed_data,
                         output ) ) )
@@ -810,29 +805,8 @@ int cipher_finish( cipher_context_t *ctx,
 
         *olen += ctx->unprocessed_len;
 
-        if( 0 != ( ret = gcm_finish( ctx->cipher_ctx, check_tag, tag_len ) ) )
-            return( ret );
-
-        /* On encryption, write the tag */
-        if( POLARSSL_ENCRYPT == ctx->operation )
-        {
-            if( tag_len != 0 )
-                memcpy( tag, check_tag, tag_len );
-            return( 0 );
-        }
-
-        /* On decryption, check the tag (in "constant-time") */
-        for( diff = 0, i = 0; i < tag_len; i++ )
-            diff |= tag[i] ^ check_tag[i];
-
-        if( diff != 0 )
-            return( POLARSSL_ERR_GCM_AUTH_FAILED );
-
         return( 0 );
     }
-#else
-    ((void) tag);
-    ((void) tag_len);
 #endif
 
     if( POLARSSL_MODE_CBC == ctx->cipher_info->mode )
@@ -930,6 +904,51 @@ int cipher_set_padding_mode( cipher_context_t *ctx, cipher_padding_t mode )
     return 0;
 }
 
+int cipher_write_tag( cipher_context_t *ctx,
+                      unsigned char *tag, size_t tag_len )
+{
+    if( NULL == ctx || NULL == ctx->cipher_info || NULL == tag )
+        return POLARSSL_ERR_CIPHER_BAD_INPUT_DATA;
+
+    if( POLARSSL_MODE_GCM != ctx->cipher_info->mode )
+        return( 0 );
+
+    if( POLARSSL_ENCRYPT != ctx->operation )
+        return POLARSSL_ERR_CIPHER_BAD_INPUT_DATA;
+
+    return gcm_finish( ctx->cipher_ctx, tag, tag_len );
+}
+ 
+int cipher_check_tag( cipher_context_t *ctx,
+                      const unsigned char *tag, size_t tag_len )
+{
+    int ret;
+    unsigned char check_tag[16];
+    size_t i;
+    int diff;
+
+    if( NULL == ctx || NULL == ctx->cipher_info )
+        return POLARSSL_ERR_CIPHER_BAD_INPUT_DATA;
+
+    if( POLARSSL_MODE_GCM != ctx->cipher_info->mode )
+        return( 0 );
+
+    if( POLARSSL_DECRYPT != ctx->operation || tag_len > sizeof( check_tag ) )
+        return POLARSSL_ERR_CIPHER_BAD_INPUT_DATA;
+
+    if( 0 != ( ret = gcm_finish( ctx->cipher_ctx, check_tag, tag_len ) ) )
+        return( ret );
+
+    /* On decryption, check the tag (in "constant-time") */
+    for( diff = 0, i = 0; i < tag_len; i++ )
+        diff |= tag[i] ^ check_tag[i];
+
+    if( diff != 0 )
+        return( POLARSSL_ERR_GCM_AUTH_FAILED );
+
+    return( 0 );
+}
+
 #if defined(POLARSSL_SELF_TEST)
 
 #include <stdio.h>
diff --git a/library/pkcs12.c b/library/pkcs12.c
index 98ebd8876..335af7ebf 100644
--- a/library/pkcs12.c
+++ b/library/pkcs12.c
@@ -196,11 +196,8 @@ int pkcs12_pbe( asn1_buf *pbe_params, int mode,
         goto exit;
     }
 
-    if( ( ret = cipher_finish( &cipher_ctx, output + olen, &olen, NULL, 0 ) )
-                != 0 )
-    {
+    if( ( ret = cipher_finish( &cipher_ctx, output + olen, &olen ) ) != 0 )
         ret = POLARSSL_ERR_PKCS12_PASSWORD_MISMATCH;
-    }
 
 exit:
     cipher_free_ctx( &cipher_ctx );
diff --git a/library/pkcs5.c b/library/pkcs5.c
index a27d4fb04..0b9830dc6 100644
--- a/library/pkcs5.c
+++ b/library/pkcs5.c
@@ -199,11 +199,8 @@ int pkcs5_pbes2( asn1_buf *pbe_params, int mode,
         goto exit;
     }
 
-    if( ( ret = cipher_finish( &cipher_ctx, output + olen, &olen, NULL, 0 ) )
-                != 0 )
-    {
+    if( ( ret = cipher_finish( &cipher_ctx, output + olen, &olen ) ) != 0 )
         ret = POLARSSL_ERR_PKCS5_PASSWORD_MISMATCH;
-    }
 
 exit:
     md_free_ctx( &md_ctx );
diff --git a/programs/aes/crypt_and_hash.c b/programs/aes/crypt_and_hash.c
index c46713d15..6caaad872 100644
--- a/programs/aes/crypt_and_hash.c
+++ b/programs/aes/crypt_and_hash.c
@@ -343,7 +343,7 @@ int main( int argc, char *argv[] )
             }
         }
 
-        if( cipher_finish( &cipher_ctx, output, &olen, NULL, 0 ) != 0 )
+        if( cipher_finish( &cipher_ctx, output, &olen ) != 0 )
         {
             fprintf( stderr, "cipher_finish() returned error\n" );
             goto exit;
@@ -461,7 +461,7 @@ int main( int argc, char *argv[] )
         /*
          * Write the final block of data
          */
-        cipher_finish( &cipher_ctx, output, &olen, NULL, 0 );
+        cipher_finish( &cipher_ctx, output, &olen );
 
         if( fwrite( output, 1, olen, fout ) != olen )
         {
diff --git a/tests/suites/test_suite_cipher.function b/tests/suites/test_suite_cipher.function
index aa82daae9..5d32bc3d1 100644
--- a/tests/suites/test_suite_cipher.function
+++ b/tests/suites/test_suite_cipher.function
@@ -76,10 +76,11 @@ void enc_dec_buf( int cipher_id, char *cipher_string, int key_len,
                    total_len < length &&
                    total_len + cipher_get_block_size( &ctx_enc ) > length ) );
 
-    TEST_ASSERT( 0 == cipher_finish( &ctx_enc, encbuf + outlen, &outlen,
-                                     tag, 16 ) );
+    TEST_ASSERT( 0 == cipher_finish( &ctx_enc, encbuf + outlen, &outlen ) );
     total_len += outlen;
 
+    TEST_ASSERT( 0 == cipher_write_tag( &ctx_enc, tag, 16 ) );
+
     TEST_ASSERT( total_len == length ||
                  ( total_len % cipher_get_block_size( &ctx_enc ) == 0 &&
                    total_len > length &&
@@ -94,10 +95,11 @@ void enc_dec_buf( int cipher_id, char *cipher_string, int key_len,
                    total_len < length &&
                    total_len + cipher_get_block_size( &ctx_dec ) >= length ) );
 
-    TEST_ASSERT( 0 == cipher_finish( &ctx_dec, decbuf + outlen, &outlen,
-                                     tag, 16 ) );
+    TEST_ASSERT( 0 == cipher_finish( &ctx_dec, decbuf + outlen, &outlen ) );
     total_len += outlen;
 
+    TEST_ASSERT( 0 == cipher_check_tag( &ctx_dec, tag, 16 ) );
+
     TEST_ASSERT( total_len == length );
 
     TEST_ASSERT( 0 == memcmp(inbuf, decbuf, length) );
@@ -145,7 +147,7 @@ void enc_fail( int cipher_id, int pad_mode, int key_len,
 
     /* encode length number of bytes from inbuf */
     TEST_ASSERT( 0 == cipher_update( &ctx, inbuf, length, encbuf, &outlen ) );
-    TEST_ASSERT( ret == cipher_finish( &ctx, encbuf + outlen, &outlen, NULL, 0 ) );
+    TEST_ASSERT( ret == cipher_finish( &ctx, encbuf + outlen, &outlen ) );
 
     /* done */
     TEST_ASSERT( 0 == cipher_free_ctx( &ctx ) );
@@ -192,7 +194,7 @@ void dec_empty_buf()
     TEST_ASSERT( 0 == cipher_update( &ctx_dec, encbuf, 0, decbuf, &outlen ) );
     TEST_ASSERT( 0 == outlen );
     TEST_ASSERT( POLARSSL_ERR_CIPHER_FULL_BLOCK_EXPECTED == cipher_finish(
-                 &ctx_dec, decbuf + outlen, &outlen, NULL, 0 ) );
+                 &ctx_dec, decbuf + outlen, &outlen ) );
     TEST_ASSERT( 0 == outlen );
 
     TEST_ASSERT( 0 == cipher_free_ctx( &ctx_dec ) );
@@ -259,8 +261,7 @@ void enc_dec_buf_multipart( int cipher_id, int key_len, int first_length_val,
                    totaloutlen < length &&
                    totaloutlen + cipher_get_block_size( &ctx_enc ) > length ) );
 
-    TEST_ASSERT( 0 == cipher_finish( &ctx_enc, encbuf + totaloutlen, &outlen,
-                                     NULL, 0 ) );
+    TEST_ASSERT( 0 == cipher_finish( &ctx_enc, encbuf + totaloutlen, &outlen ) );
     totaloutlen += outlen;
     TEST_ASSERT( totaloutlen == length ||
                  ( totaloutlen % cipher_get_block_size( &ctx_enc ) == 0 &&
@@ -276,8 +277,7 @@ void enc_dec_buf_multipart( int cipher_id, int key_len, int first_length_val,
                    totaloutlen < length &&
                    totaloutlen + cipher_get_block_size( &ctx_dec ) >= length ) );
 
-    TEST_ASSERT( 0 == cipher_finish( &ctx_dec, decbuf + outlen, &outlen,
-                                     NULL, 0 ) );
+    TEST_ASSERT( 0 == cipher_finish( &ctx_dec, decbuf + outlen, &outlen ) );
     totaloutlen += outlen;
 
     TEST_ASSERT( totaloutlen == length );