diff --git a/scripts/mbedtls_dev/crypto_knowledge.py b/scripts/mbedtls_dev/crypto_knowledge.py index a56c8638e..3029c801d 100644 --- a/scripts/mbedtls_dev/crypto_knowledge.py +++ b/scripts/mbedtls_dev/crypto_knowledge.py @@ -246,6 +246,8 @@ class KeyType: # So a public key object with a key agreement algorithm is not # a valid combination. return False + if alg.is_invalid_key_agreement_with_derivation(): + return False if self.head == 'ECC': assert self.params is not None eccc = EllipticCurveCategory.from_family(self.params[0]) @@ -412,17 +414,38 @@ class Algorithm: self.category = self.determine_category(self.base_expression, self.head) self.is_wildcard = self.determine_wildcard(self.expression) - def is_key_agreement_with_derivation(self) -> bool: - """Whether this is a combined key agreement and key derivation algorithm.""" + def get_key_agreement_derivation(self) -> Optional[str]: + """For a combined key agreement and key derivation algorithm, get the derivation part. + + For anything else, return None. + """ if self.category != AlgorithmCategory.KEY_AGREEMENT: - return False + return None m = re.match(r'PSA_ALG_KEY_AGREEMENT\(\w+,\s*(.*)\)\Z', self.expression) if not m: - return False + return None kdf_alg = m.group(1) # Assume kdf_alg is either a valid KDF or 0. - return not re.match(r'(?:0[Xx])?0+\s*\Z', kdf_alg) + if re.match(r'(?:0[Xx])?0+\s*\Z', kdf_alg): + return None + return kdf_alg + KEY_DERIVATIONS_INCOMPATIBLE_WITH_AGREEMENT = frozenset([ + 'PSA_ALG_TLS12_ECJPAKE_TO_PMS', # secret input in specific format + ]) + def is_valid_key_agreement_with_derivation(self) -> bool: + """Whether this is a valid combined key agreement and key derivation algorithm.""" + kdf_alg = self.get_key_agreement_derivation() + if kdf_alg is None: + return False + return kdf_alg not in self.KEY_DERIVATIONS_INCOMPATIBLE_WITH_AGREEMENT + + def is_invalid_key_agreement_with_derivation(self) -> bool: + """Whether this is an invalid combined key agreement and key derivation algorithm.""" + kdf_alg = self.get_key_agreement_derivation() + if kdf_alg is None: + return False + return kdf_alg in self.KEY_DERIVATIONS_INCOMPATIBLE_WITH_AGREEMENT def short_expression(self, level: int = 0) -> str: """Abbreviate the expression, keeping it human-readable. @@ -515,7 +538,7 @@ class Algorithm: if category == self.category: return True if category == AlgorithmCategory.KEY_DERIVATION and \ - self.is_key_agreement_with_derivation(): + self.is_valid_key_agreement_with_derivation(): return True return False