diff --git a/scripts/mbedtls_dev/bignum_mod_raw.py b/scripts/mbedtls_dev/bignum_mod_raw.py index 1127ced8d..19942ed16 100644 --- a/scripts/mbedtls_dev/bignum_mod_raw.py +++ b/scripts/mbedtls_dev/bignum_mod_raw.py @@ -53,40 +53,75 @@ class BignumModRawTarget(test_data_generation.BaseTarget, metaclass=ABCMeta): # BEGIN MERGE SLOT 7 class BignumModRawOperation(bignum_common.OperationCommon, BignumModRawTarget, metaclass=ABCMeta): #pylint: disable=abstract-method - pass + """Target for bignum mod_raw test case generation.""" + + def __init__(self, val_n: str, val_a: str, val_b: str = "0", bits_in_limb: int = 64) -> None: + super().__init__(val_a=val_a, val_b=val_b) + self.val_n = val_n + self.bits_in_limb = bits_in_limb + + @property + def int_n(self) -> int: + return bignum_common.hex_to_int(self.val_n) + + @property + def boundary(self) -> int: + data_in = [self.int_a, self.int_b, self.int_n] + return max([n for n in data_in if n is not None]) + + @property + def limbs(self) -> int: + return bignum_common.limbs_mpi(self.boundary, self.bits_in_limb) + + @property + def hex_digits(self) -> int: + return 2 * (self.limbs * self.bits_in_limb // 8) + + @property + def hex_n(self) -> str: + return "{:x}".format(self.int_n).zfill(self.hex_digits) + + @property + def hex_a(self) -> str: + return "{:x}".format(self.int_a).zfill(self.hex_digits) + + @property + def hex_b(self) -> str: + return "{:x}".format(self.int_b).zfill(self.hex_digits) + + @property + def r(self) -> int: # pylint: disable=invalid-name + l = bignum_common.limbs_mpi(self.int_n, self.bits_in_limb) + return bignum_common.bound_mpi_limbs(l, self.bits_in_limb) + + @property + def r_inv(self) -> int: + return bignum_common.invmod(self.r, self.int_n) + + @property + def r_sqrt(self) -> int: # pylint: disable=invalid-name + return pow(self.r, 2) class BignumModRawOperationArchSplit(BignumModRawOperation): #pylint: disable=abstract-method - """Common features for bignum core operations where the result depends on + """Common features for bignum mod raw operations where the result depends on the limb size.""" - def __init__(self, val_a: str, val_b: str, bits_in_limb: int) -> None: - super().__init__(val_a, val_b) - bound_val = max(self.int_a, self.int_b) - self.bits_in_limb = bits_in_limb - self.bound = bignum_common.bound_mpi(bound_val, self.bits_in_limb) - limbs = bignum_common.limbs_mpi(bound_val, self.bits_in_limb) - byte_len = limbs * self.bits_in_limb // 8 - self.hex_digits = 2 * byte_len - if self.bits_in_limb == 32: - self.dependencies = ["MBEDTLS_HAVE_INT32"] - elif self.bits_in_limb == 64: - self.dependencies = ["MBEDTLS_HAVE_INT64"] - else: - raise ValueError("Invalid number of bits in limb!") - self.arg_a = self.arg_a.zfill(self.hex_digits) - self.arg_b = self.arg_b.zfill(self.hex_digits) - self.arg_a_int = bignum_common.hex_to_int(self.arg_a) - self.arg_b_int = bignum_common.hex_to_int(self.arg_b) + limb_sizes = [32, 64] # type: List[int] - def pad_to_limbs(self, val) -> str: - return "{:x}".format(val).zfill(self.hex_digits) + def __init__(self, val_n: str, val_a: str, val_b: str = "0", bits_in_limb: int = 64) -> None: + super().__init__(val_n=val_n, val_a=val_a, val_b=val_b, bits_in_limb=bits_in_limb) + + if bits_in_limb not in self.limb_sizes: + raise ValueError("Invalid number of bits in limb!") + + self.dependencies = ["MBEDTLS_HAVE_INT{:d}".format(bits_in_limb)] @classmethod def generate_function_tests(cls) -> Iterator[test_case.TestCase]: for a_value, b_value in cls.get_value_pairs(): - yield cls(a_value, b_value, 32).create_test_case() - yield cls(a_value, b_value, 64).create_test_case() + for bil in cls.limb_sizes: + yield cls(a_value, b_value, bits_in_limb=bil).create_test_case() # END MERGE SLOT 7 # BEGIN MERGE SLOT 8