diff --git a/scripts/mbedtls_dev/test_generation.py b/scripts/mbedtls_dev/test_generation.py index b825df07b..aeb551d05 100644 --- a/scripts/mbedtls_dev/test_generation.py +++ b/scripts/mbedtls_dev/test_generation.py @@ -25,7 +25,7 @@ import posixpath import re from abc import ABCMeta, abstractmethod -from typing import Callable, Dict, Iterable, List, Type, TypeVar +from typing import Callable, Dict, Iterable, Iterator, List, Type, TypeVar from mbedtls_dev import build_tree from mbedtls_dev import test_case @@ -91,16 +91,31 @@ class BaseTarget(metaclass=ABCMeta): return tc @classmethod - def generate_tests(cls): - """Generate test cases for the target subclasses. + @abstractmethod + def generate_function_tests(cls) -> Iterator[test_case.TestCase]: + """Generate test cases for the test function. - During generation, each class will iterate over any subclasses, calling - this method in each. - In abstract classes, no tests will be generated, as there is no - function to generate tests for. - In classes which do implement a test function, this should be overridden - and a means to use `create_test_case()` should be added. + This will be called in classes where `test_function` is set. + Implementations should yield TestCase objects, by creating instances + of the class with appropriate input data, and then calling + `create_test_case()` on each. """ + pass + + @classmethod + def generate_tests(cls) -> Iterator[test_case.TestCase]: + """Generate test cases for the class and its subclasses. + + In classes with `test_function` set, `generate_function_tests()` is + used to generate test cases first. + In all classes, this method will iterate over its subclasses, and + yield from `generate_tests()` in each. + + Calling this method on a class X will yield test cases from all classes + derived from X. + """ + if cls.test_function: + yield from cls.generate_function_tests() for subclass in sorted(cls.__subclasses__(), key=lambda c: c.__name__): yield from subclass.generate_tests() diff --git a/tests/scripts/generate_bignum_tests.py b/tests/scripts/generate_bignum_tests.py index 3f556ce29..1f6448528 100755 --- a/tests/scripts/generate_bignum_tests.py +++ b/tests/scripts/generate_bignum_tests.py @@ -160,14 +160,10 @@ class BignumOperation(BignumTarget, metaclass=ABCMeta): yield from cls.input_cases @classmethod - def generate_tests(cls) -> Iterator[test_case.TestCase]: - if cls.test_function: - # Generate tests for the current class - for l_value, r_value in cls.get_value_pairs(): - cur_op = cls(l_value, r_value) - yield cur_op.create_test_case() - # Once current class completed, check descendants - yield from super().generate_tests() + def generate_function_tests(cls) -> Iterator[test_case.TestCase]: + for l_value, r_value in cls.get_value_pairs(): + cur_op = cls(l_value, r_value) + yield cur_op.create_test_case() class BignumCmp(BignumOperation):