Skip to content

Commit bf87b56

Browse files
committed
增加一种通用的验证字段的方式
1 parent c0bb3c6 commit bf87b56

File tree

4 files changed

+240
-23
lines changed

4 files changed

+240
-23
lines changed

run.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,28 +55,27 @@
5555
async def request_validation_exception_handler2(request: Request, exc: tutorial.chapter03.ValDtoError):
5656
print(f"===>参数校验异常{request.method} {request.url}")
5757
return JSONResponse({"message":exc.message})
58-
# @app.exception_handler(RequestValidationError) # 重写请求验证异常处理器
59-
# async def validation_exception_handler(request, exc):
60-
#
61-
# """
62-
# :param request: 这个参数不能省
63-
# :param exc:
64-
# :return:
65-
# """
66-
# print("222"*10)
67-
# # 自定义错误输出信息
68-
# errors = []
69-
# for error in exc.errors():
70-
# print(error)
71-
# print(type(error))
72-
# errors.append({
73-
# "loc": error["loc"],
74-
# "msg": error["msg"],
75-
# "type": error["type"]
76-
# })
77-
# # raise HTTPException(status_code=422, detail=errors)
78-
#
79-
# return PlainTextResponse(str(errors), status_code=400)
58+
@app.exception_handler(RequestValidationError) # 重写请求验证异常处理器
59+
async def validation_exception_handler(request, exc):
60+
61+
"""
62+
:param request: 这个参数不能省
63+
:param exc:
64+
:return:
65+
"""
66+
errors = []
67+
for error in exc.errors():
68+
print(error)
69+
print(type(error))
70+
errors.append({
71+
"loc": error["loc"],
72+
"msg": error["msg"],
73+
"type": error["type"]
74+
})
75+
# raise HTTPException(status_code=422, detail=errors)
76+
return JSONResponse({"code":402,"msg":errors[0]['msg'],"errors":errors})
77+
78+
# return PlainTextResponse(str(errors), status_code=400)
8079

8180

8281
# @app.middleware('http')

