first commit

This commit is contained in:
xuyong
2026-04-15 21:35:26 +08:00
commit 7097fa6b44
69 changed files with 5642 additions and 0 deletions

0
app/__init__.py Normal file
View File

0
app/api/__init__.py Normal file
View File

0
app/api/v1/__init__.py Normal file
View File

54
app/api/v1/auth.py Normal file
View File

@@ -0,0 +1,54 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_user
from app.models import User
from app.services.auth_service import AuthService
from app.datamodels.schemas import (
RegisterRequest, LoginRequest, TokenResponse, RefreshRequest,
ForgotPasswordRequest, ResetPasswordRequest, UserResponse, MessageResponse,
)
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/register", response_model=UserResponse)
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
user = await AuthService.register(db, body.email, body.password)
return user
@router.post("/login", response_model=TokenResponse)
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
return await AuthService.login(db, body.email, body.password)
@router.post("/refresh", response_model=TokenResponse)
async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
return await AuthService.refresh(db, body.refresh_token)
@router.post("/logout", response_model=MessageResponse)
async def logout():
return {"message": "Logged out successfully"}
@router.post("/forgot-password", response_model=MessageResponse)
async def forgot_password(body: ForgotPasswordRequest, db: AsyncSession = Depends(get_db)):
token = await AuthService.forgot_password(db, body.email)
if token:
# MVP: print token to console; production: send email
print(f"[Password Reset] email={body.email} token={token}")
return {"message": "If the email exists, a reset link has been sent"}
@router.post("/reset-password", response_model=MessageResponse)
async def reset_password(body: ResetPasswordRequest, db: AsyncSession = Depends(get_db)):
await AuthService.reset_password(db, body.token, body.new_password)
return {"message": "Password reset successfully"}
@router.get("/me", response_model=UserResponse)
async def me(user: User = Depends(get_current_user)):
return user

20
app/api/v1/example.py Normal file
View File

@@ -0,0 +1,20 @@
from fastapi import APIRouter
from app.services.example_service import ExampleService
router = APIRouter()
service = ExampleService()
@router.get("/examples")
async def list_examples():
return await service.list_all()
@router.get("/examples/{example_id}")
async def get_example(example_id: str):
return await service.get_by_id(example_id)
@router.post("/examples")
async def create_example(data: dict):
return await service.create(data)

8
app/api/v1/health.py Normal file
View File

@@ -0,0 +1,8 @@
from fastapi import APIRouter
router = APIRouter()
@router.get("/health")
async def health_check():
return {"status": "ok", "service": "SuperDream"}

39
app/api/v1/keys.py Normal file
View File

