diff --git a/Lib/test/support/_hypothesis_stubs/__init__.py b/Lib/test/support/_hypothesis_stubs/__init__.py index 6ba5bb814b92f7..6fa013b55b2ac4 100644 --- a/Lib/test/support/_hypothesis_stubs/__init__.py +++ b/Lib/test/support/_hypothesis_stubs/__init__.py @@ -24,7 +24,13 @@ def decorator(f): @functools.wraps(f) def test_function(self): for example_args, example_kwargs in examples: - with self.subTest(*example_args, **example_kwargs): + if len(example_args) < 2: + subtest_args = example_args + else: + # subTest takes up to one positional argument. + # When there are more, display them as a tuple + subtest_args = [example_args] + with self.subTest(*subtest_args, **example_kwargs): f(self, *example_args, **example_kwargs) else: diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py index 409c8c109e885f..dd599515a34908 100644 --- a/Lib/test/test_base64.py +++ b/Lib/test/test_base64.py @@ -1,11 +1,14 @@ import unittest import base64 import binascii +import string import os from array import array from test.support import os_helper from test.support import script_helper +from test.support.hypothesis_helper import hypothesis + class LegacyBase64TestCase(unittest.TestCase): @@ -60,6 +63,13 @@ def test_decodebytes(self): eq(base64.decodebytes(array('B', b'YWJj\n')), b'abc') self.check_type_errors(base64.decodebytes) + @hypothesis.given(payload=hypothesis.strategies.binary()) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz') + def test_bytes_encode_decode_round_trip(self, payload): + encoded = base64.encodebytes(payload) + decoded = base64.decodebytes(encoded) + self.assertEqual(payload, decoded) + def test_encode(self): eq = self.assertEqual from io import BytesIO, StringIO @@ -88,6 +98,19 @@ def test_decode(self): self.assertRaises(TypeError, base64.encode, BytesIO(b'YWJj\n'), StringIO()) self.assertRaises(TypeError, base64.encode, StringIO('YWJj\n'), StringIO()) + @hypothesis.given(payload=hypothesis.strategies.binary()) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz') + def test_legacy_encode_decode_round_trip(self, payload): + from io import BytesIO + payload_file_r = BytesIO(payload) + encoded_file_w = BytesIO() + base64.encode(payload_file_r, encoded_file_w) + encoded_file_r = BytesIO(encoded_file_w.getvalue()) + decoded_file_w = BytesIO() + base64.decode(encoded_file_r, decoded_file_w) + decoded = decoded_file_w.getvalue() + self.assertEqual(payload, decoded) + class BaseXYTestCase(unittest.TestCase): @@ -268,6 +291,44 @@ def test_b64decode_invalid_chars(self): self.assertEqual(base64.b64decode(b'++[[//]]', b'[]'), res) self.assertEqual(base64.urlsafe_b64decode(b'++--//__'), res) + + def _altchars_strategy(): + """Generate 'altchars' for base64 encoding.""" + reserved_chars = (string.digits + string.ascii_letters + "=").encode() + allowed_chars = hypothesis.strategies.sampled_from( + [n for n in range(256) if n not in reserved_chars]) + two_bytes_strategy = hypothesis.strategies.lists( + allowed_chars, min_size=2, max_size=2, unique=True).map(bytes) + return (hypothesis.strategies.none() + | hypothesis.strategies.just(b"_-") + | two_bytes_strategy) + + @hypothesis.given( + payload=hypothesis.strategies.binary(), + altchars=_altchars_strategy(), + validate=hypothesis.strategies.booleans()) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', b"_-", True) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', b"_-", False) + def test_b64_encode_decode_round_trip(self, payload, altchars, validate): + encoded = base64.b64encode(payload, altchars=altchars) + decoded = base64.b64decode(encoded, altchars=altchars, + validate=validate) + self.assertEqual(payload, decoded) + + @hypothesis.given(payload=hypothesis.strategies.binary()) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz') + def test_standard_b64_encode_decode_round_trip(self, payload): + encoded = base64.standard_b64encode(payload) + decoded = base64.standard_b64decode(encoded) + self.assertEqual(payload, decoded) + + @hypothesis.given(payload=hypothesis.strategies.binary()) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz') + def test_urlsafe_b64_encode_decode_round_trip(self, payload): + encoded = base64.urlsafe_b64encode(payload) + decoded = base64.urlsafe_b64decode(encoded) + self.assertEqual(payload, decoded) + def test_b32encode(self): eq = self.assertEqual eq(base64.b32encode(b''), b'') @@ -355,6 +416,19 @@ def test_b32decode_error(self): with self.assertRaises(binascii.Error): base64.b32decode(data.decode('ascii')) + @hypothesis.given( + payload=hypothesis.strategies.binary(), + casefold=hypothesis.strategies.booleans(), + map01=( + hypothesis.strategies.none() + | hypothesis.strategies.binary(min_size=1, max_size=1))) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True, None) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False, None) + def test_b32_encode_decode_round_trip(self, payload, casefold, map01): + encoded = base64.b32encode(payload) + decoded = base64.b32decode(encoded, casefold=casefold, map01=map01) + self.assertEqual(payload, decoded) + def test_b32hexencode(self): test_cases = [ # to_encode, expected @@ -424,6 +498,15 @@ def test_b32hexdecode_error(self): with self.assertRaises(binascii.Error): base64.b32hexdecode(data.decode('ascii')) + @hypothesis.given( + payload=hypothesis.strategies.binary(), + casefold=hypothesis.strategies.booleans()) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False) + def test_b32_hexencode_decode_round_trip(self, payload, casefold): + encoded = base64.b32hexencode(payload) + decoded = base64.b32hexdecode(encoded, casefold=casefold) + self.assertEqual(payload, decoded) def test_b16encode(self): eq = self.assertEqual @@ -461,6 +544,16 @@ def test_b16decode(self): # Incorrect "padding" self.assertRaises(binascii.Error, base64.b16decode, '010') + @hypothesis.given( + payload=hypothesis.strategies.binary(), + casefold=hypothesis.strategies.booleans()) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False) + def test_b16_encode_decode_round_trip(self, payload, casefold): + endoded = base64.b16encode(payload) + decoded = base64.b16decode(endoded, casefold=casefold) + self.assertEqual(payload, decoded) + def test_a85encode(self): eq = self.assertEqual @@ -791,6 +884,61 @@ def test_z85decode_errors(self): self.assertRaises(ValueError, base64.z85decode, b'%nSc') self.assertRaises(ValueError, base64.z85decode, b'%nSc1') + def add_padding(self, payload): + """Add the expected padding for test_?85_encode_decode_round_trip.""" + if len(payload) % 4 != 0: + padding = b"\0" * ((-len(payload)) % 4) + payload = payload + padding + return payload + + @hypothesis.given( + payload=hypothesis.strategies.binary(), + foldspaces=hypothesis.strategies.booleans(), + wrapcol=( + hypothesis.strategies.just(0) + | hypothesis.strategies.integers(1, 1000)), + pad=hypothesis.strategies.booleans(), + adobe=hypothesis.strategies.booleans(), + ) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False, 0, False, False) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False, 20, True, True) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True, 0, False, True) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True, 20, True, False) + def test_a85_encode_decode_round_trip( + self, payload, foldspaces, wrapcol, pad, adobe + ): + encoded = base64.a85encode( + payload, foldspaces=foldspaces, wrapcol=wrapcol, + pad=pad, adobe=adobe, + ) + if wrapcol: + if adobe and wrapcol == 1: + # "adobe" needs wrapcol to be at least 2. + # a85decode quietly uses 2 if 1 is given; it's not worth + # loudly deprecating this behavior. + wrapcol = 2 + for line in encoded.splitlines(keepends=False): + self.assertLessEqual(len(line), wrapcol) + if adobe: + self.assertTrue(encoded.startswith(b'<~')) + self.assertTrue(encoded.endswith(b'~>')) + decoded = base64.a85decode(encoded, foldspaces=foldspaces, adobe=adobe) + if pad: + payload = self.add_padding(payload) + self.assertEqual(payload, decoded) + + @hypothesis.given( + payload=hypothesis.strategies.binary(), + pad=hypothesis.strategies.booleans()) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True) + @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False) + def test_b85_encode_decode_round_trip(self, payload, pad): + encoded = base64.b85encode(payload, pad=pad) + if pad: + payload = self.add_padding(payload) + decoded = base64.b85decode(encoded) + self.assertEqual(payload, decoded) + def test_decode_nonascii_str(self): decode_funcs = (base64.b64decode, base64.standard_b64decode,