Skip to content

Commit ab03f22

Browse files
kristjanvalurpre-commit-ci[bot]tiangolo
authored
✨ Add exception handler for WebSocketRequestValidationError (which also allows to override it) (fastapi#6030)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sebastián Ramírez <[email protected]>
1 parent f5e2dd8 commit ab03f22

4 files changed

Lines changed: 166 additions & 9 deletions

File tree

fastapi/applications.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
from fastapi.exception_handlers import (
2020
http_exception_handler,
2121
request_validation_exception_handler,
22+
websocket_request_validation_exception_handler,
2223
)
23-
from fastapi.exceptions import RequestValidationError
24+
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
2425
from fastapi.logger import logger
2526
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
2627
from fastapi.openapi.docs import (
@@ -145,6 +146,11 @@ def __init__(
145146
self.exception_handlers.setdefault(
146147
RequestValidationError, request_validation_exception_handler
147148
)
149+
self.exception_handlers.setdefault(
150+
WebSocketRequestValidationError,
151+
# Starlette still has incorrect type specification for the handlers
152+
websocket_request_validation_exception_handler, # type: ignore
153+
)
148154

149155
self.user_middleware: List[Middleware] = (
150156
[] if middleware is None else list(middleware)

fastapi/exception_handlers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from fastapi.encoders import jsonable_encoder
2-
from fastapi.exceptions import RequestValidationError
2+
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
33
from fastapi.utils import is_body_allowed_for_status_code
4+
from fastapi.websockets import WebSocket
45
from starlette.exceptions import HTTPException
56
from starlette.requests import Request
67
from starlette.responses import JSONResponse, Response
7-
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
8+
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
89

910

1011
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
@@ -23,3 +24,11 @@ async def request_validation_exception_handler(
2324
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
2425
content={"detail": jsonable_encoder(exc.errors())},
2526
)
27+
28+
29+
async def websocket_request_validation_exception_handler(
30+
websocket: WebSocket, exc: WebSocketRequestValidationError
31+
) -> None:
32+
await websocket.close(
33+
code=WS_1008_POLICY_VIOLATION, reason=jsonable_encoder(exc.errors())
34+
)

fastapi/routing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
request_response,
5757
websocket_session,
5858
)
59-
from starlette.status import WS_1008_POLICY_VIOLATION
6059
from starlette.types import ASGIApp, Lifespan, Scope
6160
from starlette.websockets import WebSocket
6261

@@ -283,7 +282,6 @@ async def app(websocket: WebSocket) -> None:
283282
)
284283
values, errors, _, _2, _3 = solved_result
285284
if errors:
286-
await websocket.close(code=WS_1008_POLICY_VIOLATION)
287285
raise WebSocketRequestValidationError(errors)
288286
assert dependant.call is not None, "dependant.call must be a function"
289287
await dependant.call(**values)

tests/test_ws_router.py

Lines changed: 148 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1-
from fastapi import APIRouter, Depends, FastAPI, WebSocket
1+
import functools
2+
3+
import pytest
4+
from fastapi import (
5+
APIRouter,
6+
Depends,
7+
FastAPI,
8+
Header,
9+
WebSocket,
10+
WebSocketDisconnect,
11+
status,
12+
)
13+
from fastapi.middleware import Middleware
214
from fastapi.testclient import TestClient
315

416
router = APIRouter()
@@ -63,9 +75,44 @@ async def router_native_prefix_ws(websocket: WebSocket):
6375
await websocket.close()
6476

6577

66-
app.include_router(router)
67-
app.include_router(prefix_router, prefix="/prefix")
68-
app.include_router(native_prefix_route)
78+
async def ws_dependency_err():
79+
raise NotImplementedError()
80+
81+
82+
@router.websocket("/depends-err/")
83+
async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)):
84+
pass # pragma: no cover
85+
86+
87+
async def ws_dependency_validate(x_missing: str = Header()):
88+
pass # pragma: no cover
89+
90+
91+
@router.websocket("/depends-validate/")
92+
async def router_ws_depends_validate(
93+
websocket: WebSocket, data=Depends(ws_dependency_validate)
94+
):
95+
pass # pragma: no cover
96+
97+
98+
class CustomError(Exception):
99+
pass
100+
101+
102+
@router.websocket("/custom_error/")
103+
async def router_ws_custom_error(websocket: WebSocket):
104+
raise CustomError()
105+
106+
107+
def make_app(app=None, **kwargs):
108+
app = app or FastAPI(**kwargs)
109+
app.include_router(router)
110+
app.include_router(prefix_router, prefix="/prefix")
111+
app.include_router(native_prefix_route)
112+
return app
113+
114+
115+
app = make_app(app)
69116

