|
1 | | -from fastapi import APIRouter, Depends, FastAPI, WebSocket |
| 1 | +import functools |
| 2 | + |
| 3 | +import pytest |
| 4 | +from fastapi import ( |
| 5 | + APIRouter, |
| 6 | + Depends, |
| 7 | + FastAPI, |
| 8 | + Header, |
| 9 | + WebSocket, |
| 10 | + WebSocketDisconnect, |
| 11 | + status, |
| 12 | +) |
| 13 | +from fastapi.middleware import Middleware |
2 | 14 | from fastapi.testclient import TestClient |
3 | 15 |
|
4 | 16 | router = APIRouter() |
@@ -63,9 +75,44 @@ async def router_native_prefix_ws(websocket: WebSocket): |
63 | 75 | await websocket.close() |
64 | 76 |
|
65 | 77 |
|
66 | | -app.include_router(router) |
67 | | -app.include_router(prefix_router, prefix="/prefix") |
68 | | -app.include_router(native_prefix_route) |
| 78 | +async def ws_dependency_err(): |
| 79 | + raise NotImplementedError() |
| 80 | + |
| 81 | + |
| 82 | +@router.websocket("/depends-err/") |
| 83 | +async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)): |
| 84 | + pass # pragma: no cover |
| 85 | + |
| 86 | + |
| 87 | +async def ws_dependency_validate(x_missing: str = Header()): |
| 88 | + pass # pragma: no cover |
| 89 | + |
| 90 | + |
| 91 | +@router.websocket("/depends-validate/") |
| 92 | +async def router_ws_depends_validate( |
| 93 | + websocket: WebSocket, data=Depends(ws_dependency_validate) |
| 94 | +): |
| 95 | + pass # pragma: no cover |
| 96 | + |
| 97 | + |
| 98 | +class CustomError(Exception): |
| 99 | + pass |
| 100 | + |
| 101 | + |
| 102 | +@router.websocket("/custom_error/") |
| 103 | +async def router_ws_custom_error(websocket: WebSocket): |
| 104 | + raise CustomError() |
| 105 | + |
| 106 | + |
| 107 | +def make_app(app=None, **kwargs): |
| 108 | + app = app or FastAPI(**kwargs) |
| 109 | + app.include_router(router) |
| 110 | + app.include_router(prefix_router, prefix="/prefix") |
| 111 | + app.include_router(native_prefix_route) |
| 112 | + return app |
| 113 | + |
| 114 | + |
| 115 | +app = make_app(app) |
69 | 116 |
|
70 | 117 |
|
71 | 118 | def test_app(): |
@@ -125,3 +172,100 @@ def test_router_with_params(): |
125 | 172 | assert data == "path/to/file" |
126 | 173 | data = websocket.receive_text() |
127 | 174 | assert data == "a_query_param" |
| 175 | + |
| 176 | + |
| 177 | +def test_wrong_uri(): |
| 178 | + """ |
| 179 | + Verify that a websocket connection to a non-existent endpoing returns in a shutdown |
| 180 | + """ |
| 181 | + client = TestClient(app) |
| 182 | + with pytest.raises(WebSocketDisconnect) as e: |
| 183 | + with client.websocket_connect("/no-router/"): |
| 184 | + pass # pragma: no cover |
| 185 | + assert e.value.code == status.WS_1000_NORMAL_CLOSURE |
| 186 | + |
| 187 | + |
| 188 | +def websocket_middleware(middleware_func): |
| 189 | + """ |
| 190 | + Helper to create a Starlette pure websocket middleware |
| 191 | + """ |
| 192 | + |
| 193 | + def middleware_constructor(app): |
| 194 | + @functools.wraps(app) |
| 195 | + async def wrapped_app(scope, receive, send): |
| 196 | + if scope["type"] != "websocket": |
| 197 | + return await app(scope, receive, send) # pragma: no cover |
| 198 | + |
| 199 | + async def call_next(): |
| 200 | + return await app(scope, receive, send) |
| 201 | + |
| 202 | + websocket = WebSocket(scope, receive=receive, send=send) |
| 203 | + return await middleware_func(websocket, call_next) |
| 204 | + |
| 205 | + return wrapped_app |
| 206 | + |
| 207 | + return middleware_constructor |
| 208 | + |
| 209 | + |
| 210 | +def test_depend_validation(): |
| 211 | + """ |
| 212 | + Verify that a validation in a dependency invokes the correct exception handler |
| 213 | + """ |
| 214 | + caught = [] |
| 215 | + |
| 216 | + @websocket_middleware |
| 217 | + async def catcher(websocket, call_next): |
| 218 | + try: |
| 219 | + return await call_next() |
| 220 | + except Exception as e: # pragma: no cover |
| 221 | + caught.append(e) |
| 222 | + raise |
| 223 | + |
| 224 | + myapp = make_app(middleware=[Middleware(catcher)]) |
| 225 | + |
| 226 | + client = TestClient(myapp) |
| 227 | + with pytest.raises(WebSocketDisconnect) as e: |
| 228 | + with client.websocket_connect("/depends-validate/"): |
| 229 | + pass # pragma: no cover |
| 230 | + # the validation error does produce a close message |
| 231 | + assert e.value.code == status.WS_1008_POLICY_VIOLATION |
| 232 | + # and no error is leaked |
| 233 | + assert caught == [] |
| 234 | + |
| 235 | + |
| 236 | +def test_depend_err_middleware(): |
| 237 | + """ |
| 238 | + Verify that it is possible to write custom WebSocket middleware to catch errors |
| 239 | + """ |
| 240 | + |
| 241 | + @websocket_middleware |
| 242 | + async def errorhandler(websocket: WebSocket, call_next): |
| 243 | + try: |
| 244 | + return await call_next() |
| 245 | + except Exception as e: |
| 246 | + await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e)) |
| 247 | + |
| 248 | + myapp = make_app(middleware=[Middleware(errorhandler)]) |
| 249 | + client = TestClient(myapp) |
| 250 | + with pytest.raises(WebSocketDisconnect) as e: |
| 251 | + with client.websocket_connect("/depends-err/"): |
| 252 | + pass # pragma: no cover |
| 253 | + assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE |
| 254 | + assert "NotImplementedError" in e.value.reason |
| 255 | + |
| 256 | + |
| 257 | +def test_depend_err_handler(): |
| 258 | + """ |
| 259 | + Verify that it is possible to write custom WebSocket middleware to catch errors |
| 260 | + """ |
| 261 | + |
| 262 | + async def custom_handler(websocket: WebSocket, exc: CustomError) -> None: |
| 263 | + await websocket.close(1002, "foo") |
| 264 | + |
| 265 | + myapp = make_app(exception_handlers={CustomError: custom_handler}) |
| 266 | + client = TestClient(myapp) |
| 267 | + with pytest.raises(WebSocketDisconnect) as e: |
| 268 | + with client.websocket_connect("/custom_error/"): |
| 269 | + pass # pragma: no cover |
| 270 | + assert e.value.code == 1002 |
| 271 | + assert "foo" in e.value.reason |
0 commit comments