Skip to content

Commit 93616ad

Browse files
committed
Add support for custom OAuth functions
1 parent 497ef2a commit 93616ad

File tree

4 files changed

+256
-82
lines changed

4 files changed

+256
-82
lines changed

src/confluent_kafka/schema_registry/schema_registry_client.py

Lines changed: 91 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from enum import Enum
3131
from threading import Lock
3232
from typing import List, Dict, Type, TypeVar, \
33-
cast, Optional, Union, Any, Tuple
33+
cast, Optional, Union, Any, Tuple, Callable
3434

3535
from cachetools import TTLCache, LRUCache
3636
from httpx import Response
@@ -62,7 +62,26 @@ def _urlencode(value: str) -> str:
6262
VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO']
6363

6464

65-
class _OAuthClient:
65+
class _BearerFieldProvider:
66+
def get_bearer_fields(self) -> dict:
67+
raise NotImplementedError
68+
69+
70+
class _StaticFieldProvider(_BearerFieldProvider):
71+
def get_bearer_fields(self) -> dict:
72+
return {}
73+
74+
75+
class _CustomOAuthClient(_BearerFieldProvider):
76+
def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict):
77+
self.custom_function = custom_function
78+
self.custom_config = custom_config
79+
80+
def get_bearer_fields(self) -> dict:
81+
return self.custom_function(self.custom_config)
82+
83+
84+
class _OAuthClient(_BearerFieldProvider):
6685
def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str,
6786
max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
6887
self.token = None
@@ -73,7 +92,10 @@ def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoin
7392
self.retries_max_wait_ms = retries_max_wait_ms
7493
self.token_expiry_threshold = 0.8
7594

76-
def token_expired(self):
95+
def get_bearer_fields(self) -> dict:
96+
return {'bearer.auth.token': self.get_access_token()}
97+
98+
def token_expired(self) -> bool:
7799
expiry_window = self.token['expires_in'] * self.token_expiry_threshold
78100

79101
return self.token['expires_at'] < time.time() + expiry_window
@@ -84,7 +106,7 @@ def get_access_token(self) -> str:
84106

85107
return self.token['access_token']
86108

87-
def generate_access_token(self):
109+
def generate_access_token(self) -> None:
88110
for i in range(self.max_retries + 1):
89111
try:
90112
self.token = self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials')
@@ -206,23 +228,28 @@ def __init__(self, conf: dict):
206228
+ str(type(retries_max_wait_ms)))
207229
self.retries_max_wait_ms = retries_max_wait_ms
208230

209-
self.oauth_client = None
231+
self.bearer_field_provider = None
232+
self.bearer_token = None
233+
self.logical_cluster = None
234+
self.identity_pool_id = None
210235
self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None)
211236
if self.bearer_auth_credentials_source is not None:
212237
self.auth = None
213-
headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id']
214-
missing_headers = [header for header in headers if header not in conf_copy]
215-
if missing_headers:
216-
raise ValueError("Missing required bearer configuration properties: {}"
217-
.format(", ".join(missing_headers)))
218238

219-
self.logical_cluster = conf_copy.pop('bearer.auth.logical.cluster')
220-
if not isinstance(self.logical_cluster, str):
221-
raise TypeError("logical cluster must be a str, not " + str(type(self.logical_cluster)))
239+
if self.bearer_auth_credentials_source in {"OAUTHBEARER", "STATIC_TOKEN"}:
240+
headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id']
241+
missing_headers = [header for header in headers if header not in conf_copy]
242+
if missing_headers:
243+
raise ValueError("Missing required bearer configuration properties: {}"
244+
.format(", ".join(missing_headers)))
245+
246+
self.logical_cluster = conf_copy.pop('bearer.auth.logical.cluster')
247+
if not isinstance(self.logical_cluster, str):
248+
raise TypeError("logical cluster must be a str, not " + str(type(self.logical_cluster)))
222249

223-
self.identity_pool_id = conf_copy.pop('bearer.auth.identity.pool.id')
224-
if not isinstance(self.identity_pool_id, str):
225-
raise TypeError("identity pool id must be a str, not " + str(type(self.identity_pool_id)))
250+
self.identity_pool_id = conf_copy.pop('bearer.auth.identity.pool.id')
251+
if not isinstance(self.identity_pool_id, str):
252+
raise TypeError("identity pool id must be a str, not " + str(type(self.identity_pool_id)))
226253

227254
if self.bearer_auth_credentials_source == 'OAUTHBEARER':
228255
properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope',
@@ -249,15 +276,37 @@ def __init__(self, conf: dict):
249276
raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not "
250277
+ str(type(self.token_endpoint)))
251278

