Files
SPARC/tests/test_api.py
T
agent-company 96d5d27b17 feat(jobs): persist async batch job state in PostgreSQL
- Add jobs table to database schema (job_id, status, progress, result_json, etc.)
- Add DatabaseClient methods: create_job, update_job, get_job, list_jobs
- Add mark_stale_jobs_failed() called at startup to handle interrupted jobs
- Refactor _run_batch_job and job endpoints to read/write from PostgreSQL
- Remove in-memory _jobs dict; job state now survives API restarts
- Update init_database.py to list all tables in output

Closes leeworks-agents/SPARC#8

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:22:57 +00:00

184 lines
5.6 KiB
Python

"""Tests for FastAPI web service endpoints."""
import pytest
from datetime import datetime
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture
def mock_analyzer(mocker):
"""Mock the global analyzer."""
mock = Mock()
mocker.patch("SPARC.api._analyzer", mock)
return mock
class TestHealthEndpoint:
"""Test health check endpoint."""
def test_health_returns_ok(self, client):
"""Test health endpoint returns healthy status."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["version"] == "1.0.0"
assert "timestamp" in data
class TestAnalyzeCompanyEndpoint:
"""Test single company analysis endpoint."""
def test_analyze_company_success(self, client, mock_analyzer):
"""Test successful company analysis."""
mock_result = CompanyAnalysisResult(
company_name="nvidia",
analysis="Strong AI patent portfolio",
patent_count=5,
success=True,
timestamp=datetime.now(),
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/nvidia")
assert response.status_code == 200
data = response.json()
assert data["company_name"] == "nvidia"
assert data["analysis"] == "Strong AI patent portfolio"
assert data["patent_count"] == 5
assert data["success"] is True
def test_analyze_company_failure(self, client, mock_analyzer):
"""Test company analysis with error."""
mock_result = CompanyAnalysisResult(
company_name="unknown",
analysis="",
patent_count=0,
success=False,
error="No patents found",
timestamp=datetime.now(),
)
mock_analyzer._analyze_company_safe.return_value = mock_result
response = client.get("/analyze/unknown")
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["error"] == "No patents found"
class TestBatchAnalysisEndpoint:
"""Test batch analysis endpoint."""
def test_batch_analysis_success(self, client, mock_analyzer):
"""Test successful batch analysis."""
results = [
CompanyAnalysisResult(
company_name="nvidia",
analysis="Strong portfolio",
patent_count=5,
success=True,
timestamp=datetime.now(),
),
CompanyAnalysisResult(
company_name="amd",
analysis="Growing portfolio",
patent_count=3,
success=True,
timestamp=datetime.now(),
),
]
mock_batch = BatchAnalysisResult(
results=results,
total_companies=2,
successful=2,
failed=0,
timestamp=datetime.now(),
)
mock_analyzer.analyze_companies.return_value = mock_batch
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia", "amd"], "max_workers": 2},
)
assert response.status_code == 200
data = response.json()
assert data["total_companies"] == 2
assert data["successful"] == 2
assert data["failed"] == 0
assert len(data["results"]) == 2
def test_batch_analysis_validation(self, client):
"""Test batch analysis request validation."""
# Empty companies list
response = client.post("/analyze/batch", json={"companies": []})
assert response.status_code == 422
# Too many companies
response = client.post(
"/analyze/batch",
json={"companies": [f"company{i}" for i in range(25)]},
)
assert response.status_code == 422
# Invalid max_workers
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia"], "max_workers": 10},
)
assert response.status_code == 422
class TestAsyncBatchEndpoint:
"""Test async batch analysis endpoint."""
def test_async_batch_creates_job(self, client, mock_analyzer):
"""Test async endpoint creates a job."""
response = client.post(
"/analyze/batch/async",
json={"companies": ["nvidia", "amd"]},
)
assert response.status_code == 200
data = response.json()
assert "job_id" in data
assert data["status"] == "pending"
assert data["total_companies"] == 2
assert data["progress"] == 0
class TestJobEndpoints:
"""Test job management endpoints."""
def test_get_job_not_found(self, client):
"""Test getting nonexistent job."""
response = client.get("/jobs/nonexistent")
assert response.status_code == 404
def test_list_jobs(self, client, mocker):
"""Test listing jobs."""
# Clear existing jobs
mocker.patch.dict("SPARC.api._jobs", {}, clear=True)
response = client.get("/jobs")
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_list_jobs_with_filter(self, client, mocker):
"""Test listing jobs with status filter."""
response = client.get("/jobs?status=completed")
assert response.status_code == 200