diff --git a/tests/scripts/generate_tls13_compat_tests.py b/tests/scripts/generate_tls13_compat_tests.py index 3d1754642..20dd8c4c5 100755 --- a/tests/scripts/generate_tls13_compat_tests.py +++ b/tests/scripts/generate_tls13_compat_tests.py @@ -349,27 +349,26 @@ class MbedTLSCli(TLSProgram): return ret -SERVER_CLS = {'OpenSSL': OpenSSLServ, 'GnuTLS': GnuTLSServ} -CLIENT_CLS = {'mbedTLS': MbedTLSCli} +SERVER_CLASSES = {'OpenSSL': OpenSSLServ, 'GnuTLS': GnuTLSServ} +CLIENT_CLASSES = {'mbedTLS': MbedTLSCli} -def generate_compat_test(server=None, client=None, cipher=None, # pylint: disable=unused-argument - sig_alg=None, named_group=None, **kwargs): +def generate_compat_test(server=None, client=None, cipher=None, sig_alg=None, named_group=None): """ Generate test case with `ssl-opt.sh` format. """ name = 'TLS 1.3 {client[0]}->{server[0]}: {cipher},{named_group},{sig_alg}'.format( client=client, server=server, cipher=cipher, sig_alg=sig_alg, named_group=named_group) - server = SERVER_CLS[server](cipher, sig_alg, named_group) - client = CLIENT_CLS[client](cipher, sig_alg, named_group) + server_object = SERVER_CLASSES[server](cipher, sig_alg, named_group) + client_object = CLIENT_CLASSES[client](cipher, sig_alg, named_group) cmd = ['run_test "{}"'.format(name), '"{}"'.format( - server.cmd()), '"{}"'.format(client.cmd()), '0'] - cmd += server.post_checks() - cmd += client.post_checks() + server_object.cmd()), '"{}"'.format(client_object.cmd()), '0'] + cmd += server_object.post_checks() + cmd += client_object.post_checks() prefix = ' \\\n' + (' '*9) cmd = prefix.join(cmd) - return '\n'.join(server.pre_checks() + client.pre_checks() + [cmd]) + return '\n'.join(server_object.pre_checks() + client_object.pre_checks() + [cmd]) SSL_OUTPUT_HEADER = '''#!/bin/sh @@ -429,11 +428,11 @@ def main(): parser.add_argument('--list-clients', action='store_true', default=False, help='List supported TLS Clients') - parser.add_argument('server', choices=SERVER_CLS.keys(), nargs='?', - default=list(SERVER_CLS.keys())[0], + parser.add_argument('server', choices=SERVER_CLASSES.keys(), nargs='?', + default=list(SERVER_CLASSES.keys())[0], help='Choose TLS server program for test') - parser.add_argument('client', choices=CLIENT_CLS.keys(), nargs='?', - default=list(CLIENT_CLS.keys())[0], + parser.add_argument('client', choices=CLIENT_CLASSES.keys(), nargs='?', + default=list(CLIENT_CLASSES.keys())[0], help='Choose TLS client program for test') parser.add_argument('cipher', choices=CIPHER_SUITE_IANA_VALUE.keys(), nargs='?', default=list(CIPHER_SUITE_IANA_VALUE.keys())[0], @@ -448,16 +447,18 @@ def main(): args = parser.parse_args() def get_all_test_cases(): - for i in itertools.product(CIPHER_SUITE_IANA_VALUE.keys(), SIG_ALG_IANA_VALUE.keys(), - NAMED_GROUP_IANA_VALUE.keys(), SERVER_CLS.keys(), - CLIENT_CLS.keys()): - yield generate_compat_test(**dict( - zip(['cipher', 'sig_alg', 'named_group', 'server', 'client'], i))) + for cipher, sig_alg, named_group, server, client in \ + itertools.product(CIPHER_SUITE_IANA_VALUE.keys(), SIG_ALG_IANA_VALUE.keys(), + NAMED_GROUP_IANA_VALUE.keys(), SERVER_CLASSES.keys(), + CLIENT_CLASSES.keys()): + yield generate_compat_test(cipher=cipher, sig_alg=sig_alg, named_group=named_group, + server=server, client=client) if args.generate_all_tls13_compat_tests: if args.output: with open(args.output, 'w', encoding="utf-8") as f: - f.write(SSL_OUTPUT_HEADER.format(filename=os.path.basename(args.output))) + f.write(SSL_OUTPUT_HEADER.format( + filename=os.path.basename(args.output))) f.write('\n\n'.join(get_all_test_cases())) f.write('\n') else: @@ -473,12 +474,13 @@ def main(): if args.list_named_groups: print(*NAMED_GROUP_IANA_VALUE.keys()) if args.list_servers: - print(*SERVER_CLS.keys()) + print(*SERVER_CLASSES.keys()) if args.list_clients: - print(*CLIENT_CLS.keys()) + print(*CLIENT_CLASSES.keys()) return 0 - print(generate_compat_test(**vars(args))) + print(generate_compat_test(server=args.server, client=args.client, + sig_alg=args.sig_alg, cipher=args.cipher, named_group=args.named_group)) return 0