diff --git a/README.md b/README.md index 2a06340..33321ba 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,8 @@ SPARC automatically collects, parses, and analyzes patents from companies to pro - **AI Analysis**: Uses Claude 3.5 Sonnet via OpenRouter to analyze innovation quality and market potential - **Portfolio Analysis**: Evaluates multiple patents holistically for comprehensive insights - **Batch Processing**: Analyze multiple companies concurrently with progress tracking -- **Robust Testing**: 31 tests covering all major functionality +- **REST API**: FastAPI web service with async job support +- **Robust Testing**: 40 tests covering all major functionality ## Architecture @@ -25,6 +26,7 @@ SPARC/ ├── serp_api.py # Patent retrieval and PDF parsing ├── llm.py # Claude AI integration via OpenRouter ├── analyzer.py # High-level orchestration +├── api.py # FastAPI web service ├── types.py # Data models └── config.py # Environment configuration ``` @@ -127,6 +129,44 @@ for result in batch_result.results: batch_result = analyzer.analyze_companies_sequential(["nvidia", "amd"]) ``` +### REST API + +Start the FastAPI server: + +```bash +uvicorn SPARC.api:app --reload +``` + +API endpoints: + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/health` | GET | Health check | +| `/analyze/{company}` | GET | Analyze single company | +| `/analyze/batch` | POST | Analyze multiple companies | +| `/analyze/batch/async` | POST | Start async batch job | +| `/jobs/{job_id}` | GET | Get job status | +| `/jobs` | GET | List all jobs | + +Interactive docs available at `http://localhost:8000/docs` + +Example API usage: + +```bash +# Single company +curl http://localhost:8000/analyze/nvidia + +# Batch analysis +curl -X POST http://localhost:8000/analyze/batch \ + -H "Content-Type: application/json" \ + -d '{"companies": ["nvidia", "amd", "intel"]}' + +# Async batch (for long-running jobs) +curl -X POST http://localhost:8000/analyze/batch/async \ + -H "Content-Type: application/json" \ + -d '{"companies": ["nvidia", "amd", "intel", "qualcomm"]}' +``` + ## Running Tests ```bash @@ -159,7 +199,7 @@ pytest tests/ --cov=SPARC --cov-report=term-missing - [X] LLM integration for analysis - [X] Company performance estimation - [X] Multi-company batch processing -- [ ] FastAPI web service wrapper +- [X] FastAPI web service wrapper - [X] Docker containerization - [X] Results persistence (database) - [ ] Visualization dashboard diff --git a/SPARC/api.py b/SPARC/api.py new file mode 100644 index 0000000..2a75fee --- /dev/null +++ b/SPARC/api.py @@ -0,0 +1,286 @@ +"""FastAPI web service wrapper for SPARC patent analysis. + +Provides REST API endpoints for analyzing company patent portfolios. +""" + +from contextlib import asynccontextmanager +from datetime import datetime +from typing import Annotated + +from fastapi import BackgroundTasks, FastAPI, HTTPException, Query +from pydantic import BaseModel, Field + +from SPARC.analyzer import CompanyAnalyzer +from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult + + +# Pydantic models for API +class CompanyAnalysisResponse(BaseModel): + """Response model for single company analysis.""" + + company_name: str + analysis: str + patent_count: int + success: bool + error: str | None = None + timestamp: datetime + + +class BatchAnalysisResponse(BaseModel): + """Response model for batch company analysis.""" + + results: list[CompanyAnalysisResponse] + total_companies: int + successful: int + failed: int + timestamp: datetime + + +class BatchAnalysisRequest(BaseModel): + """Request model for batch company analysis.""" + + companies: list[str] = Field( + ..., min_length=1, max_length=20, description="List of company names to analyze" + ) + max_workers: int = Field( + default=3, ge=1, le=5, description="Max concurrent analyses" + ) + + +class JobStatus(BaseModel): + """Status of a background analysis job.""" + + job_id: str + status: str # "pending", "running", "completed", "failed" + progress: int # 0-100 + total_companies: int + completed_companies: int + result: BatchAnalysisResponse | None = None + error: str | None = None + + +class HealthResponse(BaseModel): + """Health check response.""" + + status: str + version: str + timestamp: datetime + + +# In-memory job storage (for demo; production would use Redis/DB) +_jobs: dict[str, JobStatus] = {} +_job_counter = 0 + + +def _convert_result(result: CompanyAnalysisResult) -> CompanyAnalysisResponse: + """Convert internal result to API response model.""" + return CompanyAnalysisResponse( + company_name=result.company_name, + analysis=result.analysis, + patent_count=result.patent_count, + success=result.success, + error=result.error, + timestamp=result.timestamp, + ) + + +def _convert_batch_result(result: BatchAnalysisResult) -> BatchAnalysisResponse: + """Convert internal batch result to API response model.""" + return BatchAnalysisResponse( + results=[_convert_result(r) for r in result.results], + total_companies=result.total_companies, + successful=result.successful, + failed=result.failed, + timestamp=result.timestamp, + ) + + +# Global analyzer instance +_analyzer: CompanyAnalyzer | None = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Initialize resources on startup.""" + global _analyzer + _analyzer = CompanyAnalyzer() + yield + # Cleanup if needed + _analyzer = None + + +app = FastAPI( + title="SPARC API", + description="Semiconductor Patent & Analytics Report Core - Patent portfolio analysis using AI", + version="1.0.0", + lifespan=lifespan, +) + + +@app.get("/health", response_model=HealthResponse, tags=["System"]) +async def health_check(): + """Check API health status.""" + return HealthResponse( + status="healthy", + version="1.0.0", + timestamp=datetime.now(), + ) + + +@app.get( + "/analyze/{company_name}", + response_model=CompanyAnalysisResponse, + tags=["Analysis"], +) +async def analyze_company(company_name: str): + """Analyze a single company's patent portfolio. + + This endpoint retrieves recent patents for the specified company, + parses them, and uses AI to generate a comprehensive analysis. + + Args: + company_name: Name of the company to analyze (e.g., "nvidia", "intel") + + Returns: + Analysis results including patent count, AI insights, and success status + """ + if not _analyzer: + raise HTTPException(status_code=503, detail="Analyzer not initialized") + + result = _analyzer._analyze_company_safe(company_name) + return _convert_result(result) + + +@app.post( + "/analyze/batch", + response_model=BatchAnalysisResponse, + tags=["Analysis"], +) +async def analyze_companies_batch(request: BatchAnalysisRequest): + """Analyze multiple companies' patent portfolios. + + Processes companies concurrently for improved performance. + Limited to 20 companies per request. + + Args: + request: List of company names and optional worker count + + Returns: + Batch results with individual company analyses and summary statistics + """ + if not _analyzer: + raise HTTPException(status_code=503, detail="Analyzer not initialized") + + result = _analyzer.analyze_companies( + companies=request.companies, + max_workers=request.max_workers, + ) + return _convert_batch_result(result) + + +def _run_batch_job(job_id: str, companies: list[str], max_workers: int): + """Background task for batch analysis.""" + global _jobs, _analyzer + + if not _analyzer: + _jobs[job_id].status = "failed" + _jobs[job_id].error = "Analyzer not initialized" + return + + _jobs[job_id].status = "running" + + def progress_callback(company: str, completed: int, total: int): + _jobs[job_id].completed_companies = completed + _jobs[job_id].progress = int((completed / total) * 100) + + try: + result = _analyzer.analyze_companies( + companies=companies, + max_workers=max_workers, + progress_callback=progress_callback, + ) + _jobs[job_id].status = "completed" + _jobs[job_id].progress = 100 + _jobs[job_id].result = _convert_batch_result(result) + except Exception as e: + _jobs[job_id].status = "failed" + _jobs[job_id].error = str(e) + + +@app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"]) +async def analyze_companies_async( + request: BatchAnalysisRequest, background_tasks: BackgroundTasks +): + """Start an asynchronous batch analysis job. + + Returns immediately with a job ID that can be used to poll for status. + Useful for large batch analyses that may take a long time. + + Args: + request: List of company names and optional worker count + + Returns: + Job status with job_id for polling + """ + global _job_counter + + _job_counter += 1 + job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}" + + _jobs[job_id] = JobStatus( + job_id=job_id, + status="pending", + progress=0, + total_companies=len(request.companies), + completed_companies=0, + ) + + background_tasks.add_task( + _run_batch_job, job_id, request.companies, request.max_workers + ) + + return _jobs[job_id] + + +@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"]) +async def get_job_status(job_id: str): + """Get the status of a background analysis job. + + Args: + job_id: The job ID returned from the async batch endpoint + + Returns: + Current job status including progress and results when complete + """ + if job_id not in _jobs: + raise HTTPException(status_code=404, detail=f"Job {job_id} not found") + + return _jobs[job_id] + + +@app.get("/jobs", response_model=list[JobStatus], tags=["Jobs"]) +async def list_jobs( + status: Annotated[ + str | None, + Query(description="Filter by status: pending, running, completed, failed"), + ] = None, + limit: Annotated[int, Query(ge=1, le=100)] = 10, +): + """List all analysis jobs. + + Args: + status: Optional filter by job status + limit: Maximum number of jobs to return (default 10, max 100) + + Returns: + List of job statuses + """ + jobs = list(_jobs.values()) + + if status: + jobs = [j for j in jobs if j.status == status] + + # Return most recent first + jobs.sort(key=lambda j: j.job_id, reverse=True) + + return jobs[:limit] diff --git a/requirements.txt b/requirements.txt index df43541..c081cb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,6 @@ pytest pytest-mock openai psycopg2-binary +fastapi +uvicorn[standard] +httpx diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..4852f2e --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,183 @@ +"""Tests for FastAPI web service endpoints.""" + +import pytest +from datetime import datetime +from unittest.mock import Mock, patch +from fastapi.testclient import TestClient + +from SPARC.api import app, _analyzer, _jobs +from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +@pytest.fixture +def mock_analyzer(mocker): + """Mock the global analyzer.""" + mock = Mock() + mocker.patch("SPARC.api._analyzer", mock) + return mock + + +class TestHealthEndpoint: + """Test health check endpoint.""" + + def test_health_returns_ok(self, client): + """Test health endpoint returns healthy status.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert data["version"] == "1.0.0" + assert "timestamp" in data + + +class TestAnalyzeCompanyEndpoint: + """Test single company analysis endpoint.""" + + def test_analyze_company_success(self, client, mock_analyzer): + """Test successful company analysis.""" + mock_result = CompanyAnalysisResult( + company_name="nvidia", + analysis="Strong AI patent portfolio", + patent_count=5, + success=True, + timestamp=datetime.now(), + ) + mock_analyzer._analyze_company_safe.return_value = mock_result + + response = client.get("/analyze/nvidia") + + assert response.status_code == 200 + data = response.json() + assert data["company_name"] == "nvidia" + assert data["analysis"] == "Strong AI patent portfolio" + assert data["patent_count"] == 5 + assert data["success"] is True + + def test_analyze_company_failure(self, client, mock_analyzer): + """Test company analysis with error.""" + mock_result = CompanyAnalysisResult( + company_name="unknown", + analysis="", + patent_count=0, + success=False, + error="No patents found", + timestamp=datetime.now(), + ) + mock_analyzer._analyze_company_safe.return_value = mock_result + + response = client.get("/analyze/unknown") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert data["error"] == "No patents found" + + +class TestBatchAnalysisEndpoint: + """Test batch analysis endpoint.""" + + def test_batch_analysis_success(self, client, mock_analyzer): + """Test successful batch analysis.""" + results = [ + CompanyAnalysisResult( + company_name="nvidia", + analysis="Strong portfolio", + patent_count=5, + success=True, + timestamp=datetime.now(), + ), + CompanyAnalysisResult( + company_name="amd", + analysis="Growing portfolio", + patent_count=3, + success=True, + timestamp=datetime.now(), + ), + ] + mock_batch = BatchAnalysisResult( + results=results, + total_companies=2, + successful=2, + failed=0, + timestamp=datetime.now(), + ) + mock_analyzer.analyze_companies.return_value = mock_batch + + response = client.post( + "/analyze/batch", + json={"companies": ["nvidia", "amd"], "max_workers": 2}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["total_companies"] == 2 + assert data["successful"] == 2 + assert data["failed"] == 0 + assert len(data["results"]) == 2 + + def test_batch_analysis_validation(self, client): + """Test batch analysis request validation.""" + # Empty companies list + response = client.post("/analyze/batch", json={"companies": []}) + assert response.status_code == 422 + + # Too many companies + response = client.post( + "/analyze/batch", + json={"companies": [f"company{i}" for i in range(25)]}, + ) + assert response.status_code == 422 + + # Invalid max_workers + response = client.post( + "/analyze/batch", + json={"companies": ["nvidia"], "max_workers": 10}, + ) + assert response.status_code == 422 + + +class TestAsyncBatchEndpoint: + """Test async batch analysis endpoint.""" + + def test_async_batch_creates_job(self, client, mock_analyzer): + """Test async endpoint creates a job.""" + response = client.post( + "/analyze/batch/async", + json={"companies": ["nvidia", "amd"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert "job_id" in data + assert data["status"] == "pending" + assert data["total_companies"] == 2 + assert data["progress"] == 0 + + +class TestJobEndpoints: + """Test job management endpoints.""" + + def test_get_job_not_found(self, client): + """Test getting nonexistent job.""" + response = client.get("/jobs/nonexistent") + assert response.status_code == 404 + + def test_list_jobs(self, client, mocker): + """Test listing jobs.""" + # Clear existing jobs + mocker.patch.dict("SPARC.api._jobs", {}, clear=True) + + response = client.get("/jobs") + assert response.status_code == 200 + assert isinstance(response.json(), list) + + def test_list_jobs_with_filter(self, client, mocker): + """Test listing jobs with status filter.""" + response = client.get("/jobs?status=completed") + assert response.status_code == 200