diff --git a/bmemcached/client/distributed.py b/bmemcached/client/distributed.py index 7ef9d02..3db305c 100644 --- a/bmemcached/client/distributed.py +++ b/bmemcached/client/distributed.py @@ -12,9 +12,10 @@ class DistributedClient(ClientMixin): It tries to distribute keys over the specified servers using `HashRing` consistent hash. """ def __init__(self, servers=('127.0.0.1:11211',), username=None, password=None, compression=None, - socket_timeout=SOCKET_TIMEOUT, pickle_protocol=0, pickler=pickle.Pickler, unpickler=pickle.Unpickler): + socket_timeout=SOCKET_TIMEOUT, pickle_protocol=0, pickler=pickle.Pickler, unpickler=pickle.Unpickler, + tls_context=None): super(DistributedClient, self).__init__(servers, username, password, compression, socket_timeout, - pickle_protocol, pickler, unpickler) + pickle_protocol, pickler, unpickler, tls_context) self._ring = HashRing(self._servers) def _get_server(self, key): diff --git a/bmemcached/client/mixin.py b/bmemcached/client/mixin.py index 656f51d..990f6a8 100644 --- a/bmemcached/client/mixin.py +++ b/bmemcached/client/mixin.py @@ -28,6 +28,9 @@ class ClientMixin(object): :type pickler: function :param unpickler: Use this to replace the object deserialization mechanism. :type unpickler: function + :param tls_context: A TLS context in order to connect to TLS enabled + memcached servers. + :type tls_context: ssl.SSLContext """ def __init__(self, servers=('127.0.0.1:11211',), username=None, @@ -36,7 +39,8 @@ def __init__(self, servers=('127.0.0.1:11211',), socket_timeout=SOCKET_TIMEOUT, pickle_protocol=PICKLE_PROTOCOL, pickler=pickle.Pickler, - unpickler=pickle.Unpickler): + unpickler=pickle.Unpickler, + tls_context=None): self.username = username self.password = password self.compression = compression @@ -44,6 +48,7 @@ def __init__(self, servers=('127.0.0.1:11211',), self.pickle_protocol = pickle_protocol self.pickler = pickler self.unpickler = unpickler + self.tls_context = tls_context self.set_servers(servers) @property @@ -73,6 +78,7 @@ def set_servers(self, servers): pickle_protocol=self.pickle_protocol, pickler=self.pickler, unpickler=self.unpickler, + tls_context=self.tls_context, ) for server in servers] def flush_all(self, time=0): diff --git a/bmemcached/protocol.py b/bmemcached/protocol.py index dd9f94a..3ec99be 100644 --- a/bmemcached/protocol.py +++ b/bmemcached/protocol.py @@ -99,7 +99,7 @@ class Protocol(threading.local): COMPRESSION_THRESHOLD = 128 def __init__(self, server, username=None, password=None, compression=None, socket_timeout=None, - pickle_protocol=None, pickler=None, unpickler=None): + pickle_protocol=None, pickler=None, unpickler=None, tls_context=None): super(Protocol, self).__init__() self.server = server self._username = username @@ -112,6 +112,7 @@ def __init__(self, server, username=None, password=None, compression=None, socke self.pickle_protocol = pickle_protocol self.pickler = pickler self.unpickler = unpickler + self.tls_context = tls_context self.reconnects_deferred_until = None @@ -144,6 +145,12 @@ def _open_connection(self): self.connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.connection.settimeout(self.socket_timeout) self.connection.connect((self.host, self.port)) + + if self.tls_context: + self.connection = self.tls_context.wrap_socket( + self.connection, + server_hostname=self.host, + ) else: self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.connection.connect(self.server) diff --git a/requirements_test.txt b/requirements_test.txt index eae07ad..1e58316 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -3,3 +3,4 @@ pytest-cov==2.7.1 mock==2.0.0 flake8==3.7.7 bumpversion==0.5.3 +trustme==0.6.0 diff --git a/test/conftest.py b/test/conftest.py index 8bb76cc..b4cb3b8 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,30 +5,38 @@ import pytest -os.environ.setdefault('MEMCACHED_HOST', '127.0.0.1') +os.environ.setdefault("MEMCACHED_HOST", "localhost") -@pytest.yield_fixture(scope='session', autouse=True) +@pytest.yield_fixture(scope="session", autouse=True) def memcached_standard_port(): - p = subprocess.Popen(['memcached'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p = subprocess.Popen( + ["memcached"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) time.sleep(0.1) yield p p.kill() p.wait() -@pytest.yield_fixture(scope='session', autouse=True) +@pytest.yield_fixture(scope="session", autouse=True) def memcached_other_port(): - p = subprocess.Popen(['memcached', '-p5000'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p = subprocess.Popen( + ["memcached", "-p5000"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) time.sleep(0.1) yield p p.kill() p.wait() -@pytest.yield_fixture(scope='session', autouse=True) +@pytest.yield_fixture(scope="session", autouse=True) def memcached_socket(): - p = subprocess.Popen(['memcached', '-s/tmp/memcached.sock'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p = subprocess.Popen( + ["memcached", "-s/tmp/memcached.sock"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) time.sleep(0.1) yield p p.kill() diff --git a/test/test_server_parsing.py b/test/test_server_parsing.py index b7fb158..a384395 100644 --- a/test/test_server_parsing.py +++ b/test/test_server_parsing.py @@ -27,12 +27,12 @@ def testNoPortGiven(self): self.assertEqual(server.port, 11211) def testInvalidPort(self): - server = bmemcached.protocol.Protocol('127.0.0.1:blah') + server = bmemcached.protocol.Protocol('{}:blah'.format(os.environ['MEMCACHED_HOST'])) self.assertEqual(server.host, os.environ['MEMCACHED_HOST']) self.assertEqual(server.port, 11211) def testNonStandardPort(self): - server = bmemcached.protocol.Protocol('127.0.0.1:5000') + server = bmemcached.protocol.Protocol('{}:5000'.format(os.environ['MEMCACHED_HOST'])) self.assertEqual(server.host, os.environ['MEMCACHED_HOST']) self.assertEqual(server.port, 5000) diff --git a/test/test_tls.py b/test/test_tls.py new file mode 100644 index 0000000..9091035 --- /dev/null +++ b/test/test_tls.py @@ -0,0 +1,59 @@ +import os +import pytest +import subprocess +import ssl +import time +import trustme + +import bmemcached +import test_simple_functions + + +ca = trustme.CA() +server_cert = ca.issue_cert(os.environ["MEMCACHED_HOST"] + u"") + + +@pytest.yield_fixture(scope="module", autouse=True) +def memcached_tls(): + key = server_cert.private_key_pem + cert = server_cert.cert_chain_pems[0] + + with cert.tempfile() as c, key.tempfile() as k: + p = subprocess.Popen( + [ + "memcached", + "-p5001", + "-Z", + "-o", + "ssl_key={}".format(k), + "-o", + "ssl_chain_cert={}".format(c), + "-o", + "ssl_verify_mode=1", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + time.sleep(0.1) + + if p.poll() is not None: + pytest.skip("Memcached server is not built with TLS support.") + + yield p + p.kill() + p.wait() + + +class TLSMemcachedTests(test_simple_functions.MemcachedTests): + """ + Same tests as above, just make sure it works with TLS. + """ + + def setUp(self): + ctx = ssl.create_default_context() + + ca.configure_trust(ctx) + + self.server = "{}:5001".format(os.environ["MEMCACHED_HOST"]) + self.client = bmemcached.Client(self.server, tls_context=ctx) + self.reset()