diff --git a/tests/scripts/check_names.py b/tests/scripts/check_names.py index ce03b8a66..37a8be325 100755 --- a/tests/scripts/check_names.py +++ b/tests/scripts/check_names.py @@ -179,8 +179,11 @@ class NameCheck(): self.return_code = 0 self.setup_logger(verbose) + # Memo for storing "glob expression": set(filepaths) + self.files = {} + # Globally excluded filenames - self.excluded_files = ["bn_mul", "compat-2.x.h"] + self.excluded_files = ["**/bn_mul", "**/compat-2.x.h"] # Will contain the parse result after a comprehensive parse self.parse_result = {} @@ -212,23 +215,46 @@ class NameCheck(): self.log.setLevel(logging.INFO) self.log.addHandler(logging.StreamHandler()) - def get_files(self, wildcard): + def get_files(self, include_wildcards, exclude_wildcards): """ - Get all files that match a UNIX-style wildcard recursively. While the - script is designed only for use on UNIX/macOS (due to nm), this function - would work fine on Windows even with forward slashes in the wildcard. + Get all files that match any of the UNIX-style wildcards. While the + check_names script is designed only for use on UNIX/macOS (due to nm), + this function alone would work fine on Windows even with forward slashes + in the wildcard. Args: - * wildcard: shell-style wildcards to match filepaths against. + * include_wildcards: a List of shell-style wildcards to match filepaths. + * exclude_wildacrds: a List of shell-style wildcards to exclude. Returns a List of relative filepaths. """ - accumulator = [] + accumulator = set() - for filepath in glob.iglob(wildcard, recursive=True): - if os.path.basename(filepath) not in self.excluded_files: - accumulator.append(filepath) - return accumulator + # exclude_wildcards may be None. Also, consider the global exclusions. + exclude_wildcards = (exclude_wildcards or []) + self.excluded_files + + # Perform set union on the glob results. Memoise individual sets. + for include_wildcard in include_wildcards: + if include_wildcard not in self.files: + self.files[include_wildcard] = set(glob.glob( + include_wildcard, + recursive=True + )) + + accumulator = accumulator.union(self.files[include_wildcard]) + + # Perform set difference to exclude. Also use the same memo since their + # behaviour is pretty much identical and it can benefit from the cache. + for exclude_wildcard in exclude_wildcards: + if exclude_wildcard not in self.files: + self.files[exclude_wildcard] = set(glob.glob( + exclude_wildcard, + recursive=True + )) + + accumulator = accumulator.difference(self.files[exclude_wildcard]) + + return list(accumulator) def parse_names_in_source(self): """ @@ -243,31 +269,37 @@ class NameCheck(): .format(str(self.excluded_files)) ) - m_headers = self.get_files("include/mbedtls/*.h") - p_headers = self.get_files("include/psa/*.h") - t_headers = [ + all_macros = self.parse_macros([ + "include/mbedtls/*.h", + "include/psa/*.h", + "library/*.h", + "tests/include/test/drivers/*.h", "3rdparty/everest/include/everest/everest.h", "3rdparty/everest/include/everest/x25519.h" - ] - d_headers = self.get_files("tests/include/test/drivers/*.h") - l_headers = self.get_files("library/*.h") - libraries = self.get_files("library/*.c") + [ + ]) + enum_consts = self.parse_enum_consts([ + "include/mbedtls/*.h", + "library/*.h", + "3rdparty/everest/include/everest/everest.h", + "3rdparty/everest/include/everest/x25519.h" + ]) + identifiers = self.parse_identifiers([ + "include/mbedtls/*.h", + "include/psa/*.h", + "library/*.h", + "3rdparty/everest/include/everest/everest.h", + "3rdparty/everest/include/everest/x25519.h" + ]) + mbed_words = self.parse_mbed_words([ + "include/mbedtls/*.h", + "include/psa/*.h", + "library/*.h", + "3rdparty/everest/include/everest/everest.h", + "3rdparty/everest/include/everest/x25519.h", + "library/*.c", "3rdparty/everest/library/everest.c", "3rdparty/everest/library/x25519.c" - ] - - all_macros = self.parse_macros( - m_headers + p_headers + t_headers + l_headers + d_headers - ) - enum_consts = self.parse_enum_consts( - m_headers + l_headers + t_headers - ) - identifiers = self.parse_identifiers( - m_headers + p_headers + t_headers + l_headers - ) - mbed_words = self.parse_mbed_words( - m_headers + p_headers + t_headers + l_headers + libraries - ) + ]) symbols = self.parse_symbols() # Remove identifier macros like mbedtls_printf or mbedtls_calloc @@ -284,7 +316,6 @@ class NameCheck(): self.log.debug(" {} Identifiers".format(len(identifiers))) self.log.debug(" {} Exported Symbols".format(len(symbols))) self.log.info("Analysing...") - self.parse_result = { "macros": actual_macros, "enum_consts": enum_consts, @@ -293,12 +324,13 @@ class NameCheck(): "mbed_words": mbed_words } - def parse_macros(self, files): + def parse_macros(self, include, exclude=None): """ Parse all macros defined by #define preprocessor directives. Args: - * files: A List of filepaths to look through. + * include: A List of glob expressions to look for files through. + * exclude: A List of glob expressions for excluding files. Returns a List of Match objects for the found macros. """ @@ -307,11 +339,9 @@ class NameCheck(): "asm", "inline", "EMIT", "_CRT_SECURE_NO_DEPRECATE", "MULADDC_" ) - self.log.debug("Looking for macros in {} files".format(len(files))) - macros = [] - for header_file in files: + for header_file in self.get_files(include, exclude): with open(header_file, "r", encoding="utf-8") as header: for line_no, line in enumerate(header): for macro in macro_regex.finditer(line): @@ -326,13 +356,14 @@ class NameCheck(): return macros - def parse_mbed_words(self, files): + def parse_mbed_words(self, include, exclude=None): """ Parse all words in the file that begin with MBED, in and out of macros, comments, anything. Args: - * files: a List of filepaths to look through. + * include: A List of glob expressions to look for files through. + * exclude: A List of glob expressions for excluding files. Returns a List of Match objects for words beginning with MBED. """ @@ -340,11 +371,9 @@ class NameCheck(): mbed_regex = re.compile(r"\bMBED.+?_[A-Z0-9_]*") exclusions = re.compile(r"// *no-check-names|#error") - self.log.debug("Looking for MBED names in {} files".format(len(files))) - mbed_words = [] - for filename in files: + for filename in self.get_files(include, exclude): with open(filename, "r", encoding="utf-8") as fp: for line_no, line in enumerate(fp): if exclusions.search(line): @@ -360,23 +389,19 @@ class NameCheck(): return mbed_words - def parse_enum_consts(self, files): + def parse_enum_consts(self, include, exclude=None): """ Parse all enum value constants that are declared. Args: - * files: A List of filepaths to look through. + * include: A List of glob expressions to look for files through. + * exclude: A List of glob expressions for excluding files. Returns a List of Match objects for the findings. """ - self.log.debug( - "Looking for enum consts in {} files" - .format(len(files)) - ) - enum_consts = [] - for header_file in files: + for header_file in self.get_files(include, exclude): # Emulate a finite state machine to parse enum declarations. # 0 = not in enum # 1 = inside enum @@ -408,7 +433,7 @@ class NameCheck(): return enum_consts - def parse_identifiers(self, files): + def parse_identifiers(self, include, exclude=None): """ Parse all lines of a header where a function identifier is declared, based on some huersitics. Highly dependent on formatting style. @@ -416,7 +441,8 @@ class NameCheck(): .search() checks throughout. Args: - * files: A List of filepaths to look through. + * include: A List of glob expressions to look for files through. + * exclude: A List of glob expressions for excluding files. Returns a List of Match objects with identifiers. """ @@ -445,15 +471,9 @@ class NameCheck(): r"#" r")" ) - - self.log.debug( - "Looking for identifiers in {} files" - .format(len(files)) - ) - identifiers = [] - for header_file in files: + for header_file in self.get_files(include, exclude): with open(header_file, "r", encoding="utf-8") as header: in_block_comment = False # The previous line varibale is used for concatenating lines diff --git a/tests/scripts/list_internal_identifiers.py b/tests/scripts/list_internal_identifiers.py index d58cb3f05..75b1646aa 100755 --- a/tests/scripts/list_internal_identifiers.py +++ b/tests/scripts/list_internal_identifiers.py @@ -45,12 +45,10 @@ def main(): try: name_check = NameCheck() - internal_headers = ( - name_check.get_files("include/mbedtls/*_internal.h") + - name_check.get_files("library/*.h") - ) - - result = name_check.parse_identifiers(internal_headers) + result = name_check.parse_identifiers([ + "include/mbedtls/*_internal.h", + "library/*.h" + ]) identifiers = ["{}\n".format(match.name) for match in result] with open("_identifiers", "w", encoding="utf-8") as f: