6
6
import random
7
7
import socket
8
8
import errno
9
- import ssl
10
9
11
10
__all__ = ['find_available_port' , 'SocketFactory' ]
12
11
@@ -49,14 +48,29 @@ def fatal_exception_message(typ, err) -> (str, None):
49
48
return None
50
49
return getattr (err , 'strerror' , '<strerror not present>' )
51
50
52
- def secure (self , socket : socket .socket ) -> ssl .SSLSocket :
51
+ @property
52
+ def _security_context (self ):
53
+ if self ._security_context_ii is None :
54
+ from ssl import SSLContext , PROTOCOL_TLS_CLIENT
55
+ ctx = self ._security_context_ii = SSLContext (PROTOCOL_TLS_CLIENT )
56
+ ctx .check_hostname = False
57
+
58
+ cf = self .socket_secure .get ('certfile' )
59
+ kf = self .socket_secure .get ('keyfile' )
60
+ if cf is not None :
61
+ self ._security_context_ii .load_cert_chain (cf , keyfile = kf )
62
+
63
+ ca = self .socket_secure .get ('ca_certs' )
64
+ if ca is not None :
65
+ self ._security_context_ii .load_verify_locations (ca )
66
+
67
+ return self ._security_context_ii
68
+
69
+ def secure (self , socket : socket .socket ):
53
70
"""
54
71
Secure a socket with SSL.
55
72
"""
56
- if self .socket_secure is not None :
57
- return ssl .wrap_socket (socket , ** self .socket_secure )
58
- else :
59
- return ssl .wrap_socket (socket )
73
+ return self ._security_context .wrap_socket (socket )
60
74
61
75
def __call__ (self , timeout = None ):
62
76
s = socket .socket (* self .socket_create )
@@ -73,10 +87,12 @@ def __init__(self,
73
87
socket_create ,
74
88
socket_connect ,
75
89
socket_secure = None ,
90
+ socket_security_context = None
76
91
):
92
+ self ._security_context_ii = socket_security_context
77
93
self .socket_create = socket_create
78
94
self .socket_connect = socket_connect
79
- self .socket_secure = socket_secure
95
+ self .socket_secure = socket_secure or {}
80
96
81
97
def __str__ (self ):
82
98
return 'socket' + repr (self .socket_connect )
0 commit comments