Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions polyapi/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

_sync_client: httpx.Client | None = None
_async_client: httpx.AsyncClient | None = None
_async_client_loop: asyncio.AbstractEventLoop | None = None


def _get_sync_client() -> httpx.Client:
Expand All @@ -13,9 +14,11 @@ def _get_sync_client() -> httpx.Client:


def _get_async_client() -> httpx.AsyncClient:
global _async_client
if _async_client is None:
global _async_client, _async_client_loop
current_loop = asyncio.get_running_loop()
if _async_client is None or _async_client_loop is not current_loop:
_async_client = httpx.AsyncClient(timeout=None)
_async_client_loop = current_loop
return _async_client


Expand Down Expand Up @@ -66,8 +69,15 @@ def close():
_sync_client = None

async def close_async():
global _sync_client, _async_client
global _sync_client, _async_client, _async_client_loop
close()
if _async_client is not None:
await _async_client.aclose()
_async_client = None
client = _async_client
client_loop = _async_client_loop
_async_client = None
_async_client_loop = None
if client is None:
return

current_loop = asyncio.get_running_loop()
if client_loop is current_loop:
await client.aclose()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "polyapi-python"
version = "0.3.15"
version = "0.3.16"
description = "The Python Client for PolyAPI, the IPaaS by Developers for Developers"
authors = [{ name = "Dan Fellin", email = "[email protected]" }]
dependencies = [
Expand Down
86 changes: 86 additions & 0 deletions tests/test_async_proof.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def setup_method(self):
# Reset singletons so each test starts fresh
http_client._sync_client = None
http_client._async_client = None
http_client._async_client_loop = None

def teardown_method(self):
http_client._sync_client = None
http_client._async_client = None
http_client._async_client_loop = None

@patch.object(httpx.Client, "post", return_value=_fake_response())
def test_sync_post_uses_sync_client(self, mock_post):
Expand All @@ -79,6 +81,90 @@ async def _run():
mock_post.assert_called_once()
assert resp.status_code == 200
assert http_client._async_client is not None
assert http_client._async_client_loop is not None

def test_async_post_reuses_client_within_same_loop(self):
first_client = MagicMock()
first_client.post = AsyncMock(return_value=_fake_response())

with patch("polyapi.http_client.httpx.AsyncClient", return_value=first_client) as mock_async_client:
async def _run():
first_response = await http_client.async_post("https://example.com/first", json={})
second_response = await http_client.async_post("https://example.com/second", json={})
return first_response, second_response, asyncio.get_running_loop()

first_response, second_response, current_loop = asyncio.run(_run())

assert first_response.status_code == 200
assert second_response.status_code == 200
assert mock_async_client.call_count == 1
assert first_client.post.await_count == 2
assert http_client._async_client is first_client
assert http_client._async_client_loop is current_loop

def test_async_post_recreates_client_after_loop_change(self):
first_client = MagicMock()
first_client.post = AsyncMock(return_value=_fake_response())
second_client = MagicMock()
second_client.post = AsyncMock(return_value=_fake_response())

with patch(
"polyapi.http_client.httpx.AsyncClient",
side_effect=[first_client, second_client],
) as mock_async_client:
async def _run_once(url: str):
response = await http_client.async_post(url, json={})
return response, http_client._async_client, asyncio.get_running_loop()

first_response, first_cached_client, first_loop = asyncio.run(_run_once("https://example.com/first"))
second_response, second_cached_client, second_loop = asyncio.run(_run_once("https://example.com/second"))

assert first_response.status_code == 200
assert second_response.status_code == 200
assert mock_async_client.call_count == 2
assert first_client.post.await_count == 1
assert second_client.post.await_count == 1
assert first_cached_client is first_client
assert second_cached_client is second_client
assert first_loop is not second_loop
assert http_client._async_client is second_client
assert http_client._async_client_loop is second_loop

def test_close_async_clears_cached_client_for_current_loop(self):
async def _run():
cached_client = MagicMock()
cached_client.aclose = AsyncMock()
http_client._async_client = cached_client
http_client._async_client_loop = asyncio.get_running_loop()

await http_client.close_async()

return cached_client

cached_client = asyncio.run(_run())

cached_client.aclose.assert_awaited_once()
assert http_client._async_client is None
assert http_client._async_client_loop is None

def test_close_async_drops_stale_client_without_cross_loop_close(self):
stale_client = MagicMock()
stale_client.aclose = AsyncMock()

async def _seed_stale_client():
http_client._async_client = stale_client
http_client._async_client_loop = asyncio.get_running_loop()

asyncio.run(_seed_stale_client())

async def _close_on_new_loop():
await http_client.close_async()

asyncio.run(_close_on_new_loop())

stale_client.aclose.assert_not_awaited()
assert http_client._async_client is None
assert http_client._async_client_loop is None

@patch.object(httpx.Client, "get", return_value=_fake_response())
def test_sync_get(self, mock_get):
Expand Down