70117

71118
def test_app():
@@ -125,3 +172,100 @@ def test_router_with_params():
125172
assert data == "path/to/file"
126173
data = websocket.receive_text()
127174
assert data == "a_query_param"
175+
176+
177+
def test_wrong_uri():
178+
"""
179+
Verify that a websocket connection to a non-existent endpoing returns in a shutdown
180+
"""
181+
client = TestClient(app)
182+
with pytest.raises(WebSocketDisconnect) as e:
183+
with client.websocket_connect("/no-router/"):
184+
pass # pragma: no cover
185+
assert e.value.code == status.WS_1000_NORMAL_CLOSURE
186+
187+
188+
def websocket_middleware(middleware_func):
189+
"""
190+
Helper to create a Starlette pure websocket middleware
191+
"""
192+
193+
def middleware_constructor(app):
194+
@functools.wraps(app)
195+
async def wrapped_app(scope, receive, send):
196+
if scope["type"] != "websocket":
197+
return await app(scope, receive, send) # pragma: no cover
198+
199+
async def call_next():
200+
return await app(scope, receive, send)
201+
202+
websocket = WebSocket(scope, receive=receive, send=send)
203+
return await middleware_func(websocket, call_next)
204+
205+
return wrapped_app
206+
207+
return middleware_constructor
208+
209+
210+
def test_depend_validation():
211+
"""
212+
Verify that a validation in a dependency invokes the correct exception handler
213+
"""
214+
caught = []
215+
216+
@websocket_middleware
217+
async def catcher(websocket, call_next):
218+
try:
219+
return await call_next()
220+
except Exception as e: # pragma: no cover
221+
caught.append(e)
222+
raise
223+
224+
myapp = make_app(middleware=[Middleware(catcher)])
225+
226+
client = TestClient(myapp)
227+
with pytest.raises(WebSocketDisconnect) as e:
228+
with client.websocket_connect("/depends-validate/"):
229+
pass # pragma: no cover
230+
# the validation error does produce a close message
231+
assert e.value.code == status.WS_1008_POLICY_VIOLATION
232+
# and no error is leaked
233+
assert caught == []
234+
235+
236+
def test_depend_err_middleware():
237+
"""
238+
Verify that it is possible to write custom WebSocket middleware to catch errors
239+
"""
240+
241+
@websocket_middleware
242+
async def errorhandler(websocket: WebSocket, call_next):
243+
try:
244+
return await call_next()
245+
except Exception as e:
246+
await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e))
247+
248+
myapp = make_app(middleware=[Middleware(errorhandler)])
249+
client = TestClient(myapp)
250+
with pytest.raises(WebSocketDisconnect) as e:
251+
with client.websocket_connect("/depends-err/"):
252+
pass # pragma: no cover
253+
assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE
254+
assert "NotImplementedError" in e.value.reason
255+
256+
257+
def test_depend_err_handler():
258+
"""
259+
Verify that it is possible to write custom WebSocket middleware to catch errors
260+
"""
261+
262+
async def custom_handler(websocket: WebSocket, exc: CustomError) -> None:
263+
await websocket.close(1002, "foo")
264+
265+
myapp = make_app(exception_handlers={CustomError: custom_handler})
266+
client = TestClient(myapp)
267+
with pytest.raises(WebSocketDisconnect) as e:
268+
with client.websocket_connect("/custom_error/"):
269+
pass # pragma: no cover
270+
assert e.value.code == 1002
271+
assert "foo" in e.value.reason

0 commit comments

Comments
 (0)