diff --git a/tests/scripts/audit-validity-dates.py b/tests/scripts/audit-validity-dates.py index 400066840..d74c6f826 100755 --- a/tests/scripts/audit-validity-dates.py +++ b/tests/scripts/audit-validity-dates.py @@ -171,9 +171,9 @@ class Auditor: """A base class for audit.""" def __init__(self, logger): self.logger = logger - self.default_files = [] + self.default_files = [] # type: typing.List[str] # A list to store the parsed audit_data. - self.audit_data = [] + self.audit_data = [] # type: typing.List[AuditData] self.parser = X509Parser({ DataType.CRT: { DataFormat.PEM: x509.load_pem_x509_certificate, @@ -354,7 +354,11 @@ def main(): help=('not valid after this date (UTC, YYYY-MM-DD). ' 'Default: not-before'), metavar='DATE') - parser.add_argument('files', nargs='*', help='files to audit', + parser.add_argument('--data-files', action='append', nargs='*', + help='data files to audit', + metavar='FILE') + parser.add_argument('--suite-data-files', action='append', nargs='*', + help='suite data files to audit', metavar='FILE') args = parser.parse_args() @@ -368,22 +372,29 @@ def main(): td_auditor = TestDataAuditor(logger) sd_auditor = SuiteDataAuditor(logger) - if args.files: - data_files = args.files - suite_data_files = args.files - else: + data_files = [] + suite_data_files = [] + if args.data_files is None and args.suite_data_files is None: data_files = td_auditor.default_files suite_data_files = sd_auditor.default_files + else: + if args.data_files is not None: + data_files = [x for l in args.data_files for x in l] + if args.suite_data_files is not None: + suite_data_files = [x for l in args.suite_data_files for x in l] + # validity period start date if args.not_before: not_before_date = datetime.datetime.fromisoformat(args.not_before) else: not_before_date = datetime.datetime.today() + # validity period end date if args.not_after: not_after_date = datetime.datetime.fromisoformat(args.not_after) else: not_after_date = not_before_date + # go through all the files td_auditor.walk_all(data_files) sd_auditor.walk_all(suite_data_files) audit_results = td_auditor.audit_data + sd_auditor.audit_data @@ -396,6 +407,7 @@ def main(): if args.all: filter_func = None + # filter and output the results for d in filter(filter_func, audit_results): list_all(d)