Skip to content

Commit ffee512

Browse files
committed
Update token, logical cluster, pool into BearerFieldProvider
1 parent 93616ad commit ffee512

File tree

3 files changed

+58
-56
lines changed

3 files changed

+58
-56
lines changed

src/confluent_kafka/schema_registry/schema_registry_client.py

+40-38
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
#
18-
18+
import abc
1919
import json
2020
import logging
2121
import random
@@ -62,14 +62,21 @@ def _urlencode(value: str) -> str:
6262
VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO']
6363

6464

65-
class _BearerFieldProvider:
65+
class _BearerFieldProvider(metaclass=abc.ABCMeta):
66+
@abc.abstractmethod
6667
def get_bearer_fields(self) -> dict:
6768
raise NotImplementedError
6869

6970

7071
class _StaticFieldProvider(_BearerFieldProvider):
72+
def __init__(self, token: str, logical_cluster: str, identity_pool: str):
73+
self.token = token
74+
self.logical_cluster = logical_cluster
75+
self.identity_pool = identity_pool
76+
7177
def get_bearer_fields(self) -> dict:
72-
return {}
78+
return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster,
79+
'bearer.auth.identity.pool.id': self.identity_pool}
7380

7481

7582
class _CustomOAuthClient(_BearerFieldProvider):
@@ -82,9 +89,11 @@ def get_bearer_fields(self) -> dict:
8289

8390

8491
class _OAuthClient(_BearerFieldProvider):
85-
def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str,
86-
max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
92+
def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str,
93+
identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
8794
self.token = None
95+
self.logical_cluster = logical_cluster
96+
self.identity_pool = identity_pool
8897
self.client = OAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope)
8998
self.token_endpoint = token_endpoint
9099
self.max_retries = max_retries
@@ -93,7 +102,8 @@ def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoin
93102
self.token_expiry_threshold = 0.8
94103

95104
def get_bearer_fields(self) -> dict:
96-
return {'bearer.auth.token': self.get_access_token()}
105+
return {'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster,
106+
'bearer.auth.identity.pool.id': self.identity_pool}
97107

98108
def token_expired(self) -> bool:
99109
expiry_window = self.token['expires_in'] * self.token_expiry_threshold
@@ -229,27 +239,26 @@ def __init__(self, conf: dict):
229239
self.retries_max_wait_ms = retries_max_wait_ms
230240

231241
self.bearer_field_provider = None
232-
self.bearer_token = None
233-
self.logical_cluster = None
234-
self.identity_pool_id = None
242+
logical_cluster = None
243+
identity_pool = None
235244
self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None)
236245
if self.bearer_auth_credentials_source is not None:
237246
self.auth = None
238247

239-
if self.bearer_auth_credentials_source in {"OAUTHBEARER", "STATIC_TOKEN"}:
248+
if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}:
240249
headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id']
241250
missing_headers = [header for header in headers if header not in conf_copy]
242251
if missing_headers:
243252
raise ValueError("Missing required bearer configuration properties: {}"
244253
.format(", ".join(missing_headers)))
245254

246-
self.logical_cluster = conf_copy.pop('bearer.auth.logical.cluster')
247-
if not isinstance(self.logical_cluster, str):
248-
raise TypeError("logical cluster must be a str, not " + str(type(self.logical_cluster)))
255+
logical_cluster = conf_copy.pop('bearer.auth.logical.cluster')
256+
if not isinstance(logical_cluster, str):
257+
raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster)))
249258

250-
self.identity_pool_id = conf_copy.pop('bearer.auth.identity.pool.id')
251-
if not isinstance(self.identity_pool_id, str):
252-
raise TypeError("identity pool id must be a str, not " + str(type(self.identity_pool_id)))
259+
identity_pool = conf_copy.pop('bearer.auth.identity.pool.id')
260+
if not isinstance(identity_pool, str):
261+
raise TypeError("identity pool id must be a str, not " + str(type(identity_pool)))
253262

254263
if self.bearer_auth_credentials_source == 'OAUTHBEARER':
255264
properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope',
@@ -277,16 +286,17 @@ def __init__(self, conf: dict):
277286
+ str(type(self.token_endpoint)))
278287

