forked from 0xWheatyz/SPARC
feat: add FastAPI web service wrapper
- Create REST API with endpoints for single and batch analysis - Add async job support for long-running batch operations - Implement job status tracking and listing endpoints - Add 9 tests for API endpoints - Update requirements.txt with fastapi, uvicorn, httpx - Document API usage in README 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,183 @@
|
||||
"""Tests for FastAPI web service endpoints."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from SPARC.api import app, _analyzer, _jobs
|
||||
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_analyzer(mocker):
|
||||
"""Mock the global analyzer."""
|
||||
mock = Mock()
|
||||
mocker.patch("SPARC.api._analyzer", mock)
|
||||
return mock
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Test health check endpoint."""
|
||||
|
||||
def test_health_returns_ok(self, client):
|
||||
"""Test health endpoint returns healthy status."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert data["version"] == "1.0.0"
|
||||
assert "timestamp" in data
|
||||
|
||||
|
||||
class TestAnalyzeCompanyEndpoint:
|
||||
"""Test single company analysis endpoint."""
|
||||
|
||||
def test_analyze_company_success(self, client, mock_analyzer):
|
||||
"""Test successful company analysis."""
|
||||
mock_result = CompanyAnalysisResult(
|
||||
company_name="nvidia",
|
||||
analysis="Strong AI patent portfolio",
|
||||
patent_count=5,
|
||||
success=True,
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
mock_analyzer._analyze_company_safe.return_value = mock_result
|
||||
|
||||
response = client.get("/analyze/nvidia")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["company_name"] == "nvidia"
|
||||
assert data["analysis"] == "Strong AI patent portfolio"
|
||||
assert data["patent_count"] == 5
|
||||
assert data["success"] is True
|
||||
|
||||
def test_analyze_company_failure(self, client, mock_analyzer):
|
||||
"""Test company analysis with error."""
|
||||
mock_result = CompanyAnalysisResult(
|
||||
company_name="unknown",
|
||||
analysis="",
|
||||
patent_count=0,
|
||||
success=False,
|
||||
error="No patents found",
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
mock_analyzer._analyze_company_safe.return_value = mock_result
|
||||
|
||||
response = client.get("/analyze/unknown")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["error"] == "No patents found"
|
||||
|
||||
|
||||
class TestBatchAnalysisEndpoint:
|
||||
"""Test batch analysis endpoint."""
|
||||
|
||||
def test_batch_analysis_success(self, client, mock_analyzer):
|
||||
"""Test successful batch analysis."""
|
||||
results = [
|
||||
CompanyAnalysisResult(
|
||||
company_name="nvidia",
|
||||
analysis="Strong portfolio",
|
||||
patent_count=5,
|
||||
success=True,
|
||||
timestamp=datetime.now(),
|
||||
),
|
||||
CompanyAnalysisResult(
|
||||
company_name="amd",
|
||||
analysis="Growing portfolio",
|
||||
patent_count=3,
|
||||
success=True,
|
||||
timestamp=datetime.now(),
|
||||
),
|
||||
]
|
||||
mock_batch = BatchAnalysisResult(
|
||||
results=results,
|
||||
total_companies=2,
|
||||
successful=2,
|
||||
failed=0,
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
mock_analyzer.analyze_companies.return_value = mock_batch
|
||||
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["nvidia", "amd"], "max_workers": 2},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_companies"] == 2
|
||||
assert data["successful"] == 2
|
||||
assert data["failed"] == 0
|
||||
assert len(data["results"]) == 2
|
||||
|
||||
def test_batch_analysis_validation(self, client):
|
||||
"""Test batch analysis request validation."""
|
||||
# Empty companies list
|
||||
response = client.post("/analyze/batch", json={"companies": []})
|
||||
assert response.status_code == 422
|
||||
|
||||
# Too many companies
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": [f"company{i}" for i in range(25)]},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
# Invalid max_workers
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["nvidia"], "max_workers": 10},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestAsyncBatchEndpoint:
|
||||
"""Test async batch analysis endpoint."""
|
||||
|
||||
def test_async_batch_creates_job(self, client, mock_analyzer):
|
||||
"""Test async endpoint creates a job."""
|
||||
response = client.post(
|
||||
"/analyze/batch/async",
|
||||
json={"companies": ["nvidia", "amd"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
assert data["status"] == "pending"
|
||||
assert data["total_companies"] == 2
|
||||
assert data["progress"] == 0
|
||||
|
||||
|
||||
class TestJobEndpoints:
|
||||
"""Test job management endpoints."""
|
||||
|
||||
def test_get_job_not_found(self, client):
|
||||
"""Test getting nonexistent job."""
|
||||
response = client.get("/jobs/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_list_jobs(self, client, mocker):
|
||||
"""Test listing jobs."""
|
||||
# Clear existing jobs
|
||||
mocker.patch.dict("SPARC.api._jobs", {}, clear=True)
|
||||
|
||||
response = client.get("/jobs")
|
||||
assert response.status_code == 200
|
||||
assert isinstance(response.json(), list)
|
||||
|
||||
def test_list_jobs_with_filter(self, client, mocker):
|
||||
"""Test listing jobs with status filter."""
|
||||
response = client.get("/jobs?status=completed")
|
||||
assert response.status_code == 200
|
||||
Reference in New Issue
Block a user