Skip to content

Commit 635a4ba

Browse files
authored
feat: allow injection of httpx client (#591)
1 parent fa90200 commit 635a4ba

14 files changed

+327
-26
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ remove_pytest_asyncio_from_sync:
3737
sed -i 's/@pytest.mark.asyncio//g' tests/_sync/test_client.py
3838
sed -i 's/_async/_sync/g' tests/_sync/test_client.py
3939
sed -i 's/Async/Sync/g' tests/_sync/test_client.py
40+
sed -i 's/Async/Sync/g' postgrest/_sync/request_builder.py
4041
sed -i 's/_client\.SyncClient/_client\.Client/g' tests/_sync/test_client.py
42+
sed -i 's/SyncHTTPTransport/HTTPTransport/g' tests/_sync/test_client.py
43+
sed -i 's/SyncHTTPTransport/HTTPTransport/g' tests/_sync/client.py
4144

4245
sleep:
4346
sleep 2

postgrest/__init__.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,41 @@
2727
from .deprecated_client import Client, PostgrestClient
2828
from .deprecated_get_request_builder import GetRequestBuilder
2929
from .exceptions import APIError
30+
from .types import (
31+
CountMethod,
32+
Filters,
33+
RequestMethod,
34+
ReturnMethod,
35+
)
3036
from .version import __version__
37+
38+
__all__ = [
39+
"AsyncPostgrestClient",
40+
"AsyncFilterRequestBuilder",
41+
"AsyncQueryRequestBuilder",
42+
"AsyncRequestBuilder",
43+
"AsyncRPCFilterRequestBuilder",
44+
"AsyncSelectRequestBuilder",
45+
"AsyncSingleRequestBuilder",
46+
"AsyncMaybeSingleRequestBuilder",
47+
"SyncPostgrestClient",
48+
"SyncFilterRequestBuilder",
49+
"SyncMaybeSingleRequestBuilder",
50+
"SyncQueryRequestBuilder",
51+
"SyncRequestBuilder",
52+
"SyncRPCFilterRequestBuilder",
53+
"SyncSelectRequestBuilder",
54+
"SyncSingleRequestBuilder",
55+
"APIResponse",
56+
"DEFAULT_POSTGREST_CLIENT_HEADERS",
57+
"Client",
58+
"PostgrestClient",
59+
"GetRequestBuilder",
60+
"APIError",
61+
"CountMethod",
62+
"Filters",
63+
"RequestMethod",
64+
"ReturnMethod",
65+
"Timeout",
66+
"__version__",
67+
]

postgrest/_async/client.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from typing import Any, Dict, Optional, Union, cast
4+
from warnings import warn
45

56
from deprecation import deprecated
67
from httpx import Headers, QueryParams, Timeout
@@ -27,18 +28,50 @@ def __init__(
2728
*,
2829
schema: str = "public",
2930
headers: Dict[str, str] = DEFAULT_POSTGREST_CLIENT_HEADERS,
30-
timeout: Union[int, float, Timeout] = DEFAULT_POSTGREST_CLIENT_TIMEOUT,
31-
verify: bool = True,
31+
timeout: Union[int, float, Timeout, None] = None,
32+
verify: Optional[bool] = None,
3233
proxy: Optional[str] = None,
34+
http_client: Optional[AsyncClient] = None,
3335
) -> None:
36+
if timeout is not None:
37+
warn(
38+
"The 'timeout' parameter is deprecated. Please configure it in the http client instead.",
39+
DeprecationWarning,
40+
stacklevel=2,
41+
)
42+
if verify is not None:
43+
warn(
44+
"The 'verify' parameter is deprecated. Please configure it in the http client instead.",
45+
DeprecationWarning,
46+
stacklevel=2,
47+
)
48+
if proxy is not None:
49+
warn(
50+
"The 'proxy' parameter is deprecated. Please configure it in the http client instead.",
51+
DeprecationWarning,
52+
stacklevel=2,
53+
)
54+
55+
self.verify = bool(verify) if verify is not None else True
56+
self.timeout = (
57+
timeout
58+
if isinstance(timeout, Timeout)
59+
else (
60+
int(abs(timeout))
61+
if timeout is not None
62+
else DEFAULT_POSTGREST_CLIENT_TIMEOUT
63+
)
64+
)
65+
3466
BasePostgrestClient.__init__(
3567
self,
3668
base_url,
3769
schema=schema,
3870
headers=headers,
39-
timeout=timeout,
40-
verify=verify,
71+
timeout=self.timeout,
72+
verify=self.verify,
4173
proxy=proxy,
74+
http_client=http_client,
4275
)
4376
self.session = cast(AsyncClient, self.session)
4477

@@ -50,6 +83,15 @@ def create_session(
5083
verify: bool = True,
5184
proxy: Optional[str] = None,
5285
) -> AsyncClient:
86+
http_client = None
87+
if isinstance(self.http_client, AsyncClient):
88+
http_client = self.http_client
89+
90+
if http_client is not None:
91+
http_client.base_url = base_url
92+
http_client.headers.update({**headers})
93+
return http_client
94+
5395
return AsyncClient(
5496
base_url=base_url,
5597
headers=headers,

postgrest/_sync/client.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from typing import Any, Dict, Optional, Union, cast
4+
from warnings import warn
45

56
from deprecation import deprecated
67
from httpx import Headers, QueryParams, Timeout
@@ -27,18 +28,50 @@ def __init__(
2728
*,
2829
schema: str = "public",
2930
headers: Dict[str, str] = DEFAULT_POSTGREST_CLIENT_HEADERS,
30-
timeout: Union[int, float, Timeout] = DEFAULT_POSTGREST_CLIENT_TIMEOUT,
31-
verify: bool = True,
31+
timeout: Union[int, float, Timeout, None] = None,
32+
verify: Optional[bool] = None,
3233
proxy: Optional[str] = None,
34+
http_client: Optional[SyncClient] = None,
3335
) -> None:
36+
if timeout is not None:
37+
warn(
38+
"The 'timeout' parameter is deprecated. Please configure it in the http client instead.",
39+
DeprecationWarning,
40+
stacklevel=2,
41+
)
42+
if verify is not None:
43+
warn(
44+
"The 'verify' parameter is deprecated. Please configure it in the http client instead.",
45+
DeprecationWarning,
46+
stacklevel=2,
47+
)
48+
if proxy is not None:
49+
warn(
50+
"The 'proxy' parameter is deprecated. Please configure it in the http client instead.",
51+
DeprecationWarning,
52+
stacklevel=2,
53+
)
54+
55+
self.verify = bool(verify) if verify is not None else True
56+
self.timeout = (
57+
timeout
58+
if isinstance(timeout, Timeout)
59+
else (
60+
int(abs(timeout))
61+
if timeout is not None
62+
else DEFAULT_POSTGREST_CLIENT_TIMEOUT
63+
)
64+
)
65+
3466
BasePostgrestClient.__init__(
3567
self,
3668
base_url,
3769
schema=schema,
3870
headers=headers,
39-
timeout=timeout,
40-
verify=verify,
71+
timeout=self.timeout,
72+
verify=self.verify,
4173
proxy=proxy,
74+
http_client=http_client,
4275
)
4376
self.session = cast(SyncClient, self.session)
4477

@@ -50,6 +83,15 @@ def create_session(
5083
verify: bool = True,
5184
proxy: Optional[str] = None,
5285
) -> SyncClient:
86+
http_client = None
87+
if isinstance(self.http_client, SyncClient):
88+
http_client = self.http_client
89+
90+
if http_client is not None:
91+
http_client.base_url = base_url
92+
http_client.headers.update({**headers})
93+
return http_client
94+
5395
return SyncClient(
5496
base_url=base_url,
5597
headers=headers,

postgrest/_sync/request_builder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def select(
287287
*columns: The names of the columns to fetch.
288288
count: The method to use to get the count of rows returned.
289289
Returns:
290-
:class:`AsyncSelectRequestBuilder`
290+
:class:`SyncSelectRequestBuilder`
291291
"""
292292
method, params, headers, json = pre_select(*columns, count=count, head=head)
293293
return SyncSelectRequestBuilder[_ReturnT](
@@ -314,7 +314,7 @@ def insert(
314314
Otherwise, use the default value for the column.
315315
Only applies for bulk inserts.
316316
Returns:
317-
:class:`AsyncQueryRequestBuilder`
317+
:class:`SyncQueryRequestBuilder`
318318
"""
319319
method, params, headers, json = pre_insert(
320320
json,
@@ -350,7 +350,7 @@ def upsert(
350350
not when merging with existing rows under `ignoreDuplicates: false`.
351351
This also only applies when doing bulk upserts.
352352
Returns:
353-
:class:`AsyncQueryRequestBuilder`
353+
:class:`SyncQueryRequestBuilder`
354354
"""
355355
method, params, headers, json = pre_upsert(
356356
json,
@@ -378,7 +378,7 @@ def update(
378378
count: The method to use to get the count of rows returned.
379379
returning: Either 'minimal' or 'representation'
380380
Returns:
381-
:class:`AsyncFilterRequestBuilder`
381+
:class:`SyncFilterRequestBuilder`
382382
"""
383383
method, params, headers, json = pre_update(
384384
json,
@@ -401,7 +401,7 @@ def delete(
401401
count: The method to use to get the count of rows returned.
402402
returning: Either 'minimal' or 'representation'
403403
Returns:
404-
:class:`AsyncFilterRequestBuilder`
404+
:class:`SyncFilterRequestBuilder`
405405
"""
406406
method, params, headers, json = pre_delete(
407407
count=count,

postgrest/base_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
timeout: Union[int, float, Timeout],
2121
verify: bool = True,
2222
proxy: Optional[str] = None,
23+
http_client: Union[SyncClient, AsyncClient, None] = None,
2324
) -> None:
2425
if not is_http_url(base_url):
2526
ValueError("base_url must be a valid HTTP URL string")
@@ -33,8 +34,13 @@ def __init__(
3334
self.timeout = timeout
3435
self.verify = verify
3536
self.proxy = proxy
37+
self.http_client = http_client
3638
self.session = self.create_session(
37-
self.base_url, self.headers, self.timeout, self.verify, self.proxy
39+
self.base_url,
40+
self.headers,
41+
self.timeout,
42+
self.verify,
43+
self.proxy,
3844
)
3945

4046
@abstractmethod

postgrest/exceptions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional
1+
from typing import Any, Dict, Optional
22

33
from pydantic import BaseModel
44

@@ -34,7 +34,7 @@ class APIError(Exception):
3434
details: Optional[str]
3535
"""The error details."""
3636

37-
def __init__(self, error: Dict[str, str]) -> None:
37+
def __init__(self, error: Dict[str, Any]) -> None:
3838
self._raw_error = error
3939
self.message = error.get("message")
4040
self.code = error.get("code")

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ furo = ">=2023.9.10,<2025.0.0"
4444

4545
[tool.pytest.ini_options]
4646
asyncio_mode = "auto"
47+
filterwarnings = [
48+
"ignore::DeprecationWarning", # ignore deprecation warnings globally
49+
]
4750

4851
[build-system]
4952
requires = ["poetry-core>=1.0.0"]

tests/_async/client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from httpx import AsyncHTTPTransport, Limits
2+
13
from postgrest import AsyncPostgrestClient
4+
from postgrest.utils import AsyncClient
25

36
REST_URL = "http://127.0.0.1:3000"
47

@@ -7,3 +10,20 @@ def rest_client():
710
return AsyncPostgrestClient(
811
base_url=REST_URL,
912
)
13+
14+
15+
def rest_client_httpx():
16+
transport = AsyncHTTPTransport(
17+
retries=4,
18+
limits=Limits(
19+
max_connections=1,
20+
max_keepalive_connections=1,
21+
keepalive_expiry=None,
22+
),
23+
)
24+
headers = {"x-user-agent": "my-app/0.0.1"}
25+
http_client = AsyncClient(transport=transport, headers=headers)
26+
return AsyncPostgrestClient(
27+
base_url=REST_URL,
28+
http_client=http_client,
29+
)

tests/_async/test_client.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
from unittest.mock import patch
22

33
import pytest
4-
from httpx import BasicAuth, Headers, Request, Response
4+
from httpx import (
5+
AsyncHTTPTransport,
6+
BasicAuth,
7+
Headers,
8+
Limits,
9+
Request,
10+
Response,
11+
Timeout,
12+
)
513

614
from postgrest import AsyncPostgrestClient
715
from postgrest.exceptions import APIError
16+
from postgrest.utils import AsyncClient
817

918

1019
@pytest.fixture
@@ -46,6 +55,32 @@ async def test_custom_headers(self):
4655
assert session.headers.items() >= headers.items()
4756

4857

58+
class TestHttpxClientConstructor:
59+
@pytest.mark.asyncio
60+
async def test_custom_httpx_client(self):
61+
transport = AsyncHTTPTransport(
62+
retries=10,
63+
limits=Limits(
64+
max_connections=1,
65+
max_keepalive_connections=1,
66+
keepalive_expiry=None,
67+
),
68+
)
69+
headers = {"x-user-agent": "my-app/0.0.1"}
70+
http_client = AsyncClient(transport=transport, headers=headers)
71+
async with AsyncPostgrestClient(
72+
"https://example.com", http_client=http_client, timeout=20.0
73+
) as client:
74+
session = client.session
75+
76+
assert session.base_url == "https://example.com"
77+
assert session.timeout == Timeout(
78+
timeout=5.0
79+
) # Should be the default 5 since we use custom httpx client
80+
assert session.headers.get("x-user-agent") == "my-app/0.0.1"
81+
assert isinstance(session, AsyncClient)
82+
83+
4984
class TestAuth:
5085
def test_auth_token(self, postgrest_client: AsyncPostgrestClient):
5186
postgrest_client.auth("s3cr3t")

0 commit comments

Comments
 (0)