279288
self.bearer_field_provider = _OAuthClient(self.client_id, self.client_secret, self.scope,
280-
self.token_endpoint, self.max_retries, self.retries_wait_ms,
289+
self.token_endpoint, logical_cluster, identity_pool,
290+
self.max_retries, self.retries_wait_ms,
281291
self.retries_max_wait_ms)
282292
elif self.bearer_auth_credentials_source == 'STATIC_TOKEN':
283293
if 'bearer.auth.token' not in conf_copy:
284294
raise ValueError("Missing bearer.auth.token")
285-
self.bearer_token = conf_copy.pop('bearer.auth.token')
286-
self.bearer_field_provider = _StaticFieldProvider()
287-
if not isinstance(self.bearer_token, string_type):
288-
raise TypeError("bearer.auth.token must be a str, not " + str(type(self.bearer_token)))
289-
elif self.bearer_auth_credentials_source == "CUSTOM":
295+
static_token = conf_copy.pop('bearer.auth.token')
296+
self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool)
297+
if not isinstance(static_token, string_type):
298+
raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token)))
299+
elif self.bearer_auth_credentials_source == 'CUSTOM':
290300
custom_bearer_properties = ['bearer.auth.custom.provider.function',
291301
'bearer.auth.custom.provider.config']
292302
missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy]
@@ -349,29 +359,21 @@ def __init__(self, conf: dict):
349359

350360
def handle_bearer_auth(self, headers: dict) -> None:
351361
bearer_fields = self.bearer_field_provider.get_bearer_fields()
352-
token = bearer_fields['bearer.auth.token'] if 'bearer.auth.token' in bearer_fields else self.bearer_token
353-
354-
headers["Authorization"] = "Bearer {}".format(token)
355-
headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] \
356-
if ('bearer.auth.identity.pool.id' in bearer_fields) else self.identity_pool_id
357-
headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] \
358-
if ('bearer.auth.logical.cluster' in bearer_fields) else self.logical_cluster
362+
required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster']
359363

360364
missing_fields = []
361-
362-
if not token:
363-
missing_fields.append('bearer.auth.token')
364-
365-
if not headers['Confluent-Identity-Pool-Id']:
366-
missing_fields.append('bearer.auth.identity.pool.id')
367-
368-
if not headers['target-sr-cluster']:
369-
missing_fields.append('bearer.auth.logical.cluster')
365+
for field in required_fields:
366+
if field not in bearer_fields:
367+
missing_fields.append(field)
370368

371369
if missing_fields:
372370
raise ValueError("Missing required bearer auth fields, needs to be set in config or custom function: {}"
373371
.format(", ".join(missing_fields)))
374372

373+
headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token'])
374+
headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id']
375+
headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster']
376+
375377
def get(self, url: str, query: Optional[dict] = None) -> Any:
376378
return self.send_request(url, method='GET', query=query)
377379

tests/schema_registry/test_bearer_field_provider.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,25 @@ def custom_oauth_function(config: dict) -> dict:
3333
return config
3434

3535

36-
CUSTOM_FUNCTION = custom_oauth_function
37-
CUSTOM_CONFIG = {'bearer.auth.token': '123', 'bearer.auth.logical.cluster': 'lsrc-cluster',
38-
'bearer.auth.identity.pool.id': 'pool-id'}
36+
TEST_TOKEN = 'token123'
37+
TEST_CLUSTER = 'lsrc-cluster'
38+
TEST_POOL = 'pool-id'
39+
TEST_FUNCTION = custom_oauth_function
40+
TEST_CONFIG = {'bearer.auth.token': TEST_TOKEN, 'bearer.auth.logical.cluster': TEST_CLUSTER,
41+
'bearer.auth.identity.pool.id': TEST_POOL}
3942
TEST_URL = 'http://SchemaRegistry:65534'
4043

4144

4245
def test_expiry():
43-
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 2, 1000, 20000)
46+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000)
4447
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1}
4548
assert not oauth_client.token_expired()
4649
time.sleep(1.5)
4750
assert oauth_client.token_expired()
4851

4952

