integrate sub2api as upstream for auth/keys/usage via FastAPI BFF
Preserve local user table for superDream-specific features while syncing user lifecycle, API key CRUD and usage queries through sub2api. Admin token handles reads and user lifecycle; per-user tokens (Fernet-encrypted in DB) handle key writes that admin endpoints do not expose. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -3,12 +3,20 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.datamodels.schemas import (
|
||||
ForgotPasswordRequest,
|
||||
LoginRequest,
|
||||
MessageResponse,
|
||||
RefreshRequest,
|
||||
RegisterRequest,
|
||||
ResetPasswordRequest,
|
||||
TokenResponse,
|
||||
UserResponse,
|
||||
)
|
||||
from app.integrations.sub2api import admin as sub2api_admin
|
||||
from app.integrations.sub2api.client import Sub2APIError, Sub2APITransportError
|
||||
from app.models import User
|
||||
from app.services.auth_service import AuthService
|
||||
from app.datamodels.schemas import (
|
||||
RegisterRequest, LoginRequest, TokenResponse, RefreshRequest,
|
||||
ForgotPasswordRequest, ResetPasswordRequest, UserResponse, MessageResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
@@ -16,7 +24,7 @@ router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
@router.post("/register", response_model=UserResponse)
|
||||
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
user = await AuthService.register(db, body.email, body.password)
|
||||
return user
|
||||
return _user_to_response(user)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
@@ -38,7 +46,7 @@ async def logout():
|
||||
async def forgot_password(body: ForgotPasswordRequest, db: AsyncSession = Depends(get_db)):
|
||||
token = await AuthService.forgot_password(db, body.email)
|
||||
if token:
|
||||
# MVP: print token to console; production: send email
|
||||
# MVP: print to console; production should send via email
|
||||
print(f"[Password Reset] email={body.email} token={token}")
|
||||
return {"message": "If the email exists, a reset link has been sent"}
|
||||
|
||||
@@ -51,4 +59,24 @@ async def reset_password(body: ResetPasswordRequest, db: AsyncSession = Depends(
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def me(user: User = Depends(get_current_user)):
|
||||
return user
|
||||
"""Merge the local row with sub2api's live balance when available."""
|
||||
response = _user_to_response(user)
|
||||
if user.sub2api_user_id:
|
||||
try:
|
||||
remote = await sub2api_admin.get_user(user.sub2api_user_id)
|
||||
response.balance = float(remote.get("balance") or 0)
|
||||
except (Sub2APIError, Sub2APITransportError):
|
||||
# non-fatal; fall back to 0
|
||||
pass
|
||||
return response
|
||||
|
||||
|
||||
def _user_to_response(user: User) -> UserResponse:
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
status=user.status,
|
||||
created_at=user.created_at,
|
||||
sub2api_user_id=user.sub2api_user_id,
|
||||
balance=0.0,
|
||||
)
|
||||
@@ -1,39 +1,74 @@
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.datamodels.schemas import CreateKeyRequest, MessageResponse, UpdateKeyRequest
|
||||
from app.models import User
|
||||
from app.services.key_service import KeyService
|
||||
from app.datamodels.schemas import CreateKeyRequest, ApiKeyResponse, ApiKeyCreatedResponse, MessageResponse
|
||||
|
||||
router = APIRouter(prefix="/keys", tags=["keys"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[ApiKeyResponse])
|
||||
@router.get("")
|
||||
async def list_keys(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
sort_by: str = Query("created_at"),
|
||||
sort_order: str = Query("desc"),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await KeyService.list_keys(db, user.id)
|
||||
) -> Dict[str, Any]:
|
||||
return await KeyService.list_keys(
|
||||
db, user, page=page, page_size=page_size, sort_by=sort_by, sort_order=sort_order
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=ApiKeyCreatedResponse)
|
||||
@router.get("/meta/available-groups")
|
||||
async def available_groups(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> List[Dict[str, Any]]:
|
||||
return await KeyService.available_groups(db, user)
|
||||
|
||||
|
||||
@router.get("/{key_id}")
|
||||
async def get_key(
|
||||
key_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
return await KeyService.get_key(db, user, key_id)
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_key(
|
||||
body: CreateKeyRequest,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await KeyService.create_key(db, user.id, body.name)
|
||||
) -> Dict[str, Any]:
|
||||
payload = body.model_dump(exclude_none=True)
|
||||
return await KeyService.create_key(db, user, payload)
|
||||
|
||||
|
||||
@router.put("/{key_id}")
|
||||
async def update_key(
|
||||
key_id: int,
|
||||
body: UpdateKeyRequest,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
payload = body.model_dump(exclude_none=True)
|
||||
return await KeyService.update_key(db, user, key_id, payload)
|
||||
|
||||
|
||||
@router.delete("/{key_id}", response_model=MessageResponse)
|
||||
async def delete_key(
|
||||
key_id: str,
|
||||
key_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await KeyService.delete_key(db, user.id, key_id)
|
||||
return {"message": "Key deleted"}
|
||||
await KeyService.delete_key(db, user, key_id)
|
||||
return {"message": "API key deleted successfully"}
|
||||
@@ -1,64 +1,112 @@
|
||||
from datetime import date
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.datamodels.schemas import DashboardAPIKeysUsageRequest
|
||||
from app.models import User
|
||||
from app.services.usage_service import UsageService
|
||||
from app.datamodels.schemas import (
|
||||
UsageSummaryResponse, DailyUsageResponse,
|
||||
ModelUsageResponse, KeyUsageResponse, UsageLogResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/usage", tags=["usage"])
|
||||
|
||||
|
||||
@router.get("/summary", response_model=UsageSummaryResponse)
|
||||
async def usage_summary(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await UsageService.summary(db, user.id)
|
||||
|
||||
|
||||
@router.get("/daily", response_model=List[DailyUsageResponse])
|
||||
async def usage_daily(
|
||||
start: Optional[date] = Query(None),
|
||||
end: Optional[date] = Query(None),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await UsageService.daily(db, user.id, start, end)
|
||||
|
||||
|
||||
@router.get("/by-model", response_model=List[ModelUsageResponse])
|
||||
async def usage_by_model(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await UsageService.by_model(db, user.id)
|
||||
|
||||
|
||||
@router.get("/by-key", response_model=List[KeyUsageResponse])
|
||||
async def usage_by_key(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await UsageService.by_key(db, user.id)
|
||||
|
||||
|
||||
@router.get("/logs", response_model=List[UsageLogResponse])
|
||||
async def usage_logs(
|
||||
@router.get("")
|
||||
async def list_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
sort_by: str = Query("created_at"),
|
||||
sort_order: str = Query("desc"),
|
||||
api_key_id: Optional[int] = Query(None),
|
||||
model: Optional[str] = Query(None),
|
||||
key_id: Optional[str] = Query(None),
|
||||
start: Optional[date] = Query(None),
|
||||
end: Optional[date] = Query(None),
|
||||
request_type: Optional[str] = Query(None),
|
||||
stream: Optional[bool] = Query(None),
|
||||
billing_type: Optional[int] = Query(None),
|
||||
start_date: Optional[str] = Query(None),
|
||||
end_date: Optional[str] = Query(None),
|
||||
timezone: Optional[str] = Query(None),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await UsageService.logs(db, user.id, page, size, model, key_id, start, end)
|
||||
) -> Dict[str, Any]:
|
||||
return await UsageService.list_logs(
|
||||
db,
|
||||
user,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
api_key_id=api_key_id,
|
||||
model=model,
|
||||
request_type=request_type,
|
||||
stream=stream,
|
||||
billing_type=billing_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def stats(
|
||||
period: Optional[str] = Query(None, regex="^(today|week|month)$"),
|
||||
start_date: Optional[str] = Query(None),
|
||||
end_date: Optional[str] = Query(None),
|
||||
api_key_id: Optional[int] = Query(None),
|
||||
timezone: Optional[str] = Query(None),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
return await UsageService.stats(
|
||||
db,
|
||||
user,
|
||||
period=period,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
api_key_id=api_key_id,
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/dashboard/stats")
|
||||
async def dashboard_stats(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
return await UsageService.dashboard_stats(db, user)
|
||||
|
||||
|
||||
@router.get("/dashboard/trend")
|
||||
async def dashboard_trend(
|
||||
granularity: str = Query("day"),
|
||||
start_date: Optional[str] = Query(None),
|
||||
end_date: Optional[str] = Query(None),
|
||||
timezone: Optional[str] = Query(None),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
return await UsageService.dashboard_trend(
|
||||
db, user, granularity=granularity, start_date=start_date, end_date=end_date, timezone=timezone
|
||||
)
|
||||
|
||||
|
||||
@router.get("/dashboard/models")
|
||||
async def dashboard_models(
|
||||
start_date: Optional[str] = Query(None),
|
||||
end_date: Optional[str] = Query(None),
|
||||
timezone: Optional[str] = Query(None),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
return await UsageService.dashboard_models(
|
||||
db, user, start_date=start_date, end_date=end_date, timezone=timezone
|
||||
)
|
||||
|
||||
|
||||
@router.post("/dashboard/api-keys-usage")
|
||||
async def dashboard_api_keys_usage(
|
||||
body: DashboardAPIKeysUsageRequest,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
return await UsageService.dashboard_api_keys_usage(db, user, body.api_key_ids)
|
||||
@@ -25,6 +25,15 @@ class Settings(BaseSettings):
|
||||
jwt_access_expire_minutes: int = 30
|
||||
jwt_refresh_expire_days: int = 7
|
||||
|
||||
# sub2api upstream
|
||||
sub2api_base_url: str = "http://127.0.0.1:8080"
|
||||
sub2api_admin_token: str = ""
|
||||
sub2api_request_timeout: float = 10.0
|
||||
|
||||
# Fernet key (urlsafe base64, 44 chars). Generate with:
|
||||
# python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
token_encryption_key: str = ""
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
return (
|
||||
@@ -32,6 +41,10 @@ class Settings(BaseSettings):
|
||||
f"@{self.db_host}:{self.db_port}/{self.db_name}"
|
||||
)
|
||||
|
||||
@property
|
||||
def sub2api_api_prefix(self) -> str:
|
||||
return self.sub2api_base_url.rstrip("/") + "/api/v1"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_prefix = "SD_"
|
||||
|
||||
29
app/core/crypto.py
Normal file
29
app/core/crypto.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
key = settings.token_encryption_key
|
||||
if not key:
|
||||
raise RuntimeError(
|
||||
"SD_TOKEN_ENCRYPTION_KEY is not configured. "
|
||||
"Generate one with: python -c \"from cryptography.fernet import Fernet; "
|
||||
"print(Fernet.generate_key().decode())\""
|
||||
)
|
||||
return Fernet(key.encode() if isinstance(key, str) else key)
|
||||
|
||||
|
||||
def encrypt_token(plaintext: str) -> str:
|
||||
if not plaintext:
|
||||
return ""
|
||||
return _get_fernet().encrypt(plaintext.encode()).decode()
|
||||
|
||||
|
||||
def decrypt_token(ciphertext: str) -> str:
|
||||
if not ciphertext:
|
||||
return ""
|
||||
try:
|
||||
return _get_fernet().decrypt(ciphertext.encode()).decode()
|
||||
except InvalidToken as exc:
|
||||
raise ValueError("Failed to decrypt token (key mismatch or corrupted value)") from exc
|
||||
@@ -1,10 +1,12 @@
|
||||
from decimal import Decimal
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from datetime import datetime, date
|
||||
from typing import Optional, List
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# ── Auth ──
|
||||
# ── Auth ──────────────────────────────────────────────────────────────
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: str
|
||||
@@ -36,144 +38,59 @@ class ResetPasswordRequest(BaseModel):
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""Local user view. Balance mirrored from sub2api when available."""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
balance: Decimal
|
||||
status: str
|
||||
created_at: datetime
|
||||
sub2api_user_id: Optional[int] = None
|
||||
balance: float = 0.0
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ── API Key ──
|
||||
# ── API Key (shapes match sub2api dto.APIKey 1:1) ─────────────────────
|
||||
|
||||
class CreateKeyRequest(BaseModel):
|
||||
name: str = ""
|
||||
|
||||
|
||||
class ApiKeyResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
key_prefix: str
|
||||
key_suffix: str
|
||||
status: str
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
group_id: Optional[int] = None
|
||||
custom_key: Optional[str] = None
|
||||
ip_whitelist: List[str] = []
|
||||
ip_blacklist: List[str] = []
|
||||
quota: Optional[float] = None
|
||||
expires_in_days: Optional[int] = None
|
||||
rate_limit_5h: Optional[float] = None
|
||||
rate_limit_1d: Optional[float] = None
|
||||
rate_limit_7d: Optional[float] = None
|
||||
|
||||
|
||||
class ApiKeyCreatedResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
key: str
|
||||
key_prefix: str
|
||||
key_suffix: str
|
||||
created_at: datetime
|
||||
class UpdateKeyRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
group_id: Optional[int] = None
|
||||
status: Optional[str] = None
|
||||
ip_whitelist: Optional[List[str]] = None
|
||||
ip_blacklist: Optional[List[str]] = None
|
||||
quota: Optional[float] = None
|
||||
expires_at: Optional[str] = None
|
||||
reset_quota: Optional[bool] = None
|
||||
rate_limit_5h: Optional[float] = None
|
||||
rate_limit_1d: Optional[float] = None
|
||||
rate_limit_7d: Optional[float] = None
|
||||
reset_rate_limit_usage: Optional[bool] = None
|
||||
|
||||
|
||||
# ── Wallet ──
|
||||
# ── Usage ─────────────────────────────────────────────────────────────
|
||||
|
||||
class RedeemCodeRequest(BaseModel):
|
||||
code: str
|
||||
class DashboardAPIKeysUsageRequest(BaseModel):
|
||||
api_key_ids: List[int]
|
||||
|
||||
|
||||
class TransactionResponse(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
amount: Decimal
|
||||
balance_after: Decimal
|
||||
reference_id: str
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class BalanceResponse(BaseModel):
|
||||
balance: Decimal
|
||||
|
||||
|
||||
# ── Models ──
|
||||
|
||||
class ModelPricingResponse(BaseModel):
|
||||
id: int
|
||||
model_name: str
|
||||
provider: str
|
||||
input_price_per_1k: Decimal
|
||||
output_price_per_1k: Decimal
|
||||
status: str
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ── Example (legacy) ──
|
||||
|
||||
class ExampleCreate(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class ExampleResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ── Common ──
|
||||
# ── Common ────────────────────────────────────────────────────────────
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
# ── Usage ──
|
||||
|
||||
class UsageSummaryResponse(BaseModel):
|
||||
today_tokens: int
|
||||
today_cost: Decimal
|
||||
month_tokens: int
|
||||
month_cost: Decimal
|
||||
total_requests: int
|
||||
|
||||
|
||||
class DailyUsageResponse(BaseModel):
|
||||
date: date
|
||||
total_tokens: int
|
||||
cost: Decimal
|
||||
requests: int
|
||||
|
||||
|
||||
class ModelUsageResponse(BaseModel):
|
||||
model: str
|
||||
total_tokens: int
|
||||
cost: Decimal
|
||||
requests: int
|
||||
|
||||
|
||||
class KeyUsageResponse(BaseModel):
|
||||
key_id: str
|
||||
key_name: str
|
||||
key_prefix: str
|
||||
key_suffix: str
|
||||
total_tokens: int
|
||||
cost: Decimal
|
||||
requests: int
|
||||
|
||||
|
||||
class UsageLogResponse(BaseModel):
|
||||
id: int
|
||||
key_id: str
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
cost: Decimal
|
||||
request_time: datetime
|
||||
status: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
JSONDict = Dict[str, Any]
|
||||
0
app/integrations/__init__.py
Normal file
0
app/integrations/__init__.py
Normal file
18
app/integrations/sub2api/__init__.py
Normal file
18
app/integrations/sub2api/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from app.integrations.sub2api.client import (
|
||||
Sub2APIError,
|
||||
Sub2APIReauthRequired,
|
||||
Sub2APITransportError,
|
||||
get_client,
|
||||
close_client,
|
||||
)
|
||||
from app.integrations.sub2api import admin, user
|
||||
|
||||
__all__ = [
|
||||
"Sub2APIError",
|
||||
"Sub2APIReauthRequired",
|
||||
"Sub2APITransportError",
|
||||
"get_client",
|
||||
"close_client",
|
||||
"admin",
|
||||
"user",
|
||||
]
|
||||
97
app/integrations/sub2api/admin.py
Normal file
97
app/integrations/sub2api/admin.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Admin-token-authenticated calls to sub2api.
|
||||
|
||||
Used for:
|
||||
- User lifecycle sync (create / update / delete / lookup)
|
||||
- Reading a user's API keys and usage (admin endpoints only support reads)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.integrations.sub2api.client import admin_request
|
||||
|
||||
|
||||
# ── Users ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def create_user(
|
||||
*,
|
||||
email: str,
|
||||
password: str,
|
||||
username: str = "",
|
||||
notes: str = "",
|
||||
balance: float = 0,
|
||||
concurrency: int = 0,
|
||||
allowed_groups: list[int] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return await admin_request(
|
||||
"POST",
|
||||
"/admin/users",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"username": username,
|
||||
"notes": notes,
|
||||
"balance": balance,
|
||||
"concurrency": concurrency,
|
||||
"allowed_groups": allowed_groups or [],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def update_user(user_id: int, **fields: Any) -> dict[str, Any]:
|
||||
"""Partial update. Only non-None fields are sent."""
|
||||
payload = {k: v for k, v in fields.items() if v is not None}
|
||||
return await admin_request("PUT", f"/admin/users/{user_id}", json=payload)
|
||||
|
||||
|
||||
async def delete_user(user_id: int) -> None:
|
||||
await admin_request("DELETE", f"/admin/users/{user_id}")
|
||||
|
||||
|
||||
async def get_user(user_id: int) -> dict[str, Any]:
|
||||
return await admin_request("GET", f"/admin/users/{user_id}")
|
||||
|
||||
|
||||
async def find_user_by_email(email: str) -> dict[str, Any] | None:
|
||||
data = await admin_request(
|
||||
"GET",
|
||||
"/admin/users",
|
||||
params={"search": email, "page": 1, "page_size": 5},
|
||||
)
|
||||
items = (data or {}).get("items") or []
|
||||
for item in items:
|
||||
if (item.get("email") or "").lower() == email.lower():
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
# ── API Keys (read-only from admin side) ──────────────────────────────
|
||||
|
||||
async def list_user_api_keys(
|
||||
user_id: int,
|
||||
*,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> dict[str, Any]:
|
||||
return await admin_request(
|
||||
"GET",
|
||||
f"/admin/users/{user_id}/api-keys",
|
||||
params={
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"sort_by": sort_by,
|
||||
"sort_order": sort_order,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── Usage (admin view per user) ───────────────────────────────────────
|
||||
|
||||
async def get_user_usage_stats(user_id: int, period: str = "month") -> dict[str, Any]:
|
||||
return await admin_request(
|
||||
"GET",
|
||||
f"/admin/users/{user_id}/usage",
|
||||
params={"period": period},
|
||||
)
|
||||
131
app/integrations/sub2api/client.py
Normal file
131
app/integrations/sub2api/client.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Low-level HTTP client for sub2api.
|
||||
|
||||
Responsibilities:
|
||||
- Singleton httpx.AsyncClient with configured base URL and timeout.
|
||||
- Uniform envelope parsing: {code, message, reason, metadata, data}.
|
||||
- Error translation to domain exceptions.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Mapping
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Sub2APIError(Exception):
|
||||
"""sub2api returned a non-zero envelope code or HTTP error."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code: int,
|
||||
message: str,
|
||||
reason: str = "",
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
http_status: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(f"[{code}] {message}" + (f" ({reason})" if reason else ""))
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.reason = reason
|
||||
self.metadata = dict(metadata) if metadata else {}
|
||||
self.http_status = http_status
|
||||
|
||||
|
||||
class Sub2APIReauthRequired(Sub2APIError):
|
||||
"""User-level token invalid/expired; caller must re-authenticate with password."""
|
||||
|
||||
|
||||
class Sub2APITransportError(Exception):
|
||||
"""Network / timeout / invalid JSON etc."""
|
||||
|
||||
|
||||
_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
def get_client() -> httpx.AsyncClient:
|
||||
global _client
|
||||
if _client is None or _client.is_closed:
|
||||
_client = httpx.AsyncClient(
|
||||
base_url=settings.sub2api_api_prefix,
|
||||
timeout=settings.sub2api_request_timeout,
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
async def close_client() -> None:
|
||||
global _client
|
||||
if _client is not None and not _client.is_closed:
|
||||
await _client.aclose()
|
||||
_client = None
|
||||
|
||||
|
||||
def _parse_envelope(resp: httpx.Response) -> Any:
|
||||
try:
|
||||
body = resp.json()
|
||||
except ValueError as exc:
|
||||
raise Sub2APITransportError(
|
||||
f"sub2api returned non-JSON body (status={resp.status_code}): {resp.text[:200]}"
|
||||
) from exc
|
||||
|
||||
code = body.get("code", resp.status_code)
|
||||
message = body.get("message", "")
|
||||
reason = body.get("reason", "") or ""
|
||||
metadata = body.get("metadata") or {}
|
||||
|
||||
if code == 0:
|
||||
return body.get("data")
|
||||
|
||||
# Token invalidation signals from sub2api admin middleware
|
||||
if reason in {"TOKEN_EXPIRED", "INVALID_TOKEN", "TOKEN_REVOKED", "USER_INACTIVE"}:
|
||||
raise Sub2APIReauthRequired(code, message, reason, metadata, resp.status_code)
|
||||
|
||||
raise Sub2APIError(code, message, reason, metadata, resp.status_code)
|
||||
|
||||
|
||||
async def request(
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
json: Any = None,
|
||||
params: Mapping[str, Any] | None = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
) -> Any:
|
||||
"""Execute a sub2api call and return the unwrapped ``data`` field."""
|
||||
client = get_client()
|
||||
try:
|
||||
resp = await client.request(
|
||||
method,
|
||||
path,
|
||||
json=json,
|
||||
params={k: v for k, v in (params or {}).items() if v is not None},
|
||||
headers=headers,
|
||||
)
|
||||
except httpx.TimeoutException as exc:
|
||||
raise Sub2APITransportError(f"sub2api timeout on {method} {path}") from exc
|
||||
except httpx.HTTPError as exc:
|
||||
raise Sub2APITransportError(f"sub2api transport error on {method} {path}: {exc}") from exc
|
||||
|
||||
return _parse_envelope(resp)
|
||||
|
||||
|
||||
async def admin_request(method: str, path: str, **kwargs: Any) -> Any:
|
||||
"""Issue a request authenticated with the admin API key."""
|
||||
if not settings.sub2api_admin_token:
|
||||
raise RuntimeError("SD_SUB2API_ADMIN_TOKEN is not configured")
|
||||
headers = dict(kwargs.pop("headers", None) or {})
|
||||
headers["x-api-key"] = settings.sub2api_admin_token
|
||||
return await request(method, path, headers=headers, **kwargs)
|
||||
|
||||
|
||||
async def user_request(access_token: str, method: str, path: str, **kwargs: Any) -> Any:
|
||||
"""Issue a request authenticated with a user's JWT access token."""
|
||||
if not access_token:
|
||||
raise Sub2APIReauthRequired(401, "missing user access token", "MISSING_TOKEN")
|
||||
headers = dict(kwargs.pop("headers", None) or {})
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
return await request(method, path, headers=headers, **kwargs)
|
||||
96
app/integrations/sub2api/user.py
Normal file
96
app/integrations/sub2api/user.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""User-JWT-authenticated calls to sub2api.
|
||||
|
||||
Used for:
|
||||
- Auth login / refresh / me (to obtain tokens for BFF proxying)
|
||||
- API Key CRUD (admin endpoints cannot create / delete keys)
|
||||
- Usage list / detail / dashboard for the authenticated user
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.integrations.sub2api.client import request, user_request
|
||||
|
||||
|
||||
# ── Auth (public; no Bearer required) ─────────────────────────────────
|
||||
|
||||
async def login(email: str, password: str) -> dict[str, Any]:
|
||||
"""Returns AuthResponse: {access_token, refresh_token, expires_in, token_type, user}."""
|
||||
return await request(
|
||||
"POST",
|
||||
"/auth/login",
|
||||
json={"email": email, "password": password, "turnstile_token": ""},
|
||||
)
|
||||
|
||||
|
||||
async def refresh_tokens(refresh_token: str) -> dict[str, Any]:
|
||||
"""Returns RefreshTokenResponse: {access_token, refresh_token, expires_in, token_type}."""
|
||||
return await request(
|
||||
"POST",
|
||||
"/auth/refresh",
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
|
||||
|
||||
async def logout(refresh_token: str | None = None) -> dict[str, Any]:
|
||||
body = {"refresh_token": refresh_token} if refresh_token else {}
|
||||
return await request("POST", "/auth/logout", json=body)
|
||||
|
||||
|
||||
# ── API Key CRUD ──────────────────────────────────────────────────────
|
||||
|
||||
async def create_key(access_token: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return await user_request(access_token, "POST", "/keys", json=payload)
|
||||
|
||||
|
||||
async def update_key(access_token: str, key_id: int, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return await user_request(access_token, "PUT", f"/keys/{key_id}", json=payload)
|
||||
|
||||
|
||||
async def delete_key(access_token: str, key_id: int) -> dict[str, Any]:
|
||||
return await user_request(access_token, "DELETE", f"/keys/{key_id}")
|
||||
|
||||
|
||||
async def get_key(access_token: str, key_id: int) -> dict[str, Any]:
|
||||
return await user_request(access_token, "GET", f"/keys/{key_id}")
|
||||
|
||||
|
||||
# ── Groups ────────────────────────────────────────────────────────────
|
||||
|
||||
async def list_available_groups(access_token: str) -> list[dict[str, Any]]:
|
||||
return await user_request(access_token, "GET", "/groups/available")
|
||||
|
||||
|
||||
async def get_user_group_rates(access_token: str) -> dict[str, float]:
|
||||
return await user_request(access_token, "GET", "/groups/rates")
|
||||
|
||||
|
||||
# ── Usage (user view) ─────────────────────────────────────────────────
|
||||
|
||||
async def list_usage(access_token: str, **params: Any) -> dict[str, Any]:
|
||||
return await user_request(access_token, "GET", "/usage", params=params)
|
||||
|
||||
|
||||
async def usage_stats(access_token: str, **params: Any) -> dict[str, Any]:
|
||||
return await user_request(access_token, "GET", "/usage/stats", params=params)
|
||||
|
||||
|
||||
async def dashboard_stats(access_token: str) -> dict[str, Any]:
|
||||
return await user_request(access_token, "GET", "/usage/dashboard/stats")
|
||||
|
||||
|
||||
async def dashboard_trend(access_token: str, **params: Any) -> dict[str, Any]:
|
||||
return await user_request(access_token, "GET", "/usage/dashboard/trend", params=params)
|
||||
|
||||
|
||||
async def dashboard_models(access_token: str, **params: Any) -> dict[str, Any]:
|
||||
return await user_request(access_token, "GET", "/usage/dashboard/models", params=params)
|
||||
|
||||
|
||||
async def dashboard_api_keys_usage(access_token: str, api_key_ids: list[int]) -> dict[str, Any]:
|
||||
return await user_request(
|
||||
access_token,
|
||||
"POST",
|
||||
"/usage/dashboard/api-keys-usage",
|
||||
json={"api_key_ids": api_key_ids},
|
||||
)
|
||||
@@ -3,7 +3,7 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
from app.api.v1 import health, example, auth, keys, models as models_api, wallet, usage
|
||||
from app.api.v1 import health, example, auth, keys, usage
|
||||
from app.config.settings import settings
|
||||
from app.core.database import init_db
|
||||
import os
|
||||
@@ -35,8 +35,6 @@ def create_app() -> FastAPI:
|
||||
app.include_router(health.router, prefix="/api/v1", tags=["health"])
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(keys.router, prefix="/api/v1")
|
||||
app.include_router(models_api.router, prefix="/api/v1")
|
||||
app.include_router(wallet.router, prefix="/api/v1")
|
||||
app.include_router(usage.router, prefix="/api/v1")
|
||||
app.include_router(example.router, prefix="/api/v1", tags=["example"])
|
||||
|
||||
|
||||
@@ -2,11 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import String, Text, Integer, BigInteger, DateTime, Numeric, Enum, ForeignKey, Index, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy import String, Text, BigInteger, DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
@@ -17,71 +16,12 @@ class User(Base):
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
balance: Mapped[Decimal] = mapped_column(Numeric(16, 6), default=Decimal("0"))
|
||||
status: Mapped[str] = mapped_column(String(20), default="active") # active / disabled
|
||||
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
|
||||
|
||||
api_keys: Mapped[List[ApiKey]] = relationship(back_populates="user")
|
||||
transactions: Mapped[List[Transaction]] = relationship(back_populates="user")
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id: Mapped[str] = mapped_column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
name: Mapped[str] = mapped_column(String(100), default="")
|
||||
key_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False) # SHA256
|
||||
key_prefix: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
key_suffix: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(20), default="active") # active / revoked
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="api_keys")
|
||||
|
||||
|
||||
class UsageLog(Base):
|
||||
__tablename__ = "usage_logs"
|
||||
__table_args__ = (
|
||||
Index("ix_usage_user_time", "user_id", "request_time"),
|
||||
Index("ix_usage_model", "user_id", "model"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[str] = mapped_column(String(36), ForeignKey("users.id"), nullable=False)
|
||||
key_id: Mapped[str] = mapped_column(String(36), ForeignKey("api_keys.id"), nullable=False)
|
||||
model: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
prompt_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
completion_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
cost: Mapped[Decimal] = mapped_column(Numeric(16, 6), default=Decimal("0"))
|
||||
request_time: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
response_time: Mapped[datetime] = mapped_column(DateTime, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="success") # success / error
|
||||
|
||||
|
||||
class Transaction(Base):
|
||||
__tablename__ = "transactions"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id: Mapped[str] = mapped_column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
type: Mapped[str] = mapped_column(String(20), nullable=False) # topup / consume / refund
|
||||
amount: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
|
||||
balance_after: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
|
||||
reference_id: Mapped[str] = mapped_column(String(100), default="")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="transactions")
|
||||
|
||||
|
||||
class ModelPricing(Base):
|
||||
__tablename__ = "models_pricing"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
model_name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
input_price_per_1k: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
|
||||
output_price_per_1k: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(20), default="available") # available / offline
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
|
||||
# sub2api 用户体系映射
|
||||
sub2api_user_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, index=True)
|
||||
sub2api_refresh_token_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
sub2api_access_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
sub2api_access_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
@@ -1,11 +1,41 @@
|
||||
"""Auth service: local user table + sub2api sync.
|
||||
|
||||
Registration and login keep two systems in sync:
|
||||
- superDream local User row (owns password hash, superDream JWT)
|
||||
- sub2api User (owns API keys, usage, balance)
|
||||
|
||||
On registration, sub2api user is created via admin API; then we login to sub2api
|
||||
with the same password to obtain and store a refresh_token. Any sub2api failure
|
||||
rolls back the local insert so the two systems never drift apart silently.
|
||||
|
||||
On login, after local password verification we also refresh the stored sub2api
|
||||
tokens (by re-logging in to sub2api) so that subsequent proxied calls have a
|
||||
valid refresh_token without ever persisting the plaintext password.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import User
|
||||
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
|
||||
from app.core.exceptions import BadRequestError, UnauthorizedError
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
create_reset_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
from app.integrations.sub2api import admin as sub2api_admin
|
||||
from app.integrations.sub2api import user as sub2api_user
|
||||
from app.integrations.sub2api.client import Sub2APIError, Sub2APITransportError
|
||||
from app.models import User
|
||||
from app.services.sub2api_session import store_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_password(password: str) -> None:
|
||||
@@ -17,8 +47,14 @@ def _validate_password(password: str) -> None:
|
||||
raise BadRequestError("密码需同时包含字母和数字")
|
||||
|
||||
|
||||
class AuthService:
|
||||
def _upstream_failure(action: str, exc: Exception) -> BadRequestError:
|
||||
logger.error("sub2api %s failed: %s", action, exc)
|
||||
if isinstance(exc, Sub2APIError):
|
||||
return BadRequestError(f"上游同步失败:{exc.message or exc.reason or 'unknown'}")
|
||||
return BadRequestError("上游服务不可用,请稍后重试")
|
||||
|
||||
|
||||
class AuthService:
|
||||
@staticmethod
|
||||
async def register(db: AsyncSession, email: str, password: str) -> User:
|
||||
_validate_password(password)
|
||||
@@ -27,10 +63,44 @@ class AuthService:
|
||||
if existing.scalar_one_or_none():
|
||||
raise BadRequestError("Email already registered")
|
||||
|
||||
user = User(email=email, password_hash=hash_password(password))
|
||||
# 1. Create the sub2api user first (so a conflict on their side halts us
|
||||
# before we touch our DB).
|
||||
try:
|
||||
remote = await sub2api_admin.create_user(email=email, password=password)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _upstream_failure("create_user", exc) from exc
|
||||
|
||||
remote_user_id = int(remote.get("id"))
|
||||
|
||||
# 2. Grab a token pair by logging in as the new user.
|
||||
try:
|
||||
tokens = await sub2api_user.login(email, password)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
# Roll back the remote user so re-registration works.
|
||||
try:
|
||||
await sub2api_admin.delete_user(remote_user_id)
|
||||
except Exception: # pragma: no cover - best-effort cleanup
|
||||
logger.exception("failed to roll back sub2api user %s", remote_user_id)
|
||||
raise _upstream_failure("login_after_register", exc) from exc
|
||||
|
||||
# 3. Persist the local user with sub2api references.
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=hash_password(password),
|
||||
sub2api_user_id=remote_user_id,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
try:
|
||||
await db.flush()
|
||||
await store_tokens(db, user, tokens)
|
||||
except Exception:
|
||||
await db.rollback()
|
||||
try:
|
||||
await sub2api_admin.delete_user(remote_user_id)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("failed to roll back sub2api user %s", remote_user_id)
|
||||
raise
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
@@ -42,6 +112,18 @@ class AuthService:
|
||||
if user.status != "active":
|
||||
raise UnauthorizedError("Account is disabled")
|
||||
|
||||
# Refresh stored sub2api tokens with the same password. If sub2api is
|
||||
# down we still allow local login; proxied calls will surface 503 later.
|
||||
try:
|
||||
tokens = await sub2api_user.login(email, password)
|
||||
if not user.sub2api_user_id and tokens.get("user", {}).get("id"):
|
||||
user.sub2api_user_id = int(tokens["user"]["id"])
|
||||
await store_tokens(db, user, tokens)
|
||||
except Sub2APIError as exc:
|
||||
logger.warning("sub2api login failed for %s: %s", email, exc)
|
||||
except Sub2APITransportError as exc:
|
||||
logger.warning("sub2api unreachable during login: %s", exc)
|
||||
|
||||
return {
|
||||
"access_token": create_access_token(user.id),
|
||||
"refresh_token": create_refresh_token(user.id),
|
||||
@@ -65,14 +147,10 @@ class AuthService:
|
||||
|
||||
@staticmethod
|
||||
async def forgot_password(db: AsyncSession, email: str) -> str:
|
||||
"""Generate a password reset token. MVP: returns the token directly (production: send via email)."""
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
# Don't reveal whether email exists
|
||||
return ""
|
||||
|
||||
from app.core.security import create_reset_token
|
||||
return create_reset_token(user.id)
|
||||
|
||||
@staticmethod
|
||||
@@ -88,4 +166,16 @@ class AuthService:
|
||||
raise BadRequestError("User not found")
|
||||
|
||||
user.password_hash = hash_password(new_password)
|
||||
await db.commit()
|
||||
|
||||
# Sync the new password to sub2api so subsequent proxied calls continue
|
||||
# to work. If this fails the user can still log in locally but won't be
|
||||
# able to mutate keys until the systems re-converge.
|
||||
if user.sub2api_user_id:
|
||||
try:
|
||||
await sub2api_admin.update_user(user.sub2api_user_id, password=new_password)
|
||||
tokens = await sub2api_user.login(user.email, new_password)
|
||||
await store_tokens(db, user, tokens)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
logger.error("sub2api password sync failed for user %s: %s", user.id, exc)
|
||||
|
||||
await db.commit()
|
||||
@@ -1,84 +1,111 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
"""API Key service: proxied to sub2api.
|
||||
|
||||
from sqlalchemy import select, func
|
||||
- Reads use the admin API key to call ``/admin/users/:id/api-keys``.
|
||||
- Writes (create/update/delete) require the user's own sub2api JWT; we fetch
|
||||
one via ``ensure_access_token`` using the stored refresh token.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import ApiKey
|
||||
from app.core.exceptions import BadRequestError, NotFoundError
|
||||
from app.core.exceptions import BadRequestError
|
||||
from app.integrations.sub2api import admin as sub2api_admin
|
||||
from app.integrations.sub2api import user as sub2api_user
|
||||
from app.integrations.sub2api.client import (
|
||||
Sub2APIError,
|
||||
Sub2APIReauthRequired,
|
||||
Sub2APITransportError,
|
||||
)
|
||||
from app.models import User
|
||||
from app.services.sub2api_session import ensure_access_token
|
||||
|
||||
MAX_KEYS_PER_USER = 5
|
||||
KEY_PREFIX = "sk-sd-"
|
||||
|
||||
def _require_sub2api_binding(user: User) -> int:
|
||||
if not user.sub2api_user_id:
|
||||
raise BadRequestError("账号未完成 sub2api 绑定,请重新登录")
|
||||
return user.sub2api_user_id
|
||||
|
||||
|
||||
def _translate_upstream(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, Sub2APIReauthRequired):
|
||||
return HTTPException(status_code=401, detail="sub2api_reauth_required")
|
||||
if isinstance(exc, Sub2APIError):
|
||||
status = exc.http_status or 502
|
||||
if status < 400 or status >= 600:
|
||||
status = 502
|
||||
return HTTPException(status_code=status, detail=exc.message or exc.reason or "upstream_error")
|
||||
return HTTPException(status_code=504, detail="upstream_timeout")
|
||||
|
||||
|
||||
class KeyService:
|
||||
@staticmethod
|
||||
async def list_keys(
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
*,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> dict[str, Any]:
|
||||
uid = _require_sub2api_binding(user)
|
||||
try:
|
||||
return await sub2api_admin.list_user_api_keys(
|
||||
uid,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
def _generate_key() -> str:
|
||||
return KEY_PREFIX + secrets.token_urlsafe(36)
|
||||
async def create_key(db: AsyncSession, user: User, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
_require_sub2api_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.create_key(token, payload)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
def _hash_key(raw_key: str) -> str:
|
||||
return hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
async def update_key(
|
||||
db: AsyncSession, user: User, key_id: int, payload: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
_require_sub2api_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.update_key(token, key_id, payload)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def list_keys(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(ApiKey)
|
||||
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
|
||||
.order_by(ApiKey.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
async def delete_key(db: AsyncSession, user: User, key_id: int) -> dict[str, Any]:
|
||||
_require_sub2api_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.delete_key(token, key_id)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def create_key(db: AsyncSession, user_id: str, name: str = "") -> dict:
|
||||
# Check limit
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(ApiKey)
|
||||
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
|
||||
)
|
||||
count = count_result.scalar()
|
||||
if count >= MAX_KEYS_PER_USER:
|
||||
raise BadRequestError(f"最多创建 {MAX_KEYS_PER_USER} 个 Key")
|
||||
|
||||
raw_key = KeyService._generate_key()
|
||||
key_hash = KeyService._hash_key(raw_key)
|
||||
|
||||
# prefix/suffix for masked display (after "sk-sd-")
|
||||
body = raw_key[len(KEY_PREFIX):]
|
||||
key_prefix = body[:4]
|
||||
key_suffix = body[-4:]
|
||||
|
||||
api_key = ApiKey(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
key_suffix=key_suffix,
|
||||
)
|
||||
db.add(api_key)
|
||||
await db.commit()
|
||||
await db.refresh(api_key)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"name": api_key.name,
|
||||
"key": raw_key, # only returned once
|
||||
"key_prefix": key_prefix,
|
||||
"key_suffix": key_suffix,
|
||||
"created_at": api_key.created_at,
|
||||
}
|
||||
async def get_key(db: AsyncSession, user: User, key_id: int) -> dict[str, Any]:
|
||||
_require_sub2api_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.get_key(token, key_id)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def delete_key(db: AsyncSession, user_id: str, key_id: str) -> None:
|
||||
result = await db.execute(
|
||||
select(ApiKey).where(ApiKey.id == key_id, ApiKey.user_id == user_id)
|
||||
)
|
||||
api_key = result.scalar_one_or_none()
|
||||
if not api_key:
|
||||
raise NotFoundError("Key not found")
|
||||
|
||||
api_key.status = "revoked"
|
||||
await db.commit()
|
||||
async def available_groups(db: AsyncSession, user: User) -> list[dict[str, Any]]:
|
||||
_require_sub2api_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.list_available_groups(token)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate_upstream(exc) from exc
|
||||
74
app/services/sub2api_session.py
Normal file
74
app/services/sub2api_session.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Helpers to obtain a fresh sub2api user access token from stored state.
|
||||
|
||||
The User row carries:
|
||||
- sub2api_access_token + sub2api_access_expires_at (cache)
|
||||
- sub2api_refresh_token_enc (Fernet-encrypted refresh_token)
|
||||
|
||||
``ensure_access_token`` returns a usable access_token, refreshing via sub2api
|
||||
``/auth/refresh`` when the cached one is expired. On refresh failure the caller
|
||||
receives ``Sub2APIReauthRequired`` and should surface 401 so the frontend can
|
||||
prompt re-login.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.crypto import decrypt_token, encrypt_token
|
||||
from app.integrations.sub2api import user as sub2api_user
|
||||
from app.integrations.sub2api.client import Sub2APIError, Sub2APIReauthRequired
|
||||
from app.models import User
|
||||
|
||||
_CLOCK_SKEW = timedelta(seconds=60)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
|
||||
def _is_access_fresh(user: User) -> bool:
|
||||
if not user.sub2api_access_token or not user.sub2api_access_expires_at:
|
||||
return False
|
||||
expires_at = user.sub2api_access_expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return expires_at - _CLOCK_SKEW > _utcnow()
|
||||
|
||||
|
||||
async def store_tokens(db: AsyncSession, user: User, token_response: dict) -> None:
|
||||
"""Persist a fresh token pair (from /auth/login or /auth/refresh)."""
|
||||
access = token_response.get("access_token") or ""
|
||||
refresh = token_response.get("refresh_token") or ""
|
||||
expires_in = int(token_response.get("expires_in") or 3600)
|
||||
|
||||
user.sub2api_access_token = access
|
||||
user.sub2api_access_expires_at = _utcnow() + timedelta(seconds=expires_in)
|
||||
if refresh:
|
||||
user.sub2api_refresh_token_enc = encrypt_token(refresh)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
async def ensure_access_token(db: AsyncSession, user: User) -> str:
|
||||
"""Return a usable sub2api access_token, refreshing if necessary."""
|
||||
if _is_access_fresh(user):
|
||||
return user.sub2api_access_token # type: ignore[return-value]
|
||||
|
||||
if not user.sub2api_refresh_token_enc:
|
||||
raise Sub2APIReauthRequired(401, "no refresh token stored", "NO_REFRESH_TOKEN")
|
||||
|
||||
try:
|
||||
refresh_plain = decrypt_token(user.sub2api_refresh_token_enc)
|
||||
except ValueError as exc:
|
||||
raise Sub2APIReauthRequired(401, str(exc), "DECRYPT_FAILED") from exc
|
||||
|
||||
try:
|
||||
token_response = await sub2api_user.refresh_tokens(refresh_plain)
|
||||
except Sub2APIError as exc:
|
||||
raise Sub2APIReauthRequired(
|
||||
exc.code, exc.message, exc.reason or "REFRESH_FAILED", exc.metadata
|
||||
) from exc
|
||||
|
||||
await store_tokens(db, user, token_response)
|
||||
return user.sub2api_access_token # type: ignore[return-value]
|
||||
@@ -1,139 +1,92 @@
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
"""Usage service: fully proxied to sub2api (user JWT for user-scoped endpoints)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select, func, cast, Date
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import UsageLog, ApiKey
|
||||
from app.core.exceptions import BadRequestError
|
||||
from app.integrations.sub2api import user as sub2api_user
|
||||
from app.integrations.sub2api.client import (
|
||||
Sub2APIError,
|
||||
Sub2APIReauthRequired,
|
||||
Sub2APITransportError,
|
||||
)
|
||||
from app.models import User
|
||||
from app.services.sub2api_session import ensure_access_token
|
||||
|
||||
|
||||
def _require_binding(user: User) -> int:
|
||||
if not user.sub2api_user_id:
|
||||
raise BadRequestError("账号未完成 sub2api 绑定,请重新登录")
|
||||
return user.sub2api_user_id
|
||||
|
||||
|
||||
def _translate(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, Sub2APIReauthRequired):
|
||||
return HTTPException(status_code=401, detail="sub2api_reauth_required")
|
||||
if isinstance(exc, Sub2APIError):
|
||||
status = exc.http_status or 502
|
||||
if status < 400 or status >= 600:
|
||||
status = 502
|
||||
return HTTPException(status_code=status, detail=exc.message or exc.reason or "upstream_error")
|
||||
return HTTPException(status_code=504, detail="upstream_timeout")
|
||||
|
||||
|
||||
class UsageService:
|
||||
@staticmethod
|
||||
async def list_logs(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.list_usage(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def summary(db: AsyncSession, user_id: str) -> dict:
|
||||
today_start = datetime.combine(date.today(), datetime.min.time())
|
||||
month_start = today_start.replace(day=1)
|
||||
|
||||
# Today
|
||||
today_row = (await db.execute(
|
||||
select(
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
).where(UsageLog.user_id == user_id, UsageLog.request_time >= today_start)
|
||||
)).one()
|
||||
|
||||
# This month
|
||||
month_row = (await db.execute(
|
||||
select(
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
).where(UsageLog.user_id == user_id, UsageLog.request_time >= month_start)
|
||||
)).one()
|
||||
|
||||
return {
|
||||
"today_tokens": int(today_row[0]),
|
||||
"today_cost": today_row[1],
|
||||
"month_tokens": int(month_row[0]),
|
||||
"month_cost": month_row[1],
|
||||
"total_requests": int(month_row[2]),
|
||||
}
|
||||
async def stats(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.usage_stats(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def daily(
|
||||
db: AsyncSession, user_id: str,
|
||||
start: Optional[date] = None, end: Optional[date] = None,
|
||||
) -> list:
|
||||
if not start:
|
||||
start = date.today() - timedelta(days=29)
|
||||
if not end:
|
||||
end = date.today()
|
||||
|
||||
day_col = cast(UsageLog.request_time, Date).label("day")
|
||||
result = await db.execute(
|
||||
select(
|
||||
day_col,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
UsageLog.user_id == user_id,
|
||||
cast(UsageLog.request_time, Date) >= start,
|
||||
cast(UsageLog.request_time, Date) <= end,
|
||||
)
|
||||
.group_by(day_col)
|
||||
.order_by(day_col)
|
||||
)
|
||||
return [
|
||||
{"date": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
|
||||
for row in result.all()
|
||||
]
|
||||
async def dashboard_stats(db: AsyncSession, user: User) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_stats(token)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def by_model(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(
|
||||
UsageLog.model,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.where(UsageLog.user_id == user_id)
|
||||
.group_by(UsageLog.model)
|
||||
.order_by(func.sum(UsageLog.cost).desc())
|
||||
)
|
||||
return [
|
||||
{"model": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
|
||||
for row in result.all()
|
||||
]
|
||||
async def dashboard_trend(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_trend(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def by_key(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(
|
||||
UsageLog.key_id,
|
||||
ApiKey.name,
|
||||
ApiKey.key_prefix,
|
||||
ApiKey.key_suffix,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.join(ApiKey, UsageLog.key_id == ApiKey.id)
|
||||
.where(UsageLog.user_id == user_id)
|
||||
.group_by(UsageLog.key_id, ApiKey.name, ApiKey.key_prefix, ApiKey.key_suffix)
|
||||
.order_by(func.sum(UsageLog.cost).desc())
|
||||
)
|
||||
return [
|
||||
{
|
||||
"key_id": row[0], "key_name": row[1] or "",
|
||||
"key_prefix": row[2], "key_suffix": row[3],
|
||||
"total_tokens": int(row[4]), "cost": row[5], "requests": int(row[6]),
|
||||
}
|
||||
for row in result.all()
|
||||
]
|
||||
async def dashboard_models(db: AsyncSession, user: User, **params: Any) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_models(token, **params)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
@staticmethod
|
||||
async def logs(
|
||||
db: AsyncSession, user_id: str,
|
||||
page: int = 1, size: int = 20,
|
||||
model: Optional[str] = None,
|
||||
key_id: Optional[str] = None,
|
||||
start: Optional[date] = None,
|
||||
end: Optional[date] = None,
|
||||
) -> list:
|
||||
q = select(UsageLog).where(UsageLog.user_id == user_id)
|
||||
if model:
|
||||
q = q.where(UsageLog.model == model)
|
||||
if key_id:
|
||||
q = q.where(UsageLog.key_id == key_id)
|
||||
if start:
|
||||
q = q.where(UsageLog.request_time >= datetime.combine(start, datetime.min.time()))
|
||||
if end:
|
||||
q = q.where(UsageLog.request_time < datetime.combine(end + timedelta(days=1), datetime.min.time()))
|
||||
|
||||
q = q.order_by(UsageLog.request_time.desc()).offset((page - 1) * size).limit(size)
|
||||
result = await db.execute(q)
|
||||
return result.scalars().all()
|
||||
async def dashboard_api_keys_usage(
|
||||
db: AsyncSession, user: User, api_key_ids: list[int]
|
||||
) -> dict[str, Any]:
|
||||
_require_binding(user)
|
||||
try:
|
||||
token = await ensure_access_token(db, user)
|
||||
return await sub2api_user.dashboard_api_keys_usage(token, api_key_ids)
|
||||
except (Sub2APIError, Sub2APITransportError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
Reference in New Issue
Block a user