Split generate_tests to reduce code complexity

Previous implementation mixed the test case generation and the
recursive generation calls together. A separate method is added to
generate test cases for the current class' test function. This reduces
the need to override generate_tests().

Signed-off-by: Werner Lewis <werner.lewis@arm.com>
This commit is contained in:
Werner Lewis 2022-08-24 12:42:00 +01:00
parent 699e126942
commit 2b527a394d
2 changed files with 28 additions and 17 deletions

View file

@ -25,7 +25,7 @@ import posixpath
import re import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Callable, Dict, Iterable, List, Type, TypeVar from typing import Callable, Dict, Iterable, Iterator, List, Type, TypeVar
from mbedtls_dev import build_tree from mbedtls_dev import build_tree
from mbedtls_dev import test_case from mbedtls_dev import test_case
@ -91,16 +91,31 @@ class BaseTarget(metaclass=ABCMeta):
return tc return tc
@classmethod @classmethod
def generate_tests(cls): @abstractmethod
"""Generate test cases for the target subclasses. def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
"""Generate test cases for the test function.
During generation, each class will iterate over any subclasses, calling This will be called in classes where `test_function` is set.
this method in each. Implementations should yield TestCase objects, by creating instances
In abstract classes, no tests will be generated, as there is no of the class with appropriate input data, and then calling
function to generate tests for. `create_test_case()` on each.
In classes which do implement a test function, this should be overridden
and a means to use `create_test_case()` should be added.
""" """
pass
@classmethod
def generate_tests(cls) -> Iterator[test_case.TestCase]:
"""Generate test cases for the class and its subclasses.
In classes with `test_function` set, `generate_function_tests()` is
used to generate test cases first.
In all classes, this method will iterate over its subclasses, and
yield from `generate_tests()` in each.
Calling this method on a class X will yield test cases from all classes
derived from X.
"""
if cls.test_function:
yield from cls.generate_function_tests()
for subclass in sorted(cls.__subclasses__(), key=lambda c: c.__name__): for subclass in sorted(cls.__subclasses__(), key=lambda c: c.__name__):
yield from subclass.generate_tests() yield from subclass.generate_tests()

View file

@ -160,14 +160,10 @@ class BignumOperation(BignumTarget, metaclass=ABCMeta):
yield from cls.input_cases yield from cls.input_cases
@classmethod @classmethod
def generate_tests(cls) -> Iterator[test_case.TestCase]: def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
if cls.test_function: for l_value, r_value in cls.get_value_pairs():
# Generate tests for the current class cur_op = cls(l_value, r_value)
for l_value, r_value in cls.get_value_pairs(): yield cur_op.create_test_case()
cur_op = cls(l_value, r_value)
yield cur_op.create_test_case()
# Once current class completed, check descendants
yield from super().generate_tests()
class BignumCmp(BignumOperation): class BignumCmp(BignumOperation):