diff --git a/SPARC/api.py b/SPARC/api.py index 23d2f2b..a929fff 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -115,8 +115,7 @@ class AnalyticsResponse(BaseModel): period_days: int -# In-memory job storage (for demo; production would use Redis/DB) -_jobs: dict[str, JobStatus] = {} +# Job counter for generating unique IDs (the actual state is in PostgreSQL) _job_counter = 0 @@ -149,10 +148,20 @@ _analyzer: CompanyAnalyzer | None = None @asynccontextmanager async def lifespan(app: FastAPI): - """Initialize resources on startup.""" + """Initialize resources on startup, clean up on shutdown.""" global _analyzer check_jwt_secret() _analyzer = CompanyAnalyzer() + # Mark any jobs that were running/pending before the restart as failed + from SPARC.database import DatabaseClient + _db = DatabaseClient(config.database_url) + _db.connect() + _db.initialize_schema() + stale = _db.mark_stale_jobs_failed() + if stale: + import logging + logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale) + _db.close() yield # Cleanup if needed _analyzer = None @@ -424,20 +433,52 @@ async def analyze_companies_batch( return _convert_batch_result(result) +def _get_job_db() -> "DatabaseClient": + """Get a DatabaseClient for job persistence.""" + from SPARC.database import DatabaseClient + db = DatabaseClient(config.database_url) + return db + + +def _job_row_to_status(row: dict) -> JobStatus: + """Convert a database job row to a JobStatus model.""" + import json as _json + result = None + if row.get("result_json"): + result_data = row["result_json"] + if isinstance(result_data, str): + result_data = _json.loads(result_data) + result = BatchAnalysisResponse(**result_data) + return JobStatus( + job_id=row["job_id"], + status=row["status"], + progress=row["progress"], + total_companies=row["total_companies"], + completed_companies=row["completed_companies"], + result=result, + error=row.get("error"), + ) + + def _run_batch_job(job_id: str, companies: list[str], max_workers: int): """Background task for batch analysis.""" - global _jobs, _analyzer + import json as _json + global _analyzer + + db = _get_job_db() if not _analyzer: - _jobs[job_id].status = "failed" - _jobs[job_id].error = "Analyzer not initialized" + db.update_job(job_id, status="failed", error="Analyzer not initialized") return - _jobs[job_id].status = "running" + db.update_job(job_id, status="running") def progress_callback(company: str, completed: int, total: int): - _jobs[job_id].completed_companies = completed - _jobs[job_id].progress = int((completed / total) * 100) + db.update_job( + job_id, + completed_companies=completed, + progress=int((completed / total) * 100), + ) try: result = _analyzer.analyze_companies( @@ -445,12 +486,15 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int): max_workers=max_workers, progress_callback=progress_callback, ) - _jobs[job_id].status = "completed" - _jobs[job_id].progress = 100 - _jobs[job_id].result = _convert_batch_result(result) + batch_response = _convert_batch_result(result) + db.update_job( + job_id, + status="completed", + progress=100, + result_json=_json.dumps(batch_response.model_dump(), default=str), + ) except Exception as e: - _jobs[job_id].status = "failed" - _jobs[job_id].error = str(e) + db.update_job(job_id, status="failed", error=str(e)) @app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"]) @@ -475,19 +519,14 @@ async def analyze_companies_async( _job_counter += 1 job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}" - _jobs[job_id] = JobStatus( - job_id=job_id, - status="pending", - progress=0, - total_companies=len(request.companies), - completed_companies=0, - ) + db = _get_job_db() + job_row = db.create_job(job_id=job_id, total_companies=len(request.companies)) background_tasks.add_task( _run_batch_job, job_id, request.companies, request.max_workers ) - return _jobs[job_id] + return _job_row_to_status(job_row) @app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"]) @@ -503,10 +542,13 @@ async def get_job_status( Returns: Current job status including progress and results when complete """ - if job_id not in _jobs: + db = _get_job_db() + job_row = db.get_job(job_id) + + if not job_row: raise HTTPException(status_code=404, detail=f"Job {job_id} not found") - return _jobs[job_id] + return _job_row_to_status(job_row) @app.get("/jobs", response_model=list[JobStatus], tags=["Jobs"]) @@ -527,12 +569,6 @@ async def list_jobs( Returns: List of job statuses """ - jobs = list(_jobs.values()) - - if status: - jobs = [j for j in jobs if j.status == status] - - # Return most recent first - jobs.sort(key=lambda j: j.job_id, reverse=True) - - return jobs[:limit] + db = _get_job_db() + job_rows = db.list_jobs(status=status, limit=limit) + return [_job_row_to_status(row) for row in job_rows] diff --git a/SPARC/database.py b/SPARC/database.py index 0468312..cc55304 100644 --- a/SPARC/database.py +++ b/SPARC/database.py @@ -171,6 +171,26 @@ class DatabaseClient: ON serp_queries(query_hash) """) + # Create jobs table for persisting async batch job state + cursor.execute(""" + CREATE TABLE IF NOT EXISTS jobs ( + job_id VARCHAR(128) PRIMARY KEY, + status VARCHAR(20) NOT NULL DEFAULT 'pending', + progress INTEGER NOT NULL DEFAULT 0, + total_companies INTEGER NOT NULL DEFAULT 0, + completed_companies INTEGER NOT NULL DEFAULT 0, + result_json JSONB, + error TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_jobs_status + ON jobs(status) + """) + self.conn.commit() @staticmethod @@ -462,6 +482,131 @@ class DatabaseClient: ) conn.commit() + # Job Persistence Methods + + def create_job( + self, + job_id: str, + total_companies: int, + ) -> Dict: + """Create a new job record. + + Args: + job_id: Unique job identifier + total_companies: Number of companies in the batch + + Returns: + Job dict + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + """ + INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies) + VALUES (%s, 'pending', 0, %s, 0) + RETURNING * + """, + (job_id, total_companies), + ) + job = cursor.fetchone() + conn.commit() + return dict(job) + + def update_job( + self, + job_id: str, + status: Optional[str] = None, + progress: Optional[int] = None, + completed_companies: Optional[int] = None, + result_json: Optional[str] = None, + error: Optional[str] = None, + ) -> Optional[Dict]: + """Update a job's state. + + Only non-None fields are updated. + """ + updates = [] + params = [] + if status is not None: + updates.append("status = %s") + params.append(status) + if progress is not None: + updates.append("progress = %s") + params.append(progress) + if completed_companies is not None: + updates.append("completed_companies = %s") + params.append(completed_companies) + if result_json is not None: + updates.append("result_json = %s") + params.append(result_json) + if error is not None: + updates.append("error = %s") + params.append(error) + + if not updates: + return self.get_job(job_id) + + updates.append("updated_at = CURRENT_TIMESTAMP") + params.append(job_id) + + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + f"UPDATE jobs SET {', '.join(updates)} WHERE job_id = %s RETURNING *", + params, + ) + job = cursor.fetchone() + conn.commit() + return dict(job) if job else None + + def get_job(self, job_id: str) -> Optional[Dict]: + """Get a job by ID.""" + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,)) + job = cursor.fetchone() + return dict(job) if job else None + + def list_jobs( + self, + status: Optional[str] = None, + limit: int = 10, + ) -> List[Dict]: + """List jobs, optionally filtered by status.""" + query = "SELECT * FROM jobs" + params: list = [] + if status: + query += " WHERE status = %s" + params.append(status) + query += " ORDER BY created_at DESC LIMIT %s" + params.append(limit) + + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute(query, params) + return [dict(row) for row in cursor.fetchall()] + + def mark_stale_jobs_failed(self) -> int: + """Mark any jobs in 'running' or 'pending' state as 'failed'. + + Called at startup to clean up jobs that were interrupted by a restart. + + Returns: + Number of jobs marked as failed. + """ + with self.get_conn() as conn: + with conn.cursor() as cursor: + cursor.execute( + """ + UPDATE jobs SET status = 'failed', error = 'Interrupted by server restart', + updated_at = CURRENT_TIMESTAMP + WHERE status IN ('running', 'pending') + """ + ) + count = cursor.rowcount + conn.commit() + return count + # User Authentication Methods @staticmethod diff --git a/scripts/init_database.py b/scripts/init_database.py index 607ca1f..a61d68f 100644 --- a/scripts/init_database.py +++ b/scripts/init_database.py @@ -40,6 +40,9 @@ def main(): print("\nTables created:") print(" - llm_messages: Stores all LLM prompts and responses") print(" - users: Stores user accounts") + print(" - jobs: Stores async batch job state") + print(" - patents: Patent PDF cache") + print(" - serp_queries: SERP query result cache") print("\nIndexes created:") print(" - idx_messages_timestamp: For time-based queries") print(" - idx_messages_company: For company-specific queries") diff --git a/tests/test_api.py b/tests/test_api.py index 4852f2e..a5923c6 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -5,7 +5,7 @@ from datetime import datetime from unittest.mock import Mock, patch from fastapi.testclient import TestClient -from SPARC.api import app, _analyzer, _jobs +from SPARC.api import app from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult