diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py index ab3fe07cf..044042bf0 100755 --- a/tests/scripts/generate_psa_tests.py +++ b/tests/scripts/generate_psa_tests.py @@ -115,6 +115,16 @@ def hack_dependencies_not_implemented(dependencies: List[str]) -> None: for dep in dependencies): dependencies.append('DEPENDENCY_NOT_IMPLEMENTED_YET') +# PSA_WANT_KEY_TYPE_xxx_KEY_PAIR symbols have a GENERATE suffix to state that +# they support key generation. +def fix_key_pair_dependencies(dep_list: str, type: str): + # Note: this LEGACY replacement for RSA is temporary and it's going to be + # aligned with ECC one in #7772. + new_list = [re.sub(r'RSA_KEY_PAIR\Z', r'RSA_KEY_PAIR_LEGACY', dep) + for dep in dep_list] + new_list = [re.sub(r'ECC_KEY_PAIR\Z', r'ECC_KEY_PAIR_' + type, dep) + for dep in new_list] + return new_list class Information: """Gather information about PSA constructors.""" @@ -208,13 +218,8 @@ class KeyTypeNotSupported: if kt.name.endswith('_PUBLIC_KEY'): generate_dependencies = [] else: - # PSA_WANT_KEY_TYPE_xxx_KEY_PAIR symbols have a GENERATE and - # IMPORT suffixes to state that they support key generation and - # import, respectively. - generate_dependencies = [re.sub(r'KEY_PAIR\Z', r'KEY_PAIR_GENERATE', dep) - for dep in import_dependencies] - import_dependencies = [re.sub(r'KEY_PAIR\Z', r'KEY_PAIR_IMPORT', dep) - for dep in import_dependencies] + generate_dependencies = fix_key_pair_dependencies(import_dependencies, 'GENERATE') + import_dependencies = fix_key_pair_dependencies(import_dependencies, 'BASIC_IMPORT_EXPORT') for bits in kt.sizes_to_test(): yield test_case_for_key_type_not_supported( 'import', kt.expression, bits, @@ -306,11 +311,7 @@ class KeyGenerate: generate_dependencies = [] result = 'PSA_ERROR_INVALID_ARGUMENT' else: - generate_dependencies = import_dependencies - # PSA_WANT_KEY_TYPE_xxx_KEY_PAIR symbols have a GENERATE suffix - # to state that they support key generation. - generate_dependencies = [re.sub(r'KEY_PAIR\Z', r'KEY_PAIR_GENERATE', dep) - for dep in generate_dependencies] + generate_dependencies = fix_key_pair_dependencies(import_dependencies, 'GENERATE') for bits in kt.sizes_to_test(): yield test_case_for_key_generation( kt.expression, bits, @@ -379,6 +380,7 @@ class OpFail: pretty_reason, ' with ' + pretty_type if pretty_type else '')) dependencies = automatic_dependencies(alg.base_expression, key_type) + dependencies = fix_key_pair_dependencies(dependencies, 'BASIC_IMPORT_EXPORT') for i, dep in enumerate(dependencies): if dep in not_deps: dependencies[i] = '!' + dep @@ -602,6 +604,7 @@ class StorageFormat: ) dependencies = finish_family_dependencies(dependencies, key.bits) dependencies += generate_key_dependencies(key.description) + dependencies = fix_key_pair_dependencies(dependencies, 'BASIC_IMPORT_EXPORT') tc.set_dependencies(dependencies) tc.set_function('key_storage_' + verb) if self.forward: