Change mbedtls_mpi_cmp_mpi_ct to check less than

The signature of mbedtls_mpi_cmp_mpi_ct() meant to support using it in
place of mbedtls_mpi_cmp_mpi(). This meant full comparison functionality
and a signed result.

To make the function more universal and friendly to constant time
coding, we change the result type to unsigned. Theoretically, we could
encode the comparison result in an unsigned value, but it would be less
intuitive.

Therefore we won't be able to represent the result as unsigned anymore
and the functionality will be constrained to checking if the first
operand is less than the second. This is sufficient to support the
current use case and to check any relationship between MPIs.

The only drawback is that we need to call the function twice when
checking for equality, but this can be optimised later if an when it is
needed.
This commit is contained in:
Janos Follath 2019-10-11 14:21:53 +01:00
parent 1fc97594da
commit 0e5532d6cf
5 changed files with 71 additions and 70 deletions
library

View file

@ -1148,7 +1148,8 @@ int mbedtls_mpi_cmp_mpi( const mbedtls_mpi *X, const mbedtls_mpi *Y )
return( 0 );
}
static int ct_lt_mpi_uint( const mbedtls_mpi_uint x, const mbedtls_mpi_uint y )
static unsigned ct_lt_mpi_uint( const mbedtls_mpi_uint x,
const mbedtls_mpi_uint y )
{
mbedtls_mpi_uint ret;
mbedtls_mpi_uint cond;
@ -1175,16 +1176,11 @@ static int ct_lt_mpi_uint( const mbedtls_mpi_uint x, const mbedtls_mpi_uint y )
return ret;
}
static int ct_bool_get_mask( unsigned int b )
{
return ~( b - 1 );
}
/*
* Compare signed values in constant time
*/
int mbedtls_mpi_cmp_mpi_ct( const mbedtls_mpi *X, const mbedtls_mpi *Y,
int *ret )
int mbedtls_mpi_lt_mpi_ct( const mbedtls_mpi *X, const mbedtls_mpi *Y,
unsigned *ret )
{
size_t i;
unsigned int cond, done, sign_X, sign_Y;
@ -1197,45 +1193,49 @@ int mbedtls_mpi_cmp_mpi_ct( const mbedtls_mpi *X, const mbedtls_mpi *Y,
return MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
/*
* if( X->s > 0 && Y->s < 0 )
* {
* *ret = 1;
* done = 1;
* }
* else if( Y->s > 0 && X->s < 0 )
* {
* *ret = -1;
* done = 1;
* }
* Get sign bits of the signs.
*/
sign_X = X->s;
sign_X = sign_X >> ( sizeof( unsigned int ) * 8 - 1 );
sign_Y = Y->s;
cond = ( ( sign_X ^ sign_Y ) >> ( sizeof( unsigned int ) * 8 - 1 ) );
*ret = ct_bool_get_mask( cond ) & X->s;
sign_Y = sign_Y >> ( sizeof( unsigned int ) * 8 - 1 );
/*
* If the signs are different, then the positive operand is the bigger.
* That is if X is negative (sign bit 1), then X < Y is true and it is false
* if X is positive (sign bit 0).
*/
cond = ( sign_X ^ sign_Y );
*ret = cond & sign_X;
/*
* This is a constant time function, we might have the result, but we still
* need to go through the loop. Record if we have the result already.
*/
done = cond;
for( i = X->n; i > 0; i-- )
{
/*
* if( ( X->p[i - 1] > Y->p[i - 1] ) && !done )
* {
* done = 1;
* *ret = X->s;
* }
* If Y->p[i - 1] < X->p[i - 1] and both X and Y are negative, then
* X < Y.
*
* Again even if we can make a decision, we just mark the result and
* the fact that we are done and continue looping.
*/
cond = ct_lt_mpi_uint( Y->p[i - 1], X->p[i - 1] );
*ret |= ct_bool_get_mask( cond & ( 1 - done ) ) & X->s;
cond = ct_lt_mpi_uint( Y->p[i - 1], X->p[i - 1] ) & sign_X;
*ret |= cond & ( 1 - done );
done |= cond & ( 1 - done );
/*
* if( ( X->p[i - 1] < Y->p[i - 1] ) && !done )
* {
* done = 1;
* *ret = -X->s;
* }
* If X->p[i - 1] < Y->p[i - 1] and both X and Y are positive, then
* X < Y.
*
* Again even if we can make a decision, we just mark the result and
* the fact that we are done and continue looping.
*/
cond = ct_lt_mpi_uint( X->p[i - 1], Y->p[i - 1] );
*ret |= ct_bool_get_mask( cond & ( 1 - done ) ) & -X->s;
cond = ct_lt_mpi_uint( X->p[i - 1], Y->p[i - 1] ) & ( 1 - sign_X );
*ret |= cond & ( 1 - done );
done |= cond & ( 1 - done );
}