15
15
# See the License for the specific language governing permissions and
16
16
# limitations under the License.
17
17
#
18
-
18
+ import abc
19
19
import json
20
20
import logging
21
21
import random
@@ -62,14 +62,21 @@ def _urlencode(value: str) -> str:
62
62
VALID_AUTH_PROVIDERS = ['URL' , 'USER_INFO' ]
63
63
64
64
65
- class _BearerFieldProvider :
65
+ class _BearerFieldProvider (metaclass = abc .ABCMeta ):
66
+ @abc .abstractmethod
66
67
def get_bearer_fields (self ) -> dict :
67
68
raise NotImplementedError
68
69
69
70
70
71
class _StaticFieldProvider (_BearerFieldProvider ):
72
+ def __init__ (self , token : str , logical_cluster : str , identity_pool : str ):
73
+ self .token = token
74
+ self .logical_cluster = logical_cluster
75
+ self .identity_pool = identity_pool
76
+
71
77
def get_bearer_fields (self ) -> dict :
72
- return {}
78
+ return {'bearer.auth.token' : self .token , 'bearer.auth.logical.cluster' : self .logical_cluster ,
79
+ 'bearer.auth.identity.pool.id' : self .identity_pool }
73
80
74
81
75
82
class _CustomOAuthClient (_BearerFieldProvider ):
@@ -82,9 +89,11 @@ def get_bearer_fields(self) -> dict:
82
89
83
90
84
91
class _OAuthClient (_BearerFieldProvider ):
85
- def __init__ (self , client_id : str , client_secret : str , scope : str , token_endpoint : str ,
86
- max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
92
+ def __init__ (self , client_id : str , client_secret : str , scope : str , token_endpoint : str , logical_cluster : str ,
93
+ identity_pool : str , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
87
94
self .token = None
95
+ self .logical_cluster = logical_cluster
96
+ self .identity_pool = identity_pool
88
97
self .client = OAuth2Client (client_id = client_id , client_secret = client_secret , scope = scope )
89
98
self .token_endpoint = token_endpoint
90
99
self .max_retries = max_retries
@@ -93,7 +102,8 @@ def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoin
93
102
self .token_expiry_threshold = 0.8
94
103
95
104
def get_bearer_fields (self ) -> dict :
96
- return {'bearer.auth.token' : self .get_access_token ()}
105
+ return {'bearer.auth.token' : self .get_access_token (), 'bearer.auth.logical.cluster' : self .logical_cluster ,
106
+ 'bearer.auth.identity.pool.id' : self .identity_pool }
97
107
98
108
def token_expired (self ) -> bool :
99
109
expiry_window = self .token ['expires_in' ] * self .token_expiry_threshold
@@ -229,27 +239,26 @@ def __init__(self, conf: dict):
229
239
self .retries_max_wait_ms = retries_max_wait_ms
230
240
231
241
self .bearer_field_provider = None
232
- self .bearer_token = None
233
- self .logical_cluster = None
234
- self .identity_pool_id = None
242
+ logical_cluster = None
243
+ identity_pool = None
235
244
self .bearer_auth_credentials_source = conf_copy .pop ('bearer.auth.credentials.source' , None )
236
245
if self .bearer_auth_credentials_source is not None :
237
246
self .auth = None
238
247
239
- if self .bearer_auth_credentials_source in {" OAUTHBEARER" , " STATIC_TOKEN" }:
248
+ if self .bearer_auth_credentials_source in {' OAUTHBEARER' , ' STATIC_TOKEN' }:
240
249
headers = ['bearer.auth.logical.cluster' , 'bearer.auth.identity.pool.id' ]
241
250
missing_headers = [header for header in headers if header not in conf_copy ]
242
251
if missing_headers :
243
252
raise ValueError ("Missing required bearer configuration properties: {}"
244
253
.format (", " .join (missing_headers )))
245
254
246
- self . logical_cluster = conf_copy .pop ('bearer.auth.logical.cluster' )
247
- if not isinstance (self . logical_cluster , str ):
248
- raise TypeError ("logical cluster must be a str, not " + str (type (self . logical_cluster )))
255
+ logical_cluster = conf_copy .pop ('bearer.auth.logical.cluster' )
256
+ if not isinstance (logical_cluster , str ):
257
+ raise TypeError ("logical cluster must be a str, not " + str (type (logical_cluster )))
249
258
250
- self . identity_pool_id = conf_copy .pop ('bearer.auth.identity.pool.id' )
251
- if not isinstance (self . identity_pool_id , str ):
252
- raise TypeError ("identity pool id must be a str, not " + str (type (self . identity_pool_id )))
259
+ identity_pool = conf_copy .pop ('bearer.auth.identity.pool.id' )
260
+ if not isinstance (identity_pool , str ):
261
+ raise TypeError ("identity pool id must be a str, not " + str (type (identity_pool )))
253
262
254
263
if self .bearer_auth_credentials_source == 'OAUTHBEARER' :
255
264
properties_list = ['bearer.auth.client.id' , 'bearer.auth.client.secret' , 'bearer.auth.scope' ,
@@ -277,16 +286,17 @@ def __init__(self, conf: dict):
277
286
+ str (type (self .token_endpoint )))
278
287
279
288
self .bearer_field_provider = _OAuthClient (self .client_id , self .client_secret , self .scope ,
280
- self .token_endpoint , self .max_retries , self .retries_wait_ms ,
289
+ self .token_endpoint , logical_cluster , identity_pool ,
290
+ self .max_retries , self .retries_wait_ms ,
281
291
self .retries_max_wait_ms )
282
292
elif self .bearer_auth_credentials_source == 'STATIC_TOKEN' :
283
293
if 'bearer.auth.token' not in conf_copy :
284
294
raise ValueError ("Missing bearer.auth.token" )
285
- self . bearer_token = conf_copy .pop ('bearer.auth.token' )
286
- self .bearer_field_provider = _StaticFieldProvider ()
287
- if not isinstance (self . bearer_token , string_type ):
288
- raise TypeError ("bearer.auth.token must be a str, not " + str (type (self . bearer_token )))
289
- elif self .bearer_auth_credentials_source == " CUSTOM" :
295
+ static_token = conf_copy .pop ('bearer.auth.token' )
296
+ self .bearer_field_provider = _StaticFieldProvider (static_token , logical_cluster , identity_pool )
297
+ if not isinstance (static_token , string_type ):
298
+ raise TypeError ("bearer.auth.token must be a str, not " + str (type (static_token )))
299
+ elif self .bearer_auth_credentials_source == ' CUSTOM' :
290
300
custom_bearer_properties = ['bearer.auth.custom.provider.function' ,
291
301
'bearer.auth.custom.provider.config' ]
292
302
missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy ]
@@ -349,29 +359,21 @@ def __init__(self, conf: dict):
349
359
350
360
def handle_bearer_auth (self , headers : dict ) -> None :
351
361
bearer_fields = self .bearer_field_provider .get_bearer_fields ()
352
- token = bearer_fields ['bearer.auth.token' ] if 'bearer.auth.token' in bearer_fields else self .bearer_token
353
-
354
- headers ["Authorization" ] = "Bearer {}" .format (token )
355
- headers ['Confluent-Identity-Pool-Id' ] = bearer_fields ['bearer.auth.identity.pool.id' ] \
356
- if ('bearer.auth.identity.pool.id' in bearer_fields ) else self .identity_pool_id
357
- headers ['target-sr-cluster' ] = bearer_fields ['bearer.auth.logical.cluster' ] \
358
- if ('bearer.auth.logical.cluster' in bearer_fields ) else self .logical_cluster
362
+ required_fields = ['bearer.auth.token' , 'bearer.auth.identity.pool.id' , 'bearer.auth.logical.cluster' ]
359
363
360
364
missing_fields = []
361
-
362
- if not token :
363
- missing_fields .append ('bearer.auth.token' )
364
-
365
- if not headers ['Confluent-Identity-Pool-Id' ]:
366
- missing_fields .append ('bearer.auth.identity.pool.id' )
367
-
368
- if not headers ['target-sr-cluster' ]:
369
- missing_fields .append ('bearer.auth.logical.cluster' )
365
+ for field in required_fields :
366
+ if field not in bearer_fields :
367
+ missing_fields .append (field )
370
368
371
369
if missing_fields :
372
370
raise ValueError ("Missing required bearer auth fields, needs to be set in config or custom function: {}"
373
371
.format (", " .join (missing_fields )))
374
372
373
+ headers ["Authorization" ] = "Bearer {}" .format (bearer_fields ['bearer.auth.token' ])
374
+ headers ['Confluent-Identity-Pool-Id' ] = bearer_fields ['bearer.auth.identity.pool.id' ]
375
+ headers ['target-sr-cluster' ] = bearer_fields ['bearer.auth.logical.cluster' ]
376
+
375
377
def get (self , url : str , query : Optional [dict ] = None ) -> Any :
376
378
return self .send_request (url , method = 'GET' , query = query )
377
379
0 commit comments