Files
SPARC/SPARC/api.py
T
agent-company 3d8922366e Add user-level API key generation for programmatic access
- Add api_keys table (id, user_id, key_hash, label, created_at) to schema
- Add POST /auth/apikeys to generate 32-byte hex API keys (bcrypt-hashed)
- Add GET /auth/apikeys to list active key metadata (no secrets)
- Add DELETE /auth/apikeys/{key_id} to revoke keys
- Extend get_current_user to accept either JWT Bearer or X-API-Key header
- Plaintext key returned only at creation time
- 16 new tests covering creation, listing, revocation, auth, and full flow

Closes leeworks-agents/SPARC#1673

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 15:18:34 +00:00

1527 lines
46 KiB
Python

"""FastAPI web service wrapper for SPARC patent analysis.
Provides REST API endpoints for analyzing company patent portfolios.
"""
from __future__ import annotations
from collections import deque
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, List
if TYPE_CHECKING:
from SPARC.database import DatabaseClient
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Path, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, EmailStr, Field, StringConstraints
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from SPARC import config
from SPARC.analyzer import CompanyAnalyzer
from SPARC.auth import (
TokenResponse,
UserResponse,
check_jwt_secret,
close_db_client,
create_tokens,
decode_token,
generate_api_key,
get_current_admin,
get_current_user,
get_db_client,
hash_api_key,
init_db_client,
)
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
# Validated company name type: 2-100 chars, alphanumeric + spaces/hyphens/ampersands/periods only.
CompanyName = Annotated[
str,
StringConstraints(
min_length=2,
max_length=100,
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
),
]
# Pydantic models for API
class CompanyAnalysisResponse(BaseModel):
"""Response model for single company analysis."""
company_name: str
analysis: str
patent_count: int
success: bool
error: str | None = None
model: str | None = None
timestamp: datetime
class BatchAnalysisResponse(BaseModel):
"""Response model for batch company analysis."""
results: list[CompanyAnalysisResponse]
total_companies: int
successful: int
failed: int
timestamp: datetime
class CompanyAnalysisRequest(BaseModel):
"""Request model for single company analysis with optional model selection."""
model: str | None = Field(
default=None,
description="LLM model to use (e.g. 'anthropic/claude-3.5-sonnet', 'openai/gpt-4o'). Defaults to server config.",
)
class BatchAnalysisRequest(BaseModel):
"""Request model for batch company analysis."""
companies: list[CompanyName] = Field(
..., min_length=1, max_length=20, description="List of company names to analyze"
)
max_workers: int = Field(
default=3, ge=1, le=5, description="Max concurrent analyses"
)
model: str | None = Field(
default=None,
description="LLM model to use for all analyses in this batch. Defaults to server config.",
)
class JobStatus(BaseModel):
"""Status of a background analysis job."""
job_id: str
status: str # "pending", "running", "completed", "failed"
progress: int # 0-100
total_companies: int
completed_companies: int
result: BatchAnalysisResponse | None = None
error: str | None = None
class AnalysisRecord(BaseModel):
"""A single stored analysis result."""
id: int
company_name: str | None = None
analysis_type: str | None = None
model: str | None = None
response: str | None = None
timestamp: datetime | None = None
class PaginatedAnalysisResponse(BaseModel):
"""Paginated response for analysis result listings."""
items: list[AnalysisRecord]
next_cursor: str | None = None
class PaginatedJobsResponse(BaseModel):
"""Paginated response for job listings."""
items: list["JobStatus"]
next_cursor: str | None = None
class HealthResponse(BaseModel):
"""Health check response."""
status: str
version: str
timestamp: datetime
# Historical diff models
class AnalysisDiffResponse(BaseModel):
"""Response model for diffing two analysis runs of the same company."""
company_name: str
from_id: int
to_id: int
from_timestamp: datetime
to_timestamp: datetime
patent_count_delta: int
added_patents: list[str]
removed_patents: list[str]
changed_fields: dict[str, dict]
summary: str
class CompanyAnalysisHistoryItem(BaseModel):
"""A summary item from a company's analysis history."""
id: int
analysis_type: str | None = None
model: str | None = None
timestamp: datetime
# Auth request/response models
class RegisterRequest(BaseModel):
"""User registration request."""
email: EmailStr
password: str = Field(..., min_length=8, description="Password (min 8 characters)")
class LoginRequest(BaseModel):
"""User login request."""
email: EmailStr
password: str
class RefreshRequest(BaseModel):
"""Token refresh request."""
refresh_token: str
class UpdateRoleRequest(BaseModel):
"""Update user role request."""
role: str = Field(..., pattern="^(admin|user)$")
class AnalyticsResponse(BaseModel):
"""Analytics response model."""
total_messages: int
by_company: List[dict]
by_type: List[dict]
period_days: int
# Job counter for generating unique IDs (the actual state is in PostgreSQL)
_job_counter = 0
def _convert_result(result: CompanyAnalysisResult) -> CompanyAnalysisResponse:
"""Convert internal result to API response model."""
return CompanyAnalysisResponse(
company_name=result.company_name,
analysis=result.analysis,
patent_count=result.patent_count,
success=result.success,
error=result.error,
model=result.model,
timestamp=result.timestamp,
)
def _convert_batch_result(result: BatchAnalysisResult) -> BatchAnalysisResponse:
"""Convert internal batch result to API response model."""
return BatchAnalysisResponse(
results=[_convert_result(r) for r in result.results],
total_companies=result.total_companies,
successful=result.successful,
failed=result.failed,
timestamp=result.timestamp,
)
# Global analyzer instance
_analyzer: CompanyAnalyzer | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize resources on startup, clean up on shutdown."""
global _analyzer
check_jwt_secret()
init_db_client()
_analyzer = CompanyAnalyzer()
# Mark any jobs that were running/pending before the restart as failed
from SPARC.database import DatabaseClient
_db = DatabaseClient(config.database_url)
_db.connect()
_db.initialize_schema()
stale = _db.mark_stale_jobs_failed()
if stale:
import logging
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
_db.close()
# Start scheduled analysis if tracked companies are configured
from SPARC.scheduler import start_scheduler
start_scheduler()
yield
# Cleanup
_analyzer = None
close_db_client()
app = FastAPI(
title="SPARC API",
description="Semiconductor Patent & Analytics Report Core - Patent portfolio analysis using AI",
version="1.0.0",
lifespan=lifespan,
root_path=config.root_path,
)
# Rate limiter (in-memory storage, suitable for single-instance deployments)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
# In-memory rate limit statistics
_rate_limit_stats: dict[str, dict] = {}
# Time-series log of rejected requests (capped to last 24 h worth of entries).
_rejected_log: deque[dict] = deque(maxlen=100_000)
def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None:
"""Record a request against a rate-limited endpoint."""
key = endpoint
if key not in _rate_limit_stats:
_rate_limit_stats[key] = {
"endpoint": endpoint,
"total_requests": 0,
"rejected_requests": 0,
"by_ip": {},
}
_rate_limit_stats[key]["total_requests"] += 1
if rejected:
_rate_limit_stats[key]["rejected_requests"] += 1
_rejected_log.append({
"endpoint": endpoint,
"ip": ip,
"timestamp": datetime.now(timezone.utc).isoformat(),
})
ip_stats = _rate_limit_stats[key].setdefault("by_ip", {})
if ip not in ip_stats:
ip_stats[ip] = {"total": 0, "rejected": 0}
ip_stats[ip]["total"] += 1
if rejected:
ip_stats[ip]["rejected"] += 1
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return 429 with Retry-After header when rate limit is exceeded."""
endpoint = request.url.path
ip = get_remote_address(request)
_track_rate_limit_request(endpoint, ip, rejected=True)
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded. Please try again later."},
headers={"Retry-After": str(retry_after)},
)
# Add CORS middleware for React frontend
app.add_middleware(
CORSMiddleware,
allow_origins=config.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============== Auth Endpoints ==============
@app.post("/auth/register", response_model=UserResponse, tags=["Auth"])
@limiter.limit("5/minute")
async def register(request: Request, body: RegisterRequest):
"""Register a new user.
The first registered user automatically becomes an admin.
"""
_track_rate_limit_request("/auth/register", get_remote_address(request))
db = get_db_client()
# First user becomes admin
user_count = db.get_user_count()
role = "admin" if user_count == 0 else "user"
user = db.create_user(
email=body.email,
password=body.password,
role=role,
)
if not user:
raise HTTPException(
status_code=400,
detail="Email already registered",
)
return UserResponse(
id=user["id"],
email=user["email"],
role=user["role"],
created_at=user["created_at"],
)
@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"])
@limiter.limit("10/minute")
async def login(request: Request, body: LoginRequest):
"""Authenticate user and return JWT tokens."""
_track_rate_limit_request("/auth/login", get_remote_address(request))
db = get_db_client()
user = db.authenticate_user(body.email, body.password)
if not user:
raise HTTPException(
status_code=401,
detail="Invalid email or password",
)
return create_tokens(user["id"], user["email"], user["role"])
@app.post("/auth/refresh", response_model=TokenResponse, tags=["Auth"])
async def refresh_token(request: RefreshRequest):
"""Refresh access token using refresh token."""
payload = decode_token(request.refresh_token)
if not payload or payload.type != "refresh":
raise HTTPException(
status_code=401,
detail="Invalid refresh token",
)
db = get_db_client()
user = db.get_user_by_id(payload.user_id)
if not user:
raise HTTPException(
status_code=401,
detail="User not found",
)
return create_tokens(user["id"], user["email"], user["role"])
@app.get("/auth/me", response_model=UserResponse, tags=["Auth"])
async def get_me(current_user: UserResponse = Depends(get_current_user)):
"""Get current authenticated user."""
return current_user
# ============== API Key Endpoints ==============
class CreateApiKeyRequest(BaseModel):
"""Request to create a new API key."""
label: str | None = Field(default=None, max_length=100, description="Optional label for the key")
class ApiKeyResponse(BaseModel):
"""Response after creating an API key (includes plaintext key)."""
id: int
key: str # plaintext key, shown only at creation time
label: str | None = None
created_at: datetime
class ApiKeyInfo(BaseModel):
"""API key metadata (no secret)."""
id: int
label: str | None = None
created_at: datetime
@app.post("/auth/apikeys", response_model=ApiKeyResponse, tags=["Auth"])
async def create_api_key_endpoint(
body: CreateApiKeyRequest | None = None,
current_user: UserResponse = Depends(get_current_user),
):
"""Generate a new API key for the authenticated user.
The plaintext key is returned **only once** in the response.
Store it securely; it cannot be retrieved again.
"""
plaintext_key = generate_api_key()
key_hash = hash_api_key(plaintext_key)
db = get_db_client()
label = body.label if body else None
row = db.create_api_key(
user_id=current_user.id,
key_hash=key_hash,
label=label,
)
return ApiKeyResponse(
id=row["id"],
key=plaintext_key,
label=row["label"],
created_at=row["created_at"],
)
@app.get("/auth/apikeys", response_model=list[ApiKeyInfo], tags=["Auth"])
async def list_api_keys_endpoint(
current_user: UserResponse = Depends(get_current_user),
):
"""List active API key IDs and labels for the authenticated user.
Does **not** return the secret keys.
"""
db = get_db_client()
keys = db.list_api_keys(current_user.id)
return [ApiKeyInfo(**k) for k in keys]
@app.delete("/auth/apikeys/{key_id}", tags=["Auth"])
async def revoke_api_key_endpoint(
key_id: int,
current_user: UserResponse = Depends(get_current_user),
):
"""Revoke (delete) an API key by its ID.
The key must belong to the authenticated user.
"""
db = get_db_client()
deleted = db.delete_api_key(key_id, current_user.id)
if not deleted:
raise HTTPException(status_code=404, detail="API key not found")
return {"message": "API key revoked"}
# ============== Admin Endpoints ==============
@app.get("/admin/users", response_model=List[UserResponse], tags=["Admin"])
async def list_users(
limit: int = Query(default=100, ge=1, le=1000),
offset: int = Query(default=0, ge=0),
_: UserResponse = Depends(get_current_admin),
):
"""List all users (admin only)."""
db = get_db_client()
users = db.get_all_users(limit=limit, offset=offset)
return [
UserResponse(
id=u["id"],
email=u["email"],
role=u["role"],
created_at=u["created_at"],
)
for u in users
]
@app.patch("/admin/users/{user_id}/role", response_model=UserResponse, tags=["Admin"])
async def update_user_role(
user_id: int,
request: UpdateRoleRequest,
current_admin: UserResponse = Depends(get_current_admin),
):
"""Update a user's role (admin only)."""
if user_id == current_admin.id:
raise HTTPException(
status_code=400,
detail="Cannot change your own role",
)
db = get_db_client()
user = db.update_user_role(user_id, request.role)
if not user:
raise HTTPException(
status_code=404,
detail="User not found",
)
return UserResponse(
id=user["id"],
email=user["email"],
role=user["role"],
created_at=user["created_at"],
)
@app.delete("/admin/users/{user_id}", tags=["Admin"])
async def delete_user(
user_id: int,
current_admin: UserResponse = Depends(get_current_admin),
):
"""Delete a user (admin only)."""
if user_id == current_admin.id:
raise HTTPException(
status_code=400,
detail="Cannot delete yourself",
)
db = get_db_client()
deleted = db.delete_user(user_id)
if not deleted:
raise HTTPException(
status_code=404,
detail="User not found",
)
return {"message": "User deleted"}
# ============== Tracked Companies Endpoints ==============
class TrackCompanyRequest(BaseModel):
"""Request to add a company to tracking."""
company_name: CompanyName = Field(...)
@app.get("/admin/tracked", tags=["Admin"])
async def list_tracked_companies(
_: UserResponse = Depends(get_current_admin),
):
"""List all tracked companies (admin only)."""
db = get_db_client()
return db.list_tracked_companies()
@app.post("/admin/tracked", tags=["Admin"])
async def add_tracked_company(
request: TrackCompanyRequest,
_: UserResponse = Depends(get_current_admin),
):
"""Add a company to the tracked list (admin only)."""
db = get_db_client()
result = db.add_tracked_company(request.company_name)
if not result:
raise HTTPException(status_code=409, detail="Company already tracked")
return result
@app.delete("/admin/tracked/{company_name}", tags=["Admin"])
async def remove_tracked_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_admin),
):
"""Remove a company from the tracked list (admin only)."""
db = get_db_client()
removed = db.remove_tracked_company(company_name)
if not removed:
raise HTTPException(status_code=404, detail="Company not found in tracking list")
return {"message": f"Stopped tracking {company_name}"}
@app.get("/admin/rate-limits", tags=["Admin"])
async def get_rate_limit_stats(
_: UserResponse = Depends(get_current_admin),
):
"""Get rate limit status and usage statistics (admin only).
Returns current rate limit configuration and request statistics
for all rate-limited endpoints, including per-IP breakdown and
a time-series of throttled (rejected) requests in the last 24 hours.
Returns:
Rate limit stats per endpoint, per-IP breakdown, and throttled
request history bucketed by hour.
"""
rate_limits_config = {
"/auth/register": {"limit": "5/minute"},
"/auth/login": {"limit": "10/minute"},
}
results = []
for endpoint, conf in rate_limits_config.items():
stats = _rate_limit_stats.get(endpoint, {})
by_ip_raw = stats.get("by_ip", {})
by_ip = [
{"ip": ip, "total": counts["total"], "rejected": counts["rejected"]}
for ip, counts in by_ip_raw.items()
]
results.append({
"endpoint": endpoint,
"limit": conf["limit"],
"total_requests": stats.get("total_requests", 0),
"rejected_requests": stats.get("rejected_requests", 0),
"by_ip": by_ip,
})
# Build hourly buckets of throttled requests for the last 24 hours
now = datetime.now(timezone.utc)
cutoff = now - timedelta(hours=24)
hourly_buckets: dict[str, int] = {}
throttled_24h = 0
for entry in _rejected_log:
ts_str = entry["timestamp"]
try:
ts = datetime.fromisoformat(ts_str)
except (ValueError, TypeError):
continue
if ts >= cutoff:
throttled_24h += 1
bucket = ts.strftime("%Y-%m-%dT%H:00:00Z")
hourly_buckets[bucket] = hourly_buckets.get(bucket, 0) + 1
throttled_over_time = [
{"timestamp": k, "count": v}
for k, v in sorted(hourly_buckets.items())
]
return {
"rate_limits": results,
"throttled_24h": throttled_24h,
"throttled_over_time": throttled_over_time,
}
@app.get("/admin/alerts", tags=["Admin"])
async def list_alerts(
limit: int = Query(default=50, ge=1, le=200),
_: UserResponse = Depends(get_current_admin),
):
"""List recent alerts from scheduled analysis (admin only)."""
db = get_db_client()
return db.list_alerts(limit=limit)
# ============== Analytics Endpoint ==============
@app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"])
async def get_analytics(
days: int = Query(default=30, ge=1, le=365),
_: UserResponse = Depends(get_current_user),
):
"""Get analytics data (authenticated users only)."""
db = get_db_client()
analytics = db.get_analytics(days=days)
return AnalyticsResponse(
total_messages=analytics["total_messages"],
by_company=analytics["by_company"],
by_type=analytics["by_type"],
period_days=analytics["period_days"],
)
# ============== Model Selection Endpoints ==============
# Supported models via OpenRouter
SUPPORTED_MODELS = [
{"id": "anthropic/claude-3.5-sonnet", "name": "Claude 3.5 Sonnet", "provider": "Anthropic"},
{"id": "openai/gpt-4o", "name": "GPT-4o", "provider": "OpenAI"},
{"id": "openai/gpt-4o-mini", "name": "GPT-4o Mini", "provider": "OpenAI"},
{"id": "google/gemini-pro-1.5", "name": "Gemini Pro 1.5", "provider": "Google"},
{"id": "meta-llama/llama-3.1-70b-instruct", "name": "Llama 3.1 70B", "provider": "Meta"},
]
_SUPPORTED_MODEL_IDS = {m["id"] for m in SUPPORTED_MODELS}
def _validate_model(model: str | None) -> None:
"""Raise HTTP 400 if *model* is not in the supported allow-list."""
if model is not None and model not in _SUPPORTED_MODEL_IDS:
raise HTTPException(
status_code=400,
detail=(
f"Unsupported model '{model}'. "
f"Supported models: {', '.join(sorted(_SUPPORTED_MODEL_IDS))}"
),
)
@app.get("/models", tags=["System"])
async def list_models():
"""List supported LLM models for analysis.
Returns the available models that can be passed as the `model` field
in analysis requests. The default model is determined by the `MODEL`
environment variable on the server.
"""
return {
"models": SUPPORTED_MODELS,
"default": config.model,
}
@app.get("/analytics/trends", tags=["Analytics"])
async def get_analytics_trends(
days: int = Query(default=90, ge=7, le=365),
_: UserResponse = Depends(get_current_user),
):
"""Get trend data for patent analysis over time.
Returns two datasets:
- ``by_month``: analysis count per company per month
- ``by_type_over_time``: analysis type distribution per month
Args:
days: Number of days to look back (default 90)
Returns:
Trend data suitable for time-series and distribution charts
"""
db = get_db_client()
with db.get_conn() as conn:
with conn.cursor() as cur:
# Analyses per company per month
cur.execute(
"""
SELECT
TO_CHAR(timestamp, 'YYYY-MM') AS month,
company_name,
COUNT(*) AS count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
AND is_cached = FALSE
AND company_name IS NOT NULL
GROUP BY month, company_name
ORDER BY month
""",
(days,),
)
by_month_rows = cur.fetchall()
# Analysis type distribution per month
cur.execute(
"""
SELECT
TO_CHAR(timestamp, 'YYYY-MM') AS month,
analysis_type,
COUNT(*) AS count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
AND is_cached = FALSE
GROUP BY month, analysis_type
ORDER BY month
""",
(days,),
)
by_type_rows = cur.fetchall()
by_month = [
{"month": row[0], "company_name": row[1], "count": row[2]}
for row in by_month_rows
]
by_type_over_time = [
{"month": row[0], "analysis_type": row[1], "count": row[2]}
for row in by_type_rows
]
return {
"by_month": by_month,
"by_type_over_time": by_type_over_time,
"period_days": days,
}
# ============== Export Endpoints ==============
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp.
Args:
company_name: Company name to export results for
Returns:
CSV file download
"""
import csv
import io
db = get_db_client()
# Query all non-cached analysis results for this company
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
ORDER BY timestamp DESC
""",
(company_name,),
)
rows = cur.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(["company_name", "analysis_type", "model", "analysis", "timestamp"])
for row in rows:
writer.writerow(row)
output.seek(0)
safe_name = company_name.replace(" ", "_").lower()
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="sparc_{safe_name}_export.csv"'},
)
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted
as a downloadable PDF document.
Args:
company_name: Company name to export results for
Returns:
PDF file download
"""
import io
from reportlab.lib import colors
from reportlab.lib.pagesizes import letter
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
from reportlab.lib.units import inch
from reportlab.platypus import (
Paragraph,
SimpleDocTemplate,
Spacer,
Table,
TableStyle,
)
db = get_db_client()
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
ORDER BY timestamp DESC
""",
(company_name,),
)
rows = cur.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=letter,
rightMargin=0.75 * inch,
leftMargin=0.75 * inch,
topMargin=0.75 * inch,
bottomMargin=0.75 * inch,
)
styles = getSampleStyleSheet()
title_style = ParagraphStyle(
"CustomTitle",
parent=styles["Title"],
fontSize=20,
spaceAfter=6,
)
subtitle_style = ParagraphStyle(
"Subtitle",
parent=styles["Normal"],
fontSize=11,
textColor=colors.grey,
spaceAfter=20,
)
heading_style = ParagraphStyle(
"SectionHeading",
parent=styles["Heading2"],
fontSize=13,
spaceBefore=16,
spaceAfter=8,
textColor=colors.HexColor("#1a1a2e"),
)
body_style = ParagraphStyle(
"BodyText",
parent=styles["Normal"],
fontSize=9,
leading=13,
spaceAfter=10,
)
elements = []
# Title and date
display_name = rows[0][0] # Use the casing from the database
analysis_date = datetime.now().strftime("%Y-%m-%d")
elements.append(Paragraph(f"SPARC Analysis Report: {display_name}", title_style))
elements.append(Paragraph(f"Generated on {analysis_date}", subtitle_style))
# Summary table
summary_data = [
["Total Analyses", str(len(rows))],
["Analysis Types", ", ".join(sorted(set(r[1] for r in rows)))],
["Models Used", ", ".join(sorted(set(r[2] for r in rows)))],
]
summary_table = Table(summary_data, colWidths=[2 * inch, 4.5 * inch])
summary_table.setStyle(
TableStyle(
[
("BACKGROUND", (0, 0), (0, -1), colors.HexColor("#f0f0f5")),
("FONTNAME", (0, 0), (0, -1), "Helvetica-Bold"),
("FONTSIZE", (0, 0), (-1, -1), 9),
("PADDING", (0, 0), (-1, -1), 6),
("GRID", (0, 0), (-1, -1), 0.5, colors.HexColor("#cccccc")),
("VALIGN", (0, 0), (-1, -1), "TOP"),
]
)
)
elements.append(summary_table)
elements.append(Spacer(1, 16))
# Individual analysis sections
for i, row in enumerate(rows, 1):
_, analysis_type, model, response, timestamp = row
ts_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") if hasattr(timestamp, "strftime") else str(timestamp)
elements.append(
Paragraph(f"Analysis {i}: {analysis_type} (via {model})", heading_style)
)
elements.append(
Paragraph(f"<i>Performed: {ts_str}</i>", body_style)
)
# Wrap long response text into paragraphs, escaping XML special chars
safe_response = (
response.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
# Split into manageable paragraphs to avoid overflow
for line in safe_response.split("\n"):
if line.strip():
elements.append(Paragraph(line, body_style))
else:
elements.append(Spacer(1, 4))
elements.append(Spacer(1, 10))
doc.build(elements)
buffer.seek(0)
safe_name = company_name.replace(" ", "_").lower()
filename = f"{safe_name}-analysis-{analysis_date}.pdf"
return StreamingResponse(
iter([buffer.getvalue()]),
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
# ============== System Endpoints ==============
@app.get("/health", response_model=HealthResponse, tags=["System"])
async def health_check():
"""Check API health status."""
return HealthResponse(
status="healthy",
version="1.0.0",
timestamp=datetime.now(),
)
@app.get(
"/analyze/{company_name}",
response_model=CompanyAnalysisResponse,
tags=["Analysis"],
)
async def analyze_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single company's patent portfolio.
This endpoint retrieves recent patents for the specified company,
parses them, and uses AI to generate a comprehensive analysis.
Args:
company_name: Name of the company to analyze (e.g., "nvidia", "intel")
model: Optional LLM model override
Returns:
Analysis results including patent count, AI insights, and success status
"""
_validate_model(model)
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer._analyze_company_safe(company_name, model=model)
return _convert_result(result)
def _extract_patent_ids(response_text: str) -> set[str]:
"""Extract patent IDs from an analysis response text.
Looks for patterns like US-12345678-B2, US12345678B2, etc.
"""
import re
pattern = r"US[-\s]?\d{7,8}[-\s]?[A-Z]\d?"
return set(re.findall(pattern, response_text or ""))
def _compute_analysis_diff(from_rec: dict, to_rec: dict) -> AnalysisDiffResponse:
"""Compute a structured diff between two analysis records."""
from_patents = _extract_patent_ids(from_rec.get("response", "") or "")
to_patents = _extract_patent_ids(to_rec.get("response", "") or "")
added = sorted(to_patents - from_patents)
removed = sorted(from_patents - to_patents)
patent_count_delta = len(to_patents) - len(from_patents)
changed_fields: dict[str, dict] = {}
if from_rec.get("model") != to_rec.get("model"):
changed_fields["model"] = {
"from": from_rec.get("model"),
"to": to_rec.get("model"),
}
if from_rec.get("analysis_type") != to_rec.get("analysis_type"):
changed_fields["analysis_type"] = {
"from": from_rec.get("analysis_type"),
"to": to_rec.get("analysis_type"),
}
# Build a human-readable summary
parts: list[str] = []
if added:
parts.append(f"{len(added)} new patent(s) appeared")
if removed:
parts.append(f"{len(removed)} patent(s) no longer referenced")
if patent_count_delta > 0:
parts.append(f"patent mention count increased by {patent_count_delta}")
elif patent_count_delta < 0:
parts.append(f"patent mention count decreased by {abs(patent_count_delta)}")
if changed_fields:
parts.append(f"field(s) changed: {', '.join(changed_fields.keys())}")
summary = "; ".join(parts) if parts else "No significant differences detected."
return AnalysisDiffResponse(
company_name=to_rec["company_name"],
from_id=from_rec["id"],
to_id=to_rec["id"],
from_timestamp=from_rec["timestamp"],
to_timestamp=to_rec["timestamp"],
patent_count_delta=patent_count_delta,
added_patents=added,
removed_patents=removed,
changed_fields=changed_fields,
summary=summary,
)
@app.get(
"/analyze/{company_name}/history",
response_model=list[CompanyAnalysisHistoryItem],
tags=["Analysis"],
)
async def list_company_analysis_history(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
limit: int = Query(default=20, ge=1, le=100),
_: UserResponse = Depends(get_current_user),
):
"""List previous analysis runs for a company.
Returns a list of analysis records ordered by timestamp descending,
useful for selecting which runs to compare via the diff endpoint.
Args:
company_name: Company name to look up
limit: Maximum number of results
Returns:
List of analysis history items
"""
db = _get_job_db()
rows = db.list_company_analyses(company_name, limit=limit)
return [
CompanyAnalysisHistoryItem(
id=r["id"],
analysis_type=r.get("analysis_type"),
model=r.get("model"),
timestamp=r["timestamp"],
)
for r in rows
]
@app.get(
"/analyze/{company_name}/diff",
response_model=AnalysisDiffResponse,
tags=["Analysis"],
)
async def diff_company_analyses(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
from_id: int = Query(..., alias="from", description="Analysis ID of the older run"),
to_id: int = Query(..., alias="to", description="Analysis ID of the newer run"),
_: UserResponse = Depends(get_current_user),
):
"""Compare two analysis runs for the same company.
Returns a structured diff showing added/removed patents, score delta,
and a summary narrative.
Args:
company_name: Company name (must match both analysis records)
from_id: ID of the older analysis run
to_id: ID of the newer analysis run
Returns:
AnalysisDiffResponse with added/removed/changed fields
Raises:
404: If either analysis ID does not exist or belongs to a different company
"""
db = _get_job_db()
from_rec = db.get_analysis_by_id(from_id)
if not from_rec or (from_rec["company_name"] or "").lower() != company_name.lower():
raise HTTPException(
status_code=404,
detail=f"Analysis ID {from_id} not found for company '{company_name}'",
)
to_rec = db.get_analysis_by_id(to_id)
if not to_rec or (to_rec["company_name"] or "").lower() != company_name.lower():
raise HTTPException(
status_code=404,
detail=f"Analysis ID {to_id} not found for company '{company_name}'",
)
return _compute_analysis_diff(from_rec, to_rec)
@app.get(
"/analyze/patent/{patent_id}",
tags=["Analysis"],
)
async def analyze_single_patent(
patent_id: str,
company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single patent by its publication ID.
If the patent PDF is not already cached locally, the system will attempt
to download it automatically from a previously cached link. If no link
is available, a 404 error is returned.
Args:
patent_id: Patent publication ID (e.g. "US-11234567-B2")
company_name: Company name for analysis context
Returns:
Analysis text for the patent
"""
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
try:
analysis = _analyzer.analyze_single_patent(patent_id, company_name)
return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@app.get(
"/analyze/batch",
response_model=PaginatedAnalysisResponse,
tags=["Analysis"],
)
async def list_analysis_results(
company_name: Annotated[
str | None,
Query(description="Filter results by company name"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_user),
):
"""List stored analysis results with cursor-based pagination.
Returns past analysis results ordered by timestamp descending. Use
``limit`` to control page size (default 50, max 200). The response
includes a ``next_cursor`` field; pass it back as the ``cursor`` query
parameter to fetch the next page. When ``next_cursor`` is ``null``,
there are no more results.
Args:
company_name: Optional filter by company name
limit: Maximum number of results to return (default 50, max 200)
cursor: Opaque pagination cursor from a previous response
Returns:
Paginated list of analysis results
"""
db = _get_job_db()
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor)
has_next = len(rows) > limit
if has_next:
rows = rows[:limit]
items = [AnalysisRecord(**row) for row in rows]
next_cursor = None
if has_next and rows:
last = rows[-1]
ts = last["timestamp"]
ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts)
next_cursor = f"{ts_str}|{last['id']}"
return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor)
@app.post(
"/analyze/batch",
response_model=BatchAnalysisResponse,
tags=["Analysis"],
)
async def analyze_companies_batch(
request: BatchAnalysisRequest,
_: UserResponse = Depends(get_current_user),
):
"""Analyze multiple companies' patent portfolios.
Processes companies concurrently for improved performance.
Limited to 20 companies per request.
Args:
request: List of company names and optional worker count
Returns:
Batch results with individual company analyses and summary statistics
"""
_validate_model(request.model)
if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer.analyze_companies(
companies=request.companies,
max_workers=request.max_workers,
model=request.model,
)
return _convert_batch_result(result)
def _get_job_db() -> "DatabaseClient":
"""Get a DatabaseClient for job persistence."""
from SPARC.database import DatabaseClient
db = DatabaseClient(config.database_url)
return db
def _job_row_to_status(row: dict) -> JobStatus:
"""Convert a database job row to a JobStatus model."""
import json as _json
result = None
if row.get("result_json"):
result_data = row["result_json"]
if isinstance(result_data, str):
result_data = _json.loads(result_data)
result = BatchAnalysisResponse(**result_data)
return JobStatus(
job_id=row["job_id"],
status=row["status"],
progress=row["progress"],
total_companies=row["total_companies"],
completed_companies=row["completed_companies"],
result=result,
error=row.get("error"),
)
def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: str | None = None):
"""Background task for batch analysis."""
import json as _json
global _analyzer
db = _get_job_db()
if not _analyzer:
db.update_job(job_id, status="failed", error="Analyzer not initialized")
return
db.update_job(job_id, status="running")
def progress_callback(company: str, completed: int, total: int):
db.update_job(
job_id,
completed_companies=completed,
progress=int((completed / total) * 100),
)
try:
result = _analyzer.analyze_companies(
companies=companies,
max_workers=max_workers,
progress_callback=progress_callback,
model=model,
)
batch_response = _convert_batch_result(result)
db.update_job(
job_id,
status="completed",
progress=100,
result_json=_json.dumps(batch_response.model_dump(), default=str),
)
# Fire webhook notification
from SPARC.webhooks import notify_job_completed
notify_job_completed(
job_id=job_id,
status="completed",
total_companies=result.total_companies,
successful=result.successful,
failed=result.failed,
)
except Exception as e:
db.update_job(job_id, status="failed", error=str(e))
from SPARC.webhooks import notify_job_completed
notify_job_completed(
job_id=job_id,
status="failed",
total_companies=len(companies),
successful=0,
failed=len(companies),
)
@app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"])
async def analyze_companies_async(
request: BatchAnalysisRequest,
background_tasks: BackgroundTasks,
_: UserResponse = Depends(get_current_user),
):
"""Start an asynchronous batch analysis job.
Returns immediately with a job ID that can be used to poll for status.
Useful for large batch analyses that may take a long time.
Args:
request: List of company names and optional worker count
Returns:
Job status with job_id for polling
"""
_validate_model(request.model)
global _job_counter
_job_counter += 1
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
db = _get_job_db()
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies))
background_tasks.add_task(
_run_batch_job, job_id, request.companies, request.max_workers, request.model
)
return _job_row_to_status(job_row)
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
async def get_job_status(
job_id: str,
_: UserResponse = Depends(get_current_user),
):
"""Get the status of a background analysis job.
Args:
job_id: The job ID returned from the async batch endpoint
Returns:
Current job status including progress and results when complete
"""
db = _get_job_db()
job_row = db.get_job(job_id)
if not job_row:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
return _job_row_to_status(job_row)
@app.get("/jobs", response_model=PaginatedJobsResponse, tags=["Jobs"])
async def list_jobs(
status: Annotated[
str | None,
Query(description="Filter by status: pending, running, completed, failed"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[
str | None,
Query(description="Opaque cursor from a previous response's next_cursor field"),
] = None,
_: UserResponse = Depends(get_current_user),
):
"""List analysis jobs with cursor-based pagination.
Pass ``limit`` to control page size. The response includes a ``next_cursor``
field; pass it back as the ``cursor`` query parameter to fetch the next page.
When ``next_cursor`` is ``null``, there are no more results.
Existing clients that use only ``limit`` (without ``cursor``) continue to
work without modification.
Args:
status: Optional filter by job status
limit: Maximum number of jobs to return (default 10, max 100)
cursor: Opaque pagination cursor from a previous response
Returns:
Paginated list of job statuses
"""
db = _get_job_db()
# Fetch one extra to determine if there is a next page
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor)
has_next = len(job_rows) > limit
if has_next:
job_rows = job_rows[:limit]
items = [_job_row_to_status(row) for row in job_rows]
next_cursor = None
if has_next and job_rows:
last = job_rows[-1]
created = last["created_at"]
ts = created.isoformat() if hasattr(created, "isoformat") else str(created)
next_cursor = f"{ts}|{last['job_id']}"
return PaginatedJobsResponse(items=items, next_cursor=next_cursor)