test-jwt.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from datetime import timedelta, datetime
2+
from typing import Union
3+
4+
import uvicorn
5+
from fastapi import FastAPI, Depends, HTTPException
6+
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
7+
from jose import jwt, JWTError
8+
from passlib.context import CryptContext
9+
from pydantic import BaseModel
10+
from starlette import status
11+
12+
# openssl rand -hex 32
13+
# 使用命令行生成jwt的随机密钥
14+
SECRET_KEY = "25a7a7c926bdc1a840daced5aeaa0f8ed1abebaa975a49522752aa16c4a63fc9"
15+
ALGORITHM = "HS256" # 算法
16+
ACCESS_TOKEN_EXPIRE_MINUTES = 20 # token过期时间
17+
18+
# 这是数据库中的用户列表
19+
fake_users_db = {
20+
"johndoe": {
21+
"username": "johndoe",
22+
"full_name": "John Doe",
23+
"email": "[email protected]",
24+
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW",
25+
"disabled": False,
26+
}
27+
}
28+
29+
30+
class Token(BaseModel):
31+
"""
32+
token返回对象
33+
"""
34+
access_token: str
35+
token_type: str
36+
37+
38+
class TokenData(BaseModel):
39+
username: Union[str, None] = None
40+
41+
42+
class User(BaseModel):
43+
username: str
44+
email: Union[str, None] = None
45+
full_name: Union[str, None] = None
46+
disabled: Union[bool, None] = None
47+
48+
49+
# 继承了User类,存储与数据库的user
50+
class UserInDB(User):
51+
hashed_password: str
52+
53+
54+
# 这是一个实例对象
55+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
56+
57+
# 这是一个实例对象
58+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token")
59+
60+
app = FastAPI()
61+
62+
63+
def verify_password(plain_password, hashed_password) -> bool:
64+
"""
65+
校验密码是否正确
66+
:param plain_password: 密码明文
67+
:param hashed_password: 加密后的密码
68+
:return: 校验结果
69+
"""
70+
return pwd_context.verify(plain_password, hashed_password)
71+
72+
73+
def get_password_hash(password) -> str:
74+
"""
75+
获取密码的hash值
76+
:param password: 密码明文
77+
:return: 密码hash值
78+
"""
79+
return pwd_context.hash(password)
80+
81+
82+
def get_user(db: dict, username: str):
83+
"""
84+
从数据库中获取用户
85+
:param db: 数据库,此处假设为字典
86+
:param username: 用户名
87+
:return: 用户信息
88+
"""
89+
if username in db:
90+
user_dict = db[username]
91+
return UserInDB(**user_dict)
92+
93+
94+
def authenticate_user(fake_db: dict, username: str, password: str):
95+
"""
96+
验证用户身份
97+
:param fake_db:
98+
:param username:
99+
:param password:
100+
:return:
101+
"""
102+
# 从数据库获取数据库中的用户对象
103+
user = get_user(fake_db, username)
104+
if not user:
105+
# 如果数据库中没有这个用户,返回false
106+
return False
107+
if not verify_password(password, user.hashed_password):
108+
# 如果数据库中有这个用户,但是密码校验不通过,返回false
109+
return False
110+
return user
111+
112+
113+
def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None):
114+
"""
115+
生成token
116+
:param data: 用户信息
117+
:param expires_delta: 过期时间
118+
:return: 利用用户信息生成的token
119+
"""
120+
to_encode = data.copy()
121+
if expires_delta:
122+
expire = datetime.utcnow() + expires_delta
123+
else:
124+
expire = datetime.utcnow() + timedelta(minutes=15)
125+
to_encode.update({"exp": expire})
126+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
127+
return encoded_jwt
128+
129+
130+
async def get_current_user(token: str = Depends(oauth2_scheme)):
131+
"""
132+
获取当前用户
133+
:param token:
134+
:return:
135+
"""
136+
credentials_exception = HTTPException(
137+
status_code=status.HTTP_401_UNAUTHORIZED,
138+
detail="Could not validate credentials",
139+
headers={"WWW-Authenticate": "Bearer"},
140+
)
141+
try:
142+
# 利用jwt解析请求头中的token值,获取实际的用户信息
143+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
144+
username: str = payload.get('sub')
145+
if username is None:
146+
raise credentials_exception
147+
token_data = TokenData(username=username)
148+
except JWTError:
149+
raise credentials_exception
150+
user = get_user(fake_users_db, username=token_data.username)
151+
if user is None:
152+
raise credentials_exception
153+
return user
154+
155+
156+
async def get_current_active_user(current_user: User = Depends(get_current_user)):
157+
"""
158+
获取当前用户
159+
:param current_user: 当前用户
160+
:return:
161+
"""
162+
if current_user.disabled:
163+
raise HTTPException(status_code=400, detail="Inactive user")
164+
return current_user
165+
166+
167+
@app.post("/token", response_model=Token)
168+
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
169+
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
170+
if not user:
171+
raise HTTPException(
172+
status_code=status.HTTP_401_UNAUTHORIZED,
173+
detail="Incorrect username or password",
174+
headers={"WWW-Authenticate": "Bearer"},
175+
)
176+
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
177+
access_token = create_access_token(
178+
data={"sub": user.username},
179+
expires_delta=access_token_expires
180+
)
181+
return {"access_token": access_token, "token_type": "bearer"}
182+
183+
184+
@app.get("/users/me/", response_model=User)
185+
async def read_users_me(current_user: User = Depends(get_current_active_user)):
186+
"""
187+
获取当前用户
188+
:param current_user:
189+
:return:
190+
"""
191+
return current_user
192+
193+
194+
@app.get("/users/me/items/")
195+
async def read_own_items(current_user: User = Depends(get_current_active_user)):
196+
"""
197+
获取当前用户拥有的物品
198+
:param current_user:
199+
:return:
200+
"""
201+
return [{"item_id": "Foo", "owner": current_user.username}]
202+
203+
204+
if __name__ == '__main__':
205+
uvicorn.run('test-jwt:app', host='0.0.0.0', port=8000, reload=True, debug=True, workers=1)

test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
c = 16.6
2+
price = 0.6
3+
print(price * c / 100)

tutorial/chapter03.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Optional, List, Annotated
88

99
from fastapi import APIRouter, Query, Path, Body, Cookie, Header
10-
from pydantic import BaseModel, Field, BeforeValidator
10+
from pydantic import BaseModel, Field, BeforeValidator, field_validator
1111

1212
app03 = APIRouter()
1313

@@ -100,6 +100,16 @@ class CityInfo(BaseModel):
100100

101101
currPage: Annotated[int, BeforeValidator(curr_page_v)]
102102

103+
@field_validator('name')
104+
def name_validator(cls, v):
105+
assert v.startswith('a'), "name必须以a开头"
106+
return v
107+
108+
@field_validator('country')
109+
def country(cls, v):
110+
if ' ' not in v:
111+
raise ValueError('country必须包含空格')
112+
return v
103113
class Config:
104114
schema_extra = {
105115
"example": {

0 commit comments

Comments
 (0)