Skip to content

Add TLS support for TCP sockets #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions bmemcached/client/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ class DistributedClient(ClientMixin):
It tries to distribute keys over the specified servers using `HashRing` consistent hash.
"""
def __init__(self, servers=('127.0.0.1:11211',), username=None, password=None, compression=None,
socket_timeout=SOCKET_TIMEOUT, pickle_protocol=0, pickler=pickle.Pickler, unpickler=pickle.Unpickler):
socket_timeout=SOCKET_TIMEOUT, pickle_protocol=0, pickler=pickle.Pickler, unpickler=pickle.Unpickler,
tls_context=None):
super(DistributedClient, self).__init__(servers, username, password, compression, socket_timeout,
pickle_protocol, pickler, unpickler)
pickle_protocol, pickler, unpickler, tls_context)
self._ring = HashRing(self._servers)

def _get_server(self, key):
Expand Down
8 changes: 7 additions & 1 deletion bmemcached/client/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class ClientMixin(object):
:type pickler: function
:param unpickler: Use this to replace the object deserialization mechanism.
:type unpickler: function
:param tls_context: A TLS context in order to connect to TLS enabled
memcached servers.
:type tls_context: ssl.SSLContext
"""
def __init__(self, servers=('127.0.0.1:11211',),
username=None,
Expand All @@ -36,14 +39,16 @@ def __init__(self, servers=('127.0.0.1:11211',),
socket_timeout=SOCKET_TIMEOUT,
pickle_protocol=PICKLE_PROTOCOL,
pickler=pickle.Pickler,
unpickler=pickle.Unpickler):
unpickler=pickle.Unpickler,
tls_context=None):
Copy link
Owner

@jaysonsantos jaysonsantos Dec 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there, could you also add some info about in on the __doc__?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'm on vacation now and will be back on this task in the first week of January.

self.username = username
self.password = password
self.compression = compression
self.socket_timeout = socket_timeout
self.pickle_protocol = pickle_protocol
self.pickler = pickler
self.unpickler = unpickler
self.tls_context = tls_context
self.set_servers(servers)

@property
Expand Down Expand Up @@ -73,6 +78,7 @@ def set_servers(self, servers):
pickle_protocol=self.pickle_protocol,
pickler=self.pickler,
unpickler=self.unpickler,
tls_context=self.tls_context,
) for server in servers]

def flush_all(self, time=0):
Expand Down
9 changes: 8 additions & 1 deletion bmemcached/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class Protocol(threading.local):
COMPRESSION_THRESHOLD = 128

def __init__(self, server, username=None, password=None, compression=None, socket_timeout=None,
pickle_protocol=None, pickler=None, unpickler=None):
pickle_protocol=None, pickler=None, unpickler=None, tls_context=None):
super(Protocol, self).__init__()
self.server = server
self._username = username
Expand All @@ -112,6 +112,7 @@ def __init__(self, server, username=None, password=None, compression=None, socke
self.pickle_protocol = pickle_protocol
self.pickler = pickler
self.unpickler = unpickler
self.tls_context = tls_context

self.reconnects_deferred_until = None

Expand Down Expand Up @@ -144,6 +145,12 @@ def _open_connection(self):
self.connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.connection.settimeout(self.socket_timeout)
self.connection.connect((self.host, self.port))

if self.tls_context:
self.connection = self.tls_context.wrap_socket(
self.connection,
server_hostname=self.host,
)
else:
self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.connection.connect(self.server)
Expand Down
1 change: 1 addition & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pytest-cov==2.7.1
mock==2.0.0
flake8==3.7.7
bumpversion==0.5.3
trustme==0.6.0
22 changes: 15 additions & 7 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,38 @@
import pytest


os.environ.setdefault('MEMCACHED_HOST', '127.0.0.1')
os.environ.setdefault("MEMCACHED_HOST", "localhost")


@pytest.yield_fixture(scope='session', autouse=True)
@pytest.yield_fixture(scope="session", autouse=True)
def memcached_standard_port():
p = subprocess.Popen(['memcached'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
p = subprocess.Popen(
["memcached"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
time.sleep(0.1)
yield p
p.kill()
p.wait()


@pytest.yield_fixture(scope='session', autouse=True)
@pytest.yield_fixture(scope="session", autouse=True)
def memcached_other_port():
p = subprocess.Popen(['memcached', '-p5000'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
p = subprocess.Popen(
["memcached", "-p5000"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
time.sleep(0.1)
yield p
p.kill()
p.wait()


@pytest.yield_fixture(scope='session', autouse=True)
@pytest.yield_fixture(scope="session", autouse=True)
def memcached_socket():
p = subprocess.Popen(['memcached', '-s/tmp/memcached.sock'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
p = subprocess.Popen(
["memcached", "-s/tmp/memcached.sock"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
time.sleep(0.1)
yield p
p.kill()
Expand Down
4 changes: 2 additions & 2 deletions test/test_server_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def testNoPortGiven(self):
self.assertEqual(server.port, 11211)

def testInvalidPort(self):
server = bmemcached.protocol.Protocol('127.0.0.1:blah')
server = bmemcached.protocol.Protocol('{}:blah'.format(os.environ['MEMCACHED_HOST']))
self.assertEqual(server.host, os.environ['MEMCACHED_HOST'])
self.assertEqual(server.port, 11211)

def testNonStandardPort(self):
server = bmemcached.protocol.Protocol('127.0.0.1:5000')
server = bmemcached.protocol.Protocol('{}:5000'.format(os.environ['MEMCACHED_HOST']))
self.assertEqual(server.host, os.environ['MEMCACHED_HOST'])
self.assertEqual(server.port, 5000)

Expand Down
59 changes: 59 additions & 0 deletions test/test_tls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import pytest
import subprocess
import ssl
import time
import trustme

import bmemcached
import test_simple_functions


ca = trustme.CA()
server_cert = ca.issue_cert(os.environ["MEMCACHED_HOST"] + u"")


@pytest.yield_fixture(scope="module", autouse=True)
def memcached_tls():
key = server_cert.private_key_pem
cert = server_cert.cert_chain_pems[0]

with cert.tempfile() as c, key.tempfile() as k:
p = subprocess.Popen(
[
"memcached",
"-p5001",
"-Z",
"-o",
"ssl_key={}".format(k),
"-o",
"ssl_chain_cert={}".format(c),
"-o",
"ssl_verify_mode=1",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
time.sleep(0.1)

if p.poll() is not None:
pytest.skip("Memcached server is not built with TLS support.")

yield p
p.kill()
p.wait()


class TLSMemcachedTests(test_simple_functions.MemcachedTests):
"""
Same tests as above, just make sure it works with TLS.
"""

def setUp(self):
ctx = ssl.create_default_context()

ca.configure_trust(ctx)

self.server = "{}:5001".format(os.environ["MEMCACHED_HOST"])
self.client = bmemcached.Client(self.server, tls_context=ctx)
self.reset()