Bignum test: remove type restrictrion

The special case list type depends on the arity and the subclass. Remove
type restriction to make defining special case lists more flexible and natural.

Signed-off-by: Janos Follath <janos.follath@arm.com>
This commit is contained in:
Janos Follath 2022-11-19 12:48:17 +00:00
parent c4fca5de3e
commit 98edf21bb4
2 changed files with 21 additions and 5 deletions

View file

@ -15,7 +15,8 @@
# limitations under the License. # limitations under the License.
from abc import abstractmethod from abc import abstractmethod
from typing import Iterator, List, Tuple, TypeVar from typing import Iterator, List, Tuple, TypeVar, Any
from itertools import chain
from . import test_case from . import test_case
from . import test_data_generation from . import test_data_generation
@ -90,7 +91,7 @@ class OperationCommon(test_data_generation.BaseTest):
""" """
symbol = "" symbol = ""
input_values = [] # type: List[str] input_values = [] # type: List[str]
input_cases = [] # type: List[Tuple[str, str]] input_cases = [] # type: List[Any]
unique_combinations_only = True unique_combinations_only = True
input_styles = ["variable", "arch_split"] # type: List[str] input_styles = ["variable", "arch_split"] # type: List[str]
input_style = "variable" # type: str input_style = "variable" # type: str
@ -200,7 +201,6 @@ class OperationCommon(test_data_generation.BaseTest):
for a in cls.input_values for a in cls.input_values
for b in cls.input_values for b in cls.input_values
) )
yield from cls.input_cases
@classmethod @classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]: def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
@ -212,14 +212,20 @@ class OperationCommon(test_data_generation.BaseTest):
test_objects = (cls(a, b, bits_in_limb=bil) test_objects = (cls(a, b, bits_in_limb=bil)
for a, b in cls.get_value_pairs() for a, b in cls.get_value_pairs()
for bil in cls.limb_sizes) for bil in cls.limb_sizes)
special_cases = (cls(*args, bits_in_limb=bil) # type: ignore
for args in cls.input_cases
for bil in cls.limb_sizes)
else: else:
test_objects = (cls(a, b) test_objects = (cls(a, b)
for a, b in cls.get_value_pairs()) for a, b in cls.get_value_pairs())
special_cases = (cls(*args) for args in cls.input_cases)
yield from (valid_test_object.create_test_case() yield from (valid_test_object.create_test_case()
for valid_test_object in filter( for valid_test_object in filter(
lambda test_object: test_object.is_valid, lambda test_object: test_object.is_valid,
test_objects chain(test_objects, special_cases)
)) )
)
class ModOperationCommon(OperationCommon): class ModOperationCommon(OperationCommon):

View file

@ -243,6 +243,16 @@ class BignumCoreMLA(BignumCoreOperation):
"\"{:x}\"".format(carry_8) "\"{:x}\"".format(carry_8)
] ]
@classmethod
def get_value_pairs(cls) -> Iterator[Tuple[str, str]]:
"""Generator to yield pairs of inputs.
Combinations are first generated from all input values, and then
specific cases provided.
"""
yield from super().get_value_pairs()
yield from cls.input_cases
@classmethod @classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]: def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
"""Override for additional scalar input.""" """Override for additional scalar input."""