first commit
This commit is contained in:
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/api/__init__.py
Normal file
0
app/api/__init__.py
Normal file
0
app/api/v1/__init__.py
Normal file
0
app/api/v1/__init__.py
Normal file
54
app/api/v1/auth.py
Normal file
54
app/api/v1/auth.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.services.auth_service import AuthService
|
||||
from app.datamodels.schemas import (
|
||||
RegisterRequest, LoginRequest, TokenResponse, RefreshRequest,
|
||||
ForgotPasswordRequest, ResetPasswordRequest, UserResponse, MessageResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse)
|
||||
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
user = await AuthService.register(db, body.email, body.password)
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
return await AuthService.login(db, body.email, body.password)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||
return await AuthService.refresh(db, body.refresh_token)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=MessageResponse)
|
||||
async def logout():
|
||||
return {"message": "Logged out successfully"}
|
||||
|
||||
|
||||
@router.post("/forgot-password", response_model=MessageResponse)
|
||||
async def forgot_password(body: ForgotPasswordRequest, db: AsyncSession = Depends(get_db)):
|
||||
token = await AuthService.forgot_password(db, body.email)
|
||||
if token:
|
||||
# MVP: print token to console; production: send email
|
||||
print(f"[Password Reset] email={body.email} token={token}")
|
||||
return {"message": "If the email exists, a reset link has been sent"}
|
||||
|
||||
|
||||
@router.post("/reset-password", response_model=MessageResponse)
|
||||
async def reset_password(body: ResetPasswordRequest, db: AsyncSession = Depends(get_db)):
|
||||
await AuthService.reset_password(db, body.token, body.new_password)
|
||||
return {"message": "Password reset successfully"}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def me(user: User = Depends(get_current_user)):
|
||||
return user
|
||||
20
app/api/v1/example.py
Normal file
20
app/api/v1/example.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from fastapi import APIRouter
|
||||
from app.services.example_service import ExampleService
|
||||
|
||||
router = APIRouter()
|
||||
service = ExampleService()
|
||||
|
||||
|
||||
@router.get("/examples")
|
||||
async def list_examples():
|
||||
return await service.list_all()
|
||||
|
||||
|
||||
@router.get("/examples/{example_id}")
|
||||
async def get_example(example_id: str):
|
||||
return await service.get_by_id(example_id)
|
||||
|
||||
|
||||
@router.post("/examples")
|
||||
async def create_example(data: dict):
|
||||
return await service.create(data)
|
||||
8
app/api/v1/health.py
Normal file
8
app/api/v1/health.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "ok", "service": "SuperDream"}
|
||||
39
app/api/v1/keys.py
Normal file
39
app/api/v1/keys.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.services.key_service import KeyService
|
||||
from app.datamodels.schemas import CreateKeyRequest, ApiKeyResponse, ApiKeyCreatedResponse, MessageResponse
|
||||
|
||||
router = APIRouter(prefix="/keys", tags=["keys"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[ApiKeyResponse])
|
||||
async def list_keys(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await KeyService.list_keys(db, user.id)
|
||||
|
||||
|
||||
@router.post("", response_model=ApiKeyCreatedResponse)
|
||||
async def create_key(
|
||||
body: CreateKeyRequest,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await KeyService.create_key(db, user.id, body.name)
|
||||
|
||||
|
||||
@router.delete("/{key_id}", response_model=MessageResponse)
|
||||
async def delete_key(
|
||||
key_id: str,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await KeyService.delete_key(db, user.id, key_id)
|
||||
return {"message": "Key deleted"}
|
||||
16
app/api/v1/models.py
Normal file
16
app/api/v1/models.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.services.model_service import ModelService
|
||||
from app.datamodels.schemas import ModelPricingResponse
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[ModelPricingResponse])
|
||||
async def list_models(db: AsyncSession = Depends(get_db)):
|
||||
"""Public endpoint: list available models and pricing."""
|
||||
return await ModelService.list_models(db)
|
||||
64
app/api/v1/usage.py
Normal file
64
app/api/v1/usage.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from datetime import date
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.services.usage_service import UsageService
|
||||
from app.datamodels.schemas import (
|
||||
UsageSummaryResponse, DailyUsageResponse,
|
||||
ModelUsageResponse, KeyUsageResponse, UsageLogResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/usage", tags=["usage"])
|
||||
|
||||
|
||||
@router.get("/summary", response_model=UsageSummaryResponse)
|
||||
async def usage_summary(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await UsageService.summary(db, user.id)
|
||||
|
||||
|
||||
@router.get("/daily", response_model=List[DailyUsageResponse])
|
||||
async def usage_daily(
|
||||
start: Optional[date] = Query(None),
|
||||
end: Optional[date] = Query(None),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await UsageService.daily(db, user.id, start, end)
|
||||
|
||||
|
||||
@router.get("/by-model", response_model=List[ModelUsageResponse])
|
||||
async def usage_by_model(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await UsageService.by_model(db, user.id)
|
||||
|
||||
|
||||
@router.get("/by-key", response_model=List[KeyUsageResponse])
|
||||
async def usage_by_key(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await UsageService.by_key(db, user.id)
|
||||
|
||||
|
||||
@router.get("/logs", response_model=List[UsageLogResponse])
|
||||
async def usage_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
model: Optional[str] = Query(None),
|
||||
key_id: Optional[str] = Query(None),
|
||||
start: Optional[date] = Query(None),
|
||||
end: Optional[date] = Query(None),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await UsageService.logs(db, user.id, page, size, model, key_id, start, end)
|
||||
42
app/api/v1/wallet.py
Normal file
42
app/api/v1/wallet.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.services.wallet_service import WalletService
|
||||
from app.datamodels.schemas import (
|
||||
BalanceResponse, RedeemCodeRequest, TransactionResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/wallet", tags=["wallet"])
|
||||
|
||||
|
||||
@router.get("/balance", response_model=BalanceResponse)
|
||||
async def get_balance(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
balance = await WalletService.get_balance(db, user.id)
|
||||
return {"balance": balance}
|
||||
|
||||
|
||||
@router.post("/redeem", response_model=TransactionResponse)
|
||||
async def redeem_code(
|
||||
body: RedeemCodeRequest,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await WalletService.redeem_code(db, user.id, body.code)
|
||||
|
||||
|
||||
@router.get("/transactions", response_model=List[TransactionResponse])
|
||||
async def list_transactions(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await WalletService.list_transactions(db, user.id, page, size)
|
||||
0
app/config/__init__.py
Normal file
0
app/config/__init__.py
Normal file
45
app/config/settings.py
Normal file
45
app/config/settings.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
app_name: str = "SuperDream"
|
||||
debug: bool = False
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 18000
|
||||
|
||||
# Database
|
||||
db_type: str = "mysql"
|
||||
db_host: str = "10.11.0.43"
|
||||
db_port: int = 3306
|
||||
db_user: str = "root"
|
||||
db_password: str = "for_develop_only"
|
||||
db_name: str = "superdream"
|
||||
|
||||
# Storage
|
||||
data_dir: str = "./data"
|
||||
|
||||
# JWT
|
||||
jwt_secret: str = "superdream-secret-change-me"
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_access_expire_minutes: int = 30
|
||||
jwt_refresh_expire_days: int = 7
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
return (
|
||||
f"mysql+aiomysql://{self.db_user}:{self.db_password}"
|
||||
f"@{self.db_host}:{self.db_port}/{self.db_name}"
|
||||
)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_prefix = "SD_"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
21
app/core/database.py
Normal file
21
app/core/database.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from app.config.settings import settings
|
||||
|
||||
engine = create_async_engine(settings.database_url, echo=settings.debug)
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def init_db():
|
||||
import app.models # noqa: F401 — ensure all models are registered
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
25
app/core/dependencies.py
Normal file
25
app/core/dependencies.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from fastapi import Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db # noqa: F401
|
||||
from app.core.security import decode_token
|
||||
from app.core.exceptions import UnauthorizedError
|
||||
from app.models import User
|
||||
|
||||
security_scheme = HTTPBearer()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
payload = decode_token(credentials.credentials)
|
||||
if not payload or payload.get("type") != "access":
|
||||
raise UnauthorizedError("Invalid or expired token")
|
||||
|
||||
user_id = payload.get("sub")
|
||||
user = await db.get(User, user_id)
|
||||
if not user or user.status != "active":
|
||||
raise UnauthorizedError("User not found or disabled")
|
||||
return user
|
||||
21
app/core/exceptions.py
Normal file
21
app/core/exceptions.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
class NotFoundError(HTTPException):
|
||||
def __init__(self, detail: str = "Resource not found"):
|
||||
super().__init__(status_code=404, detail=detail)
|
||||
|
||||
|
||||
class BadRequestError(HTTPException):
|
||||
def __init__(self, detail: str = "Bad request"):
|
||||
super().__init__(status_code=400, detail=detail)
|
||||
|
||||
|
||||
class UnauthorizedError(HTTPException):
|
||||
def __init__(self, detail: str = "Not authenticated"):
|
||||
super().__init__(status_code=401, detail=detail, headers={"WWW-Authenticate": "Bearer"})
|
||||
|
||||
|
||||
class ForbiddenError(HTTPException):
|
||||
def __init__(self, detail: str = "Forbidden"):
|
||||
super().__init__(status_code=403, detail=detail)
|
||||
40
app/core/security.py
Normal file
40
app/core/security.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import jwt
|
||||
import bcrypt
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(password: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
|
||||
|
||||
def create_access_token(user_id: str) -> str:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.jwt_access_expire_minutes)
|
||||
payload = {"sub": user_id, "exp": expire, "type": "access"}
|
||||
return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: str) -> str:
|
||||
expire = datetime.utcnow() + timedelta(days=settings.jwt_refresh_expire_days)
|
||||
payload = {"sub": user_id, "exp": expire, "type": "refresh"}
|
||||
return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def create_reset_token(user_id: str) -> str:
|
||||
expire = datetime.utcnow() + timedelta(hours=1)
|
||||
payload = {"sub": user_id, "exp": expire, "type": "reset"}
|
||||
return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[dict]:
|
||||
try:
|
||||
return jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm])
|
||||
except jwt.PyJWTError:
|
||||
return None
|
||||
0
app/datamodels/__init__.py
Normal file
0
app/datamodels/__init__.py
Normal file
179
app/datamodels/schemas.py
Normal file
179
app/datamodels/schemas.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from decimal import Decimal
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from datetime import datetime, date
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
# ── Auth ──
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class ForgotPasswordRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
balance: Decimal
|
||||
status: str
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ── API Key ──
|
||||
|
||||
class CreateKeyRequest(BaseModel):
|
||||
name: str = ""
|
||||
|
||||
|
||||
class ApiKeyResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
key_prefix: str
|
||||
key_suffix: str
|
||||
status: str
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ApiKeyCreatedResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
key: str
|
||||
key_prefix: str
|
||||
key_suffix: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ── Wallet ──
|
||||
|
||||
class RedeemCodeRequest(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
class TransactionResponse(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
amount: Decimal
|
||||
balance_after: Decimal
|
||||
reference_id: str
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class BalanceResponse(BaseModel):
|
||||
balance: Decimal
|
||||
|
||||
|
||||
# ── Models ──
|
||||
|
||||
class ModelPricingResponse(BaseModel):
|
||||
id: int
|
||||
model_name: str
|
||||
provider: str
|
||||
input_price_per_1k: Decimal
|
||||
output_price_per_1k: Decimal
|
||||
status: str
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ── Example (legacy) ──
|
||||
|
||||
class ExampleCreate(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class ExampleResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ── Common ──
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
# ── Usage ──
|
||||
|
||||
class UsageSummaryResponse(BaseModel):
|
||||
today_tokens: int
|
||||
today_cost: Decimal
|
||||
month_tokens: int
|
||||
month_cost: Decimal
|
||||
total_requests: int
|
||||
|
||||
|
||||
class DailyUsageResponse(BaseModel):
|
||||
date: date
|
||||
total_tokens: int
|
||||
cost: Decimal
|
||||
requests: int
|
||||
|
||||
|
||||
class ModelUsageResponse(BaseModel):
|
||||
model: str
|
||||
total_tokens: int
|
||||
cost: Decimal
|
||||
requests: int
|
||||
|
||||
|
||||
class KeyUsageResponse(BaseModel):
|
||||
key_id: str
|
||||
key_name: str
|
||||
key_prefix: str
|
||||
key_suffix: str
|
||||
total_tokens: int
|
||||
cost: Decimal
|
||||
requests: int
|
||||
|
||||
|
||||
class UsageLogResponse(BaseModel):
|
||||
id: int
|
||||
key_id: str
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
cost: Decimal
|
||||
request_time: datetime
|
||||
status: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
63
app/main.py
Normal file
63
app/main.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
from app.api.v1 import health, example, auth, keys, models as models_api, wallet, usage
|
||||
from app.config.settings import settings
|
||||
from app.core.database import init_db
|
||||
import os
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await init_db()
|
||||
yield
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="SuperDream",
|
||||
description="SuperDream API",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Register API routers
|
||||
app.include_router(health.router, prefix="/api/v1", tags=["health"])
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(keys.router, prefix="/api/v1")
|
||||
app.include_router(models_api.router, prefix="/api/v1")
|
||||
app.include_router(wallet.router, prefix="/api/v1")
|
||||
app.include_router(usage.router, prefix="/api/v1")
|
||||
app.include_router(example.router, prefix="/api/v1", tags=["example"])
|
||||
|
||||
# Serve static files
|
||||
static_dir = os.path.join(os.path.dirname(__file__), "static")
|
||||
if os.path.exists(static_dir):
|
||||
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
||||
|
||||
# Serve frontend dist in production
|
||||
frontend_dist = os.path.join(os.path.dirname(__file__), "..", "frontend", "dist")
|
||||
if os.path.exists(frontend_dist):
|
||||
app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dist, "assets")), name="frontend-assets")
|
||||
|
||||
@app.get("/{full_path:path}")
|
||||
async def serve_frontend(full_path: str):
|
||||
file_path = os.path.join(frontend_dist, full_path)
|
||||
if os.path.isfile(file_path):
|
||||
return FileResponse(file_path)
|
||||
return FileResponse(os.path.join(frontend_dist, "index.html"))
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
87
app/models/__init__.py
Normal file
87
app/models/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import String, Text, Integer, BigInteger, DateTime, Numeric, Enum, ForeignKey, Index, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
balance: Mapped[Decimal] = mapped_column(Numeric(16, 6), default=Decimal("0"))
|
||||
status: Mapped[str] = mapped_column(String(20), default="active") # active / disabled
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
|
||||
|
||||
api_keys: Mapped[List[ApiKey]] = relationship(back_populates="user")
|
||||
transactions: Mapped[List[Transaction]] = relationship(back_populates="user")
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id: Mapped[str] = mapped_column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
name: Mapped[str] = mapped_column(String(100), default="")
|
||||
key_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False) # SHA256
|
||||
key_prefix: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
key_suffix: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(20), default="active") # active / revoked
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="api_keys")
|
||||
|
||||
|
||||
class UsageLog(Base):
|
||||
__tablename__ = "usage_logs"
|
||||
__table_args__ = (
|
||||
Index("ix_usage_user_time", "user_id", "request_time"),
|
||||
Index("ix_usage_model", "user_id", "model"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[str] = mapped_column(String(36), ForeignKey("users.id"), nullable=False)
|
||||
key_id: Mapped[str] = mapped_column(String(36), ForeignKey("api_keys.id"), nullable=False)
|
||||
model: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
prompt_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
completion_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
cost: Mapped[Decimal] = mapped_column(Numeric(16, 6), default=Decimal("0"))
|
||||
request_time: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
response_time: Mapped[datetime] = mapped_column(DateTime, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="success") # success / error
|
||||
|
||||
|
||||
class Transaction(Base):
|
||||
__tablename__ = "transactions"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id: Mapped[str] = mapped_column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
type: Mapped[str] = mapped_column(String(20), nullable=False) # topup / consume / refund
|
||||
amount: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
|
||||
balance_after: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
|
||||
reference_id: Mapped[str] = mapped_column(String(100), default="")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="transactions")
|
||||
|
||||
|
||||
class ModelPricing(Base):
|
||||
__tablename__ = "models_pricing"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
model_name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
input_price_per_1k: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
|
||||
output_price_per_1k: Mapped[Decimal] = mapped_column(Numeric(16, 6), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(20), default="available") # available / offline
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
|
||||
0
app/services/__init__.py
Normal file
0
app/services/__init__.py
Normal file
91
app/services/auth_service.py
Normal file
91
app/services/auth_service.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import re
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import User
|
||||
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
|
||||
from app.core.exceptions import BadRequestError, UnauthorizedError
|
||||
|
||||
|
||||
def _validate_password(password: str) -> None:
|
||||
if len(password) < 6:
|
||||
raise BadRequestError("密码至少 6 位")
|
||||
has_letter = bool(re.search(r"[a-zA-Z]", password))
|
||||
has_digit = bool(re.search(r"[0-9]", password))
|
||||
if not (has_letter and has_digit):
|
||||
raise BadRequestError("密码需同时包含字母和数字")
|
||||
|
||||
|
||||
class AuthService:
|
||||
|
||||
@staticmethod
|
||||
async def register(db: AsyncSession, email: str, password: str) -> User:
|
||||
_validate_password(password)
|
||||
|
||||
existing = await db.execute(select(User).where(User.email == email))
|
||||
if existing.scalar_one_or_none():
|
||||
raise BadRequestError("Email already registered")
|
||||
|
||||
user = User(email=email, password_hash=hash_password(password))
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def login(db: AsyncSession, email: str, password: str) -> dict:
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not verify_password(password, user.password_hash):
|
||||
raise UnauthorizedError("Invalid email or password")
|
||||
if user.status != "active":
|
||||
raise UnauthorizedError("Account is disabled")
|
||||
|
||||
return {
|
||||
"access_token": create_access_token(user.id),
|
||||
"refresh_token": create_refresh_token(user.id),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def refresh(db: AsyncSession, refresh_token: str) -> dict:
|
||||
payload = decode_token(refresh_token)
|
||||
if not payload or payload.get("type") != "refresh":
|
||||
raise UnauthorizedError("Invalid refresh token")
|
||||
|
||||
user = await db.get(User, payload["sub"])
|
||||
if not user or user.status != "active":
|
||||
raise UnauthorizedError("User not found or disabled")
|
||||
|
||||
return {
|
||||
"access_token": create_access_token(user.id),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def forgot_password(db: AsyncSession, email: str) -> str:
|
||||
"""Generate a password reset token. MVP: returns the token directly (production: send via email)."""
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
# Don't reveal whether email exists
|
||||
return ""
|
||||
|
||||
from app.core.security import create_reset_token
|
||||
return create_reset_token(user.id)
|
||||
|
||||
@staticmethod
|
||||
async def reset_password(db: AsyncSession, token: str, new_password: str) -> None:
|
||||
_validate_password(new_password)
|
||||
|
||||
payload = decode_token(token)
|
||||
if not payload or payload.get("type") != "reset":
|
||||
raise BadRequestError("Invalid or expired reset token")
|
||||
|
||||
user = await db.get(User, payload["sub"])
|
||||
if not user:
|
||||
raise BadRequestError("User not found")
|
||||
|
||||
user.password_hash = hash_password(new_password)
|
||||
await db.commit()
|
||||
24
app/services/example_service.py
Normal file
24
app/services/example_service.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class ExampleService:
|
||||
def __init__(self):
|
||||
self._store: Dict[str, dict] = {}
|
||||
|
||||
async def list_all(self) -> List[dict]:
|
||||
return list(self._store.values())
|
||||
|
||||
async def get_by_id(self, example_id: str) -> Optional[dict]:
|
||||
return self._store.get(example_id)
|
||||
|
||||
async def create(self, data: dict) -> dict:
|
||||
example_id = str(uuid.uuid4())
|
||||
item = {
|
||||
"id": example_id,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
**data,
|
||||
}
|
||||
self._store[example_id] = item
|
||||
return item
|
||||
84
app/services/key_service.py
Normal file
84
app/services/key_service.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import ApiKey
|
||||
from app.core.exceptions import BadRequestError, NotFoundError
|
||||
|
||||
MAX_KEYS_PER_USER = 5
|
||||
KEY_PREFIX = "sk-sd-"
|
||||
|
||||
|
||||
class KeyService:
|
||||
|
||||
@staticmethod
|
||||
def _generate_key() -> str:
|
||||
return KEY_PREFIX + secrets.token_urlsafe(36)
|
||||
|
||||
@staticmethod
|
||||
def _hash_key(raw_key: str) -> str:
|
||||
return hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
async def list_keys(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(ApiKey)
|
||||
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
|
||||
.order_by(ApiKey.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@staticmethod
|
||||
async def create_key(db: AsyncSession, user_id: str, name: str = "") -> dict:
|
||||
# Check limit
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(ApiKey)
|
||||
.where(ApiKey.user_id == user_id, ApiKey.status == "active")
|
||||
)
|
||||
count = count_result.scalar()
|
||||
if count >= MAX_KEYS_PER_USER:
|
||||
raise BadRequestError(f"最多创建 {MAX_KEYS_PER_USER} 个 Key")
|
||||
|
||||
raw_key = KeyService._generate_key()
|
||||
key_hash = KeyService._hash_key(raw_key)
|
||||
|
||||
# prefix/suffix for masked display (after "sk-sd-")
|
||||
body = raw_key[len(KEY_PREFIX):]
|
||||
key_prefix = body[:4]
|
||||
key_suffix = body[-4:]
|
||||
|
||||
api_key = ApiKey(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
key_suffix=key_suffix,
|
||||
)
|
||||
db.add(api_key)
|
||||
await db.commit()
|
||||
await db.refresh(api_key)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"name": api_key.name,
|
||||
"key": raw_key, # only returned once
|
||||
"key_prefix": key_prefix,
|
||||
"key_suffix": key_suffix,
|
||||
"created_at": api_key.created_at,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def delete_key(db: AsyncSession, user_id: str, key_id: str) -> None:
|
||||
result = await db.execute(
|
||||
select(ApiKey).where(ApiKey.id == key_id, ApiKey.user_id == user_id)
|
||||
)
|
||||
api_key = result.scalar_one_or_none()
|
||||
if not api_key:
|
||||
raise NotFoundError("Key not found")
|
||||
|
||||
api_key.status = "revoked"
|
||||
await db.commit()
|
||||
16
app/services/model_service.py
Normal file
16
app/services/model_service.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import ModelPricing
|
||||
|
||||
|
||||
class ModelService:
|
||||
|
||||
@staticmethod
|
||||
async def list_models(db: AsyncSession) -> list:
|
||||
result = await db.execute(
|
||||
select(ModelPricing)
|
||||
.where(ModelPricing.status == "available")
|
||||
.order_by(ModelPricing.provider, ModelPricing.model_name)
|
||||
)
|
||||
return result.scalars().all()
|
||||
139
app/services/usage_service.py
Normal file
139
app/services/usage_service.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func, cast, Date
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import UsageLog, ApiKey
|
||||
|
||||
|
||||
class UsageService:
|
||||
|
||||
@staticmethod
|
||||
async def summary(db: AsyncSession, user_id: str) -> dict:
|
||||
today_start = datetime.combine(date.today(), datetime.min.time())
|
||||
month_start = today_start.replace(day=1)
|
||||
|
||||
# Today
|
||||
today_row = (await db.execute(
|
||||
select(
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
).where(UsageLog.user_id == user_id, UsageLog.request_time >= today_start)
|
||||
)).one()
|
||||
|
||||
# This month
|
||||
month_row = (await db.execute(
|
||||
select(
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
).where(UsageLog.user_id == user_id, UsageLog.request_time >= month_start)
|
||||
)).one()
|
||||
|
||||
return {
|
||||
"today_tokens": int(today_row[0]),
|
||||
"today_cost": today_row[1],
|
||||
"month_tokens": int(month_row[0]),
|
||||
"month_cost": month_row[1],
|
||||
"total_requests": int(month_row[2]),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def daily(
|
||||
db: AsyncSession, user_id: str,
|
||||
start: Optional[date] = None, end: Optional[date] = None,
|
||||
) -> list:
|
||||
if not start:
|
||||
start = date.today() - timedelta(days=29)
|
||||
if not end:
|
||||
end = date.today()
|
||||
|
||||
day_col = cast(UsageLog.request_time, Date).label("day")
|
||||
result = await db.execute(
|
||||
select(
|
||||
day_col,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
UsageLog.user_id == user_id,
|
||||
cast(UsageLog.request_time, Date) >= start,
|
||||
cast(UsageLog.request_time, Date) <= end,
|
||||
)
|
||||
.group_by(day_col)
|
||||
.order_by(day_col)
|
||||
)
|
||||
return [
|
||||
{"date": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def by_model(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(
|
||||
UsageLog.model,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.where(UsageLog.user_id == user_id)
|
||||
.group_by(UsageLog.model)
|
||||
.order_by(func.sum(UsageLog.cost).desc())
|
||||
)
|
||||
return [
|
||||
{"model": row[0], "total_tokens": int(row[1]), "cost": row[2], "requests": int(row[3])}
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def by_key(db: AsyncSession, user_id: str) -> list:
|
||||
result = await db.execute(
|
||||
select(
|
||||
UsageLog.key_id,
|
||||
ApiKey.name,
|
||||
ApiKey.key_prefix,
|
||||
ApiKey.key_suffix,
|
||||
func.coalesce(func.sum(UsageLog.total_tokens), 0),
|
||||
func.coalesce(func.sum(UsageLog.cost), Decimal("0")),
|
||||
func.count(),
|
||||
)
|
||||
.join(ApiKey, UsageLog.key_id == ApiKey.id)
|
||||
.where(UsageLog.user_id == user_id)
|
||||
.group_by(UsageLog.key_id, ApiKey.name, ApiKey.key_prefix, ApiKey.key_suffix)
|
||||
.order_by(func.sum(UsageLog.cost).desc())
|
||||
)
|
||||
return [
|
||||
{
|
||||
"key_id": row[0], "key_name": row[1] or "",
|
||||
"key_prefix": row[2], "key_suffix": row[3],
|
||||
"total_tokens": int(row[4]), "cost": row[5], "requests": int(row[6]),
|
||||
}
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def logs(
|
||||
db: AsyncSession, user_id: str,
|
||||
page: int = 1, size: int = 20,
|
||||
model: Optional[str] = None,
|
||||
key_id: Optional[str] = None,
|
||||
start: Optional[date] = None,
|
||||
end: Optional[date] = None,
|
||||
) -> list:
|
||||
q = select(UsageLog).where(UsageLog.user_id == user_id)
|
||||
if model:
|
||||
q = q.where(UsageLog.model == model)
|
||||
if key_id:
|
||||
q = q.where(UsageLog.key_id == key_id)
|
||||
if start:
|
||||
q = q.where(UsageLog.request_time >= datetime.combine(start, datetime.min.time()))
|
||||
if end:
|
||||
q = q.where(UsageLog.request_time < datetime.combine(end + timedelta(days=1), datetime.min.time()))
|
||||
|
||||
q = q.order_by(UsageLog.request_time.desc()).offset((page - 1) * size).limit(size)
|
||||
result = await db.execute(q)
|
||||
return result.scalars().all()
|
||||
60
app/services/wallet_service.py
Normal file
60
app/services/wallet_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import User, Transaction
|
||||
from app.core.exceptions import BadRequestError
|
||||
|
||||
# MVP: hardcoded redeem codes → amount mapping
|
||||
REDEEM_CODES = {
|
||||
"SUPERDREAM10": Decimal("10"),
|
||||
"SUPERDREAM50": Decimal("50"),
|
||||
"SUPERDREAM100": Decimal("100"),
|
||||
}
|
||||
|
||||
|
||||
class WalletService:
|
||||
|
||||
@staticmethod
|
||||
async def get_balance(db: AsyncSession, user_id: str) -> Decimal:
|
||||
user = await db.get(User, user_id)
|
||||
return user.balance
|
||||
|
||||
@staticmethod
|
||||
async def redeem_code(db: AsyncSession, user_id: str, code: str) -> Transaction:
|
||||
amount = REDEEM_CODES.get(code.upper())
|
||||
if not amount:
|
||||
raise BadRequestError("无效的兑换码")
|
||||
|
||||
user = await db.get(User, user_id)
|
||||
user.balance += amount
|
||||
new_balance = user.balance
|
||||
|
||||
txn = Transaction(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
type="topup",
|
||||
amount=amount,
|
||||
balance_after=new_balance,
|
||||
reference_id=f"redeem:{code.upper()}",
|
||||
)
|
||||
db.add(txn)
|
||||
await db.commit()
|
||||
await db.refresh(txn)
|
||||
return txn
|
||||
|
||||
@staticmethod
|
||||
async def list_transactions(
|
||||
db: AsyncSession, user_id: str, page: int = 1, size: int = 20
|
||||
) -> list:
|
||||
offset = (page - 1) * size
|
||||
result = await db.execute(
|
||||
select(Transaction)
|
||||
.where(Transaction.user_id == user_id)
|
||||
.order_by(Transaction.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(size)
|
||||
)
|
||||
return result.scalars().all()
|
||||
Reference in New Issue
Block a user