Files
SPARC/SPARC/api.py
T
agent-company 3b6411869d feat: add cursor-based pagination to /jobs endpoint
Add a cursor query parameter to GET /jobs and return a next_cursor
field in the response envelope. Existing clients using only limit
continue to work without modification. The cursor is an opaque token
encoding created_at and job_id for stable keyset pagination.

Closes leeworks-agents/SPARC#25

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:19:01 +00:00

635 lines
17 KiB
Python

"""FastAPI web service wrapper for SPARC patent analysis.
Provides REST API endpoints for analyzing company patent portfolios.
"""
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Annotated, List
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, EmailStr, Field
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from SPARC import config
from SPARC.analyzer import CompanyAnalyzer
from SPARC.auth import (
TokenResponse,
UserResponse,
check_jwt_secret,
close_db_client,
create_tokens,
decode_token,
get_current_admin,
get_current_user,
get_db_client,
init_db_client,
)
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
# Pydantic models for API
class CompanyAnalysisResponse(BaseModel):
"""Response model for single company analysis."""
company_name: str
analysis: str
patent_count: int
success: bool
error: str | None = None
timestamp: datetime
class BatchAnalysisResponse(BaseModel):
"""Response model for batch company analysis."""
results: list[CompanyAnalysisResponse]
total_companies: int
successful: int
failed: int
timestamp: datetime
class BatchAnalysisRequest(BaseModel):
"""Request model for batch company analysis."""
companies: list[str] = Field(
..., min_length=1, max_length=20, description="List of company names to analyze"
)
max_workers: int = Field(
default=3, ge=1, le=5, description="Max concurrent analyses"
)
class JobStatus(BaseModel):
"""Status of a background analysis job."""
job_id: str
status: str # "pending", "running", "completed", "failed"
progress: int # 0-100
total_companies: int
completed_companies: int
result: BatchAnalysisResponse | None = None
error: str | None = None
class PaginatedJobsResponse(BaseModel):
"""Paginated response for job listings."""
items: list["JobStatus"]
next_cursor: str | None = None
class HealthResponse(BaseModel):
"""Health check response."""
status: str
version: str
timestamp: datetime
# Auth request/response models
class RegisterRequest(BaseModel):
"""User registration request."""
email: EmailStr
password: str = Field(..., min_length=8, description="Password (min 8 characters)")
class LoginRequest(BaseModel):
"""User login request."""
email: EmailStr
password: str
class RefreshRequest(BaseModel):
"""Token refresh request."""
refresh_token: str
class UpdateRoleRequest(BaseModel):
"""Update user role request."""
role: str = Field(..., pattern="^(admin|user)$")
class AnalyticsResponse(BaseModel):
"""Analytics response model."""
total_messages: int
by_company: List[dict]
by_type: List[dict]
period_days: int
# Job counter for generating unique IDs (the actual state is in PostgreSQL)
_job_counter = 0
def _convert_result(result: CompanyAnalysisResult) -> CompanyAnalysisResponse:
"""Convert internal result to API response model."""
return CompanyAnalysisResponse(
company_name=result.company_name,
analysis=result.analysis,
patent_count=result.patent_count,
success=result.success,
error=result.error,
timestamp=result.timestamp,
)
def _convert_batch_result(result: BatchAnalysisResult) -> BatchAnalysisResponse:
"""Convert internal batch result to API response model."""
return BatchAnalysisResponse(
results=[_convert_result(r) for r in result.results],
total_companies=result.total_companies,
successful=result.successful,
failed=result.failed,
timestamp=result.timestamp,
)
# Global analyzer instance
_analyzer: CompanyAnalyzer | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize resources on startup, clean up on shutdown."""
global _analyzer
check_jwt_secret()
init_db_client()
_analyzer = CompanyAnalyzer()
# Mark any jobs that were running/pending before the restart as failed
from SPARC.database import DatabaseClient
_db = DatabaseClient(config.database_url)
_db.connect()
_db.initialize_schema()
stale = _db.mark_stale_jobs_failed()
if stale:
import logging
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
_db.close()
yield
# Cleanup
_analyzer = None
close_db_client()
app = FastAPI(
title="SPARC API",
description="Semiconductor Patent & Analytics Report Core - Patent portfolio analysis using AI",
version="1.0.0",
lifespan=lifespan,
root_path=config.root_path,
)
# Rate limiter (in-memory storage, suitable for single-instance deployments)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return 429 with Retry-After header when rate limit is exceeded."""
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded. Please try again later."},
headers={"Retry-After": str(retry_after)},
)
# Add CORS middleware for React frontend
app.add_middleware(
CORSMiddleware,
allow_origins=config.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============== Auth Endpoints ==============
@app.post("/auth/register", response_model=UserResponse, tags=["Auth"])
@limiter.limit("5/minute")
async def register(request: Request, body: RegisterRequest):
"""Register a new user.
The first registered user automatically becomes an admin.
"""
db = get_db_client()
# First user becomes admin
user_count = db.get_user_count()
role = "admin" if user_count == 0 else "user"
user = db.create_user(
email=body.email,
password=body.password,
role=role,
)
if not user:
raise HTTPException(
status_code=400,
detail="Email already registered",
)
return UserResponse(
id=user["id"],
email=user["email"],
role=user["role"],
created_at=user["created_at"],
)
@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"])
@limiter.limit("10/minute")
async def login(request: Request, body: LoginRequest):
"""Authenticate user and return JWT tokens."""
db = get_db_client()
user = db.authenticate_user(body.email, body.password)
if not user:
raise HTTPException(
status_code=401,
detail="Invalid email or password",
)
return create_tokens(user["id"], user["email"], user["role"])
@app.post("/auth/refresh", response_model=TokenResponse, tags=["Auth"])
async def refresh_token(request: RefreshRequest):
"""Refresh access token using refresh token."""
payload = decode_token(request.refresh_token)
if not payload or payload.type != "refresh":
raise HTTPException(
status_code=401,
detail="Invalid refresh token",
)
db = get_db_client()
user = db.get_user_by_id(payload.user_id)
if not user:
raise HTTPException(
status_code=401,
detail="User not found",
)
return create_tokens(user["id"], user["email"], user["role"])
@app.get("/auth/me", response_model=UserResponse, tags=["Auth"])
async def get_me(current_user: UserResponse = Depends(get_current_user)):
"""Get current authenticated user."""
return current_user
# ============== Admin Endpoints ==============
@app.get("/admin/users", response_model=List[UserResponse], tags=["Admin"])
async def list_users(
limit: int = Query(default=100, ge=1, le=1000),
offset: int = Query(default=0, ge=0),
_: UserResponse = Depends(get_current_admin),
):
"""List all users (admin only)."""
db = get_db_client()
users = db.get_all_users(limit=limit, offset=offset)
return [
UserResponse(
id=u["id"],
email=u["email"],
role=u["role"],
created_at=u["created_at"],
)
for u in users
]
@app.patch("/admin/users/{user_id}/role", response_model=UserResponse, tags=["Admin"])
async def update_user_role(
user_id: int,
request: UpdateRoleRequest,
current_admin: UserResponse = Depends(get_current_admin),
):
"""Update a user's role (admin only)."""
if user_id == current_admin.id:
raise HTTPException(
status_code=400,
detail="Cannot change your own role",
)
db = get_db_client()
user = db.update_user_role(user_id, request.role)
if not user:
raise HTTPException(
status_code=404,
detail="User not found",
)
return UserResponse(
id=user["id"],
email=user["email"],
role=user["role"],
created_at=user["created_at"],
)
@app.delete("/admin/users/{user_id}", tags=["Admin"])
async def delete_user(
user_id: int,
current_admin: UserResponse = Depends(get_current_admin),
):
"""Delete a user (admin only)."""
if user_id == current_admin.id:
raise HTTPException(
status_code=400,
detail="Cannot delete yourself",
)
db = get_db_client()
deleted = db.delete_user(user_id)
if not deleted:
raise HTTPException(
status_code=404,
detail="User not found",
)
return {"message": "User deleted"}
# ============== Analytics Endpoint ==============
@app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"])
async def get_analytics(
days: int = Query(default=30, ge=1, le=365),
_: UserResponse = Depends(get_current_user),
):
"""Get analytics data (authenticated users only)."""
db = get_db_client()
analytics = db.get_analytics(days=days)
return AnalyticsResponse(
total_messages=analytics["total_messages"],
by_company=analytics["by_company"],
by_type=analytics["by_type"],
period_days=analytics["period_days"],
)
# ============== System Endpoints ==============
@app.get("/health", response_model=HealthResponse, tags=["System"])
async def health_check():
"""Check API health status."""
return HealthResponse(
status="healthy",
version="1.0.0",
timestamp=datetime.now(),
)
@app.get(
"/analyze/{company_name}",
response_model=CompanyAnalysisResponse,
tags=["Analysis"],
)
async def analyze_company(
company_name: str,
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single company's patent portfolio.
This endpoint retrieves recent patents for the specified company,
parses them, and uses AI to generate a comprehensive analysis.
Args:
company_name: Name of the company to analyze (e.g., "nvidia", "intel")
Returns:
Analysis results including patent count, AI insights, and success status
"""
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer._analyze_company_safe(company_name)
return _convert_result(result)
@app.post(
"/analyze/batch",
response_model=BatchAnalysisResponse,
tags=["Analysis"],
)
async def analyze_companies_batch(
request: BatchAnalysisRequest,
_: UserResponse = Depends(get_current_user),
):
"""Analyze multiple companies' patent portfolios.
Processes companies concurrently for improved performance.
Limited to 20 companies per request.
Args:
request: List of company names and optional worker count
Returns:
Batch results with individual company analyses and summary statistics
"""
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer.analyze_companies(
companies=request.companies,
max_workers=request.max_workers,
)
return _convert_batch_result(result)
def _get_job_db() -> "DatabaseClient":
"""Get a DatabaseClient for job persistence."""
from SPARC.database import DatabaseClient
db = DatabaseClient(config.database_url)
return db
def _job_row_to_status(row: dict) -> JobStatus:
"""Convert a database job row to a JobStatus model."""
import json as _json
result = None
if row.get("result_json"):
result_data = row["result_json"]
if isinstance(result_data, str):
result_data = _json.loads(result_data)
result = BatchAnalysisResponse(**result_data)
return JobStatus(
job_id=row["job_id"],
status=row["status"],
progress=row["progress"],
total_companies=row["total_companies"],
completed_companies=row["completed_companies"],
result=result,
error=row.get("error"),
)
def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
"""Background task for batch analysis."""
import json as _json
global _analyzer
db = _get_job_db()
if not _analyzer:
db.update_job(job_id, status="failed", error="Analyzer not initialized")
return
db.update_job(job_id, status="running")
def progress_callback(company: str, completed: int, total: int):
db.update_job(
job_id,
completed_companies=completed,
progress=int((completed / total) * 100),
)
try:
result = _analyzer.analyze_companies(
companies=companies,
max_workers=max_workers,
progress_callback=progress_callback,
)
batch_response = _convert_batch_result(result)
db.update_job(
job_id,
status="completed",
progress=100,
result_json=_json.dumps(batch_response.model_dump(), default=str),
)
except Exception as e:
db.update_job(job_id, status="failed", error=str(e))
@app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"])
async def analyze_companies_async(
request: BatchAnalysisRequest,
background_tasks: BackgroundTasks,
_: UserResponse = Depends(get_current_user),
):
"""Start an asynchronous batch analysis job.
Returns immediately with a job ID that can be used to poll for status.
Useful for large batch analyses that may take a long time.
Args:
request: List of company names and optional worker count
Returns:
Job status with job_id for polling
"""
global _job_counter
_job_counter += 1
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
db = _get_job_db()
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies))
background_tasks.add_task(
_run_batch_job, job_id, request.companies, request.max_workers
)
return _job_row_to_status(job_row)
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
async def get_job_status(
job_id: str,
_: UserResponse = Depends(get_current_user),
):
"""Get the status of a background analysis job.
Args:
job_id: The job ID returned from the async batch endpoint
Returns:
Current job status including progress and results when complete
"""
db = _get_job_db()
job_row = db.get_job(job_id)
if not job_row:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
return _job_row_to_status(job_row)
@app.get("/jobs", response_model=PaginatedJobsResponse, tags=["Jobs"])
async def list_jobs(
status: Annotated[
str | None,
Query(description="Filter by status: pending, running, completed, failed"),
] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 10,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_user),
):
"""List analysis jobs with cursor-based pagination.
Pass ``limit`` to control page size. The response includes a ``next_cursor``
field; pass it back as the ``cursor`` query parameter to fetch the next page.
When ``next_cursor`` is ``null``, there are no more results.
Existing clients that use only ``limit`` (without ``cursor``) continue to
work without modification.
Args:
status: Optional filter by job status
limit: Maximum number of jobs to return (default 10, max 100)
cursor: Opaque pagination cursor from a previous response
Returns:
Paginated list of job statuses
"""
db = _get_job_db()
# Fetch one extra to determine if there is a next page
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor)
has_next = len(job_rows) > limit
if has_next:
job_rows = job_rows[:limit]
items = [_job_row_to_status(row) for row in job_rows]
next_cursor = None
if has_next and job_rows:
last = job_rows[-1]
created = last["created_at"]
ts = created.isoformat() if hasattr(created, "isoformat") else str(created)
next_cursor = f"{ts}|{last['job_id']}"
return PaginatedJobsResponse(items=items, next_cursor=next_cursor)