Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d4ba13846a | |||
| 3479ba8a46 | |||
| 1c6d903301 | |||
| 84fd0bef32 |
@@ -15,7 +15,10 @@ SPARC automatically collects, parses, and analyzes patents from companies to pro
|
||||
- **Content Minimization**: Removes verbose descriptions to reduce LLM token usage
|
||||
- **AI Analysis**: Uses Claude 3.5 Sonnet via OpenRouter to analyze innovation quality and market potential
|
||||
- **Portfolio Analysis**: Evaluates multiple patents holistically for comprehensive insights
|
||||
- **Robust Testing**: 26 tests covering all major functionality
|
||||
- **Batch Processing**: Analyze multiple companies concurrently with progress tracking
|
||||
- **REST API**: FastAPI web service with async job support
|
||||
- **Dashboard**: Interactive Streamlit visualization dashboard
|
||||
- **Robust Testing**: 40 tests covering all major functionality
|
||||
|
||||
## Architecture
|
||||
|
||||
@@ -24,6 +27,7 @@ SPARC/
|
||||
├── serp_api.py # Patent retrieval and PDF parsing
|
||||
├── llm.py # Claude AI integration via OpenRouter
|
||||
├── analyzer.py # High-level orchestration
|
||||
├── api.py # FastAPI web service
|
||||
├── types.py # Data models
|
||||
└── config.py # Environment configuration
|
||||
```
|
||||
@@ -99,6 +103,87 @@ result = analyzer.analyze_single_patent(
|
||||
)
|
||||
```
|
||||
|
||||
### Multi-Company Batch Analysis
|
||||
|
||||
```python
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
|
||||
analyzer = CompanyAnalyzer()
|
||||
|
||||
# Analyze multiple companies concurrently (default 3 workers)
|
||||
batch_result = analyzer.analyze_companies(
|
||||
["nvidia", "amd", "intel", "qualcomm"],
|
||||
max_workers=3
|
||||
)
|
||||
|
||||
# Access results
|
||||
print(f"Analyzed: {batch_result.total_companies}")
|
||||
print(f"Successful: {batch_result.successful}")
|
||||
print(f"Failed: {batch_result.failed}")
|
||||
|
||||
for result in batch_result.results:
|
||||
if result.success:
|
||||
print(f"{result.company_name}: {result.patent_count} patents")
|
||||
print(result.analysis)
|
||||
|
||||
# Or use sequential processing (safer for rate limits)
|
||||
batch_result = analyzer.analyze_companies_sequential(["nvidia", "amd"])
|
||||
```
|
||||
|
||||
### REST API
|
||||
|
||||
Start the FastAPI server:
|
||||
|
||||
```bash
|
||||
uvicorn SPARC.api:app --reload
|
||||
```
|
||||
|
||||
API endpoints:
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/health` | GET | Health check |
|
||||
| `/analyze/{company}` | GET | Analyze single company |
|
||||
| `/analyze/batch` | POST | Analyze multiple companies |
|
||||
| `/analyze/batch/async` | POST | Start async batch job |
|
||||
| `/jobs/{job_id}` | GET | Get job status |
|
||||
| `/jobs` | GET | List all jobs |
|
||||
|
||||
Interactive docs available at `http://localhost:8000/docs`
|
||||
|
||||
Example API usage:
|
||||
|
||||
```bash
|
||||
# Single company
|
||||
curl http://localhost:8000/analyze/nvidia
|
||||
|
||||
# Batch analysis
|
||||
curl -X POST http://localhost:8000/analyze/batch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"companies": ["nvidia", "amd", "intel"]}'
|
||||
|
||||
# Async batch (for long-running jobs)
|
||||
curl -X POST http://localhost:8000/analyze/batch/async \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"companies": ["nvidia", "amd", "intel", "qualcomm"]}'
|
||||
```
|
||||
|
||||
### Visualization Dashboard
|
||||
|
||||
Launch the interactive Streamlit dashboard:
|
||||
|
||||
```bash
|
||||
streamlit run dashboard.py
|
||||
```
|
||||
|
||||
Dashboard features:
|
||||
- **Company Analysis**: Analyze individual companies with real-time results
|
||||
- **Batch Analysis**: Process multiple companies with progress tracking and charts
|
||||
- **Analytics**: View historical analysis data and trends (requires database mode)
|
||||
- **System Status**: Monitor database and analyzer health
|
||||
|
||||
The dashboard runs at `http://localhost:8501` by default.
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
@@ -130,11 +215,11 @@ pytest tests/ --cov=SPARC --cov-report=term-missing
|
||||
- [X] Extract and minimize patent content
|
||||
- [X] LLM integration for analysis
|
||||
- [X] Company performance estimation
|
||||
- [ ] Multi-company batch processing
|
||||
- [ ] FastAPI web service wrapper
|
||||
- [ ] Docker containerization
|
||||
- [ ] Results persistence (database)
|
||||
- [ ] Visualization dashboard
|
||||
- [X] Multi-company batch processing
|
||||
- [X] FastAPI web service wrapper
|
||||
- [X] Docker containerization
|
||||
- [X] Results persistence (database)
|
||||
- [X] Visualization dashboard
|
||||
|
||||
## Development
|
||||
|
||||
|
||||
+155
-2
@@ -4,10 +4,12 @@ This module ties together patent retrieval, parsing, and LLM analysis
|
||||
to provide company performance estimation based on patent portfolios.
|
||||
"""
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Callable
|
||||
|
||||
from SPARC.serp_api import SERP
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
from SPARC.types import Patent
|
||||
from typing import List
|
||||
from SPARC.types import Patent, CompanyAnalysisResult, BatchAnalysisResult
|
||||
|
||||
|
||||
class CompanyAnalyzer:
|
||||
@@ -110,3 +112,154 @@ class CompanyAnalyzer:
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to analyze patent {patent_id}: {e}"
|
||||
|
||||
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult:
|
||||
"""Internal wrapper that catches exceptions and returns structured result.
|
||||
|
||||
Args:
|
||||
company_name: Name of the company to analyze
|
||||
|
||||
Returns:
|
||||
CompanyAnalysisResult with success/failure status
|
||||
"""
|
||||
try:
|
||||
patents = SERP.query(company_name)
|
||||
patent_count = len(patents.patents) if patents.patents else 0
|
||||
|
||||
analysis = self.analyze_company(company_name)
|
||||
|
||||
# Check if analysis indicates failure
|
||||
if analysis.startswith("No patents found") or analysis.startswith(
|
||||
"Failed to process"
|
||||
):
|
||||
return CompanyAnalysisResult(
|
||||
company_name=company_name,
|
||||
analysis=analysis,
|
||||
patent_count=patent_count,
|
||||
success=False,
|
||||
error=analysis,
|
||||
)
|
||||
|
||||
return CompanyAnalysisResult(
|
||||
company_name=company_name,
|
||||
analysis=analysis,
|
||||
patent_count=patent_count,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return CompanyAnalysisResult(
|
||||
company_name=company_name,
|
||||
analysis="",
|
||||
patent_count=0,
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def analyze_companies(
|
||||
self,
|
||||
companies: list[str],
|
||||
max_workers: int = 3,
|
||||
progress_callback: Callable[[str, int, int], None] | None = None,
|
||||
) -> BatchAnalysisResult:
|
||||
"""Analyze multiple companies' patent portfolios in batch.
|
||||
|
||||
Processes companies concurrently for improved performance while
|
||||
respecting API rate limits.
|
||||
|
||||
Args:
|
||||
companies: List of company names to analyze
|
||||
max_workers: Maximum concurrent analyses (default 3 to avoid rate limits)
|
||||
progress_callback: Optional callback(company_name, completed, total)
|
||||
|
||||
Returns:
|
||||
BatchAnalysisResult containing all individual results and summary stats
|
||||
"""
|
||||
results: list[CompanyAnalysisResult] = []
|
||||
total = len(companies)
|
||||
|
||||
print(f"Starting batch analysis of {total} companies...")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_company = {
|
||||
executor.submit(self._analyze_company_safe, company): company
|
||||
for company in companies
|
||||
}
|
||||
|
||||
completed = 0
|
||||
for future in as_completed(future_to_company):
|
||||
company = future_to_company[future]
|
||||
completed += 1
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
|
||||
status = "✓" if result.success else "✗"
|
||||
print(f"[{completed}/{total}] {status} {company}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(company, completed, total)
|
||||
|
||||
except Exception as e:
|
||||
results.append(
|
||||
CompanyAnalysisResult(
|
||||
company_name=company,
|
||||
analysis="",
|
||||
patent_count=0,
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
print(f"[{completed}/{total}] ✗ {company}: {e}")
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = total - successful
|
||||
|
||||
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
|
||||
|
||||
return BatchAnalysisResult(
|
||||
results=results,
|
||||
total_companies=total,
|
||||
successful=successful,
|
||||
failed=failed,
|
||||
)
|
||||
|
||||
def analyze_companies_sequential(
|
||||
self, companies: list[str]
|
||||
) -> BatchAnalysisResult:
|
||||
"""Analyze multiple companies sequentially (safer for rate limits).
|
||||
|
||||
Use this when you want more control over API rate limiting or
|
||||
when debugging issues.
|
||||
|
||||
Args:
|
||||
companies: List of company names to analyze
|
||||
|
||||
Returns:
|
||||
BatchAnalysisResult containing all individual results
|
||||
"""
|
||||
results: list[CompanyAnalysisResult] = []
|
||||
total = len(companies)
|
||||
|
||||
print(f"Starting sequential analysis of {total} companies...")
|
||||
|
||||
for idx, company in enumerate(companies, 1):
|
||||
print(f"\n[{idx}/{total}] Analyzing {company}...")
|
||||
result = self._analyze_company_safe(company)
|
||||
results.append(result)
|
||||
|
||||
status = "✓" if result.success else "✗"
|
||||
print(f"[{idx}/{total}] {status} {company}")
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = total - successful
|
||||
|
||||
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
|
||||
|
||||
return BatchAnalysisResult(
|
||||
results=results,
|
||||
total_companies=total,
|
||||
successful=successful,
|
||||
failed=failed,
|
||||
)
|
||||
|
||||
+286
@@ -0,0 +1,286 @@
|
||||
"""FastAPI web service wrapper for SPARC patent analysis.
|
||||
|
||||
Provides REST API endpoints for analyzing company patent portfolios.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import BackgroundTasks, FastAPI, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||
|
||||
|
||||
# Pydantic models for API
|
||||
class CompanyAnalysisResponse(BaseModel):
|
||||
"""Response model for single company analysis."""
|
||||
|
||||
company_name: str
|
||||
analysis: str
|
||||
patent_count: int
|
||||
success: bool
|
||||
error: str | None = None
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class BatchAnalysisResponse(BaseModel):
|
||||
"""Response model for batch company analysis."""
|
||||
|
||||
results: list[CompanyAnalysisResponse]
|
||||
total_companies: int
|
||||
successful: int
|
||||
failed: int
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class BatchAnalysisRequest(BaseModel):
|
||||
"""Request model for batch company analysis."""
|
||||
|
||||
companies: list[str] = Field(
|
||||
..., min_length=1, max_length=20, description="List of company names to analyze"
|
||||
)
|
||||
max_workers: int = Field(
|
||||
default=3, ge=1, le=5, description="Max concurrent analyses"
|
||||
)
|
||||
|
||||
|
||||
class JobStatus(BaseModel):
|
||||
"""Status of a background analysis job."""
|
||||
|
||||
job_id: str
|
||||
status: str # "pending", "running", "completed", "failed"
|
||||
progress: int # 0-100
|
||||
total_companies: int
|
||||
completed_companies: int
|
||||
result: BatchAnalysisResponse | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str
|
||||
version: str
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
# In-memory job storage (for demo; production would use Redis/DB)
|
||||
_jobs: dict[str, JobStatus] = {}
|
||||
_job_counter = 0
|
||||
|
||||
|
||||
def _convert_result(result: CompanyAnalysisResult) -> CompanyAnalysisResponse:
|
||||
"""Convert internal result to API response model."""
|
||||
return CompanyAnalysisResponse(
|
||||
company_name=result.company_name,
|
||||
analysis=result.analysis,
|
||||
patent_count=result.patent_count,
|
||||
success=result.success,
|
||||
error=result.error,
|
||||
timestamp=result.timestamp,
|
||||
)
|
||||
|
||||
|
||||
def _convert_batch_result(result: BatchAnalysisResult) -> BatchAnalysisResponse:
|
||||
"""Convert internal batch result to API response model."""
|
||||
return BatchAnalysisResponse(
|
||||
results=[_convert_result(r) for r in result.results],
|
||||
total_companies=result.total_companies,
|
||||
successful=result.successful,
|
||||
failed=result.failed,
|
||||
timestamp=result.timestamp,
|
||||
)
|
||||
|
||||
|
||||
# Global analyzer instance
|
||||
_analyzer: CompanyAnalyzer | None = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize resources on startup."""
|
||||
global _analyzer
|
||||
_analyzer = CompanyAnalyzer()
|
||||
yield
|
||||
# Cleanup if needed
|
||||
_analyzer = None
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="SPARC API",
|
||||
description="Semiconductor Patent & Analytics Report Core - Patent portfolio analysis using AI",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse, tags=["System"])
|
||||
async def health_check():
|
||||
"""Check API health status."""
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
version="1.0.0",
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
@app.get(
|
||||
"/analyze/{company_name}",
|
||||
response_model=CompanyAnalysisResponse,
|
||||
tags=["Analysis"],
|
||||
)
|
||||
async def analyze_company(company_name: str):
|
||||
"""Analyze a single company's patent portfolio.
|
||||
|
||||
This endpoint retrieves recent patents for the specified company,
|
||||
parses them, and uses AI to generate a comprehensive analysis.
|
||||
|
||||
Args:
|
||||
company_name: Name of the company to analyze (e.g., "nvidia", "intel")
|
||||
|
||||
Returns:
|
||||
Analysis results including patent count, AI insights, and success status
|
||||
"""
|
||||
if not _analyzer:
|
||||
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
||||
|
||||
result = _analyzer._analyze_company_safe(company_name)
|
||||
return _convert_result(result)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/analyze/batch",
|
||||
response_model=BatchAnalysisResponse,
|
||||
tags=["Analysis"],
|
||||
)
|
||||
async def analyze_companies_batch(request: BatchAnalysisRequest):
|
||||
"""Analyze multiple companies' patent portfolios.
|
||||
|
||||
Processes companies concurrently for improved performance.
|
||||
Limited to 20 companies per request.
|
||||
|
||||
Args:
|
||||
request: List of company names and optional worker count
|
||||
|
||||
Returns:
|
||||
Batch results with individual company analyses and summary statistics
|
||||
"""
|
||||
if not _analyzer:
|
||||
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
||||
|
||||
result = _analyzer.analyze_companies(
|
||||
companies=request.companies,
|
||||
max_workers=request.max_workers,
|
||||
)
|
||||
return _convert_batch_result(result)
|
||||
|
||||
|
||||
def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
|
||||
"""Background task for batch analysis."""
|
||||
global _jobs, _analyzer
|
||||
|
||||
if not _analyzer:
|
||||
_jobs[job_id].status = "failed"
|
||||
_jobs[job_id].error = "Analyzer not initialized"
|
||||
return
|
||||
|
||||
_jobs[job_id].status = "running"
|
||||
|
||||
def progress_callback(company: str, completed: int, total: int):
|
||||
_jobs[job_id].completed_companies = completed
|
||||
_jobs[job_id].progress = int((completed / total) * 100)
|
||||
|
||||
try:
|
||||
result = _analyzer.analyze_companies(
|
||||
companies=companies,
|
||||
max_workers=max_workers,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
_jobs[job_id].status = "completed"
|
||||
_jobs[job_id].progress = 100
|
||||
_jobs[job_id].result = _convert_batch_result(result)
|
||||
except Exception as e:
|
||||
_jobs[job_id].status = "failed"
|
||||
_jobs[job_id].error = str(e)
|
||||
|
||||
|
||||
@app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"])
|
||||
async def analyze_companies_async(
|
||||
request: BatchAnalysisRequest, background_tasks: BackgroundTasks
|
||||
):
|
||||
"""Start an asynchronous batch analysis job.
|
||||
|
||||
Returns immediately with a job ID that can be used to poll for status.
|
||||
Useful for large batch analyses that may take a long time.
|
||||
|
||||
Args:
|
||||
request: List of company names and optional worker count
|
||||
|
||||
Returns:
|
||||
Job status with job_id for polling
|
||||
"""
|
||||
global _job_counter
|
||||
|
||||
_job_counter += 1
|
||||
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
|
||||
_jobs[job_id] = JobStatus(
|
||||
job_id=job_id,
|
||||
status="pending",
|
||||
progress=0,
|
||||
total_companies=len(request.companies),
|
||||
completed_companies=0,
|
||||
)
|
||||
|
||||
background_tasks.add_task(
|
||||
_run_batch_job, job_id, request.companies, request.max_workers
|
||||
)
|
||||
|
||||
return _jobs[job_id]
|
||||
|
||||
|
||||
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
|
||||
async def get_job_status(job_id: str):
|
||||
"""Get the status of a background analysis job.
|
||||
|
||||
Args:
|
||||
job_id: The job ID returned from the async batch endpoint
|
||||
|
||||
Returns:
|
||||
Current job status including progress and results when complete
|
||||
"""
|
||||
if job_id not in _jobs:
|
||||
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
||||
|
||||
return _jobs[job_id]
|
||||
|
||||
|
||||
@app.get("/jobs", response_model=list[JobStatus], tags=["Jobs"])
|
||||
async def list_jobs(
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status: pending, running, completed, failed"),
|
||||
] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 10,
|
||||
):
|
||||
"""List all analysis jobs.
|
||||
|
||||
Args:
|
||||
status: Optional filter by job status
|
||||
limit: Maximum number of jobs to return (default 10, max 100)
|
||||
|
||||
Returns:
|
||||
List of job statuses
|
||||
"""
|
||||
jobs = list(_jobs.values())
|
||||
|
||||
if status:
|
||||
jobs = [j for j in jobs if j.status == status]
|
||||
|
||||
# Return most recent first
|
||||
jobs.sort(key=lambda j: j.job_id, reverse=True)
|
||||
|
||||
return jobs[:limit]
|
||||
+25
-1
@@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -12,3 +13,26 @@ class Patent:
|
||||
@dataclass
|
||||
class Patents:
|
||||
patents: list[Patent]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompanyAnalysisResult:
|
||||
"""Result of analyzing a single company's patent portfolio."""
|
||||
|
||||
company_name: str
|
||||
analysis: str
|
||||
patent_count: int
|
||||
success: bool
|
||||
error: str | None = None
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchAnalysisResult:
|
||||
"""Result of batch analyzing multiple companies."""
|
||||
|
||||
results: list[CompanyAnalysisResult]
|
||||
total_companies: int
|
||||
successful: int
|
||||
failed: int
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
+362
@@ -0,0 +1,362 @@
|
||||
"""SPARC Visualization Dashboard.
|
||||
|
||||
A Streamlit-based dashboard for visualizing patent analysis results.
|
||||
Run with: streamlit run dashboard.py
|
||||
"""
|
||||
|
||||
import streamlit as st
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.database import DatabaseClient
|
||||
from SPARC import config
|
||||
|
||||
|
||||
st.set_page_config(
|
||||
page_title="SPARC Dashboard",
|
||||
page_icon="📊",
|
||||
layout="wide",
|
||||
initial_sidebar_state="expanded",
|
||||
)
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def get_analyzer():
|
||||
"""Get or create the CompanyAnalyzer instance."""
|
||||
return CompanyAnalyzer()
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def get_db_client():
|
||||
"""Get database client if available."""
|
||||
if config.use_database:
|
||||
try:
|
||||
client = DatabaseClient()
|
||||
client.connect()
|
||||
return client
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def render_header():
|
||||
"""Render the dashboard header."""
|
||||
st.title("SPARC Dashboard")
|
||||
st.markdown("**Semiconductor Patent & Analytics Report Core**")
|
||||
st.markdown("---")
|
||||
|
||||
|
||||
def render_sidebar():
|
||||
"""Render the sidebar with navigation and controls."""
|
||||
st.sidebar.title("Navigation")
|
||||
page = st.sidebar.radio(
|
||||
"Select Page",
|
||||
["Company Analysis", "Batch Analysis", "Analytics", "About"],
|
||||
)
|
||||
return page
|
||||
|
||||
|
||||
def render_company_analysis():
|
||||
"""Render single company analysis page."""
|
||||
st.header("Company Patent Analysis")
|
||||
|
||||
col1, col2 = st.columns([2, 1])
|
||||
|
||||
with col1:
|
||||
company_name = st.text_input(
|
||||
"Company Name",
|
||||
placeholder="e.g., nvidia, intel, amd",
|
||||
help="Enter the company name to analyze their patent portfolio",
|
||||
)
|
||||
|
||||
with col2:
|
||||
analyze_btn = st.button("Analyze", type="primary", use_container_width=True)
|
||||
|
||||
if analyze_btn and company_name:
|
||||
with st.spinner(f"Analyzing {company_name}..."):
|
||||
analyzer = get_analyzer()
|
||||
result = analyzer._analyze_company_safe(company_name)
|
||||
|
||||
if result.success:
|
||||
st.success(f"Analysis complete for {company_name}")
|
||||
|
||||
# Metrics row
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Patents Analyzed", result.patent_count)
|
||||
with col2:
|
||||
st.metric("Status", "Success")
|
||||
with col3:
|
||||
st.metric("Timestamp", result.timestamp.strftime("%H:%M:%S"))
|
||||
|
||||
# Analysis content
|
||||
st.subheader("AI Analysis")
|
||||
st.markdown(result.analysis)
|
||||
|
||||
else:
|
||||
st.error(f"Analysis failed: {result.error}")
|
||||
|
||||
|
||||
def render_batch_analysis():
|
||||
"""Render batch analysis page."""
|
||||
st.header("Batch Company Analysis")
|
||||
|
||||
st.markdown(
|
||||
"Analyze multiple companies simultaneously. Enter company names separated by commas or newlines."
|
||||
)
|
||||
|
||||
companies_input = st.text_area(
|
||||
"Company Names",
|
||||
placeholder="nvidia\namd\nintel\nqualcomm",
|
||||
height=150,
|
||||
)
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
max_workers = st.slider("Concurrent Workers", 1, 5, 3)
|
||||
with col2:
|
||||
analyze_btn = st.button(
|
||||
"Run Batch Analysis", type="primary", use_container_width=True
|
||||
)
|
||||
|
||||
if analyze_btn and companies_input:
|
||||
# Parse company names
|
||||
companies = [
|
||||
c.strip()
|
||||
for c in companies_input.replace(",", "\n").split("\n")
|
||||
if c.strip()
|
||||
]
|
||||
|
||||
if not companies:
|
||||
st.warning("Please enter at least one company name")
|
||||
return
|
||||
|
||||
st.info(f"Starting analysis of {len(companies)} companies...")
|
||||
|
||||
# Progress tracking
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
analyzer = get_analyzer()
|
||||
|
||||
def update_progress(company: str, completed: int, total: int):
|
||||
progress = completed / total
|
||||
progress_bar.progress(progress)
|
||||
status_text.text(f"Analyzing {company}... ({completed}/{total})")
|
||||
|
||||
result = analyzer.analyze_companies(
|
||||
companies=companies,
|
||||
max_workers=max_workers,
|
||||
progress_callback=update_progress,
|
||||
)
|
||||
|
||||
progress_bar.progress(1.0)
|
||||
status_text.text("Analysis complete!")
|
||||
|
||||
# Summary metrics
|
||||
st.subheader("Results Summary")
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Total Companies", result.total_companies)
|
||||
with col2:
|
||||
st.metric("Successful", result.successful)
|
||||
with col3:
|
||||
st.metric("Failed", result.failed)
|
||||
with col4:
|
||||
success_rate = (
|
||||
(result.successful / result.total_companies * 100)
|
||||
if result.total_companies > 0
|
||||
else 0
|
||||
)
|
||||
st.metric("Success Rate", f"{success_rate:.1f}%")
|
||||
|
||||
# Results chart
|
||||
if result.results:
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"Company": r.company_name,
|
||||
"Patents": r.patent_count,
|
||||
"Status": "Success" if r.success else "Failed",
|
||||
}
|
||||
for r in result.results
|
||||
]
|
||||
)
|
||||
|
||||
fig = px.bar(
|
||||
df,
|
||||
x="Company",
|
||||
y="Patents",
|
||||
color="Status",
|
||||
color_discrete_map={"Success": "#28a745", "Failed": "#dc3545"},
|
||||
title="Patents per Company",
|
||||
)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Individual results
|
||||
st.subheader("Individual Results")
|
||||
for r in result.results:
|
||||
with st.expander(
|
||||
f"{'✓' if r.success else '✗'} {r.company_name} ({r.patent_count} patents)"
|
||||
):
|
||||
if r.success:
|
||||
st.markdown(r.analysis)
|
||||
else:
|
||||
st.error(r.error)
|
||||
|
||||
|
||||
def render_analytics():
|
||||
"""Render analytics page with database insights."""
|
||||
st.header("Analytics Dashboard")
|
||||
|
||||
db_client = get_db_client()
|
||||
|
||||
if not db_client:
|
||||
st.warning(
|
||||
"Database mode is not enabled. Set USE_DATABASE=true in your .env file to enable analytics."
|
||||
)
|
||||
st.info(
|
||||
"Analytics features require storing analysis results in PostgreSQL for historical tracking."
|
||||
)
|
||||
return
|
||||
|
||||
# Time range selector
|
||||
days = st.selectbox("Time Range", [7, 14, 30, 90], index=0)
|
||||
|
||||
try:
|
||||
analytics = db_client.get_analytics(days=days)
|
||||
|
||||
if not analytics:
|
||||
st.info("No analytics data available yet. Run some analyses first!")
|
||||
return
|
||||
|
||||
# Summary metrics
|
||||
st.subheader("Summary")
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
total = analytics.get("total_messages", 0)
|
||||
st.metric("Total Analyses", total)
|
||||
|
||||
with col2:
|
||||
companies = len(analytics.get("by_company", {}))
|
||||
st.metric("Companies Analyzed", companies)
|
||||
|
||||
with col3:
|
||||
types = len(analytics.get("by_type", {}))
|
||||
st.metric("Analysis Types", types)
|
||||
|
||||
# Charts
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
by_company = analytics.get("by_company", {})
|
||||
if by_company:
|
||||
df = pd.DataFrame(
|
||||
[{"Company": k, "Count": v} for k, v in by_company.items()]
|
||||
)
|
||||
fig = px.pie(
|
||||
df, values="Count", names="Company", title="Analyses by Company"
|
||||
)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
by_type = analytics.get("by_type", {})
|
||||
if by_type:
|
||||
df = pd.DataFrame(
|
||||
[{"Type": k, "Count": v} for k, v in by_type.items()]
|
||||
)
|
||||
fig = px.bar(df, x="Type", y="Count", title="Analyses by Type")
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Recent messages
|
||||
st.subheader("Recent Analyses")
|
||||
messages = db_client.get_messages(limit=10)
|
||||
|
||||
if messages:
|
||||
for msg in messages:
|
||||
with st.expander(
|
||||
f"{msg.get('company_name', 'Unknown')} - {msg.get('analysis_type', 'N/A')} ({msg.get('timestamp', 'N/A')})"
|
||||
):
|
||||
st.markdown(f"**Model:** {msg.get('model', 'N/A')}")
|
||||
if msg.get("response"):
|
||||
st.markdown(msg["response"][:500] + "...")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error fetching analytics: {e}")
|
||||
|
||||
|
||||
def render_about():
|
||||
"""Render about page."""
|
||||
st.header("About SPARC")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**SPARC** (Semiconductor Patent & Analytics Report Core) is a patent analysis
|
||||
system that estimates company performance by analyzing their patent portfolios
|
||||
using LLM-powered insights.
|
||||
|
||||
### Features
|
||||
|
||||
- **Patent Retrieval**: Automated collection via SerpAPI's Google Patents engine
|
||||
- **Intelligent Parsing**: Extracts key sections from patent PDFs
|
||||
- **AI Analysis**: Uses Claude 3.5 Sonnet for deep analysis
|
||||
- **Batch Processing**: Analyze multiple companies concurrently
|
||||
- **REST API**: FastAPI web service for integration
|
||||
- **Analytics**: Track and visualize analysis history
|
||||
|
||||
### Technology Stack
|
||||
|
||||
- **Backend**: Python, FastAPI
|
||||
- **AI**: Claude 3.5 Sonnet via OpenRouter
|
||||
- **Database**: PostgreSQL
|
||||
- **Dashboard**: Streamlit, Plotly
|
||||
- **Patent Data**: SerpAPI Google Patents
|
||||
|
||||
### Links
|
||||
|
||||
- API Docs: `http://localhost:8000/docs`
|
||||
- Health Check: `http://localhost:8000/health`
|
||||
"""
|
||||
)
|
||||
|
||||
# System status
|
||||
st.subheader("System Status")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
db_client = get_db_client()
|
||||
if db_client:
|
||||
st.success("Database: Connected")
|
||||
else:
|
||||
st.warning("Database: Not configured")
|
||||
|
||||
with col2:
|
||||
analyzer = get_analyzer()
|
||||
if analyzer:
|
||||
st.success("Analyzer: Ready")
|
||||
else:
|
||||
st.error("Analyzer: Not initialized")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main dashboard entry point."""
|
||||
render_header()
|
||||
page = render_sidebar()
|
||||
|
||||
if page == "Company Analysis":
|
||||
render_company_analysis()
|
||||
elif page == "Batch Analysis":
|
||||
render_batch_analysis()
|
||||
elif page == "Analytics":
|
||||
render_analytics()
|
||||
elif page == "About":
|
||||
render_about()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,3 +6,9 @@ pytest
|
||||
pytest-mock
|
||||
openai
|
||||
psycopg2-binary
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
httpx
|
||||
streamlit
|
||||
plotly
|
||||
pandas
|
||||
|
||||
+176
-2
@@ -1,9 +1,9 @@
|
||||
"""Tests for the high-level company analyzer orchestration."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, patch, call
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.types import Patent, Patents
|
||||
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult
|
||||
|
||||
|
||||
class TestCompanyAnalyzer:
|
||||
@@ -176,3 +176,177 @@ class TestCompanyAnalyzer:
|
||||
|
||||
assert "Failed to analyze patent US999" in result
|
||||
assert "PDF not found" in result
|
||||
|
||||
|
||||
class TestBatchProcessing:
|
||||
"""Test multi-company batch processing functionality."""
|
||||
|
||||
def test_analyze_companies_success(self, mocker):
|
||||
"""Test batch analysis of multiple companies."""
|
||||
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
|
||||
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
|
||||
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
|
||||
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
# Setup mock returns
|
||||
def query_side_effect(company):
|
||||
patent = Patent(
|
||||
patent_id=f"US-{company}",
|
||||
pdf_link=f"http://example.com/{company}.pdf",
|
||||
)
|
||||
return Patents(patents=[patent])
|
||||
|
||||
mock_query.side_effect = query_side_effect
|
||||
|
||||
def save_side_effect(patent):
|
||||
patent.pdf_path = f"patents/{patent.patent_id}.pdf"
|
||||
return patent
|
||||
|
||||
mock_save.side_effect = save_side_effect
|
||||
mock_parse.return_value = {"abstract": "Test"}
|
||||
mock_minimize.return_value = "Content"
|
||||
|
||||
mock_llm_instance = Mock()
|
||||
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis result"
|
||||
mock_llm.return_value = mock_llm_instance
|
||||
|
||||
analyzer = CompanyAnalyzer()
|
||||
result = analyzer.analyze_companies(["CompanyA", "CompanyB"], max_workers=2)
|
||||
|
||||
assert isinstance(result, BatchAnalysisResult)
|
||||
assert result.total_companies == 2
|
||||
assert result.successful == 2
|
||||
assert result.failed == 0
|
||||
assert len(result.results) == 2
|
||||
|
||||
def test_analyze_companies_with_failures(self, mocker):
|
||||
"""Test batch analysis handles partial failures."""
|
||||
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
|
||||
mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
def query_side_effect(company):
|
||||
if company == "FailCorp":
|
||||
return Patents(patents=[])
|
||||
patent = Patent(
|
||||
patent_id=f"US-{company}",
|
||||
pdf_link=f"http://example.com/{company}.pdf",
|
||||
)
|
||||
return Patents(patents=[patent])
|
||||
|
||||
mock_query.side_effect = query_side_effect
|
||||
|
||||
analyzer = CompanyAnalyzer()
|
||||
result = analyzer.analyze_companies(["GoodCorp", "FailCorp"], max_workers=1)
|
||||
|
||||
assert result.total_companies == 2
|
||||
assert result.failed >= 1 # At least FailCorp should fail
|
||||
|
||||
# Find the failed result
|
||||
fail_result = next(r for r in result.results if r.company_name == "FailCorp")
|
||||
assert fail_result.success is False
|
||||
|
||||
def test_analyze_companies_sequential(self, mocker):
|
||||
"""Test sequential batch analysis."""
|
||||
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
|
||||
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
|
||||
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
|
||||
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
def query_side_effect(company):
|
||||
patent = Patent(
|
||||
patent_id=f"US-{company}",
|
||||
pdf_link=f"http://example.com/{company}.pdf",
|
||||
)
|
||||
return Patents(patents=[patent])
|
||||
|
||||
mock_query.side_effect = query_side_effect
|
||||
|
||||
def save_side_effect(patent):
|
||||
patent.pdf_path = f"patents/{patent.patent_id}.pdf"
|
||||
return patent
|
||||
|
||||
mock_save.side_effect = save_side_effect
|
||||
mock_parse.return_value = {"abstract": "Test"}
|
||||
mock_minimize.return_value = "Content"
|
||||
|
||||
mock_llm_instance = Mock()
|
||||
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis"
|
||||
mock_llm.return_value = mock_llm_instance
|
||||
|
||||
analyzer = CompanyAnalyzer()
|
||||
result = analyzer.analyze_companies_sequential(["Corp1", "Corp2", "Corp3"])
|
||||
|
||||
assert result.total_companies == 3
|
||||
assert len(result.results) == 3
|
||||
|
||||
def test_analyze_companies_progress_callback(self, mocker):
|
||||
"""Test that progress callback is invoked correctly."""
|
||||
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
|
||||
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
|
||||
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
|
||||
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
def query_side_effect(company):
|
||||
patent = Patent(
|
||||
patent_id=f"US-{company}",
|
||||
pdf_link=f"http://example.com/{company}.pdf",
|
||||
)
|
||||
return Patents(patents=[patent])
|
||||
|
||||
mock_query.side_effect = query_side_effect
|
||||
|
||||
def save_side_effect(patent):
|
||||
patent.pdf_path = f"patents/{patent.patent_id}.pdf"
|
||||
return patent
|
||||
|
||||
mock_save.side_effect = save_side_effect
|
||||
mock_parse.return_value = {"abstract": "Test"}
|
||||
mock_minimize.return_value = "Content"
|
||||
|
||||
mock_llm_instance = Mock()
|
||||
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis"
|
||||
mock_llm.return_value = mock_llm_instance
|
||||
|
||||
callback = Mock()
|
||||
analyzer = CompanyAnalyzer()
|
||||
analyzer.analyze_companies(["A", "B"], max_workers=1, progress_callback=callback)
|
||||
|
||||
assert callback.call_count == 2
|
||||
|
||||
def test_company_analysis_result_structure(self, mocker):
|
||||
"""Test CompanyAnalysisResult has correct structure."""
|
||||
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
|
||||
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
|
||||
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
|
||||
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf")
|
||||
mock_query.return_value = Patents(patents=[patent])
|
||||
|
||||
def save_side_effect(p):
|
||||
p.pdf_path = "patents/US123.pdf"
|
||||
return p
|
||||
|
||||
mock_save.side_effect = save_side_effect
|
||||
mock_parse.return_value = {"abstract": "Test"}
|
||||
mock_minimize.return_value = "Content"
|
||||
|
||||
mock_llm_instance = Mock()
|
||||
mock_llm_instance.analyze_patent_portfolio.return_value = "Strong innovation"
|
||||
mock_llm.return_value = mock_llm_instance
|
||||
|
||||
analyzer = CompanyAnalyzer()
|
||||
result = analyzer.analyze_companies(["TestCorp"], max_workers=1)
|
||||
|
||||
assert len(result.results) == 1
|
||||
company_result = result.results[0]
|
||||
assert company_result.company_name == "TestCorp"
|
||||
assert company_result.analysis == "Strong innovation"
|
||||
assert company_result.patent_count == 1
|
||||
assert company_result.success is True
|
||||
assert company_result.error is None
|
||||
assert company_result.timestamp is not None
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Tests for FastAPI web service endpoints."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from SPARC.api import app, _analyzer, _jobs
|
||||
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_analyzer(mocker):
|
||||
"""Mock the global analyzer."""
|
||||
mock = Mock()
|
||||
mocker.patch("SPARC.api._analyzer", mock)
|
||||
return mock
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Test health check endpoint."""
|
||||
|
||||
def test_health_returns_ok(self, client):
|
||||
"""Test health endpoint returns healthy status."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert data["version"] == "1.0.0"
|
||||
assert "timestamp" in data
|
||||
|
||||
|
||||
class TestAnalyzeCompanyEndpoint:
|
||||
"""Test single company analysis endpoint."""
|
||||
|
||||
def test_analyze_company_success(self, client, mock_analyzer):
|
||||
"""Test successful company analysis."""
|
||||
mock_result = CompanyAnalysisResult(
|
||||
company_name="nvidia",
|
||||
analysis="Strong AI patent portfolio",
|
||||
patent_count=5,
|
||||
success=True,
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
mock_analyzer._analyze_company_safe.return_value = mock_result
|
||||
|
||||
response = client.get("/analyze/nvidia")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["company_name"] == "nvidia"
|
||||
assert data["analysis"] == "Strong AI patent portfolio"
|
||||
assert data["patent_count"] == 5
|
||||
assert data["success"] is True
|
||||
|
||||
def test_analyze_company_failure(self, client, mock_analyzer):
|
||||
"""Test company analysis with error."""
|
||||
mock_result = CompanyAnalysisResult(
|
||||
company_name="unknown",
|
||||
analysis="",
|
||||
patent_count=0,
|
||||
success=False,
|
||||
error="No patents found",
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
mock_analyzer._analyze_company_safe.return_value = mock_result
|
||||
|
||||
response = client.get("/analyze/unknown")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["error"] == "No patents found"
|
||||
|
||||
|
||||
class TestBatchAnalysisEndpoint:
|
||||
"""Test batch analysis endpoint."""
|
||||
|
||||
def test_batch_analysis_success(self, client, mock_analyzer):
|
||||
"""Test successful batch analysis."""
|
||||
results = [
|
||||
CompanyAnalysisResult(
|
||||
company_name="nvidia",
|
||||
analysis="Strong portfolio",
|
||||
patent_count=5,
|
||||
success=True,
|
||||
timestamp=datetime.now(),
|
||||
),
|
||||
CompanyAnalysisResult(
|
||||
company_name="amd",
|
||||
analysis="Growing portfolio",
|
||||
patent_count=3,
|
||||
success=True,
|
||||
timestamp=datetime.now(),
|
||||
),
|
||||
]
|
||||
mock_batch = BatchAnalysisResult(
|
||||
results=results,
|
||||
total_companies=2,
|
||||
successful=2,
|
||||
failed=0,
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
mock_analyzer.analyze_companies.return_value = mock_batch
|
||||
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["nvidia", "amd"], "max_workers": 2},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_companies"] == 2
|
||||
assert data["successful"] == 2
|
||||
assert data["failed"] == 0
|
||||
assert len(data["results"]) == 2
|
||||
|
||||
def test_batch_analysis_validation(self, client):
|
||||
"""Test batch analysis request validation."""
|
||||
# Empty companies list
|
||||
response = client.post("/analyze/batch", json={"companies": []})
|
||||
assert response.status_code == 422
|
||||
|
||||
# Too many companies
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": [f"company{i}" for i in range(25)]},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
# Invalid max_workers
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["nvidia"], "max_workers": 10},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestAsyncBatchEndpoint:
|
||||
"""Test async batch analysis endpoint."""
|
||||
|
||||
def test_async_batch_creates_job(self, client, mock_analyzer):
|
||||
"""Test async endpoint creates a job."""
|
||||
response = client.post(
|
||||
"/analyze/batch/async",
|
||||
json={"companies": ["nvidia", "amd"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
assert data["status"] == "pending"
|
||||
assert data["total_companies"] == 2
|
||||
assert data["progress"] == 0
|
||||
|
||||
|
||||
class TestJobEndpoints:
|
||||
"""Test job management endpoints."""
|
||||
|
||||
def test_get_job_not_found(self, client):
|
||||
"""Test getting nonexistent job."""
|
||||
response = client.get("/jobs/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_list_jobs(self, client, mocker):
|
||||
"""Test listing jobs."""
|
||||
# Clear existing jobs
|
||||
mocker.patch.dict("SPARC.api._jobs", {}, clear=True)
|
||||
|
||||
response = client.get("/jobs")
|
||||
assert response.status_code == 200
|
||||
assert isinstance(response.json(), list)
|
||||
|
||||
def test_list_jobs_with_filter(self, client, mocker):
|
||||
"""Test listing jobs with status filter."""
|
||||
response = client.get("/jobs?status=completed")
|
||||
assert response.status_code == 200
|
||||
@@ -25,6 +25,8 @@ class TestLLMAnalyzer:
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_config = mocker.patch("SPARC.llm.config")
|
||||
mock_config.openrouter_api_key = "config-key-456"
|
||||
mock_config.use_database = False
|
||||
mock_config.database_url = "postgresql://localhost/test"
|
||||
|
||||
analyzer = LLMAnalyzer()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user