@@ -0,0 +1,39 @@
from typing import List
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_user
from app.models import User
from app.services.key_service import KeyService
from app.datamodels.schemas import CreateKeyRequest, ApiKeyResponse, ApiKeyCreatedResponse, MessageResponse
router = APIRouter(prefix="/keys", tags=["keys"])
@router.get("", response_model=List[ApiKeyResponse])
async def list_keys(
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await KeyService.list_keys(db, user.id)
@router.post("", response_model=ApiKeyCreatedResponse)
async def create_key(
body: CreateKeyRequest,
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await KeyService.create_key(db, user.id, body.name)
@router.delete("/{key_id}", response_model=MessageResponse)
async def delete_key(
key_id: str,
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await KeyService.delete_key(db, user.id, key_id)
return {"message": "Key deleted"}

16
app/api/v1/models.py Normal file
View File

@@ -0,0 +1,16 @@
from typing import List
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.services.model_service import ModelService
from app.datamodels.schemas import ModelPricingResponse
router = APIRouter(prefix="/models", tags=["models"])
@router.get("", response_model=List[ModelPricingResponse])
async def list_models(db: AsyncSession = Depends(get_db)):
"""Public endpoint: list available models and pricing."""
return await ModelService.list_models(db)

64
app/api/v1/usage.py Normal file
View File

@@ -0,0 +1,64 @@
from datetime import date
from typing import List, Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_user
from app.models import User
from app.services.usage_service import UsageService
from app.datamodels.schemas import (
UsageSummaryResponse, DailyUsageResponse,
ModelUsageResponse, KeyUsageResponse, UsageLogResponse,
)
router = APIRouter(prefix="/usage", tags=["usage"])
@router.get("/summary", response_model=UsageSummaryResponse)
async def usage_summary(
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await UsageService.summary(db, user.id)
@router.get("/daily", response_model=List[DailyUsageResponse])
async def usage_daily(
start: Optional[date] = Query(None),
end: Optional[date] = Query(None),
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await UsageService.daily(db, user.id, start, end)
@router.get("/by-model", response_model=List[ModelUsageResponse])
async def usage_by_model(
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await UsageService.by_model(db, user.id)
@router.get("/by-key", response_model=List[KeyUsageResponse])
async def usage_by_key(
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await UsageService.by_key(db, user.id)
@router.get("/logs", response_model=List[UsageLogResponse])
async def usage_logs(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
model: Optional[str] = Query(None),
key_id: Optional[str] = Query(None),
start: Optional[date] = Query(None),
end: Optional[date] = Query(None),
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await UsageService.logs(db, user.id, page, size, model, key_id, start, end)

42
app/api/v1/wallet.py Normal file
View File

@@ -0,0 +1,42 @@
from typing import List
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_user
from app.models import User
from app.services.wallet_service import WalletService
from app.datamodels.schemas import (
BalanceResponse, RedeemCodeRequest, TransactionResponse,
)
router = APIRouter(prefix="/wallet", tags=["wallet"])
@router.get("/balance", response_model=BalanceResponse)
async def get_balance(
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
balance = await WalletService.get_balance(db, user.id)
return {"balance": balance}
@router.post("/redeem", response_model=TransactionResponse)
async def redeem_code(
body: RedeemCodeRequest,
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await WalletService.redeem_code(db, user.id, body.code)
@router.get("/transactions", response_model=List[TransactionResponse])
async def list_transactions(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await WalletService.list_transactions(db, user.id, page, size)

0
app/config/__init__.py Normal file
View File

45
app/config/settings.py Normal file
View File

@@ -0,0 +1,45 @@
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
app_name: str = "SuperDream"
debug: bool = False
host: str = "0.0.0.0"
port: int = 18000
# Database
db_type: str = "mysql"
db_host: str = "10.11.0.43"
db_port: int = 3306
db_user: str = "root"
db_password: str = "for_develop_only"
db_name: str = "superdream"
# Storage
data_dir: str = "./data"
# JWT
jwt_secret: str = "superdream-secret-change-me"
jwt_algorithm: str = "HS256"
jwt_access_expire_minutes: int = 30
jwt_refresh_expire_days: int = 7
@property
def database_url(self) -> str:
return (
f"mysql+aiomysql://{self.db_user}:{self.db_password}"
f"@{self.db_host}:{self.db_port}/{self.db_name}"
)
class Config:
env_file = ".env"
env_prefix = "SD_"
@lru_cache
def get_settings() -> Settings:
return Settings()
settings = get_settings()

0
app/core/__init__.py Normal file
View File

21
app/core/database.py Normal file
View File

@@ -0,0 +1,21 @@
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.orm import DeclarativeBase
from app.config.settings import settings
engine = create_async_engine(settings.database_url, echo=settings.debug)
async_session = async_sessionmaker(engine, expire_on_commit=False)
class Base(DeclarativeBase):
pass
async def get_db() -> AsyncSession:
async with async_session() as session:
yield session
async def init_db():
import app.models # noqa: F401 — ensure all models are registered
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

25
app/core/dependencies.py Normal file
View File

@@ -0,0 +1,25 @@
from fastapi import Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db # noqa: F401
from app.core.security import decode_token
from app.core.exceptions import UnauthorizedError
from app.models import User
security_scheme = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
db: AsyncSession = Depends(get_db),
) -> User:
payload = decode_token(credentials.credentials)
if not payload or payload.get("type") != "access":
raise UnauthorizedError("Invalid or expired token")
user_id = payload.get("sub")
user = await db.get(User, user_id)
if not user or user.status != "active":
raise UnauthorizedError("User not found or disabled")
return user

21
app/core/exceptions.py Normal file
View File

@@ -0,0 +1,21 @@
from fastapi import HTTPException
class NotFoundError(HTTPException):
def __init__(self, detail: str = "Resource not found"):
super().__init__(status_code=404, detail=detail)
class BadRequestError(HTTPException):
def __init__(self, detail: str = "Bad request"):
super().__init__(status_code=400, detail=detail)
class UnauthorizedError(HTTPException):
def __init__(self, detail: str = "Not authenticated"):
super().__init__(status_code=401, detail=detail, headers={"WWW-Authenticate": "Bearer"})
class ForbiddenError(HTTPException):
def __init__(self, detail: str = "Forbidden"):
super().__init__(status_code=403, detail=detail)

40
app/core/security.py Normal file
View File

@@ -0,0 +1,40 @@
from datetime import datetime, timedelta
from typing import Optional
import jwt
import bcrypt
from app.config.settings import settings
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(password: str, hashed: str) -> bool:
return bcrypt.checkpw(password.encode(), hashed.encode())
def create_access_token(user_id: str) -> str:
expire = datetime.utcnow() + timedelta(minutes=settings.jwt_access_expire_minutes)
payload = {"sub": user_id, "exp": expire, "type": "access"}
return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm)
def create_refresh_token(user_id: str) -> str:
expire = datetime.utcnow() + timedelta(days=settings.jwt_refresh_expire_days)
payload = {"sub": user_id, "exp": expire, "type": "refresh"}
return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm)
def create_reset_token(user_id: str) -> str:
expire = datetime.utcnow() + timedelta(hours=1)
payload = {"sub": user_id, "exp": expire, "type": "reset"}
return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm)
def decode_token(token: str) -> Optional[dict]:
try:
return jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm])
except jwt.PyJWTError:
return None

View File

179
app/datamodels/schemas.py Normal file
View File

@@ -0,0 +1,179 @@
from decimal import Decimal
from pydantic import BaseModel, EmailStr
from datetime import datetime, date
from typing import Optional, List
# ── Auth ──
class RegisterRequest(BaseModel):
email: str
password: str
class LoginRequest(BaseModel):
email: str
password: str
class TokenResponse(BaseModel):
access_token: str
refresh_token: Optional[str] = None
token_type: str = "bearer"
class RefreshRequest(BaseModel):
refresh_token: str
class ForgotPasswordRequest(BaseModel):
email: str
class ResetPasswordRequest(BaseModel):
token: str
new_password: str
class UserResponse(BaseModel):
id: str
email: str
balance: Decimal
status: str
created_at: datetime
class Config:
from_attributes = True
# ── API Key ──
class CreateKeyRequest(BaseModel):
name: str = ""
class ApiKeyResponse(BaseModel):
id: str
name: str
key_prefix: str
key_suffix: str
status: str
created_at: datetime
class Config:
from_attributes = True
class ApiKeyCreatedResponse(BaseModel):
id: str
name: str
key: str
key_prefix: str
key_suffix: str
created_at: datetime
# ── Wallet ──
class RedeemCodeRequest(BaseModel):
code: str
class TransactionResponse(BaseModel):
id: str
type: str
amount: Decimal
balance_after: Decimal
reference_id: str
created_at: datetime
class Config:
from_attributes = True
class BalanceResponse(BaseModel):
balance: Decimal
# ── Models ──
class ModelPricingResponse(BaseModel):
id: int
model_name: str
provider: str
input_price_per_1k: Decimal
output_price_per_1k: Decimal
status: str
updated_at: datetime
class Config:
from_attributes = True
# ── Example (legacy) ──
class ExampleCreate(BaseModel):
name: str
description: str = ""
class ExampleResponse(BaseModel):
id: str
name: str
description: str
created_at: datetime
# ── Common ──
class MessageResponse(BaseModel):
message: str
# ── Usage ──
class UsageSummaryResponse(BaseModel):
today_tokens: int
today_cost: Decimal
month_tokens: int
month_cost: Decimal
total_requests: int
class DailyUsageResponse(BaseModel):
date: date
total_tokens: int
cost: Decimal
requests: int
class ModelUsageResponse(BaseModel):
model: str
total_tokens: int
cost: Decimal
requests: int
class KeyUsageResponse(BaseModel):
key_id: str
key_name: str
key_prefix: str
key_suffix: str
total_tokens: int
cost: Decimal
requests: int
class UsageLogResponse(BaseModel):
id: int
key_id: str
model: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
cost: Decimal
request_time: datetime
status: str
class Config:
from_attributes = True

63
app/main.py Normal file
View File

@@ -0,0 +1,63 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from app.api.v1 import health, example, auth, keys, models as models_api, wallet, usage
from app.config.settings import settings
from app.core.database import init_db
import os
@asynccontextmanager
async def lifespan(app: FastAPI):
await init_db()
yield
def create_app() -> FastAPI:
app = FastAPI(
title="SuperDream",
description="SuperDream API",
version="0.1.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Register API routers
app.include_router(health.router, prefix="/api/v1", tags=["health"])
app.include_router(auth.router, prefix="/api/v1")
app.include_router(keys.router, prefix="/api/v1")
app.include_router(models_api.router, prefix="/api/v1")
app.include_router(wallet.router, prefix="/api/v1")
app.include_router(usage.router, prefix="/api/v1")
app.include_router(example.router, prefix="/api/v1", tags=["example"])
# Serve static files
static_dir = os.path.join(os.path.dirname(__file__), "static")
if os.path.exists(static_dir):
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# Serve frontend dist in production
frontend_dist = os.path.join(os.path.dirname(__file__), "..", "frontend", "dist")
if os.path.exists(frontend_dist):
app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dist, "assets")), name="frontend-assets")
@app.get("/{full_path:path}")
async def serve_frontend(full_path: str):
file_path = os.path.join(frontend_dist, full_path)
if os.path.isfile(file_path):
return FileResponse(file_path)
return FileResponse(os.path.join(frontend_dist, "index.html"))
return app
app = create_app()

87
app/models/__init__.py Normal file
View File

@@ -0,0 +1,87 @@
from __future__ import annotations
import uuid
from datetime import datetime
from decimal import Decimal
from typing import List
from sqlalchemy import String, Text, Integer, BigInteger, DateTime, Numeric, Enum, ForeignKey, Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.core.database import Base
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
balance: Mapped[Decimal] = mapped_column(Numeric(16, 6), default=Decimal("0"))
status: Mapped[str] = mapped_column(String(20), default="active") # active / disabled
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
api_keys: Mapped[List[ApiKey]] = relationship(back_populates="user")
transactions: Mapped[List[Transaction]] = relationship(back_populates="user")
class ApiKey(Base):
__tablename__ = "api_keys"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
user_id: Mapped[str] = mapped_column(String(36), ForeignKey("users.id"), nullable=False, index=True)
name: Mapped[str] = mapped_column(String(100), default="")
key_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False) # SHA256
key_prefix: Mapped[str] = mapped_column(String(10), nullable=False)
key_suffix: Mapped[str] = mapped_column(String(10), nullable=False)
status: Mapped[str] = mapped_column(String(20), default="active") # active / revoked
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
user: Mapped["User"] = relationship(back_populates="api_keys")
class UsageLog(Base):
__tablename__ = "usage_logs"
__table_args__ = (
Index("ix_usage_user_time", "user_id", "request_time"),
Index("ix_usage_model", "user_id", "model"),
)
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
user_id: Mapped[str] = mapped_column(String(36), ForeignKey("users.id"), nullable=False)
key_id: Mapped[str] = mapped_column(String(36), ForeignKey("api_keys.id"), nullable=False)
model: Mapped[str] = mapped_column(String(100), nullable=False)
prompt_tokens: Mapped[int] = mapped_column(Integer, default=0)
completion_tokens: Mapped[int] = mapped_column(Integer, default=0)
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
cost: Mapped[Decimal] = mapped_column(Numeric(16, 6), default=Decimal("0"))
request_time: Mapped[datetime] = mapped_column(DateTime, nullable=False)
response_time: Mapped[datetime] = mapped_column(DateTime, nullable=True)
status: Mapped[str] = mapped_column(String(20), default="success") # success / error
class Transaction(Base):
__tablename__ = "transactions"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
user_id: Mapped[str] = mapped_column(String(36), ForeignKey("users.id"), nullable=False, index=True)
type: Mapped[str] = mapped_column(String(20), nullable=False) # topup / consume / refund
amount: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
balance_after: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
reference_id: Mapped[str] = mapped_column(String(100), default="")
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
user: Mapped["User"] = relationship(back_populates="transactions")
class ModelPricing(Base):
__tablename__ = "models_pricing"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
model_name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
provider: Mapped[str] = mapped_column(String(50), nullable=False)
input_price_per_1k: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
output_price_per_1k: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
status: Mapped[str] = mapped_column(String(20), default="available") # available / offline
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())

0
app/services/__init__.py Normal file
View File

View File

@@ -0,0 +1,91 @@
import re
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import User
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
from app.core.exceptions import BadRequestError, UnauthorizedError
def _validate_password(password: str) -> None:
if len(password) < 6:
raise BadRequestError("密码至少 6 位")
has_letter = bool(re.search(r"[a-zA-Z]", password))
has_digit = bool(re.search(r"[0-9]", password))
if not (has_letter and has_digit):
raise BadRequestError("密码需同时包含字母和数字")
class AuthService:
@staticmethod
async def register(db: AsyncSession, email: str, password: str) -> User:
_validate_password(password)
existing = await db.execute(select(User).where(User.email == email))
if existing.scalar_one_or_none():
raise BadRequestError("Email already registered")
user = User(email=email, password_hash=hash_password(password))
db.add(user)
await db.commit()
await db.refresh(user)
return user
@staticmethod
async def login(db: AsyncSession, email: str, password: str) -> dict:
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
if not user or not verify_password(password, user.password_hash):
raise UnauthorizedError("Invalid email or password")
if user.status != "active":
raise UnauthorizedError("Account is disabled")
return {
"access_token": create_access_token(user.id),
"refresh_token": create_refresh_token(user.id),
"token_type": "bearer",
}
@staticmethod
async def refresh(db: AsyncSession, refresh_token: str) -> dict:
payload = decode_token(refresh_token)
if not payload or payload.get("type") != "refresh":
raise UnauthorizedError("Invalid refresh token")
user = await db.get(User, payload["sub"])
if not user or user.status != "active":
raise UnauthorizedError("User not found or disabled")
return {
"access_token": create_access_token(user.id),
"token_type": "bearer",
}
@staticmethod
async def forgot_password(db: AsyncSession, email: str) -> str:
"""Generate a password reset token. MVP: returns the token directly (production: send via email)."""
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
if not user:
# Don't reveal whether email exists
return ""
from app.core.security import create_reset_token
return create_reset_token(user.id)
@staticmethod
async def reset_password(db: AsyncSession, token: str, new_password: str) -> None:
_validate_password(new_password)
payload = decode_token(token)
if not payload or payload.get("type") != "reset":
raise BadRequestError("Invalid or expired reset token")
user = await db.get(User, payload["sub"])
if not user:
raise BadRequestError("User not found")
user.password_hash = hash_password(new_password)
await db.commit()

View File

@@ -0,0 +1,24 @@
import uuid
from datetime import datetime
from typing import Dict, List, Optional
class ExampleService:
def __init__(self):
self._store: Dict[str, dict] = {}
async def list_all(self) -> List[dict]:
return list(self._store.values())
async def get_by_id(self, example_id: str) -> Optional[dict]:
return self._store.get(example_id)
async def create(self, data: dict) -> dict:
example_id = str(uuid.uuid4())
item = {
"id": example_id,
"created_at": datetime.utcnow().isoformat(),
**data,
}
self._store[example_id] = item
return item

View File

@@ -0,0 +1,84 @@
import hashlib
import secrets
import uuid
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import ApiKey
from app.core.exceptions import BadRequestError, NotFoundError
MAX_KEYS_PER_USER = 5
KEY_PREFIX = "sk-sd-"
class KeyService:
@staticmethod
def _generate_key() -> str:
return KEY_PREFIX + secrets.token_urlsafe(36)
@staticmethod
def _hash_key(raw_key: str) -> str:
return hashlib.sha256(raw_key.encode()).hexdigest()
@staticmethod
async def list_keys(db: AsyncSession, user_id: str) -> list:
result = await db.execute(
select(ApiKey)
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
.order_by(ApiKey.created_at.desc())
)
return result.scalars().all()
@staticmethod
async def create_key(db: AsyncSession, user_id: str, name: str = "") -> dict:
# Check limit
count_result = await db.execute(
select(func.count()).select_from(ApiKey)
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
)
count = count_result.scalar()
if count >= MAX_KEYS_PER_USER:
raise BadRequestError(f"最多创建 {MAX_KEYS_PER_USER} 个 Key")
raw_key = KeyService._generate_key()
key_hash = KeyService._hash_key(raw_key)
# prefix/suffix for masked display (after "sk-sd-")
body = raw_key[len(KEY_PREFIX):]
key_prefix = body[:4]
key_suffix = body[-4:]
api_key = ApiKey(
id=str(uuid.uuid4()),
user_id=user_id,
name=name,
key_hash=key_hash,
key_prefix=key_prefix,
key_suffix=key_suffix,
)
db.add(api_key)
await db.commit()
await db.refresh(api_key)
return {
"id": api_key.id,
"name": api_key.name,
"key": raw_key, # only returned once
"key_prefix": key_prefix,
"key_suffix": key_suffix,
"created_at": api_key.created_at,
}
@staticmethod
async def delete_key(db: AsyncSession, user_id: str, key_id: str) -> None:
result = await db.execute(
select(ApiKey).where(ApiKey.id == key_id, ApiKey.user_id == user_id)
)
api_key = result.scalar_one_or_none()
if not api_key:
raise NotFoundError("Key not found")
api_key.status = "revoked"
await db.commit()

View File

@@ -0,0 +1,16 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import ModelPricing
class ModelService:
@staticmethod
async def list_models(db: AsyncSession) -> list:
result = await db.execute(
select(ModelPricing)
.where(ModelPricing.status == "available")
.order_by(ModelPricing.provider, ModelPricing.model_name)
)
return result.scalars().all()

View File

@@ -0,0 +1,139 @@
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Optional
from sqlalchemy import select, func, cast, Date
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import UsageLog, ApiKey
class UsageService:
@staticmethod
async def summary(db: AsyncSession, user_id: str) -> dict:
today_start = datetime.combine(date.today(), datetime.min.time())
month_start = today_start.replace(day=1)
# Today
today_row = (await db.execute(
select(
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
).where(UsageLog.user_id == user_id, UsageLog.request_time >= today_start)
)).one()
# This month
month_row = (await db.execute(
select(
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
).where(UsageLog.user_id == user_id, UsageLog.request_time >= month_start)
)).one()
return {
"today_tokens": int(today_row[0]),
"today_cost": today_row[1],
"month_tokens": int(month_row[0]),
"month_cost": month_row[1],
"total_requests": int(month_row[2]),
}
@staticmethod
async def daily(
db: AsyncSession, user_id: str,
start: Optional[date] = None, end: Optional[date] = None,
) -> list:
if not start:
start = date.today() - timedelta(days=29)
if not end:
end = date.today()
day_col = cast(UsageLog.request_time, Date).label("day")
result = await db.execute(
select(
day_col,
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
)
.where(
UsageLog.user_id == user_id,
cast(UsageLog.request_time, Date) >= start,
cast(UsageLog.request_time, Date) <= end,
)
.group_by(day_col)
.order_by(day_col)
)
return [
{"date": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
for row in result.all()
]
@staticmethod
async def by_model(db: AsyncSession, user_id: str) -> list:
result = await db.execute(
select(
UsageLog.model,
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
)
.where(UsageLog.user_id == user_id)
.group_by(UsageLog.model)
.order_by(func.sum(UsageLog.cost).desc())
)
return [
{"model": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
for row in result.all()
]
@staticmethod
async def by_key(db: AsyncSession, user_id: str) -> list:
result = await db.execute(
select(
UsageLog.key_id,
ApiKey.name,
ApiKey.key_prefix,
ApiKey.key_suffix,
func.coalesce(func.sum(UsageLog.total_tokens), 0),
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
func.count(),
)
.join(ApiKey, UsageLog.key_id == ApiKey.id)
.where(UsageLog.user_id == user_id)
.group_by(UsageLog.key_id, ApiKey.name, ApiKey.key_prefix, ApiKey.key_suffix)
.order_by(func.sum(UsageLog.cost).desc())
)
return [
{
"key_id": row[0], "key_name": row[1] or "",
"key_prefix": row[2], "key_suffix": row[3],
"total_tokens": int(row[4]), "cost": row[5], "requests": int(row[6]),
}
for row in result.all()
]
@staticmethod
async def logs(
db: AsyncSession, user_id: str,
page: int = 1, size: int = 20,
model: Optional[str] = None,
key_id: Optional[str] = None,
start: Optional[date] = None,
end: Optional[date] = None,
) -> list:
q = select(UsageLog).where(UsageLog.user_id == user_id)
if model:
q = q.where(UsageLog.model == model)
if key_id:
q = q.where(UsageLog.key_id == key_id)
if start:
q = q.where(UsageLog.request_time >= datetime.combine(start, datetime.min.time()))
if end:
q = q.where(UsageLog.request_time < datetime.combine(end + timedelta(days=1), datetime.min.time()))
q = q.order_by(UsageLog.request_time.desc()).offset((page - 1) * size).limit(size)
result = await db.execute(q)
return result.scalars().all()

View File

@@ -0,0 +1,60 @@
import uuid
from decimal import Decimal
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import User, Transaction
from app.core.exceptions import BadRequestError
# MVP: hardcoded redeem codes → amount mapping
REDEEM_CODES = {
"SUPERDREAM10": Decimal("10"),
"SUPERDREAM50": Decimal("50"),
"SUPERDREAM100": Decimal("100"),
}
class WalletService:
@staticmethod
async def get_balance(db: AsyncSession, user_id: str) -> Decimal:
user = await db.get(User, user_id)
return user.balance
@staticmethod
async def redeem_code(db: AsyncSession, user_id: str, code: str) -> Transaction:
amount = REDEEM_CODES.get(code.upper())
if not amount:
raise BadRequestError("无效的兑换码")
user = await db.get(User, user_id)
user.balance += amount
new_balance = user.balance
txn = Transaction(
id=str(uuid.uuid4()),
user_id=user_id,
type="topup",
amount=amount,
balance_after=new_balance,
reference_id=f"redeem:{code.upper()}",
)
db.add(txn)
await db.commit()
await db.refresh(txn)
return txn
@staticmethod
async def list_transactions(
db: AsyncSession, user_id: str, page: int = 1, size: int = 20
) -> list:
offset = (page - 1) * size
result = await db.execute(
select(Transaction)
.where(Transaction.user_id == user_id)
.order_by(Transaction.created_at.desc())
.offset(offset)
.limit(size)
)
return result.scalars().all()