Compare commits

..

1 Commits

Author SHA1 Message Date
agent-company 96d5d27b17 feat(jobs): persist async batch job state in PostgreSQL
- Add jobs table to database schema (job_id, status, progress, result_json, etc.)
- Add DatabaseClient methods: create_job, update_job, get_job, list_jobs
- Add mark_stale_jobs_failed() called at startup to handle interrupted jobs
- Refactor _run_batch_job and job endpoints to read/write from PostgreSQL
- Remove in-memory _jobs dict; job state now survives API restarts
- Update init_database.py to list all tables in output

Closes leeworks-agents/SPARC#8

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:22:57 +00:00
6 changed files with 224 additions and 160 deletions
+75 -61
View File
@@ -7,13 +7,9 @@ from contextlib import asynccontextmanager
from datetime import datetime
from typing import Annotated, List
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, EmailStr, Field
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from SPARC import config
from SPARC.analyzer import CompanyAnalyzer
@@ -118,8 +114,7 @@ class AnalyticsResponse(BaseModel):
period_days: int
# In-memory job storage (for demo; production would use Redis/DB)
_jobs: dict[str, JobStatus] = {}
# Job counter for generating unique IDs (the actual state is in PostgreSQL)
_job_counter = 0
@@ -152,9 +147,19 @@ _analyzer: CompanyAnalyzer | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize resources on startup."""
"""Initialize resources on startup, clean up on shutdown."""
global _analyzer
_analyzer = CompanyAnalyzer()
# Mark any jobs that were running/pending before the restart as failed
from SPARC.database import DatabaseClient
_db = DatabaseClient(config.database_url)
_db.connect()
_db.initialize_schema()
stale = _db.mark_stale_jobs_failed()
if stale:
import logging
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
_db.close()
yield
# Cleanup if needed
_analyzer = None
@@ -168,22 +173,6 @@ app = FastAPI(
root_path=config.root_path,
)
# Rate limiter (in-memory storage, suitable for single-instance deployments)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return 429 with Retry-After header when rate limit is exceeded."""
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded. Please try again later."},
headers={"Retry-After": str(retry_after)},
)
# Add CORS middleware for React frontend
app.add_middleware(
CORSMiddleware,
@@ -198,8 +187,7 @@ app.add_middleware(
@app.post("/auth/register", response_model=UserResponse, tags=["Auth"])
@limiter.limit("5/minute")
async def register(request: Request, body: RegisterRequest):
async def register(request: RegisterRequest):
"""Register a new user.
The first registered user automatically becomes an admin.
@@ -211,8 +199,8 @@ async def register(request: Request, body: RegisterRequest):
role = "admin" if user_count == 0 else "user"
user = db.create_user(
email=body.email,
password=body.password,
email=request.email,
password=request.password,
role=role,
)
@@ -231,12 +219,11 @@ async def register(request: Request, body: RegisterRequest):
@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"])
@limiter.limit("10/minute")
async def login(request: Request, body: LoginRequest):
async def login(request: LoginRequest):
"""Authenticate user and return JWT tokens."""
db = get_db_client()
user = db.authenticate_user(body.email, body.password)
user = db.authenticate_user(request.email, request.password)
if not user:
raise HTTPException(
@@ -444,20 +431,52 @@ async def analyze_companies_batch(
return _convert_batch_result(result)
def _get_job_db() -> "DatabaseClient":
"""Get a DatabaseClient for job persistence."""
from SPARC.database import DatabaseClient
db = DatabaseClient(config.database_url)
return db
def _job_row_to_status(row: dict) -> JobStatus:
"""Convert a database job row to a JobStatus model."""
import json as _json
result = None
if row.get("result_json"):
result_data = row["result_json"]
if isinstance(result_data, str):
result_data = _json.loads(result_data)
result = BatchAnalysisResponse(**result_data)
return JobStatus(
job_id=row["job_id"],
status=row["status"],
progress=row["progress"],
total_companies=row["total_companies"],
completed_companies=row["completed_companies"],
result=result,
error=row.get("error"),
)
def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
"""Background task for batch analysis."""
global _jobs, _analyzer
import json as _json
global _analyzer
db = _get_job_db()
if not _analyzer:
_jobs[job_id].status = "failed"
_jobs[job_id].error = "Analyzer not initialized"
db.update_job(job_id, status="failed", error="Analyzer not initialized")
return
_jobs[job_id].status = "running"
db.update_job(job_id, status="running")
def progress_callback(company: str, completed: int, total: int):
_jobs[job_id].completed_companies = completed
_jobs[job_id].progress = int((completed / total) * 100)
db.update_job(
job_id,
completed_companies=completed,
progress=int((completed / total) * 100),
)
try:
result = _analyzer.analyze_companies(
@@ -465,12 +484,15 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
max_workers=max_workers,
progress_callback=progress_callback,
)
_jobs[job_id].status = "completed"
_jobs[job_id].progress = 100
_jobs[job_id].result = _convert_batch_result(result)
batch_response = _convert_batch_result(result)
db.update_job(
job_id,
status="completed",
progress=100,
result_json=_json.dumps(batch_response.model_dump(), default=str),
)
except Exception as e:
_jobs[job_id].status = "failed"
_jobs[job_id].error = str(e)
db.update_job(job_id, status="failed", error=str(e))
@app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"])
@@ -495,19 +517,14 @@ async def analyze_companies_async(
_job_counter += 1
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
_jobs[job_id] = JobStatus(
job_id=job_id,
status="pending",
progress=0,
total_companies=len(request.companies),
completed_companies=0,
)
db = _get_job_db()
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies))
background_tasks.add_task(
_run_batch_job, job_id, request.companies, request.max_workers
)
return _jobs[job_id]
return _job_row_to_status(job_row)
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
@@ -523,10 +540,13 @@ async def get_job_status(
Returns:
Current job status including progress and results when complete
"""
if job_id not in _jobs:
db = _get_job_db()
job_row = db.get_job(job_id)
if not job_row:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
return _jobs[job_id]
return _job_row_to_status(job_row)
@app.get("/jobs", response_model=list[JobStatus], tags=["Jobs"])
@@ -547,12 +567,6 @@ async def list_jobs(
Returns:
List of job statuses
"""
jobs = list(_jobs.values())
if status:
jobs = [j for j in jobs if j.status == status]
# Return most recent first
jobs.sort(key=lambda j: j.job_id, reverse=True)
return jobs[:limit]
db = _get_job_db()
job_rows = db.list_jobs(status=status, limit=limit)
return [_job_row_to_status(row) for row in job_rows]
+145
View File
@@ -171,6 +171,26 @@ class DatabaseClient:
ON serp_queries(query_hash)
""")
# Create jobs table for persisting async batch job state
cursor.execute("""
CREATE TABLE IF NOT EXISTS jobs (
job_id VARCHAR(128) PRIMARY KEY,
status VARCHAR(20) NOT NULL DEFAULT 'pending',
progress INTEGER NOT NULL DEFAULT 0,
total_companies INTEGER NOT NULL DEFAULT 0,
completed_companies INTEGER NOT NULL DEFAULT 0,
result_json JSONB,
error TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_jobs_status
ON jobs(status)
""")
self.conn.commit()
@staticmethod
@@ -462,6 +482,131 @@ class DatabaseClient:
)
conn.commit()
# Job Persistence Methods
def create_job(
self,
job_id: str,
total_companies: int,
) -> Dict:
"""Create a new job record.
Args:
job_id: Unique job identifier
total_companies: Number of companies in the batch
Returns:
Job dict
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies)
VALUES (%s, 'pending', 0, %s, 0)
RETURNING *
""",
(job_id, total_companies),
)
job = cursor.fetchone()
conn.commit()
return dict(job)
def update_job(
self,
job_id: str,
status: Optional[str] = None,
progress: Optional[int] = None,
completed_companies: Optional[int] = None,
result_json: Optional[str] = None,
error: Optional[str] = None,
) -> Optional[Dict]:
"""Update a job's state.
Only non-None fields are updated.
"""
updates = []
params = []
if status is not None:
updates.append("status = %s")
params.append(status)
if progress is not None:
updates.append("progress = %s")
params.append(progress)
if completed_companies is not None:
updates.append("completed_companies = %s")
params.append(completed_companies)
if result_json is not None:
updates.append("result_json = %s")
params.append(result_json)
if error is not None:
updates.append("error = %s")
params.append(error)
if not updates:
return self.get_job(job_id)
updates.append("updated_at = CURRENT_TIMESTAMP")
params.append(job_id)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
f"UPDATE jobs SET {', '.join(updates)} WHERE job_id = %s RETURNING *",
params,
)
job = cursor.fetchone()
conn.commit()
return dict(job) if job else None
def get_job(self, job_id: str) -> Optional[Dict]:
"""Get a job by ID."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,))
job = cursor.fetchone()
return dict(job) if job else None
def list_jobs(
self,
status: Optional[str] = None,
limit: int = 10,
) -> List[Dict]:
"""List jobs, optionally filtered by status."""
query = "SELECT * FROM jobs"
params: list = []
if status:
query += " WHERE status = %s"
params.append(status)
query += " ORDER BY created_at DESC LIMIT %s"
params.append(limit)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def mark_stale_jobs_failed(self) -> int:
"""Mark any jobs in 'running' or 'pending' state as 'failed'.
Called at startup to clean up jobs that were interrupted by a restart.
Returns:
Number of jobs marked as failed.
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
UPDATE jobs SET status = 'failed', error = 'Interrupted by server restart',
updated_at = CURRENT_TIMESTAMP
WHERE status IN ('running', 'pending')
"""
)
count = cursor.rowcount
conn.commit()
return count
# User Authentication Methods
@staticmethod
-1
View File
@@ -14,4 +14,3 @@ numpy
pandas
bcrypt
PyJWT
slowapi
+3
View File
@@ -40,6 +40,9 @@ def main():
print("\nTables created:")
print(" - llm_messages: Stores all LLM prompts and responses")
print(" - users: Stores user accounts")
print(" - jobs: Stores async batch job state")
print(" - patents: Patent PDF cache")
print(" - serp_queries: SERP query result cache")
print("\nIndexes created:")
print(" - idx_messages_timestamp: For time-based queries")
print(" - idx_messages_company: For company-specific queries")
+1 -1
View File
@@ -5,7 +5,7 @@ from datetime import datetime
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
from SPARC.api import app, _analyzer, _jobs
from SPARC.api import app
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult
-97
View File
@@ -1,97 +0,0 @@
"""Tests for rate limiting on auth endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from fastapi.testclient import TestClient
from SPARC.api import app
@pytest.fixture
def client():
"""Create test client with rate limiter enabled."""
return TestClient(app)
@pytest.fixture(autouse=True)
def reset_limiter():
"""Reset rate limiter storage between tests."""
from SPARC.api import limiter
limiter.reset()
yield
class TestRateLimiting:
"""Test rate limiting on login and register endpoints."""
@patch("SPARC.api.get_db_client")
def test_login_allows_requests_under_limit(self, mock_db_client, client):
"""Login endpoint allows requests under the rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Should allow at least a few requests
for _ in range(5):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
# 401 is expected (invalid credentials), not 429
assert response.status_code == 401
@patch("SPARC.api.get_db_client")
def test_login_rate_limited_after_threshold(self, mock_db_client, client):
"""Login endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Send more than the limit (10/minute)
statuses = []
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_register_rate_limited_after_threshold(self, mock_db_client, client):
"""Register endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = None # triggers 400 (email exists)
mock_db_client.return_value = mock_db
# Send more than the limit (5/minute)
statuses = []
for _ in range(10):
response = client.post(
"/auth/register",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_rate_limit_returns_retry_after_header(self, mock_db_client, client):
"""Rate limited responses include a Retry-After header."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Exhaust the limit
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
if response.status_code == 429:
assert "Retry-After" in response.headers
break