1231 lines
36 KiB
Python
1231 lines
36 KiB
Python
"""FastAPI web service wrapper for SPARC patent analysis.
|
|
|
|
Provides REST API endpoints for analyzing company patent portfolios.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime
|
|
from typing import TYPE_CHECKING, Annotated, List
|
|
|
|
if TYPE_CHECKING:
|
|
from SPARC.database import DatabaseClient
|
|
|
|
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Path, Query, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel, EmailStr, Field, StringConstraints
|
|
from slowapi import Limiter
|
|
from slowapi.errors import RateLimitExceeded
|
|
from slowapi.util import get_remote_address
|
|
|
|
from SPARC import config
|
|
from SPARC.analyzer import CompanyAnalyzer
|
|
from SPARC.auth import (
|
|
TokenResponse,
|
|
UserResponse,
|
|
check_jwt_secret,
|
|
close_db_client,
|
|
create_tokens,
|
|
decode_token,
|
|
get_current_admin,
|
|
get_current_user,
|
|
get_db_client,
|
|
init_db_client,
|
|
)
|
|
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
|
|
|
# Validated company name type: 2-100 chars, alphanumeric + spaces/hyphens/ampersands/periods only.
|
|
CompanyName = Annotated[
|
|
str,
|
|
StringConstraints(
|
|
min_length=2,
|
|
max_length=100,
|
|
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
|
|
),
|
|
]
|
|
|
|
|
|
# Pydantic models for API
|
|
class CompanyAnalysisResponse(BaseModel):
|
|
"""Response model for single company analysis."""
|
|
|
|
company_name: str
|
|
analysis: str
|
|
patent_count: int
|
|
success: bool
|
|
error: str | None = None
|
|
model: str | None = None
|
|
timestamp: datetime
|
|
|
|
|
|
class BatchAnalysisResponse(BaseModel):
|
|
"""Response model for batch company analysis."""
|
|
|
|
results: list[CompanyAnalysisResponse]
|
|
total_companies: int
|
|
successful: int
|
|
failed: int
|
|
timestamp: datetime
|
|
|
|
|
|
class CompanyAnalysisRequest(BaseModel):
|
|
"""Request model for single company analysis with optional model selection."""
|
|
|
|
model: str | None = Field(
|
|
default=None,
|
|
description="LLM model to use (e.g. 'anthropic/claude-3.5-sonnet', 'openai/gpt-4o'). Defaults to server config.",
|
|
)
|
|
|
|
|
|
class BatchAnalysisRequest(BaseModel):
|
|
"""Request model for batch company analysis."""
|
|
|
|
companies: list[CompanyName] = Field(
|
|
..., min_length=1, max_length=20, description="List of company names to analyze"
|
|
)
|
|
max_workers: int = Field(
|
|
default=3, ge=1, le=5, description="Max concurrent analyses"
|
|
)
|
|
model: str | None = Field(
|
|
default=None,
|
|
description="LLM model to use for all analyses in this batch. Defaults to server config.",
|
|
)
|
|
|
|
|
|
class JobStatus(BaseModel):
|
|
"""Status of a background analysis job."""
|
|
|
|
job_id: str
|
|
status: str # "pending", "running", "completed", "failed"
|
|
progress: int # 0-100
|
|
total_companies: int
|
|
completed_companies: int
|
|
result: BatchAnalysisResponse | None = None
|
|
error: str | None = None
|
|
|
|
|
|
class AnalysisRecord(BaseModel):
|
|
"""A single stored analysis result."""
|
|
|
|
id: int
|
|
company_name: str | None = None
|
|
analysis_type: str | None = None
|
|
model: str | None = None
|
|
response: str | None = None
|
|
timestamp: datetime | None = None
|
|
|
|
|
|
class PaginatedAnalysisResponse(BaseModel):
|
|
"""Paginated response for analysis result listings."""
|
|
|
|
items: list[AnalysisRecord]
|
|
next_cursor: str | None = None
|
|
|
|
|
|
class PaginatedJobsResponse(BaseModel):
|
|
"""Paginated response for job listings."""
|
|
|
|
items: list["JobStatus"]
|
|
next_cursor: str | None = None
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""Health check response."""
|
|
|
|
status: str
|
|
version: str
|
|
timestamp: datetime
|
|
|
|
|
|
# Auth request/response models
|
|
class RegisterRequest(BaseModel):
|
|
"""User registration request."""
|
|
|
|
email: EmailStr
|
|
password: str = Field(..., min_length=8, description="Password (min 8 characters)")
|
|
|
|
|
|
class LoginRequest(BaseModel):
|
|
"""User login request."""
|
|
|
|
email: EmailStr
|
|
password: str
|
|
|
|
|
|
class RefreshRequest(BaseModel):
|
|
"""Token refresh request."""
|
|
|
|
refresh_token: str
|
|
|
|
|
|
class UpdateRoleRequest(BaseModel):
|
|
"""Update user role request."""
|
|
|
|
role: str = Field(..., pattern="^(admin|user)$")
|
|
|
|
|
|
class AnalyticsResponse(BaseModel):
|
|
"""Analytics response model."""
|
|
|
|
total_messages: int
|
|
by_company: List[dict]
|
|
by_type: List[dict]
|
|
period_days: int
|
|
|
|
|
|
# Job counter for generating unique IDs (the actual state is in PostgreSQL)
|
|
_job_counter = 0
|
|
|
|
|
|
def _convert_result(result: CompanyAnalysisResult) -> CompanyAnalysisResponse:
|
|
"""Convert internal result to API response model."""
|
|
return CompanyAnalysisResponse(
|
|
company_name=result.company_name,
|
|
analysis=result.analysis,
|
|
patent_count=result.patent_count,
|
|
success=result.success,
|
|
error=result.error,
|
|
model=result.model,
|
|
timestamp=result.timestamp,
|
|
)
|
|
|
|
|
|
def _convert_batch_result(result: BatchAnalysisResult) -> BatchAnalysisResponse:
|
|
"""Convert internal batch result to API response model."""
|
|
return BatchAnalysisResponse(
|
|
results=[_convert_result(r) for r in result.results],
|
|
total_companies=result.total_companies,
|
|
successful=result.successful,
|
|
failed=result.failed,
|
|
timestamp=result.timestamp,
|
|
)
|
|
|
|
|
|
# Global analyzer instance
|
|
_analyzer: CompanyAnalyzer | None = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Initialize resources on startup, clean up on shutdown."""
|
|
global _analyzer
|
|
check_jwt_secret()
|
|
init_db_client()
|
|
_analyzer = CompanyAnalyzer()
|
|
# Mark any jobs that were running/pending before the restart as failed
|
|
from SPARC.database import DatabaseClient
|
|
_db = DatabaseClient(config.database_url)
|
|
_db.connect()
|
|
_db.initialize_schema()
|
|
stale = _db.mark_stale_jobs_failed()
|
|
if stale:
|
|
import logging
|
|
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
|
|
_db.close()
|
|
# Start scheduled analysis if tracked companies are configured
|
|
from SPARC.scheduler import start_scheduler
|
|
start_scheduler()
|
|
yield
|
|
# Cleanup
|
|
_analyzer = None
|
|
close_db_client()
|
|
|
|
|
|
app = FastAPI(
|
|
title="SPARC API",
|
|
description="Semiconductor Patent & Analytics Report Core - Patent portfolio analysis using AI",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
root_path=config.root_path,
|
|
)
|
|
|
|
# Rate limiter (in-memory storage, suitable for single-instance deployments)
|
|
limiter = Limiter(key_func=get_remote_address)
|
|
app.state.limiter = limiter
|
|
|
|
# In-memory rate limit statistics
|
|
_rate_limit_stats: dict[str, dict] = {}
|
|
|
|
|
|
def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None:
|
|
"""Record a request against a rate-limited endpoint."""
|
|
key = endpoint
|
|
if key not in _rate_limit_stats:
|
|
_rate_limit_stats[key] = {
|
|
"endpoint": endpoint,
|
|
"total_requests": 0,
|
|
"rejected_requests": 0,
|
|
"by_ip": {},
|
|
}
|
|
_rate_limit_stats[key]["total_requests"] += 1
|
|
if rejected:
|
|
_rate_limit_stats[key]["rejected_requests"] += 1
|
|
ip_stats = _rate_limit_stats[key].setdefault("by_ip", {})
|
|
if ip not in ip_stats:
|
|
ip_stats[ip] = {"total": 0, "rejected": 0}
|
|
ip_stats[ip]["total"] += 1
|
|
if rejected:
|
|
ip_stats[ip]["rejected"] += 1
|
|
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
|
|
"""Return 429 with Retry-After header when rate limit is exceeded."""
|
|
endpoint = request.url.path
|
|
ip = get_remote_address(request)
|
|
_track_rate_limit_request(endpoint, ip, rejected=True)
|
|
retry_after = getattr(exc, "retry_after", 60)
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"detail": "Rate limit exceeded. Please try again later."},
|
|
headers={"Retry-After": str(retry_after)},
|
|
)
|
|
|
|
|
|
# Add CORS middleware for React frontend
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=config.cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ============== Auth Endpoints ==============
|
|
|
|
|
|
@app.post("/auth/register", response_model=UserResponse, tags=["Auth"])
|
|
@limiter.limit("5/minute")
|
|
async def register(request: Request, body: RegisterRequest):
|
|
"""Register a new user.
|
|
|
|
The first registered user automatically becomes an admin.
|
|
"""
|
|
_track_rate_limit_request("/auth/register", get_remote_address(request))
|
|
db = get_db_client()
|
|
|
|
# First user becomes admin
|
|
user_count = db.get_user_count()
|
|
role = "admin" if user_count == 0 else "user"
|
|
|
|
user = db.create_user(
|
|
email=body.email,
|
|
password=body.password,
|
|
role=role,
|
|
)
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Email already registered",
|
|
)
|
|
|
|
return UserResponse(
|
|
id=user["id"],
|
|
email=user["email"],
|
|
role=user["role"],
|
|
created_at=user["created_at"],
|
|
)
|
|
|
|
|
|
@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"])
|
|
@limiter.limit("10/minute")
|
|
async def login(request: Request, body: LoginRequest):
|
|
"""Authenticate user and return JWT tokens."""
|
|
_track_rate_limit_request("/auth/login", get_remote_address(request))
|
|
db = get_db_client()
|
|
|
|
user = db.authenticate_user(body.email, body.password)
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Invalid email or password",
|
|
)
|
|
|
|
return create_tokens(user["id"], user["email"], user["role"])
|
|
|
|
|
|
@app.post("/auth/refresh", response_model=TokenResponse, tags=["Auth"])
|
|
async def refresh_token(request: RefreshRequest):
|
|
"""Refresh access token using refresh token."""
|
|
payload = decode_token(request.refresh_token)
|
|
|
|
if not payload or payload.type != "refresh":
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Invalid refresh token",
|
|
)
|
|
|
|
db = get_db_client()
|
|
user = db.get_user_by_id(payload.user_id)
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="User not found",
|
|
)
|
|
|
|
return create_tokens(user["id"], user["email"], user["role"])
|
|
|
|
|
|
@app.get("/auth/me", response_model=UserResponse, tags=["Auth"])
|
|
async def get_me(current_user: UserResponse = Depends(get_current_user)):
|
|
"""Get current authenticated user."""
|
|
return current_user
|
|
|
|
|
|
# ============== Admin Endpoints ==============
|
|
|
|
|
|
@app.get("/admin/users", response_model=List[UserResponse], tags=["Admin"])
|
|
async def list_users(
|
|
limit: int = Query(default=100, ge=1, le=1000),
|
|
offset: int = Query(default=0, ge=0),
|
|
_: UserResponse = Depends(get_current_admin),
|
|
):
|
|
"""List all users (admin only)."""
|
|
db = get_db_client()
|
|
users = db.get_all_users(limit=limit, offset=offset)
|
|
|
|
return [
|
|
UserResponse(
|
|
id=u["id"],
|
|
email=u["email"],
|
|
role=u["role"],
|
|
created_at=u["created_at"],
|
|
)
|
|
for u in users
|
|
]
|
|
|
|
|
|
@app.patch("/admin/users/{user_id}/role", response_model=UserResponse, tags=["Admin"])
|
|
async def update_user_role(
|
|
user_id: int,
|
|
request: UpdateRoleRequest,
|
|
current_admin: UserResponse = Depends(get_current_admin),
|
|
):
|
|
"""Update a user's role (admin only)."""
|
|
if user_id == current_admin.id:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot change your own role",
|
|
)
|
|
|
|
db = get_db_client()
|
|
user = db.update_user_role(user_id, request.role)
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="User not found",
|
|
)
|
|
|
|
return UserResponse(
|
|
id=user["id"],
|
|
email=user["email"],
|
|
role=user["role"],
|
|
created_at=user["created_at"],
|
|
)
|
|
|
|
|
|
@app.delete("/admin/users/{user_id}", tags=["Admin"])
|
|
async def delete_user(
|
|
user_id: int,
|
|
current_admin: UserResponse = Depends(get_current_admin),
|
|
):
|
|
"""Delete a user (admin only)."""
|
|
if user_id == current_admin.id:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot delete yourself",
|
|
)
|
|
|
|
db = get_db_client()
|
|
deleted = db.delete_user(user_id)
|
|
|
|
if not deleted:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="User not found",
|
|
)
|
|
|
|
return {"message": "User deleted"}
|
|
|
|
|
|
# ============== Tracked Companies Endpoints ==============
|
|
|
|
|
|
class TrackCompanyRequest(BaseModel):
|
|
"""Request to add a company to tracking."""
|
|
|
|
company_name: CompanyName = Field(...)
|
|
|
|
|
|
@app.get("/admin/tracked", tags=["Admin"])
|
|
async def list_tracked_companies(
|
|
_: UserResponse = Depends(get_current_admin),
|
|
):
|
|
"""List all tracked companies (admin only)."""
|
|
db = get_db_client()
|
|
return db.list_tracked_companies()
|
|
|
|
|
|
@app.post("/admin/tracked", tags=["Admin"])
|
|
async def add_tracked_company(
|
|
request: TrackCompanyRequest,
|
|
_: UserResponse = Depends(get_current_admin),
|
|
):
|
|
"""Add a company to the tracked list (admin only)."""
|
|
db = get_db_client()
|
|
result = db.add_tracked_company(request.company_name)
|
|
if not result:
|
|
raise HTTPException(status_code=409, detail="Company already tracked")
|
|
return result
|
|
|
|
|
|
@app.delete("/admin/tracked/{company_name}", tags=["Admin"])
|
|
async def remove_tracked_company(
|
|
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
|
_: UserResponse = Depends(get_current_admin),
|
|
):
|
|
"""Remove a company from the tracked list (admin only)."""
|
|
db = get_db_client()
|
|
removed = db.remove_tracked_company(company_name)
|
|
if not removed:
|
|
raise HTTPException(status_code=404, detail="Company not found in tracking list")
|
|
return {"message": f"Stopped tracking {company_name}"}
|
|
|
|
|
|
@app.get("/admin/rate-limits", tags=["Admin"])
|
|
async def get_rate_limit_stats(
|
|
_: UserResponse = Depends(get_current_admin),
|
|
):
|
|
"""Get rate limit status and usage statistics (admin only).
|
|
|
|
Returns current rate limit configuration and request statistics
|
|
for all rate-limited endpoints.
|
|
|
|
Returns:
|
|
List of rate limit stats per endpoint with total/rejected counts
|
|
"""
|
|
rate_limits_config = {
|
|
"/auth/register": {"limit": "5/minute"},
|
|
"/auth/login": {"limit": "10/minute"},
|
|
}
|
|
|
|
results = []
|
|
for endpoint, conf in rate_limits_config.items():
|
|
stats = _rate_limit_stats.get(endpoint, {})
|
|
results.append({
|
|
"endpoint": endpoint,
|
|
"limit": conf["limit"],
|
|
"total_requests": stats.get("total_requests", 0),
|
|
"rejected_requests": stats.get("rejected_requests", 0),
|
|
})
|
|
|
|
return {"rate_limits": results}
|
|
|
|
|
|
@app.get("/admin/alerts", tags=["Admin"])
|
|
async def list_alerts(
|
|
limit: int = Query(default=50, ge=1, le=200),
|
|
_: UserResponse = Depends(get_current_admin),
|
|
):
|
|
"""List recent alerts from scheduled analysis (admin only)."""
|
|
db = get_db_client()
|
|
return db.list_alerts(limit=limit)
|
|
|
|
|
|
# ============== Analytics Endpoint ==============
|
|
|
|
|
|
@app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"])
|
|
async def get_analytics(
|
|
days: int = Query(default=30, ge=1, le=365),
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Get analytics data (authenticated users only)."""
|
|
db = get_db_client()
|
|
analytics = db.get_analytics(days=days)
|
|
|
|
return AnalyticsResponse(
|
|
total_messages=analytics["total_messages"],
|
|
by_company=analytics["by_company"],
|
|
by_type=analytics["by_type"],
|
|
period_days=analytics["period_days"],
|
|
)
|
|
|
|
|
|
# ============== Model Selection Endpoints ==============
|
|
|
|
# Supported models via OpenRouter
|
|
SUPPORTED_MODELS = [
|
|
{"id": "anthropic/claude-3.5-sonnet", "name": "Claude 3.5 Sonnet", "provider": "Anthropic"},
|
|
{"id": "openai/gpt-4o", "name": "GPT-4o", "provider": "OpenAI"},
|
|
{"id": "openai/gpt-4o-mini", "name": "GPT-4o Mini", "provider": "OpenAI"},
|
|
{"id": "google/gemini-pro-1.5", "name": "Gemini Pro 1.5", "provider": "Google"},
|
|
{"id": "meta-llama/llama-3.1-70b-instruct", "name": "Llama 3.1 70B", "provider": "Meta"},
|
|
]
|
|
|
|
_SUPPORTED_MODEL_IDS = {m["id"] for m in SUPPORTED_MODELS}
|
|
|
|
|
|
def _validate_model(model: str | None) -> None:
|
|
"""Raise HTTP 400 if *model* is not in the supported allow-list."""
|
|
if model is not None and model not in _SUPPORTED_MODEL_IDS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=(
|
|
f"Unsupported model '{model}'. "
|
|
f"Supported models: {', '.join(sorted(_SUPPORTED_MODEL_IDS))}"
|
|
),
|
|
)
|
|
|
|
|
|
@app.get("/models", tags=["System"])
|
|
async def list_models():
|
|
"""List supported LLM models for analysis.
|
|
|
|
Returns the available models that can be passed as the `model` field
|
|
in analysis requests. The default model is determined by the `MODEL`
|
|
environment variable on the server.
|
|
"""
|
|
return {
|
|
"models": SUPPORTED_MODELS,
|
|
"default": config.model,
|
|
}
|
|
|
|
|
|
@app.get("/analytics/trends", tags=["Analytics"])
|
|
async def get_analytics_trends(
|
|
days: int = Query(default=90, ge=7, le=365),
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Get trend data for patent analysis over time.
|
|
|
|
Returns two datasets:
|
|
- ``by_month``: analysis count per company per month
|
|
- ``by_type_over_time``: analysis type distribution per month
|
|
|
|
Args:
|
|
days: Number of days to look back (default 90)
|
|
|
|
Returns:
|
|
Trend data suitable for time-series and distribution charts
|
|
"""
|
|
db = get_db_client()
|
|
|
|
with db.get_conn() as conn:
|
|
with conn.cursor() as cur:
|
|
# Analyses per company per month
|
|
cur.execute(
|
|
"""
|
|
SELECT
|
|
TO_CHAR(timestamp, 'YYYY-MM') AS month,
|
|
company_name,
|
|
COUNT(*) AS count
|
|
FROM llm_messages
|
|
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
|
AND is_cached = FALSE
|
|
AND company_name IS NOT NULL
|
|
GROUP BY month, company_name
|
|
ORDER BY month
|
|
""",
|
|
(days,),
|
|
)
|
|
by_month_rows = cur.fetchall()
|
|
|
|
# Analysis type distribution per month
|
|
cur.execute(
|
|
"""
|
|
SELECT
|
|
TO_CHAR(timestamp, 'YYYY-MM') AS month,
|
|
analysis_type,
|
|
COUNT(*) AS count
|
|
FROM llm_messages
|
|
WHERE timestamp >= NOW() - INTERVAL '%s days'
|
|
AND is_cached = FALSE
|
|
GROUP BY month, analysis_type
|
|
ORDER BY month
|
|
""",
|
|
(days,),
|
|
)
|
|
by_type_rows = cur.fetchall()
|
|
|
|
by_month = [
|
|
{"month": row[0], "company_name": row[1], "count": row[2]}
|
|
for row in by_month_rows
|
|
]
|
|
by_type_over_time = [
|
|
{"month": row[0], "analysis_type": row[1], "count": row[2]}
|
|
for row in by_type_rows
|
|
]
|
|
|
|
return {
|
|
"by_month": by_month,
|
|
"by_type_over_time": by_type_over_time,
|
|
"period_days": days,
|
|
}
|
|
|
|
|
|
# ============== Export Endpoints ==============
|
|
|
|
|
|
@app.get("/export/{company_name}", tags=["Export"])
|
|
async def export_company_csv(
|
|
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Export analysis results for a company as a CSV file.
|
|
|
|
Returns all stored analysis records for the given company, including
|
|
analysis type, model used, response text, and timestamp.
|
|
|
|
Args:
|
|
company_name: Company name to export results for
|
|
|
|
Returns:
|
|
CSV file download
|
|
"""
|
|
import csv
|
|
import io
|
|
|
|
db = get_db_client()
|
|
# Query all non-cached analysis results for this company
|
|
with db.get_conn() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT company_name, analysis_type, model, response, timestamp
|
|
FROM llm_messages
|
|
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
|
|
ORDER BY timestamp DESC
|
|
""",
|
|
(company_name,),
|
|
)
|
|
rows = cur.fetchall()
|
|
|
|
if not rows:
|
|
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
|
|
|
|
output = io.StringIO()
|
|
writer = csv.writer(output)
|
|
writer.writerow(["company_name", "analysis_type", "model", "analysis", "timestamp"])
|
|
for row in rows:
|
|
writer.writerow(row)
|
|
|
|
output.seek(0)
|
|
safe_name = company_name.replace(" ", "_").lower()
|
|
return StreamingResponse(
|
|
iter([output.getvalue()]),
|
|
media_type="text/csv",
|
|
headers={"Content-Disposition": f'attachment; filename="sparc_{safe_name}_export.csv"'},
|
|
)
|
|
|
|
|
|
@app.get("/export/{company_name}/pdf", tags=["Export"])
|
|
async def export_company_pdf(
|
|
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Export analysis results for a company as a formatted PDF report.
|
|
|
|
Returns all stored analysis records for the given company, including
|
|
analysis type, model used, response text, and timestamp, formatted
|
|
as a downloadable PDF document.
|
|
|
|
Args:
|
|
company_name: Company name to export results for
|
|
|
|
Returns:
|
|
PDF file download
|
|
"""
|
|
import io
|
|
|
|
from reportlab.lib import colors
|
|
from reportlab.lib.pagesizes import letter
|
|
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
|
|
from reportlab.lib.units import inch
|
|
from reportlab.platypus import (
|
|
Paragraph,
|
|
SimpleDocTemplate,
|
|
Spacer,
|
|
Table,
|
|
TableStyle,
|
|
)
|
|
|
|
db = get_db_client()
|
|
with db.get_conn() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT company_name, analysis_type, model, response, timestamp
|
|
FROM llm_messages
|
|
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
|
|
ORDER BY timestamp DESC
|
|
""",
|
|
(company_name,),
|
|
)
|
|
rows = cur.fetchall()
|
|
|
|
if not rows:
|
|
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
|
|
|
|
buffer = io.BytesIO()
|
|
doc = SimpleDocTemplate(
|
|
buffer,
|
|
pagesize=letter,
|
|
rightMargin=0.75 * inch,
|
|
leftMargin=0.75 * inch,
|
|
topMargin=0.75 * inch,
|
|
bottomMargin=0.75 * inch,
|
|
)
|
|
|
|
styles = getSampleStyleSheet()
|
|
title_style = ParagraphStyle(
|
|
"CustomTitle",
|
|
parent=styles["Title"],
|
|
fontSize=20,
|
|
spaceAfter=6,
|
|
)
|
|
subtitle_style = ParagraphStyle(
|
|
"Subtitle",
|
|
parent=styles["Normal"],
|
|
fontSize=11,
|
|
textColor=colors.grey,
|
|
spaceAfter=20,
|
|
)
|
|
heading_style = ParagraphStyle(
|
|
"SectionHeading",
|
|
parent=styles["Heading2"],
|
|
fontSize=13,
|
|
spaceBefore=16,
|
|
spaceAfter=8,
|
|
textColor=colors.HexColor("#1a1a2e"),
|
|
)
|
|
body_style = ParagraphStyle(
|
|
"BodyText",
|
|
parent=styles["Normal"],
|
|
fontSize=9,
|
|
leading=13,
|
|
spaceAfter=10,
|
|
)
|
|
|
|
elements = []
|
|
|
|
# Title and date
|
|
display_name = rows[0][0] # Use the casing from the database
|
|
analysis_date = datetime.now().strftime("%Y-%m-%d")
|
|
elements.append(Paragraph(f"SPARC Analysis Report: {display_name}", title_style))
|
|
elements.append(Paragraph(f"Generated on {analysis_date}", subtitle_style))
|
|
|
|
# Summary table
|
|
summary_data = [
|
|
["Total Analyses", str(len(rows))],
|
|
["Analysis Types", ", ".join(sorted(set(r[1] for r in rows)))],
|
|
["Models Used", ", ".join(sorted(set(r[2] for r in rows)))],
|
|
]
|
|
summary_table = Table(summary_data, colWidths=[2 * inch, 4.5 * inch])
|
|
summary_table.setStyle(
|
|
TableStyle(
|
|
[
|
|
("BACKGROUND", (0, 0), (0, -1), colors.HexColor("#f0f0f5")),
|
|
("FONTNAME", (0, 0), (0, -1), "Helvetica-Bold"),
|
|
("FONTSIZE", (0, 0), (-1, -1), 9),
|
|
("PADDING", (0, 0), (-1, -1), 6),
|
|
("GRID", (0, 0), (-1, -1), 0.5, colors.HexColor("#cccccc")),
|
|
("VALIGN", (0, 0), (-1, -1), "TOP"),
|
|
]
|
|
)
|
|
)
|
|
elements.append(summary_table)
|
|
elements.append(Spacer(1, 16))
|
|
|
|
# Individual analysis sections
|
|
for i, row in enumerate(rows, 1):
|
|
_, analysis_type, model, response, timestamp = row
|
|
ts_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") if hasattr(timestamp, "strftime") else str(timestamp)
|
|
|
|
elements.append(
|
|
Paragraph(f"Analysis {i}: {analysis_type} (via {model})", heading_style)
|
|
)
|
|
elements.append(
|
|
Paragraph(f"<i>Performed: {ts_str}</i>", body_style)
|
|
)
|
|
|
|
# Wrap long response text into paragraphs, escaping XML special chars
|
|
safe_response = (
|
|
response.replace("&", "&")
|
|
.replace("<", "<")
|
|
.replace(">", ">")
|
|
)
|
|
# Split into manageable paragraphs to avoid overflow
|
|
for line in safe_response.split("\n"):
|
|
if line.strip():
|
|
elements.append(Paragraph(line, body_style))
|
|
else:
|
|
elements.append(Spacer(1, 4))
|
|
|
|
elements.append(Spacer(1, 10))
|
|
|
|
doc.build(elements)
|
|
buffer.seek(0)
|
|
|
|
safe_name = company_name.replace(" ", "_").lower()
|
|
filename = f"{safe_name}-analysis-{analysis_date}.pdf"
|
|
return StreamingResponse(
|
|
iter([buffer.getvalue()]),
|
|
media_type="application/pdf",
|
|
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
|
)
|
|
|
|
|
|
# ============== System Endpoints ==============
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse, tags=["System"])
|
|
async def health_check():
|
|
"""Check API health status."""
|
|
return HealthResponse(
|
|
status="healthy",
|
|
version="1.0.0",
|
|
timestamp=datetime.now(),
|
|
)
|
|
|
|
|
|
@app.get(
|
|
"/analyze/{company_name}",
|
|
response_model=CompanyAnalysisResponse,
|
|
tags=["Analysis"],
|
|
)
|
|
async def analyze_company(
|
|
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
|
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Analyze a single company's patent portfolio.
|
|
|
|
This endpoint retrieves recent patents for the specified company,
|
|
parses them, and uses AI to generate a comprehensive analysis.
|
|
|
|
Args:
|
|
company_name: Name of the company to analyze (e.g., "nvidia", "intel")
|
|
model: Optional LLM model override
|
|
|
|
Returns:
|
|
Analysis results including patent count, AI insights, and success status
|
|
"""
|
|
_validate_model(model)
|
|
if not _analyzer:
|
|
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
|
|
|
result = _analyzer._analyze_company_safe(company_name, model=model)
|
|
return _convert_result(result)
|
|
|
|
|
|
@app.get(
|
|
"/analyze/patent/{patent_id}",
|
|
tags=["Analysis"],
|
|
)
|
|
async def analyze_single_patent(
|
|
patent_id: str,
|
|
company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Analyze a single patent by its publication ID.
|
|
|
|
If the patent PDF is not already cached locally, the system will attempt
|
|
to download it automatically from a previously cached link. If no link
|
|
is available, a 404 error is returned.
|
|
|
|
Args:
|
|
patent_id: Patent publication ID (e.g. "US-11234567-B2")
|
|
company_name: Company name for analysis context
|
|
|
|
Returns:
|
|
Analysis text for the patent
|
|
"""
|
|
if not _analyzer:
|
|
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
|
|
|
try:
|
|
analysis = _analyzer.analyze_single_patent(patent_id, company_name)
|
|
return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis}
|
|
except FileNotFoundError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
|
|
@app.get(
|
|
"/analyze/batch",
|
|
response_model=PaginatedAnalysisResponse,
|
|
tags=["Analysis"],
|
|
)
|
|
async def list_analysis_results(
|
|
company_name: Annotated[
|
|
str | None,
|
|
Query(description="Filter results by company name"),
|
|
] = None,
|
|
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
|
cursor: Annotated[
|
|
str | None,
|
|
Query(description="Opaque cursor from a previous response's next_cursor field"),
|
|
] = None,
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""List stored analysis results with cursor-based pagination.
|
|
|
|
Returns past analysis results ordered by timestamp descending. Use
|
|
``limit`` to control page size (default 50, max 200). The response
|
|
includes a ``next_cursor`` field; pass it back as the ``cursor`` query
|
|
parameter to fetch the next page. When ``next_cursor`` is ``null``,
|
|
there are no more results.
|
|
|
|
Args:
|
|
company_name: Optional filter by company name
|
|
limit: Maximum number of results to return (default 50, max 200)
|
|
cursor: Opaque pagination cursor from a previous response
|
|
|
|
Returns:
|
|
Paginated list of analysis results
|
|
"""
|
|
db = _get_job_db()
|
|
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor)
|
|
|
|
has_next = len(rows) > limit
|
|
if has_next:
|
|
rows = rows[:limit]
|
|
|
|
items = [AnalysisRecord(**row) for row in rows]
|
|
|
|
next_cursor = None
|
|
if has_next and rows:
|
|
last = rows[-1]
|
|
ts = last["timestamp"]
|
|
ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts)
|
|
next_cursor = f"{ts_str}|{last['id']}"
|
|
|
|
return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor)
|
|
|
|
|
|
@app.post(
|
|
"/analyze/batch",
|
|
response_model=BatchAnalysisResponse,
|
|
tags=["Analysis"],
|
|
)
|
|
async def analyze_companies_batch(
|
|
request: BatchAnalysisRequest,
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Analyze multiple companies' patent portfolios.
|
|
|
|
Processes companies concurrently for improved performance.
|
|
Limited to 20 companies per request.
|
|
|
|
Args:
|
|
request: List of company names and optional worker count
|
|
|
|
Returns:
|
|
Batch results with individual company analyses and summary statistics
|
|
"""
|
|
_validate_model(request.model)
|
|
if not _analyzer:
|
|
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
|
|
|
result = _analyzer.analyze_companies(
|
|
companies=request.companies,
|
|
max_workers=request.max_workers,
|
|
model=request.model,
|
|
)
|
|
return _convert_batch_result(result)
|
|
|
|
|
|
def _get_job_db() -> "DatabaseClient":
|
|
"""Get a DatabaseClient for job persistence."""
|
|
from SPARC.database import DatabaseClient
|
|
db = DatabaseClient(config.database_url)
|
|
return db
|
|
|
|
|
|
def _job_row_to_status(row: dict) -> JobStatus:
|
|
"""Convert a database job row to a JobStatus model."""
|
|
import json as _json
|
|
result = None
|
|
if row.get("result_json"):
|
|
result_data = row["result_json"]
|
|
if isinstance(result_data, str):
|
|
result_data = _json.loads(result_data)
|
|
result = BatchAnalysisResponse(**result_data)
|
|
return JobStatus(
|
|
job_id=row["job_id"],
|
|
status=row["status"],
|
|
progress=row["progress"],
|
|
total_companies=row["total_companies"],
|
|
completed_companies=row["completed_companies"],
|
|
result=result,
|
|
error=row.get("error"),
|
|
)
|
|
|
|
|
|
def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: str | None = None):
|
|
"""Background task for batch analysis."""
|
|
import json as _json
|
|
global _analyzer
|
|
|
|
db = _get_job_db()
|
|
|
|
if not _analyzer:
|
|
db.update_job(job_id, status="failed", error="Analyzer not initialized")
|
|
return
|
|
|
|
db.update_job(job_id, status="running")
|
|
|
|
def progress_callback(company: str, completed: int, total: int):
|
|
db.update_job(
|
|
job_id,
|
|
completed_companies=completed,
|
|
progress=int((completed / total) * 100),
|
|
)
|
|
|
|
try:
|
|
result = _analyzer.analyze_companies(
|
|
companies=companies,
|
|
max_workers=max_workers,
|
|
progress_callback=progress_callback,
|
|
model=model,
|
|
)
|
|
batch_response = _convert_batch_result(result)
|
|
db.update_job(
|
|
job_id,
|
|
status="completed",
|
|
progress=100,
|
|
result_json=_json.dumps(batch_response.model_dump(), default=str),
|
|
)
|
|
# Fire webhook notification
|
|
from SPARC.webhooks import notify_job_completed
|
|
notify_job_completed(
|
|
job_id=job_id,
|
|
status="completed",
|
|
total_companies=result.total_companies,
|
|
successful=result.successful,
|
|
failed=result.failed,
|
|
)
|
|
except Exception as e:
|
|
db.update_job(job_id, status="failed", error=str(e))
|
|
from SPARC.webhooks import notify_job_completed
|
|
notify_job_completed(
|
|
job_id=job_id,
|
|
status="failed",
|
|
total_companies=len(companies),
|
|
successful=0,
|
|
failed=len(companies),
|
|
)
|
|
|
|
|
|
@app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"])
|
|
async def analyze_companies_async(
|
|
request: BatchAnalysisRequest,
|
|
background_tasks: BackgroundTasks,
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Start an asynchronous batch analysis job.
|
|
|
|
Returns immediately with a job ID that can be used to poll for status.
|
|
Useful for large batch analyses that may take a long time.
|
|
|
|
Args:
|
|
request: List of company names and optional worker count
|
|
|
|
Returns:
|
|
Job status with job_id for polling
|
|
"""
|
|
_validate_model(request.model)
|
|
global _job_counter
|
|
|
|
_job_counter += 1
|
|
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
|
|
|
db = _get_job_db()
|
|
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies))
|
|
|
|
background_tasks.add_task(
|
|
_run_batch_job, job_id, request.companies, request.max_workers, request.model
|
|
)
|
|
|
|
return _job_row_to_status(job_row)
|
|
|
|
|
|
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
|
|
async def get_job_status(
|
|
job_id: str,
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""Get the status of a background analysis job.
|
|
|
|
Args:
|
|
job_id: The job ID returned from the async batch endpoint
|
|
|
|
Returns:
|
|
Current job status including progress and results when complete
|
|
"""
|
|
db = _get_job_db()
|
|
job_row = db.get_job(job_id)
|
|
|
|
if not job_row:
|
|
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
|
|
|
return _job_row_to_status(job_row)
|
|
|
|
|
|
@app.get("/jobs", response_model=PaginatedJobsResponse, tags=["Jobs"])
|
|
async def list_jobs(
|
|
status: Annotated[
|
|
str | None,
|
|
Query(description="Filter by status: pending, running, completed, failed"),
|
|
] = None,
|
|
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
|
cursor: Annotated[
|
|
str | None,
|
|
Query(description="Opaque cursor from a previous response's next_cursor field"),
|
|
] = None,
|
|
_: UserResponse = Depends(get_current_user),
|
|
):
|
|
"""List analysis jobs with cursor-based pagination.
|
|
|
|
Pass ``limit`` to control page size. The response includes a ``next_cursor``
|
|
field; pass it back as the ``cursor`` query parameter to fetch the next page.
|
|
When ``next_cursor`` is ``null``, there are no more results.
|
|
|
|
Existing clients that use only ``limit`` (without ``cursor``) continue to
|
|
work without modification.
|
|
|
|
Args:
|
|
status: Optional filter by job status
|
|
limit: Maximum number of jobs to return (default 10, max 100)
|
|
cursor: Opaque pagination cursor from a previous response
|
|
|
|
Returns:
|
|
Paginated list of job statuses
|
|
"""
|
|
db = _get_job_db()
|
|
# Fetch one extra to determine if there is a next page
|
|
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor)
|
|
|
|
has_next = len(job_rows) > limit
|
|
if has_next:
|
|
job_rows = job_rows[:limit]
|
|
|
|
items = [_job_row_to_status(row) for row in job_rows]
|
|
|
|
next_cursor = None
|
|
if has_next and job_rows:
|
|
last = job_rows[-1]
|
|
created = last["created_at"]
|
|
ts = created.isoformat() if hasattr(created, "isoformat") else str(created)
|
|
next_cursor = f"{ts}|{last['job_id']}"
|
|
|
|
return PaginatedJobsResponse(items=items, next_cursor=next_cursor)
|