252-
self.oauth_client = _OAuthClient(self.client_id, self.client_secret, self.scope, self.token_endpoint,
253-
self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms)
254-
279+
self.bearer_field_provider = _OAuthClient(self.client_id, self.client_secret, self.scope,
280+
self.token_endpoint, self.max_retries, self.retries_wait_ms,
281+
self.retries_max_wait_ms)
255282
elif self.bearer_auth_credentials_source == 'STATIC_TOKEN':
256283
if 'bearer.auth.token' not in conf_copy:
257284
raise ValueError("Missing bearer.auth.token")
258285
self.bearer_token = conf_copy.pop('bearer.auth.token')
286+
self.bearer_field_provider = _StaticFieldProvider()
259287
if not isinstance(self.bearer_token, string_type):
260288
raise TypeError("bearer.auth.token must be a str, not " + str(type(self.bearer_token)))
289+
elif self.bearer_auth_credentials_source == "CUSTOM":
290+
custom_bearer_properties = ['bearer.auth.custom.provider.function',
291+
'bearer.auth.custom.provider.config']
292+
missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy]
293+
if missing_custom_properties:
294+
raise ValueError("Missing required custom OAuth configuration properties: {}".
295+
format(", ".join(missing_custom_properties)))
296+
297+
custom_function = conf_copy.pop('bearer.auth.custom.provider.function')
298+
if not callable(custom_function):
299+
raise TypeError("bearer.auth.custom.provider.function must be a callable, not "
300+
+ str(type(custom_function)))
301+
302+
custom_config = conf_copy.pop('bearer.auth.custom.provider.config')
303+
if not isinstance(custom_config, dict):
304+
raise TypeError("bearer.auth.custom.provider.config must be a dict, not "
305+
+ str(type(custom_config)))
306+
307+
self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config)
308+
else:
309+
raise ValueError('Unrecognized bearer.auth.credentials.source')
261310

262311
# Any leftover keys are unknown to _RestClient
263312
if len(conf_copy) > 0:
@@ -298,13 +347,30 @@ def __init__(self, conf: dict):
298347
timeout=self.timeout
299348
)
300349

301-
def handle_bearer_auth(self, headers: dict):
302-
token = self.bearer_token
303-
if self.oauth_client:
304-
token = self.oauth_client.get_access_token()
350+
def handle_bearer_auth(self, headers: dict) -> None:
351+
bearer_fields = self.bearer_field_provider.get_bearer_fields()
352+
token = bearer_fields['bearer.auth.token'] if 'bearer.auth.token' in bearer_fields else self.bearer_token
353+
305354
headers["Authorization"] = "Bearer {}".format(token)
306-
headers['Confluent-Identity-Pool-Id'] = self.identity_pool_id
307-
headers['target-sr-cluster'] = self.logical_cluster
355+
headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] \
356+
if ('bearer.auth.identity.pool.id' in bearer_fields) else self.identity_pool_id
357+
headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] \
358+
if ('bearer.auth.logical.cluster' in bearer_fields) else self.logical_cluster
359+
360+
missing_fields = []
361+
362+
if not token:
363+
missing_fields.append('bearer.auth.token')
364+
365+
if not headers['Confluent-Identity-Pool-Id']:
366+
missing_fields.append('bearer.auth.identity.pool.id')
367+
368+
if not headers['target-sr-cluster']:
369+
missing_fields.append('bearer.auth.logical.cluster')
370+
371+
if missing_fields:
372+
raise ValueError("Missing required bearer auth fields, needs to be set in config or custom function: {}"
373+
.format(", ".join(missing_fields)))
308374

