92 lines
3.3 KiB
Python
92 lines
3.3 KiB
Python
import re
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models import User
|
|
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
|
|
from app.core.exceptions import BadRequestError, UnauthorizedError
|
|
|
|
|
|
def _validate_password(password: str) -> None:
|
|
if len(password) < 6:
|
|
raise BadRequestError("密码至少 6 位")
|
|
has_letter = bool(re.search(r"[a-zA-Z]", password))
|
|
has_digit = bool(re.search(r"[0-9]", password))
|
|
if not (has_letter and has_digit):
|
|
raise BadRequestError("密码需同时包含字母和数字")
|
|
|
|
|
|
class AuthService:
|
|
|
|
@staticmethod
|
|
async def register(db: AsyncSession, email: str, password: str) -> User:
|
|
_validate_password(password)
|
|
|
|
existing = await db.execute(select(User).where(User.email == email))
|
|
if existing.scalar_one_or_none():
|
|
raise BadRequestError("Email already registered")
|
|
|
|
user = User(email=email, password_hash=hash_password(password))
|
|
db.add(user)
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
return user
|
|
|
|
@staticmethod
|
|
async def login(db: AsyncSession, email: str, password: str) -> dict:
|
|
result = await db.execute(select(User).where(User.email == email))
|
|
user = result.scalar_one_or_none()
|
|
if not user or not verify_password(password, user.password_hash):
|
|
raise UnauthorizedError("Invalid email or password")
|
|
if user.status != "active":
|
|
raise UnauthorizedError("Account is disabled")
|
|
|
|
return {
|
|
"access_token": create_access_token(user.id),
|
|
"refresh_token": create_refresh_token(user.id),
|
|
"token_type": "bearer",
|
|
}
|
|
|
|
@staticmethod
|
|
async def refresh(db: AsyncSession, refresh_token: str) -> dict:
|
|
payload = decode_token(refresh_token)
|
|
if not payload or payload.get("type") != "refresh":
|
|
raise UnauthorizedError("Invalid refresh token")
|
|
|
|
user = await db.get(User, payload["sub"])
|
|
if not user or user.status != "active":
|
|
raise UnauthorizedError("User not found or disabled")
|
|
|
|
return {
|
|
"access_token": create_access_token(user.id),
|
|
"token_type": "bearer",
|
|
}
|
|
|
|
@staticmethod
|
|
async def forgot_password(db: AsyncSession, email: str) -> str:
|
|
"""Generate a password reset token. MVP: returns the token directly (production: send via email)."""
|
|
result = await db.execute(select(User).where(User.email == email))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
# Don't reveal whether email exists
|
|
return ""
|
|
|
|
from app.core.security import create_reset_token
|
|
return create_reset_token(user.id)
|
|
|
|
@staticmethod
|
|
async def reset_password(db: AsyncSession, token: str, new_password: str) -> None:
|
|
_validate_password(new_password)
|
|
|
|
payload = decode_token(token)
|
|
if not payload or payload.get("type") != "reset":
|
|
raise BadRequestError("Invalid or expired reset token")
|
|
|
|
user = await db.get(User, payload["sub"])
|
|
if not user:
|
|
raise BadRequestError("User not found")
|
|
|
|
user.password_hash = hash_password(new_password)
|
|
await db.commit()
|