diff --git a/scripts/mbedtls_dev/psa_storage.py b/scripts/mbedtls_dev/psa_storage.py index 45f0380e7..4cd3dfe91 100644 --- a/scripts/mbedtls_dev/psa_storage.py +++ b/scripts/mbedtls_dev/psa_storage.py @@ -101,6 +101,12 @@ class Key: LATEST_VERSION = 0 """The latest version of the storage format.""" + EXTENDABLE_USAGE_FLAGS = { + Expr('PSA_KEY_USAGE_SIGN_HASH'): Expr('PSA_KEY_USAGE_SIGN_MESSAGE'), + Expr('PSA_KEY_USAGE_VERIFY_HASH'): Expr('PSA_KEY_USAGE_VERIFY_MESSAGE') + } #type: Dict[Expr, Expr] + """The extendable usage flags with the corresponding extension flags.""" + def __init__(self, *, version: Optional[int] = None, id: Optional[int] = None, #pylint: disable=redefined-builtin @@ -108,18 +114,27 @@ class Key: type: Exprable, #pylint: disable=redefined-builtin bits: int, usage: Exprable, alg: Exprable, alg2: Exprable, - material: bytes #pylint: disable=used-before-assignment + material: bytes, #pylint: disable=used-before-assignment + usage_extension: bool = True ) -> None: self.version = self.LATEST_VERSION if version is None else version self.id = id #pylint: disable=invalid-name #type: Optional[int] self.lifetime = as_expr(lifetime) #type: Expr self.type = as_expr(type) #type: Expr self.bits = bits #type: int - self.usage = as_expr(usage) #type: Expr + self.original_usage = as_expr(usage) #type: Expr + self.updated_usage = self.original_usage #type: Expr self.alg = as_expr(alg) #type: Expr self.alg2 = as_expr(alg2) #type: Expr self.material = material #type: bytes + if usage_extension: + for flag, extension in self.EXTENDABLE_USAGE_FLAGS.items(): + if self.original_usage.value() & flag.value() and \ + self.original_usage.value() & extension.value() == 0: + self.updated_usage = Expr(self.updated_usage.string + + ' | ' + extension.string) + MAGIC = b'PSA\000KEY\000' @staticmethod @@ -151,7 +166,7 @@ class Key: if self.version == 0: attributes = self.pack('LHHLLL', self.lifetime, self.type, self.bits, - self.usage, self.alg, self.alg2) + self.updated_usage, self.alg, self.alg2) material = self.pack('L', len(self.material)) + self.material else: raise NotImplementedError diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py old mode 100644 new mode 100755 index af1cb533e..f1f3c42b0 --- a/tests/scripts/generate_psa_tests.py +++ b/tests/scripts/generate_psa_tests.py @@ -236,12 +236,14 @@ class StorageKey(psa_storage.Key): def __init__(self, *, description: str, **kwargs) -> None: super().__init__(**kwargs) self.description = description #type: str + self.usage = self.original_usage #type: psa_storage.Expr + class StorageKeyBuilder: - def __init__(self) -> None: - pass + def __init__(self, usage_extension: bool) -> None: + self.usage_extension = usage_extension #type: bool def build(self, **kwargs) -> StorageKey: - return StorageKey(**kwargs) + return StorageKey(usage_extension = self.usage_extension, **kwargs) class StorageFormat: """Storage format stability test cases.""" @@ -259,7 +261,7 @@ class StorageFormat: self.constructors = info.constructors #type: macro_collector.PSAMacroEnumerator self.version = version #type: int self.forward = forward #type: bool - self.key_builder = StorageKeyBuilder() #type: StorageKeyBuilder + self.key_builder = StorageKeyBuilder(usage_extension = True) #type: StorageKeyBuilder def make_test_case(self, key: StorageKey) -> test_case.TestCase: """Construct a storage format test case for the given key. @@ -473,6 +475,24 @@ class StorageFormatV0(StorageFormat): def __init__(self, info: Information) -> None: super().__init__(info, 0, False) + def all_keys_for_usage_flags(self) -> List[StorageKey]: + """Generate test keys covering usage flags.""" + # First generate keys without usage policy extension for + # compatibility testing, then generate the keys with extension + # to check the extension is working. + keys = [] #type: List[StorageKey] + prev_builder = self.key_builder + + self.key_builder = StorageKeyBuilder(usage_extension = False) + keys += super().all_keys_for_usage_flags(extra_desc = 'without extension') + + self.key_builder = StorageKeyBuilder(usage_extension = True) + keys += super().all_keys_for_usage_flags(extra_desc = 'with extension') + + self.key_builder = prev_builder + return keys + + class TestGenerator: """Generate test data."""