Refactor key generation

Remove the key builder and use iterator instead of lists.

Signed-off-by: gabor-mezei-arm <gabor.mezei@arm.com>
This commit is contained in:
gabor-mezei-arm 2021-06-28 19:26:55 +02:00
parent 5df1dee0c6
commit 340fbf3735

231
tests/scripts/generate_psa_tests.py Normal file → Executable file
View file

@ -252,14 +252,6 @@ class StorageKey(psa_storage.Key):
self.usage = psa_storage.as_expr(expected_usage) if expected_usage is not None else\ self.usage = psa_storage.as_expr(expected_usage) if expected_usage is not None else\
self.original_usage #type: psa_storage.Expr self.original_usage #type: psa_storage.Expr
class StorageKeyBuilder:
# pylint: disable=too-few-public-methods
def __init__(self, usage_extension: bool) -> None:
self.usage_extension = usage_extension #type: bool
def build(self, **kwargs) -> StorageKey:
return StorageKey(implicit_usage=self.usage_extension, **kwargs)
class StorageFormat: class StorageFormat:
"""Storage format stability test cases.""" """Storage format stability test cases."""
@ -276,7 +268,6 @@ class StorageFormat:
self.constructors = info.constructors #type: macro_collector.PSAMacroEnumerator self.constructors = info.constructors #type: macro_collector.PSAMacroEnumerator
self.version = version #type: int self.version = version #type: int
self.forward = forward #type: bool self.forward = forward #type: bool
self.key_builder = StorageKeyBuilder(usage_extension=True) #type: StorageKeyBuilder
def make_test_case(self, key: StorageKey) -> test_case.TestCase: def make_test_case(self, key: StorageKey) -> test_case.TestCase:
"""Construct a storage format test case for the given key. """Construct a storage format test case for the given key.
@ -329,19 +320,17 @@ class StorageFormat:
r'', short) r'', short)
short = re.sub(r'PSA_KEY_[A-Z]+_', r'', short) short = re.sub(r'PSA_KEY_[A-Z]+_', r'', short)
description = 'lifetime: ' + short description = 'lifetime: ' + short
key = self.key_builder.build(version=self.version, return StorageKey(version=self.version,
id=1, lifetime=lifetime, id=1, lifetime=lifetime,
type='PSA_KEY_TYPE_RAW_DATA', bits=8, type='PSA_KEY_TYPE_RAW_DATA', bits=8,
usage='PSA_KEY_USAGE_EXPORT', alg=0, alg2=0, usage='PSA_KEY_USAGE_EXPORT', alg=0, alg2=0,
material=b'L', material=b'L',
description=description) description=description)
return key
def all_keys_for_lifetimes(self) -> List[StorageKey]: def all_keys_for_lifetimes(self) -> Iterator[StorageKey]:
"""Generate test keys covering lifetimes.""" """Generate test keys covering lifetimes."""
lifetimes = sorted(self.constructors.lifetimes) lifetimes = sorted(self.constructors.lifetimes)
expressions = self.constructors.generate_expressions(lifetimes) expressions = self.constructors.generate_expressions(lifetimes)
keys = [] #type List[StorageKey]
for lifetime in expressions: for lifetime in expressions:
# Don't attempt to create or load a volatile key in storage # Don't attempt to create or load a volatile key in storage
if 'VOLATILE' in lifetime: if 'VOLATILE' in lifetime:
@ -350,67 +339,67 @@ class StorageFormat:
# but do attempt to load one. # but do attempt to load one.
if 'READ_ONLY' in lifetime and self.forward: if 'READ_ONLY' in lifetime and self.forward:
continue continue
keys.append(self.key_for_lifetime(lifetime)) yield self.key_for_lifetime(lifetime)
return keys
def key_for_usage_flags( def key_for_usage_flags(
self, self,
usage_flags: List[str], usage_flags: List[str],
short: Optional[str] = None, short: Optional[str] = None,
extra_desc: Optional[str] = None test_implicit_usage: Optional[bool] = False
) -> StorageKey: ) -> Iterator[StorageKey]:
"""Construct a test key for the given key usage.""" """Construct a test key for the given key usage."""
usage = ' | '.join(usage_flags) if usage_flags else '0' usage = ' | '.join(usage_flags) if usage_flags else '0'
if short is None: if short is None:
short = re.sub(r'\bPSA_KEY_USAGE_', r'', usage) short = re.sub(r'\bPSA_KEY_USAGE_', r'', usage)
extra_desc = ' ' + extra_desc if extra_desc else '' extra_desc = ' with implication' if test_implicit_usage else ''
description = 'usage' + extra_desc + ': ' + short description = 'usage' + extra_desc + ': ' + short
return self.key_builder.build(version=self.version, yield StorageKey(version=self.version,
id=1, lifetime=0x00000001, id=1, lifetime=0x00000001,
type='PSA_KEY_TYPE_RAW_DATA', bits=8, type='PSA_KEY_TYPE_RAW_DATA', bits=8,
usage=usage, alg=0, alg2=0, usage=usage, alg=0, alg2=0,
material=b'K', material=b'K',
description=description) description=description,
implicit_usage=True)
if test_implicit_usage:
description = 'usage without implication' + ': ' + short
yield StorageKey(version=self.version,
id=1, lifetime=0x00000001,
type='PSA_KEY_TYPE_RAW_DATA', bits=8,
usage=usage, alg=0, alg2=0,
material=b'K',
description=description,
implicit_usage=False)
def generate_keys_for_usage_flags(
self, def generate_keys_for_usage_flags(self, **kwargs) -> Iterator[StorageKey]:
extra_desc: Optional[str] = None
) -> List[StorageKey]:
"""Generate test keys covering usage flags.""" """Generate test keys covering usage flags."""
known_flags = sorted(self.constructors.key_usage_flags) known_flags = sorted(self.constructors.key_usage_flags)
keys = [] #type List[StorageKey] yield from self.key_for_usage_flags(['0'], **kwargs)
keys.append(self.key_for_usage_flags(['0'], extra_desc=extra_desc)) for usage_flag in known_flags:
keys += [self.key_for_usage_flags([usage_flag], extra_desc=extra_desc) yield from self.key_for_usage_flags([usage_flag], **kwargs)
for usage_flag in known_flags] for flag1, flag2 in zip(known_flags,
keys += [self.key_for_usage_flags([flag1, flag2], extra_desc=extra_desc) known_flags[1:] + [known_flags[0]]):
for flag1, flag2 in zip(known_flags, yield from self.key_for_usage_flags([flag1, flag2], **kwargs)
known_flags[1:] + [known_flags[0]])]
return keys
def generate_key_for_all_usage_flags(self) -> StorageKey: def generate_key_for_all_usage_flags(self) -> Iterator[StorageKey]:
known_flags = sorted(self.constructors.key_usage_flags) known_flags = sorted(self.constructors.key_usage_flags)
return self.key_for_usage_flags(known_flags, short='all known') yield from self.key_for_usage_flags(known_flags, short='all known')
def all_keys_for_usage_flags( def all_keys_for_usage_flags(self) -> Iterator[StorageKey]:
self, yield from self.generate_keys_for_usage_flags()
extra_desc: Optional[str] = None yield from self.generate_key_for_all_usage_flags()
) -> List[StorageKey]:
keys = self.generate_keys_for_usage_flags(extra_desc=extra_desc)
keys.append(self.generate_key_for_all_usage_flags())
return keys
def keys_for_type( def keys_for_type(
self, self,
key_type: str, key_type: str,
params: Optional[Iterable[str]] = None params: Optional[Iterable[str]] = None
) -> List[StorageKey]: ) -> Iterator[StorageKey]:
"""Generate test keys for the given key type. """Generate test keys for the given key type.
For key types that depend on a parameter (e.g. elliptic curve family), For key types that depend on a parameter (e.g. elliptic curve family),
`param` is the parameter to pass to the constructor. Only a single `param` is the parameter to pass to the constructor. Only a single
parameter is supported. parameter is supported.
""" """
keys = [] #type: List[StorageKey]
kt = crypto_knowledge.KeyType(key_type, params) kt = crypto_knowledge.KeyType(key_type, params)
for bits in kt.sizes_to_test(): for bits in kt.sizes_to_test():
usage_flags = 'PSA_KEY_USAGE_EXPORT' usage_flags = 'PSA_KEY_USAGE_EXPORT'
@ -421,22 +410,20 @@ class StorageFormat:
r'', r'',
kt.expression) kt.expression)
description = 'type: {} {}-bit'.format(short_expression, bits) description = 'type: {} {}-bit'.format(short_expression, bits)
keys.append(self.key_builder.build(version=self.version, yield StorageKey(version=self.version,
id=1, lifetime=0x00000001, id=1, lifetime=0x00000001,
type=kt.expression, bits=bits, type=kt.expression, bits=bits,
usage=usage_flags, alg=alg, alg2=alg2, usage=usage_flags, alg=alg, alg2=alg2,
material=key_material, material=key_material,
description=description)) description=description)
return keys
def all_keys_for_types(self) -> List[StorageKey]: def all_keys_for_types(self) -> Iterator[StorageKey]:
"""Generate test keys covering key types and their representations.""" """Generate test keys covering key types and their representations."""
key_types = sorted(self.constructors.key_types) key_types = sorted(self.constructors.key_types)
return [key for key_type in self.constructors.generate_expressions(key_types):
for key_type in self.constructors.generate_expressions(key_types) yield from self.keys_for_type(key_type)
for key in self.keys_for_type(key_type)]
def keys_for_algorithm(self, alg: str) -> List[StorageKey]: def keys_for_algorithm(self, alg: str) -> Iterator[StorageKey]:
"""Generate test keys for the specified algorithm.""" """Generate test keys for the specified algorithm."""
# For now, we don't have information on the compatibility of key # For now, we don't have information on the compatibility of key
# types and algorithms. So we just test the encoding of algorithms, # types and algorithms. So we just test the encoding of algorithms,
@ -444,26 +431,24 @@ class StorageFormat:
descr = re.sub(r'PSA_ALG_', r'', alg) descr = re.sub(r'PSA_ALG_', r'', alg)
descr = re.sub(r',', r', ', re.sub(r' +', r'', descr)) descr = re.sub(r',', r', ', re.sub(r' +', r'', descr))
usage = 'PSA_KEY_USAGE_EXPORT' usage = 'PSA_KEY_USAGE_EXPORT'
key1 = self.key_builder.build(version=self.version, yield StorageKey(version=self.version,
id=1, lifetime=0x00000001, id=1, lifetime=0x00000001,
type='PSA_KEY_TYPE_RAW_DATA', bits=8, type='PSA_KEY_TYPE_RAW_DATA', bits=8,
usage=usage, alg=alg, alg2=0, usage=usage, alg=alg, alg2=0,
material=b'K', material=b'K',
description='alg: ' + descr) description='alg: ' + descr)
key2 = self.key_builder.build(version=self.version, yield StorageKey(version=self.version,
id=1, lifetime=0x00000001, id=1, lifetime=0x00000001,
type='PSA_KEY_TYPE_RAW_DATA', bits=8, type='PSA_KEY_TYPE_RAW_DATA', bits=8,
usage=usage, alg=0, alg2=alg, usage=usage, alg=0, alg2=alg,
material=b'L', material=b'L',
description='alg2: ' + descr) description='alg2: ' + descr)
return [key1, key2]
def all_keys_for_algorithms(self) -> List[StorageKey]: def all_keys_for_algorithms(self) -> Iterator[StorageKey]:
"""Generate test keys covering algorithm encodings.""" """Generate test keys covering algorithm encodings."""
algorithms = sorted(self.constructors.algorithms) algorithms = sorted(self.constructors.algorithms)
return [key for alg in self.constructors.generate_expressions(algorithms):
for alg in self.constructors.generate_expressions(algorithms) yield from self.keys_for_algorithm(alg)
for key in self.keys_for_algorithm(alg)]
def generate_all_keys(self) -> List[StorageKey]: def generate_all_keys(self) -> List[StorageKey]:
"""Generate all keys for the test cases.""" """Generate all keys for the test cases."""
@ -474,18 +459,19 @@ class StorageFormat:
keys += self.all_keys_for_algorithms() keys += self.all_keys_for_algorithms()
return keys return keys
def all_test_cases(self) -> List[test_case.TestCase]: def all_test_cases(self) -> Iterator[test_case.TestCase]:
"""Generate all storage format test cases.""" """Generate all storage format test cases."""
# First build a list of all keys, then construct all the corresponding # First build a list of all keys, then construct all the corresponding
# test cases. This allows all required information to be obtained in # test cases. This allows all required information to be obtained in
# one go, which is a significant performance gain as the information # one go, which is a significant performance gain as the information
# includes numerical values obtained by compiling a C program. # includes numerical values obtained by compiling a C program.
keys = self.generate_all_keys() for key in self.generate_all_keys():
if key.location_value() != 0:
# Skip keys with a non-default location, because they # Skip keys with a non-default location, because they
# require a driver and we currently have no mechanism to # require a driver and we currently have no mechanism to
# determine whether a driver is available. # determine whether a driver is available.
return [self.make_test_case(key) for key in keys if key.location_value() == 0] continue
yield self.make_test_case(key)
class StorageFormatForward(StorageFormat): class StorageFormatForward(StorageFormat):
"""Storage format stability test cases for forward compatibility.""" """Storage format stability test cases for forward compatibility."""
@ -499,29 +485,10 @@ class StorageFormatV0(StorageFormat):
def __init__(self, info: Information) -> None: def __init__(self, info: Information) -> None:
super().__init__(info, 0, False) super().__init__(info, 0, False)
def all_keys_for_usage_flags( def all_keys_for_usage_flags(self) -> Iterator[StorageKey]:
self,
extra_desc: Optional[str] = None
) -> List[StorageKey]:
"""Generate test keys covering usage flags.""" """Generate test keys covering usage flags."""
# First generate keys without usage policy extension for yield from self.generate_keys_for_usage_flags(test_implicit_usage=True)
# compatibility testing, then generate the keys with extension yield from self.generate_key_for_all_usage_flags()
# to check the extension is working. Finally generate key for all known
# usage flag which needs to be separted because it is not affected by
# usage extension.
keys = [] #type: List[StorageKey]
prev_builder = self.key_builder
self.key_builder = StorageKeyBuilder(usage_extension=False)
keys += self.generate_keys_for_usage_flags(extra_desc='without extension')
self.key_builder = StorageKeyBuilder(usage_extension=True)
keys += self.generate_keys_for_usage_flags(extra_desc='with extension')
keys.append(self.generate_key_for_all_usage_flags())
self.key_builder = prev_builder
return keys
def keys_for_implicit_usage( def keys_for_implicit_usage(
self, self,
@ -529,12 +496,11 @@ class StorageFormatV0(StorageFormat):
alg: str, alg: str,
key_type: str, key_type: str,
params: Optional[Iterable[str]] = None params: Optional[Iterable[str]] = None
) -> List[StorageKey]: ) -> StorageKey:
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
"""Generate test keys for the specified implicit usage flag, """Generate test keys for the specified implicit usage flag,
algorithm and key type combination. algorithm and key type combination.
""" """
keys = [] #type: List[StorageKey]
kt = crypto_knowledge.KeyType(key_type, params) kt = crypto_knowledge.KeyType(key_type, params)
bits = kt.sizes_to_test()[0] bits = kt.sizes_to_test()[0]
implicit_usage = StorageKey.IMPLICIT_USAGE_FLAGS[implyer_usage] implicit_usage = StorageKey.IMPLICIT_USAGE_FLAGS[implyer_usage]
@ -551,15 +517,15 @@ class StorageFormatV0(StorageFormat):
kt.expression) kt.expression)
description = 'implied by {}: {} {} {}-bit'.format( description = 'implied by {}: {} {} {}-bit'.format(
usage_expression, alg_expression, key_type_expression, bits) usage_expression, alg_expression, key_type_expression, bits)
keys.append(self.key_builder.build(version=self.version, return StorageKey(version=self.version,
id=1, lifetime=0x00000001, id=1, lifetime=0x00000001,
type=kt.expression, bits=bits, type=kt.expression, bits=bits,
usage=material_usage_flags, usage=material_usage_flags,
expected_usage=expected_usage_flags, expected_usage=expected_usage_flags,
alg=alg, alg2=alg2, alg=alg, alg2=alg2,
material=key_material, material=key_material,
description=description)) description=description,
return keys implicit_usage=False)
def gather_key_types_for_sign_alg(self) -> Dict[str, List[str]]: def gather_key_types_for_sign_alg(self) -> Dict[str, List[str]]:
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
@ -609,29 +575,20 @@ class StorageFormatV0(StorageFormat):
alg_with_keys[alg] = [key_type] alg_with_keys[alg] = [key_type]
return alg_with_keys return alg_with_keys
def all_keys_for_implicit_usage(self) -> List[StorageKey]: def all_keys_for_implicit_usage(self) -> Iterator[StorageKey]:
"""Generate test keys for usage flag extensions.""" """Generate test keys for usage flag extensions."""
# Generate a key type and algorithm pair for each extendable usage # Generate a key type and algorithm pair for each extendable usage
# flag to generate a valid key for exercising. The key is generated # flag to generate a valid key for exercising. The key is generated
# without usage extension to check the extension compatiblity. # without usage extension to check the extension compatiblity.
keys = [] #type: List[StorageKey]
prev_builder = self.key_builder
# Generate the keys without usage extension
self.key_builder = StorageKeyBuilder(usage_extension=False)
alg_with_keys = self.gather_key_types_for_sign_alg() alg_with_keys = self.gather_key_types_for_sign_alg()
key_filter = StorageKey.IMPLICIT_USAGE_FLAGS_KEY_RESTRICTION key_filter = StorageKey.IMPLICIT_USAGE_FLAGS_KEY_RESTRICTION
# Walk through all combintion. The key types must be filtered to fit for usage in sorted(StorageKey.IMPLICIT_USAGE_FLAGS, key=str):
# the specific usage flag. for alg in sorted(alg_with_keys):
keys += [key for key_type in sorted(alg_with_keys[alg]):
for usage in sorted(StorageKey.IMPLICIT_USAGE_FLAGS, key=str) # The key types must be filtered to fit the specific usage flag.
for alg in sorted(alg_with_keys) if re.match(key_filter[usage], key_type):
for key_type in sorted(alg_with_keys[alg]) if re.match(key_filter[usage], key_type) yield self.keys_for_implicit_usage(usage, alg, key_type)
for key in self.keys_for_implicit_usage(usage, alg, key_type)]
self.key_builder = prev_builder
return keys
def generate_all_keys(self) -> List[StorageKey]: def generate_all_keys(self) -> List[StorageKey]:
keys = super().generate_all_keys() keys = super().generate_all_keys()