30
30
from enum import Enum
31
31
from threading import Lock
32
32
from typing import List , Dict , Type , TypeVar , \
33
- cast , Optional , Union , Any , Tuple
33
+ cast , Optional , Union , Any , Tuple , Callable
34
34
35
35
from cachetools import TTLCache , LRUCache
36
36
from httpx import Response
@@ -62,7 +62,26 @@ def _urlencode(value: str) -> str:
62
62
VALID_AUTH_PROVIDERS = ['URL' , 'USER_INFO' ]
63
63
64
64
65
- class _OAuthClient :
65
+ class _BearerFieldProvider :
66
+ def get_bearer_fields (self ) -> dict :
67
+ raise NotImplementedError
68
+
69
+
70
+ class _StaticFieldProvider (_BearerFieldProvider ):
71
+ def get_bearer_fields (self ) -> dict :
72
+ return {}
73
+
74
+
75
+ class _CustomOAuthClient (_BearerFieldProvider ):
76
+ def __init__ (self , custom_function : Callable [[Dict ], Dict ], custom_config : dict ):
77
+ self .custom_function = custom_function
78
+ self .custom_config = custom_config
79
+
80
+ def get_bearer_fields (self ) -> dict :
81
+ return self .custom_function (self .custom_config )
82
+
83
+
84
+ class _OAuthClient (_BearerFieldProvider ):
66
85
def __init__ (self , client_id : str , client_secret : str , scope : str , token_endpoint : str ,
67
86
max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
68
87
self .token = None
@@ -73,7 +92,10 @@ def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoin
73
92
self .retries_max_wait_ms = retries_max_wait_ms
74
93
self .token_expiry_threshold = 0.8
75
94
76
- def token_expired (self ):
95
+ def get_bearer_fields (self ) -> dict :
96
+ return {'bearer.auth.token' : self .get_access_token ()}
97
+
98
+ def token_expired (self ) -> bool :
77
99
expiry_window = self .token ['expires_in' ] * self .token_expiry_threshold
78
100
79
101
return self .token ['expires_at' ] < time .time () + expiry_window
@@ -84,7 +106,7 @@ def get_access_token(self) -> str:
84
106
85
107
return self .token ['access_token' ]
86
108
87
- def generate_access_token (self ):
109
+ def generate_access_token (self ) -> None :
88
110
for i in range (self .max_retries + 1 ):
89
111
try :
90
112
self .token = self .client .fetch_token (url = self .token_endpoint , grant_type = 'client_credentials' )
@@ -206,23 +228,28 @@ def __init__(self, conf: dict):
206
228
+ str (type (retries_max_wait_ms )))
207
229
self .retries_max_wait_ms = retries_max_wait_ms
208
230
209
- self .oauth_client = None
231
+ self .bearer_field_provider = None
232
+ self .bearer_token = None
233
+ self .logical_cluster = None
234
+ self .identity_pool_id = None
210
235
self .bearer_auth_credentials_source = conf_copy .pop ('bearer.auth.credentials.source' , None )
211
236
if self .bearer_auth_credentials_source is not None :
212
237
self .auth = None
213
- headers = ['bearer.auth.logical.cluster' , 'bearer.auth.identity.pool.id' ]
214
- missing_headers = [header for header in headers if header not in conf_copy ]
215
- if missing_headers :
216
- raise ValueError ("Missing required bearer configuration properties: {}"
217
- .format (", " .join (missing_headers )))
218
238
219
- self .logical_cluster = conf_copy .pop ('bearer.auth.logical.cluster' )
220
- if not isinstance (self .logical_cluster , str ):
221
- raise TypeError ("logical cluster must be a str, not " + str (type (self .logical_cluster )))
239
+ if self .bearer_auth_credentials_source in {"OAUTHBEARER" , "STATIC_TOKEN" }:
240
+ headers = ['bearer.auth.logical.cluster' , 'bearer.auth.identity.pool.id' ]
241
+ missing_headers = [header for header in headers if header not in conf_copy ]
242
+ if missing_headers :
243
+ raise ValueError ("Missing required bearer configuration properties: {}"
244
+ .format (", " .join (missing_headers )))
245
+
246
+ self .logical_cluster = conf_copy .pop ('bearer.auth.logical.cluster' )
247
+ if not isinstance (self .logical_cluster , str ):
248
+ raise TypeError ("logical cluster must be a str, not " + str (type (self .logical_cluster )))
222
249
223
- self .identity_pool_id = conf_copy .pop ('bearer.auth.identity.pool.id' )
224
- if not isinstance (self .identity_pool_id , str ):
225
- raise TypeError ("identity pool id must be a str, not " + str (type (self .identity_pool_id )))
250
+ self .identity_pool_id = conf_copy .pop ('bearer.auth.identity.pool.id' )
251
+ if not isinstance (self .identity_pool_id , str ):
252
+ raise TypeError ("identity pool id must be a str, not " + str (type (self .identity_pool_id )))
226
253
227
254
if self .bearer_auth_credentials_source == 'OAUTHBEARER' :
228
255
properties_list = ['bearer.auth.client.id' , 'bearer.auth.client.secret' , 'bearer.auth.scope' ,
@@ -249,15 +276,37 @@ def __init__(self, conf: dict):
249
276
raise TypeError ("bearer.auth.issuer.endpoint.url must be a str, not "
250
277
+ str (type (self .token_endpoint )))
251
278
252
- self .oauth_client = _OAuthClient (self .client_id , self .client_secret , self .scope , self . token_endpoint ,
253
- self .max_retries , self .retries_wait_ms , self .retries_max_wait_ms )
254
-
279
+ self .bearer_field_provider = _OAuthClient (self .client_id , self .client_secret , self .scope ,
280
+ self .token_endpoint , self .max_retries , self .retries_wait_ms ,
281
+ self . retries_max_wait_ms )
255
282
elif self .bearer_auth_credentials_source == 'STATIC_TOKEN' :
256
283
if 'bearer.auth.token' not in conf_copy :
257
284
raise ValueError ("Missing bearer.auth.token" )
258
285
self .bearer_token = conf_copy .pop ('bearer.auth.token' )
286
+ self .bearer_field_provider = _StaticFieldProvider ()
259
287
if not isinstance (self .bearer_token , string_type ):
260
288
raise TypeError ("bearer.auth.token must be a str, not " + str (type (self .bearer_token )))
289
+ elif self .bearer_auth_credentials_source == "CUSTOM" :
290
+ custom_bearer_properties = ['bearer.auth.custom.provider.function' ,
291
+ 'bearer.auth.custom.provider.config' ]
292
+ missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy ]
293
+ if missing_custom_properties :
294
+ raise ValueError ("Missing required custom OAuth configuration properties: {}" .
295
+ format (", " .join (missing_custom_properties )))
296
+
297
+ custom_function = conf_copy .pop ('bearer.auth.custom.provider.function' )
298
+ if not callable (custom_function ):
299
+ raise TypeError ("bearer.auth.custom.provider.function must be a callable, not "
300
+ + str (type (custom_function )))
301
+
302
+ custom_config = conf_copy .pop ('bearer.auth.custom.provider.config' )
303
+ if not isinstance (custom_config , dict ):
304
+ raise TypeError ("bearer.auth.custom.provider.config must be a dict, not "
305
+ + str (type (custom_config )))
306
+
307
+ self .bearer_field_provider = _CustomOAuthClient (custom_function , custom_config )
308
+ else :
309
+ raise ValueError ('Unrecognized bearer.auth.credentials.source' )
261
310
262
311
# Any leftover keys are unknown to _RestClient
263
312
if len (conf_copy ) > 0 :
@@ -298,13 +347,30 @@ def __init__(self, conf: dict):
298
347
timeout = self .timeout
299
348
)
300
349
301
- def handle_bearer_auth (self , headers : dict ):
302
- token = self .bearer_token
303
- if self .oauth_client :
304
- token = self . oauth_client . get_access_token ()
350
+ def handle_bearer_auth (self , headers : dict ) -> None :
351
+ bearer_fields = self .bearer_field_provider . get_bearer_fields ()
352
+ token = bearer_fields [ 'bearer.auth.token' ] if 'bearer.auth.token' in bearer_fields else self .bearer_token
353
+
305
354
headers ["Authorization" ] = "Bearer {}" .format (token )
306
- headers ['Confluent-Identity-Pool-Id' ] = self .identity_pool_id
307
- headers ['target-sr-cluster' ] = self .logical_cluster
355
+ headers ['Confluent-Identity-Pool-Id' ] = bearer_fields ['bearer.auth.identity.pool.id' ] \
356
+ if ('bearer.auth.identity.pool.id' in bearer_fields ) else self .identity_pool_id
357
+ headers ['target-sr-cluster' ] = bearer_fields ['bearer.auth.logical.cluster' ] \
358
+ if ('bearer.auth.logical.cluster' in bearer_fields ) else self .logical_cluster
359
+
360
+ missing_fields = []
361
+
362
+ if not token :
363
+ missing_fields .append ('bearer.auth.token' )
364
+
365
+ if not headers ['Confluent-Identity-Pool-Id' ]:
366
+ missing_fields .append ('bearer.auth.identity.pool.id' )
367
+
368
+ if not headers ['target-sr-cluster' ]:
369
+ missing_fields .append ('bearer.auth.logical.cluster' )
370
+
371
+ if missing_fields :
372
+ raise ValueError ("Missing required bearer auth fields, needs to be set in config or custom function: {}"
373
+ .format (", " .join (missing_fields )))
308
374
309
375
def get (self , url : str , query : Optional [dict ] = None ) -> Any :
310
376
return self .send_request (url , method = 'GET' , query = query )
0 commit comments