309375
def get(self, url: str, query: Optional[dict] = None) -> Any:
310376
return self.send_request(url, method='GET', query=query)
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright 2025 Confluent Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
import pytest
19+
import time
20+
from unittest.mock import Mock, patch
21+
22+
from confluent_kafka.schema_registry.schema_registry_client import (_OAuthClient, _StaticFieldProvider,
23+
_CustomOAuthClient, SchemaRegistryClient)
24+
from confluent_kafka.schema_registry.error import OAuthTokenError
25+
26+
"""
27+
Tests to ensure OAuth client is set up correctly.
28+
29+
"""
30+
31+
32+
def custom_oauth_function(config: dict) -> dict:
33+
return config
34+
35+
36+
CUSTOM_FUNCTION = custom_oauth_function
37+
CUSTOM_CONFIG = {'bearer.auth.token': '123', 'bearer.auth.logical.cluster': 'lsrc-cluster',
38+
'bearer.auth.identity.pool.id': 'pool-id'}
39+
TEST_URL = 'http://SchemaRegistry:65534'
40+
41+
42+
def test_expiry():
43+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 2, 1000, 20000)
44+
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1}
45+
assert not oauth_client.token_expired()
46+
time.sleep(1.5)
47+
assert oauth_client.token_expired()
48+
49+
50+
def test_get_token():
51+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 2, 1000, 20000)
52+
assert not oauth_client.token
53+
54+
def update_token1():
55+
oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'}
56+
57+
def update_token2():
58+
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1, 'access_token': '1234'}
59+
60+
oauth_client.generate_access_token = Mock(side_effect=update_token1)
61+
oauth_client.get_access_token()
62+
assert oauth_client.generate_access_token.call_count == 1
63+
assert oauth_client.token['access_token'] == '123'
64+
65+
oauth_client.generate_access_token = Mock(side_effect=update_token2)
66+
oauth_client.get_access_token()
67+
# Call count resets to 1 after reassigning generate_access_token
68+
assert oauth_client.generate_access_token.call_count == 1
69+
assert oauth_client.token['access_token'] == '1234'
70+
71+
oauth_client.get_access_token()
72+
assert oauth_client.generate_access_token.call_count == 1
73+
74+
75+
def test_generate_token_retry_logic():
76+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 5, 1000, 20000)
77+
78+
with (patch("confluent_kafka.schema_registry.schema_registry_client.time.sleep") as mock_sleep,
79+
patch("confluent_kafka.schema_registry.schema_registry_client.full_jitter") as mock_jitter):
80+
81+
with pytest.raises(OAuthTokenError):
82+
oauth_client.generate_access_token()
83+
84+
assert mock_sleep.call_count == 5
85+
assert mock_jitter.call_count == 5
86+
87+
88+
def test_static_field_provider():
89+
static_field_provider = _StaticFieldProvider()
90+
bearer_fields = static_field_provider.get_bearer_fields()
91+
92+
assert not bearer_fields
93+
94+
95+
def test_custom_oauth_client():
96+
custom_oauth_client = _CustomOAuthClient(CUSTOM_FUNCTION, CUSTOM_CONFIG)
97+
98+
assert custom_oauth_client.get_bearer_fields() == custom_oauth_client.get_bearer_fields()
99+
100+
101+
def test_bearer_field_headers_missing():
102+
def empty_custom(config):
103+
return {}
104+
105+
conf = {'url': TEST_URL,
106+
'bearer.auth.credentials.source': 'CUSTOM',
107+
'bearer.auth.custom.provider.function': empty_custom,
108+
'bearer.auth.custom.provider.config': CUSTOM_CONFIG}
109+
110+
headers = {'Accept': "application/vnd.schemaregistry.v1+json,"
111+
" application/vnd.schemaregistry+json,"
112+
" application/json"}
113+
114+
client = SchemaRegistryClient(conf)
115+
116+
with pytest.raises(ValueError, match=r"Missing required bearer auth fields, "
117+
r"needs to be set in config or custom function: (.*)"):
118+
client._rest_client.handle_bearer_auth(headers)
119+
120+
121+
def test_bearer_field_headers_valid():
122+
conf = {'url': TEST_URL,
123+
'bearer.auth.credentials.source': 'CUSTOM',
124+
'bearer.auth.custom.provider.function': CUSTOM_FUNCTION,
125+
'bearer.auth.custom.provider.config': CUSTOM_CONFIG}
126+
127+
client = SchemaRegistryClient(conf)
128+
129+
headers = {'Accept': "application/vnd.schemaregistry.v1+json,"
130+
" application/vnd.schemaregistry+json,"
131+
" application/json"}
132+
133+
client._rest_client.handle_bearer_auth(headers)
134+
135+
assert 'Authorization' in headers
136+
assert 'Confluent-Identity-Pool-Id' in headers
137+
assert 'target-sr-cluster' in headers
138+
assert headers['Authorization'] == "Bearer {}".format(CUSTOM_CONFIG['bearer.auth.token'])
139+
assert headers['Confluent-Identity-Pool-Id'] == CUSTOM_CONFIG['bearer.auth.identity.pool.id']
140+
assert headers['target-sr-cluster'] == CUSTOM_CONFIG['bearer.auth.logical.cluster']

tests/schema_registry/test_config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,31 @@ def test_static_bearer_config():
230230
SchemaRegistryClient(conf)
231231

232232

233+
def test_custom_bearer_config():
234+
conf = {'url': TEST_URL,
235+
'bearer.auth.credentials.source': 'CUSTOM'}
236+
237+
with pytest.raises(ValueError, match='Missing required custom OAuth configuration properties:'):
238+
SchemaRegistryClient(conf)
239+
240+
241+
def test_custom_bearer_config_valid():
242+
def custom_function(config: dict):
243+
return {}
244+
245+
custom_config = {}
246+
247+
conf = {'url': TEST_URL,
248+
'bearer.auth.credentials.source': 'CUSTOM',
249+
'bearer.auth.custom.provider.function': custom_function,
250+
'bearer.auth.custom.provider.config': custom_config}
251+
252+
client = SchemaRegistryClient(conf)
253+
254+
assert client._rest_client.bearer_field_provider.custom_function == custom_function
255+
assert client._rest_client.bearer_field_provider.custom_config == custom_config
256+
257+
233258
def test_config_unknown_prop():
234259
conf = {'url': TEST_URL,
235260
'basic.auth.credentials.source': 'SASL_INHERIT',

0 commit comments

Comments
 (0)