e2d750146c
- Add slowapi rate limiter: 10 req/min for /auth/login, 5 req/min for /auth/register - Return HTTP 429 with Retry-After header when limit is exceeded - Add slowapi to requirements.txt - Add 4 passing tests for rate limit behavior Closes leeworks-agents/SPARC#9 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
559 lines
15 KiB
Python
559 lines
15 KiB
Python
"""FastAPI web service wrapper for SPARC patent analysis.
|
|
|
|
Provides REST API endpoints for analyzing company patent portfolios.
|
|
"""
|
|
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime
|
|
from typing import Annotated, List
|
|
|
|
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel, EmailStr, Field
|
|
from slowapi import Limiter
|
|
from slowapi.errors import RateLimitExceeded
|
|
from slowapi.util import get_remote_address
|
|
|
|
from SPARC import config
|
|
from SPARC.analyzer import CompanyAnalyzer
|
|
from SPARC.auth import (
|
|
TokenResponse,
|
|
UserResponse,
|
|
create_tokens,
|
|
decode_token,
|
|
get_current_admin,
|
|
get_current_user,
|
|
get_db_client,
|
|
)
|
|
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
|
|
|
|
|
# Pydantic models for API
|
|
class CompanyAnalysisResponse(BaseModel):
|
|
"""Response model for single company analysis."""
|
|
|
|
company_name: str
|
|
analysis: str
|
|
patent_count: int
|
|
success: bool
|
|
error: str | None = None
|
|
timestamp: datetime
|
|
|
|
|
|
class BatchAnalysisResponse(BaseModel):
|
|
"""Response model for batch company analysis."""
|
|
|
|
results: list[CompanyAnalysisResponse]
|
|
total_companies: int
|
|
successful: int
|
|
failed: int
|
|
timestamp: datetime
|
|
|
|
|
|
class BatchAnalysisRequest(BaseModel):
|
|
"""Request model for batch company analysis."""
|
|
|
|
companies: list[str] = Field(
|
|
..., min_length=1, max_length=20, description="List of company names to analyze"
|
|
)
|
|
max_workers: int = Field(
|
|
default=3, ge=1, le=5, description="Max concurrent analyses"
|
|
)
|
|
|
|
|
|
class JobStatus(BaseModel):
|
|
"""Status of a background analysis job."""
|
|
|
|
job_id: str
|
|
status: str # "pending", "running", "completed", "failed"
|
|
progress: int # 0-100
|
|
total_companies: int
|
|
completed_companies: int
|
|
result: BatchAnalysisResponse | None = None
|
|
error: str | None = None
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""Health check response."""
|
|
|
|
status: str
|
|
version: str
|
|
timestamp: datetime
|
|
|
|
|
|
# Auth request/response models
|
|
class RegisterRequest(BaseModel):
|
|
"""User registration request."""
|
|
|
|
email: EmailStr
|
|
password: str = Field(..., min_length=8, description="Password (min 8 characters)")
|
|
|
|
|
|
class LoginRequest(BaseModel):
|
|
"""User login request."""
|
|
|
|
email: EmailStr
|
|
password: str
|
|
|
|
|
|
class RefreshRequest(BaseModel):
|
|
"""Token refresh request."""
|
|
|
|
refresh_token: str
|
|
|
|
|
|
class UpdateRoleRequest(BaseModel):
|
|
"""Update user role request."""
|
|
|
|
role: str = Field(..., pattern="^(admin|user)$")
|
|
|
|
|
|
class AnalyticsResponse(BaseModel):
|
|
"""Analytics response model."""
|
|
|
|
total_messages: int
|
|
by_company: List[dict]
|
|
by_type: List[dict]
|
|
period_days: int
|
|
|
|
|
|
# In-memory job storage (for demo; production would use Redis/DB)
|
|
_jobs: dict[str, JobStatus] = {}
|
|
_job_counter = 0
|
|
|
|
|
|
def _convert_result(result: CompanyAnalysisResult) -> CompanyAnalysisResponse:
|
|
"""Convert internal result to API response model."""
|
|
return CompanyAnalysisResponse(
|
|
company_name=result.company_name,
|
|
analysis=result.analysis,
|
|
patent_count=result.patent_count,
|
|
success=result.success,
|
|
error=result.error,
|
|
timestamp=result.timestamp,
|
|
)
|
|
|
|
|
|
def _convert_batch_result(result: BatchAnalysisResult) -> BatchAnalysisResponse:
|
|
"""Convert internal batch result to API response model."""
|
|
return BatchAnalysisResponse(
|
|
results=[_convert_result(r) for r in result.results],
|
|
total_companies=result.total_companies,
|
|
successful=result.successful,
|
|
failed=result.failed,
|
|
timestamp=result.timestamp,
|
|
)
|
|
|
|
|
|
# Global analyzer instance
|
|
_analyzer: CompanyAnalyzer | None = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Initialize resources on startup."""
|
|
global _analyzer
|
|
_analyzer = CompanyAnalyzer()
|
|
yield
|
|
# Cleanup if needed
|
|
_analyzer = None
|
|
|
|
|
|
app = FastAPI(
|
|
title="SPARC API",
|
|
description="Semiconductor Patent & Analytics Report Core - Patent portfolio analysis using AI",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
root_path=config.root_path,
|
|
)
|
|
|
|
# Rate limiter (in-memory storage, suitable for single-instance deployments)
|
|
limiter = Limiter(key_func=get_remote_address)
|
|
app.state.limiter = limiter
|
|
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
|
|
"""Return 429 with Retry-After header when rate limit is exceeded."""
|
|
retry_after = getattr(exc, "retry_after", 60)
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"detail": "Rate limit exceeded. Please try again later."},
|
|
headers={"Retry-After": str(retry_after)},
|
|
)
|
|
|
|
|
|
# Add CORS middleware for React frontend
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:3000", "http://localhost:5173"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ============== Auth Endpoints ==============
|
|
|
|
|
|
@app.post("/auth/register", response_model=UserResponse, tags=["Auth"])
|
|
@limiter.limit("5/minute")
|
|
async def register(request: Request, body: RegisterRequest):
|
|
"""Register a new user.
|
|
|
|
The first registered user automatically becomes an admin.
|
|
"""
|
|
db = get_db_client()
|
|
|
|
# First user becomes admin
|
|
user_count = db.get_user_count()
|
|
role = "admin" if user_count == 0 else "user"
|
|
|
|
user = db.create_user(
|
|
email=body.email,
|
|
password=body.password,
|
|
role=role,
|
|
)
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Email already registered",
|
|
)
|
|
|
|
return UserResponse(
|
|
id=user["id"],
|
|
email=user["email"],
|
|
role=user["role"],
|
|
created_at=user["created_at"],
|
|
)
|
|
|
|
|
|
@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"])
|
|
@limiter.limit("10/minute")
|
|
async def login(request: Request, body: LoginRequest):
|
|
"""Authenticate user and return JWT tokens."""
|
|
db = get_db_client()
|
|
|
|
user = db.authenticate_user(body.email, body.password)
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Invalid email or password",
|
|
)
|
|
|
|
return create_tokens(user["id"], user["email"], user["role"])
|
|
|
|
|
|
@app.post("/auth/refresh", response_model=TokenResponse, tags=["Auth"])
|
|
async def refresh_token(request: RefreshRequest):
|
|
"""Refresh access token using refresh token."""
|
|
payload = decode_token(request.refresh_token)
|
|
|
|
if not payload or payload.type != "refresh":
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Invalid refresh token",
|
|
)
|
|
|
|
db = get_db_client()
|
|
user = db.get_user_by_id(payload.user_id)
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="User not found",
|
|
)
|
|
|
|
return create_tokens(user["id"], user["email"], user["role"])
|
|
|
|
|
|
@app.get("/auth/me", response_model=UserResponse, tags=["Auth"])
|
|
async def get_me(current_user: UserResponse = Depends(get_current_user)):
|
|
"""Get current authenticated user."""
|
|
return current_user
|
|
|
|
|
|
# ============== Admin Endpoints ==============
|
|
|
|
|
|
@app.get("/admin/users", response_model=List[UserResponse], tags=["Admin"])
|
|
async def list_users(
|
|
limit: int = Query(default=100, ge=1, le=1000),
|
|
offset: int = Query(default=0, ge=0),
|
|
_: UserResponse = Depends(get_current_admin),
|
|
):
|
|
"""List all users (admin only)."""
|
|
db = get_db_client()
|
|
users = db.get_all_users(limit=limit, offset=offset)
|
|
|
|
return [
|
|
UserResponse(
|
|
id=u["id"],
|
|
email=u["email"],
|
|
role=u["role"],
|
|
created_at=u["created_at"],
|
|
)
|
|
for u in users
|
|
]
|
|
|
|
|
|
@app.patch("/admin/users/{user_id}/role", response_model=UserResponse, tags=["Admin"])
|
|
async def update_user_role(
|
|
user_id: int,
|
|
request: UpdateRoleRequest,
|
|
current_admin: UserResponse = Depends(get_current_admin),
|
|
):
|
|
"""Update a user's role (admin only)."""
|
|
if user_id == current_admin.id:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot change your own role",
|
|
)
|
|
|
|
db = get_db_client()
|
|
user = db.update_user_role(user_id, request.role)
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="User not found",
|
|
)
|
|
|
|
return UserResponse(
|
|
id=user["id"],
|
|
email=user["email"],
|
|
role=user["role"],
|
|
created_at=user["created_at"],
|
|
)
|
|
|
|
|
|
@app.delete("/admin/users/{user_id}", tags=["Admin"])
|
|
async def delete_user(
|
|
user_id: int,
|
|
current_admin: UserResponse = Depends(get_current_admin),
|
|
):
|
|
"""Delete a user (admin only)."""
|
|
if user_id == current_admin.id:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot delete yourself",
|
|
)
|
|
|
|
db = get_db_client()
|
|
deleted = db.delete_user(user_id)
|
|
|
|
if not deleted:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="User not found",
|
|
)
|
|
|
|
return {"message": "User deleted"}
|
|
|
|
|
|
# ============== Analytics Endpoint ==============
|
|
|
|
|
|
@app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"])
|
|
async def get_analytics(
|
|
days: int = Query(default=30, ge=1, le=365),
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Get analytics data (authenticated users only)."""
|
|
db = get_db_client()
|
|
analytics = db.get_analytics(days=days)
|
|
|
|
return AnalyticsResponse(
|
|
total_messages=analytics["total_messages"],
|
|
by_company=analytics["by_company"],
|
|
by_type=analytics["by_type"],
|
|
period_days=analytics["period_days"],
|
|
)
|
|
|
|
|
|
# ============== System Endpoints ==============
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse, tags=["System"])
|
|
async def health_check():
|
|
"""Check API health status."""
|
|
return HealthResponse(
|
|
status="healthy",
|
|
version="1.0.0",
|
|
timestamp=datetime.now(),
|
|
)
|
|
|
|
|
|
@app.get(
|
|
"/analyze/{company_name}",
|
|
response_model=CompanyAnalysisResponse,
|
|
tags=["Analysis"],
|
|
)
|
|
async def analyze_company(
|
|
company_name: str,
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Analyze a single company's patent portfolio.
|
|
|
|
This endpoint retrieves recent patents for the specified company,
|
|
parses them, and uses AI to generate a comprehensive analysis.
|
|
|
|
Args:
|
|
company_name: Name of the company to analyze (e.g., "nvidia", "intel")
|
|
|
|
Returns:
|
|
Analysis results including patent count, AI insights, and success status
|
|
"""
|
|
if not _analyzer:
|
|
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
|
|
|
result = _analyzer._analyze_company_safe(company_name)
|
|
return _convert_result(result)
|
|
|
|
|
|
@app.post(
|
|
"/analyze/batch",
|
|
response_model=BatchAnalysisResponse,
|
|
tags=["Analysis"],
|
|
)
|
|
async def analyze_companies_batch(
|
|
request: BatchAnalysisRequest,
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Analyze multiple companies' patent portfolios.
|
|
|
|
Processes companies concurrently for improved performance.
|
|
Limited to 20 companies per request.
|
|
|
|
Args:
|
|
request: List of company names and optional worker count
|
|
|
|
Returns:
|
|
Batch results with individual company analyses and summary statistics
|
|
"""
|
|
if not _analyzer:
|
|
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
|
|
|
result = _analyzer.analyze_companies(
|
|
companies=request.companies,
|
|
max_workers=request.max_workers,
|
|
)
|
|
return _convert_batch_result(result)
|
|
|
|
|
|
def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
|
|
"""Background task for batch analysis."""
|
|
global _jobs, _analyzer
|
|
|
|
if not _analyzer:
|
|
_jobs[job_id].status = "failed"
|
|
_jobs[job_id].error = "Analyzer not initialized"
|
|
return
|
|
|
|
_jobs[job_id].status = "running"
|
|
|
|
def progress_callback(company: str, completed: int, total: int):
|
|
_jobs[job_id].completed_companies = completed
|
|
_jobs[job_id].progress = int((completed / total) * 100)
|
|
|
|
try:
|
|
result = _analyzer.analyze_companies(
|
|
companies=companies,
|
|
max_workers=max_workers,
|
|
progress_callback=progress_callback,
|
|
)
|
|
_jobs[job_id].status = "completed"
|
|
_jobs[job_id].progress = 100
|
|
_jobs[job_id].result = _convert_batch_result(result)
|
|
except Exception as e:
|
|
_jobs[job_id].status = "failed"
|
|
_jobs[job_id].error = str(e)
|
|
|
|
|
|
@app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"])
|
|
async def analyze_companies_async(
|
|
request: BatchAnalysisRequest,
|
|
background_tasks: BackgroundTasks,
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Start an asynchronous batch analysis job.
|
|
|
|
Returns immediately with a job ID that can be used to poll for status.
|
|
Useful for large batch analyses that may take a long time.
|
|
|
|
Args:
|
|
request: List of company names and optional worker count
|
|
|
|
Returns:
|
|
Job status with job_id for polling
|
|
"""
|
|
global _job_counter
|
|
|
|
_job_counter += 1
|
|
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
|
|
|
_jobs[job_id] = JobStatus(
|
|
job_id=job_id,
|
|
status="pending",
|
|
progress=0,
|
|
total_companies=len(request.companies),
|
|
completed_companies=0,
|
|
)
|
|
|
|
background_tasks.add_task(
|
|
_run_batch_job, job_id, request.companies, request.max_workers
|
|
)
|
|
|
|
return _jobs[job_id]
|
|
|
|
|
|
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
|
|
async def get_job_status(
|
|
job_id: str,
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Get the status of a background analysis job.
|
|
|
|
Args:
|
|
job_id: The job ID returned from the async batch endpoint
|
|
|
|
Returns:
|
|
Current job status including progress and results when complete
|
|
"""
|
|
if job_id not in _jobs:
|
|
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
|
|
|
return _jobs[job_id]
|
|
|
|
|
|
@app.get("/jobs", response_model=list[JobStatus], tags=["Jobs"])
|
|
async def list_jobs(
|
|
status: Annotated[
|
|
str | None,
|
|
Query(description="Filter by status: pending, running, completed, failed"),
|
|
] = None,
|
|
limit: Annotated[int, Query(ge=1, le=100)] = 10,
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""List all analysis jobs.
|
|
|
|
Args:
|
|
status: Optional filter by job status
|
|
limit: Maximum number of jobs to return (default 10, max 100)
|
|
|
|
Returns:
|
|
List of job statuses
|
|
"""
|
|
jobs = list(_jobs.values())
|
|
|
|
if status:
|
|
jobs = [j for j in jobs if j.status == status]
|
|
|
|
# Return most recent first
|
|
jobs.sort(key=lambda j: j.job_id, reverse=True)
|
|
|
|
return jobs[:limit]
|