5053
def test_get_token():
51-
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 2, 1000, 20000)
52-
assert not oauth_client.token
54+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000)
5355

5456
def update_token1():
5557
oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'}
@@ -73,7 +75,7 @@ def update_token2():
7375

7476

7577
def test_generate_token_retry_logic():
76-
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 5, 1000, 20000)
78+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 5, 1000, 20000)
7779

7880
with (patch("confluent_kafka.schema_registry.schema_registry_client.time.sleep") as mock_sleep,
7981
patch("confluent_kafka.schema_registry.schema_registry_client.full_jitter") as mock_jitter):
@@ -86,14 +88,14 @@ def test_generate_token_retry_logic():
8688

8789

8890
def test_static_field_provider():
89-
static_field_provider = _StaticFieldProvider()
91+
static_field_provider = _StaticFieldProvider(TEST_TOKEN, TEST_CLUSTER, TEST_POOL)
9092
bearer_fields = static_field_provider.get_bearer_fields()
9193

92-
assert not bearer_fields
94+
assert bearer_fields == TEST_CONFIG
9395

9496

9597
def test_custom_oauth_client():
96-
custom_oauth_client = _CustomOAuthClient(CUSTOM_FUNCTION, CUSTOM_CONFIG)
98+
custom_oauth_client = _CustomOAuthClient(TEST_FUNCTION, TEST_CONFIG)
9799

98100
assert custom_oauth_client.get_bearer_fields() == custom_oauth_client.get_bearer_fields()
99101

@@ -105,7 +107,7 @@ def empty_custom(config):
105107
conf = {'url': TEST_URL,
106108
'bearer.auth.credentials.source': 'CUSTOM',
107109
'bearer.auth.custom.provider.function': empty_custom,
108-
'bearer.auth.custom.provider.config': CUSTOM_CONFIG}
110+
'bearer.auth.custom.provider.config': TEST_CONFIG}
109111

110112
headers = {'Accept': "application/vnd.schemaregistry.v1+json,"
111113
" application/vnd.schemaregistry+json,"
@@ -121,8 +123,8 @@ def empty_custom(config):
121123
def test_bearer_field_headers_valid():
122124
conf = {'url': TEST_URL,
123125
'bearer.auth.credentials.source': 'CUSTOM',
124-
'bearer.auth.custom.provider.function': CUSTOM_FUNCTION,
125-
'bearer.auth.custom.provider.config': CUSTOM_CONFIG}
126+
'bearer.auth.custom.provider.function': TEST_FUNCTION,
127+
'bearer.auth.custom.provider.config': TEST_CONFIG}
126128

127129
client = SchemaRegistryClient(conf)
128130

@@ -135,6 +137,6 @@ def test_bearer_field_headers_valid():
135137
assert 'Authorization' in headers
136138
assert 'Confluent-Identity-Pool-Id' in headers
137139
assert 'target-sr-cluster' in headers
138-
assert headers['Authorization'] == "Bearer {}".format(CUSTOM_CONFIG['bearer.auth.token'])
139-
assert headers['Confluent-Identity-Pool-Id'] == CUSTOM_CONFIG['bearer.auth.identity.pool.id']
140-
assert headers['target-sr-cluster'] == CUSTOM_CONFIG['bearer.auth.logical.cluster']
140+
assert headers['Authorization'] == "Bearer {}".format(TEST_CONFIG['bearer.auth.token'])
141+
assert headers['Confluent-Identity-Pool-Id'] == TEST_CONFIG['bearer.auth.identity.pool.id']
142+
assert headers['target-sr-cluster'] == TEST_CONFIG['bearer.auth.logical.cluster']

tests/schema_registry/test_config.py

-2
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,6 @@ def test_oauth_bearer_config_valid():
212212

213213
client = SchemaRegistryClient(conf)
214214

215-
assert client._rest_client.logical_cluster == TEST_CLUSTER
216-
assert client._rest_client.identity_pool_id == TEST_POOL
217215
assert client._rest_client.client_id == TEST_USERNAME
218216
assert client._rest_client.client_secret == TEST_USER_PASSWORD
219217
assert client._rest_client.scope == TEST_SCOPE

0 commit comments

Comments
 (0)