forked from 0xWheatyz/SPARC
96d5d27b17
- Add jobs table to database schema (job_id, status, progress, result_json, etc.) - Add DatabaseClient methods: create_job, update_job, get_job, list_jobs - Add mark_stale_jobs_failed() called at startup to handle interrupted jobs - Refactor _run_batch_job and job endpoints to read/write from PostgreSQL - Remove in-memory _jobs dict; job state now survives API restarts - Update init_database.py to list all tables in output Closes leeworks-agents/SPARC#8 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
184 lines
5.6 KiB
Python
184 lines
5.6 KiB
Python
"""Tests for FastAPI web service endpoints."""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from unittest.mock import Mock, patch
|
|
from fastapi.testclient import TestClient
|
|
|
|
from SPARC.api import app
|
|
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Create test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_analyzer(mocker):
|
|
"""Mock the global analyzer."""
|
|
mock = Mock()
|
|
mocker.patch("SPARC.api._analyzer", mock)
|
|
return mock
|
|
|
|
|
|
class TestHealthEndpoint:
|
|
"""Test health check endpoint."""
|
|
|
|
def test_health_returns_ok(self, client):
|
|
"""Test health endpoint returns healthy status."""
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "healthy"
|
|
assert data["version"] == "1.0.0"
|
|
assert "timestamp" in data
|
|
|
|
|
|
class TestAnalyzeCompanyEndpoint:
|
|
"""Test single company analysis endpoint."""
|
|
|
|
def test_analyze_company_success(self, client, mock_analyzer):
|
|
"""Test successful company analysis."""
|
|
mock_result = CompanyAnalysisResult(
|
|
company_name="nvidia",
|
|
analysis="Strong AI patent portfolio",
|
|
patent_count=5,
|
|
success=True,
|
|
timestamp=datetime.now(),
|
|
)
|
|
mock_analyzer._analyze_company_safe.return_value = mock_result
|
|
|
|
response = client.get("/analyze/nvidia")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["company_name"] == "nvidia"
|
|
assert data["analysis"] == "Strong AI patent portfolio"
|
|
assert data["patent_count"] == 5
|
|
assert data["success"] is True
|
|
|
|
def test_analyze_company_failure(self, client, mock_analyzer):
|
|
"""Test company analysis with error."""
|
|
mock_result = CompanyAnalysisResult(
|
|
company_name="unknown",
|
|
analysis="",
|
|
patent_count=0,
|
|
success=False,
|
|
error="No patents found",
|
|
timestamp=datetime.now(),
|
|
)
|
|
mock_analyzer._analyze_company_safe.return_value = mock_result
|
|
|
|
response = client.get("/analyze/unknown")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is False
|
|
assert data["error"] == "No patents found"
|
|
|
|
|
|
class TestBatchAnalysisEndpoint:
|
|
"""Test batch analysis endpoint."""
|
|
|
|
def test_batch_analysis_success(self, client, mock_analyzer):
|
|
"""Test successful batch analysis."""
|
|
results = [
|
|
CompanyAnalysisResult(
|
|
company_name="nvidia",
|
|
analysis="Strong portfolio",
|
|
patent_count=5,
|
|
success=True,
|
|
timestamp=datetime.now(),
|
|
),
|
|
CompanyAnalysisResult(
|
|
company_name="amd",
|
|
analysis="Growing portfolio",
|
|
patent_count=3,
|
|
success=True,
|
|
timestamp=datetime.now(),
|
|
),
|
|
]
|
|
mock_batch = BatchAnalysisResult(
|
|
results=results,
|
|
total_companies=2,
|
|
successful=2,
|
|
failed=0,
|
|
timestamp=datetime.now(),
|
|
)
|
|
mock_analyzer.analyze_companies.return_value = mock_batch
|
|
|
|
response = client.post(
|
|
"/analyze/batch",
|
|
json={"companies": ["nvidia", "amd"], "max_workers": 2},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total_companies"] == 2
|
|
assert data["successful"] == 2
|
|
assert data["failed"] == 0
|
|
assert len(data["results"]) == 2
|
|
|
|
def test_batch_analysis_validation(self, client):
|
|
"""Test batch analysis request validation."""
|
|
# Empty companies list
|
|
response = client.post("/analyze/batch", json={"companies": []})
|
|
assert response.status_code == 422
|
|
|
|
# Too many companies
|
|
response = client.post(
|
|
"/analyze/batch",
|
|
json={"companies": [f"company{i}" for i in range(25)]},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
# Invalid max_workers
|
|
response = client.post(
|
|
"/analyze/batch",
|
|
json={"companies": ["nvidia"], "max_workers": 10},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
|
|
class TestAsyncBatchEndpoint:
|
|
"""Test async batch analysis endpoint."""
|
|
|
|
def test_async_batch_creates_job(self, client, mock_analyzer):
|
|
"""Test async endpoint creates a job."""
|
|
response = client.post(
|
|
"/analyze/batch/async",
|
|
json={"companies": ["nvidia", "amd"]},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "job_id" in data
|
|
assert data["status"] == "pending"
|
|
assert data["total_companies"] == 2
|
|
assert data["progress"] == 0
|
|
|
|
|
|
class TestJobEndpoints:
|
|
"""Test job management endpoints."""
|
|
|
|
def test_get_job_not_found(self, client):
|
|
"""Test getting nonexistent job."""
|
|
response = client.get("/jobs/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
def test_list_jobs(self, client, mocker):
|
|
"""Test listing jobs."""
|
|
# Clear existing jobs
|
|
mocker.patch.dict("SPARC.api._jobs", {}, clear=True)
|
|
|
|
response = client.get("/jobs")
|
|
assert response.status_code == 200
|
|
assert isinstance(response.json(), list)
|
|
|
|
def test_list_jobs_with_filter(self, client, mocker):
|
|
"""Test listing jobs with status filter."""
|
|
response = client.get("/jobs?status=completed")
|
|
assert response.status_code == 200
|