diff --git a/src/confluent_kafka/schema_registry/schema_registry_client.py b/src/confluent_kafka/schema_registry/schema_registry_client.py index 3368e53e5..1fb719d2e 100644 --- a/src/confluent_kafka/schema_registry/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/schema_registry_client.py @@ -33,6 +33,7 @@ cast, Optional, Union, Any, Tuple from cachetools import TTLCache, LRUCache +from httpx import Response from .error import SchemaRegistryError @@ -70,9 +71,15 @@ def __init__(self, conf: dict): raise ValueError("Missing required configuration property url") if not isinstance(base_url, string_type): raise TypeError("url must be a str, not " + str(type(base_url))) - if not base_url.startswith('http') and not base_url.startswith('mock'): - raise ValueError("Invalid url {}".format(base_url)) - self.base_url = base_url.rstrip('/') + base_urls = [] + for url in base_url.split(','): + url = url.strip().rstrip('/') + if not url.startswith('http') and not url.startswith('mock'): + raise ValueError("Invalid url {}".format(url)) + base_urls.append(url) + if not base_urls: + raise ValueError("Missing required configuration property url") + self.base_urls = base_urls self.verify = True ca = conf_copy.pop('ssl.ca.location', None) @@ -93,7 +100,7 @@ def __init__(self, conf: dict): raise ValueError("ssl.certificate.location required when" " configuring ssl.key.location") - parsed = urlparse(base_url) + parsed = urlparse(self.base_urls[0]) try: userinfo = (unquote(parsed.username), unquote(parsed.password)) except (AttributeError, TypeError): @@ -219,7 +226,7 @@ def send_request( query: dict = None ) -> Any: """ - Sends HTTP request to the SchemaRegistry. + Sends HTTP request to the SchemaRegistry, trying each base URL in turn. All unsuccessful attempts will raise a SchemaRegistryError with the response contents. In most cases this will be accompanied by a @@ -250,21 +257,22 @@ def send_request( 'Content-Type': "application/vnd.schemaregistry.v1+json"} response = None - for i in range(self.max_retries + 1): - response = self.session.request( - method, url="/".join([self.base_url, url]), - headers=headers, data=body, params=query) + for i, base_url in enumerate(self.base_urls): + try: + response = self.send_http_request( + base_url, url, method, headers, body, query) - if (is_success(response.status_code) - or not is_retriable(response.status_code) - or i >= self.max_retries): - break + if is_success(response.status_code): + return response.json() - time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) + if not is_retriable(response.status_code) or i == len(self.base_urls) - 1: + break + except Exception as e: + if i == len(self.base_urls) - 1: + # Raise the exception since we have no more urls to try + raise e try: - if 200 <= response.status_code <= 299: - return response.json() raise SchemaRegistryError(response.status_code, response.json().get('error_code'), response.json().get('message')) @@ -275,6 +283,48 @@ def send_request( "Unknown Schema Registry Error: " + str(response.content)) + def send_http_request( + self, base_url: str, url: str, method: str, headers: dict, + body: Optional[str] = None, query: dict = None + ) -> Response: + """ + Sends HTTP request to the SchemaRegistry. + + All unsuccessful attempts will raise a SchemaRegistryError with the + response contents. In most cases this will be accompanied by a + Schema Registry supplied error code. + + In the event the response is malformed an error_code of -1 will be used. + + Args: + base_url (str): Schema Registry base URL + + url (str): Request path + + method (str): HTTP method + + headers (dict): Headers + + body (str): Request content + + query (dict): Query params to attach to the URL + + Returns: + Response: Schema Registry response content. + """ + for i in range(self.max_retries + 1): + response = self.session.request( + method, url="/".join([base_url, url]), + headers=headers, data=body, params=query) + + if is_success(response.status_code): + return response + + if not is_retriable(response.status_code) or i >= self.max_retries: + return response + + time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) + def is_success(status_code: int) -> bool: return 200 <= status_code <= 299 @@ -495,7 +545,7 @@ class SchemaRegistryClient(object): +------------------------------+------+-------------------------------------------------+ | Property name | type | Description | +==============================+======+=================================================+ - | ``url`` * | str | Schema Registry URL. | + | ``url`` * | str | Comma-separated list of Schema Registry URLs. | +------------------------------+------+-------------------------------------------------+ | | | Path to CA certificate file used | | ``ssl.ca.location`` | str | to verify the Schema Registry's | diff --git a/src/confluent_kafka/schema_registry/serde.py b/src/confluent_kafka/schema_registry/serde.py index 004509bc5..5462dc924 100644 --- a/src/confluent_kafka/schema_registry/serde.py +++ b/src/confluent_kafka/schema_registry/serde.py @@ -311,7 +311,7 @@ def _execute_rules( for index in range(len(rules)): rule = rules[index] - if rule.disabled: + if self._is_disabled(rule): continue if rule.mode == RuleMode.WRITEREAD: if rule_mode != RuleMode.READ and rule_mode != RuleMode.WRITE: diff --git a/tests/schema_registry/test_avro_serdes.py b/tests/schema_registry/test_avro_serdes.py index be2c4b3f9..08377f91a 100644 --- a/tests/schema_registry/test_avro_serdes.py +++ b/tests/schema_registry/test_avro_serdes.py @@ -632,13 +632,13 @@ def test_avro_cel_field_transform_disable(): registry = RuleRegistry() registry.register_rule_executor(CelFieldExecutor()) registry.register_override(RuleOverride("CEL_FIELD", None, None, True)) - ser = AvroSerializer(client, schema_str=None, conf=ser_conf) + ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_registry=registry) ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = ser(obj, ser_ctx) deser = AvroDeserializer(client) newobj = deser(obj_bytes, ser_ctx) - assert obj == newobj + assert "hi" == newobj['stringField'] def test_avro_cel_field_transform_complex(): diff --git a/tests/schema_registry/test_config.py b/tests/schema_registry/test_config.py index 9a0798993..ac2ba5d0f 100644 --- a/tests/schema_registry/test_config.py +++ b/tests/schema_registry/test_config.py @@ -54,7 +54,7 @@ def test_config_url_none(): def test_config_url_trailing_slash(): conf = {'url': 'http://SchemaRegistry:65534/'} test_client = SchemaRegistryClient(conf) - assert test_client._rest_client.base_url == TEST_URL + assert test_client._rest_client.base_urls == [TEST_URL] def test_config_ssl_key_no_certificate():