forked from 0xWheatyz/SPARC
feat: add multi-company batch processing
- Add CompanyAnalysisResult and BatchAnalysisResult dataclasses - Implement analyze_companies() for concurrent batch analysis - Implement analyze_companies_sequential() for rate-limited scenarios - Add progress callback support for monitoring batch jobs - Include 5 new tests for batch processing functionality - Fix pre-existing test mock issue in test_llm.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
+155
-2
@@ -4,10 +4,12 @@ This module ties together patent retrieval, parsing, and LLM analysis
|
||||
to provide company performance estimation based on patent portfolios.
|
||||
"""
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Callable
|
||||
|
||||
from SPARC.serp_api import SERP
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
from SPARC.types import Patent
|
||||
from typing import List
|
||||
from SPARC.types import Patent, CompanyAnalysisResult, BatchAnalysisResult
|
||||
|
||||
|
||||
class CompanyAnalyzer:
|
||||
@@ -110,3 +112,154 @@ class CompanyAnalyzer:
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to analyze patent {patent_id}: {e}"
|
||||
|
||||
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult:
|
||||
"""Internal wrapper that catches exceptions and returns structured result.
|
||||
|
||||
Args:
|
||||
company_name: Name of the company to analyze
|
||||
|
||||
Returns:
|
||||
CompanyAnalysisResult with success/failure status
|
||||
"""
|
||||
try:
|
||||
patents = SERP.query(company_name)
|
||||
patent_count = len(patents.patents) if patents.patents else 0
|
||||
|
||||
analysis = self.analyze_company(company_name)
|
||||
|
||||
# Check if analysis indicates failure
|
||||
if analysis.startswith("No patents found") or analysis.startswith(
|
||||
"Failed to process"
|
||||
):
|
||||
return CompanyAnalysisResult(
|
||||
company_name=company_name,
|
||||
analysis=analysis,
|
||||
patent_count=patent_count,
|
||||
success=False,
|
||||
error=analysis,
|
||||
)
|
||||
|
||||
return CompanyAnalysisResult(
|
||||
company_name=company_name,
|
||||
analysis=analysis,
|
||||
patent_count=patent_count,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return CompanyAnalysisResult(
|
||||
company_name=company_name,
|
||||
analysis="",
|
||||
patent_count=0,
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def analyze_companies(
|
||||
self,
|
||||
companies: list[str],
|
||||
max_workers: int = 3,
|
||||
progress_callback: Callable[[str, int, int], None] | None = None,
|
||||
) -> BatchAnalysisResult:
|
||||
"""Analyze multiple companies' patent portfolios in batch.
|
||||
|
||||
Processes companies concurrently for improved performance while
|
||||
respecting API rate limits.
|
||||
|
||||
Args:
|
||||
companies: List of company names to analyze
|
||||
max_workers: Maximum concurrent analyses (default 3 to avoid rate limits)
|
||||
progress_callback: Optional callback(company_name, completed, total)
|
||||
|
||||
Returns:
|
||||
BatchAnalysisResult containing all individual results and summary stats
|
||||
"""
|
||||
results: list[CompanyAnalysisResult] = []
|
||||
total = len(companies)
|
||||
|
||||
print(f"Starting batch analysis of {total} companies...")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_company = {
|
||||
executor.submit(self._analyze_company_safe, company): company
|
||||
for company in companies
|
||||
}
|
||||
|
||||
completed = 0
|
||||
for future in as_completed(future_to_company):
|
||||
company = future_to_company[future]
|
||||
completed += 1
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
|
||||
status = "✓" if result.success else "✗"
|
||||
print(f"[{completed}/{total}] {status} {company}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(company, completed, total)
|
||||
|
||||
except Exception as e:
|
||||
results.append(
|
||||
CompanyAnalysisResult(
|
||||
company_name=company,
|
||||
analysis="",
|
||||
patent_count=0,
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
print(f"[{completed}/{total}] ✗ {company}: {e}")
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = total - successful
|
||||
|
||||
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
|
||||
|
||||
return BatchAnalysisResult(
|
||||
results=results,
|
||||
total_companies=total,
|
||||
successful=successful,
|
||||
failed=failed,
|
||||
)
|
||||
|
||||
def analyze_companies_sequential(
|
||||
self, companies: list[str]
|
||||
) -> BatchAnalysisResult:
|
||||
"""Analyze multiple companies sequentially (safer for rate limits).
|
||||
|
||||
Use this when you want more control over API rate limiting or
|
||||
when debugging issues.
|
||||
|
||||
Args:
|
||||
companies: List of company names to analyze
|
||||
|
||||
Returns:
|
||||
BatchAnalysisResult containing all individual results
|
||||
"""
|
||||
results: list[CompanyAnalysisResult] = []
|
||||
total = len(companies)
|
||||
|
||||
print(f"Starting sequential analysis of {total} companies...")
|
||||
|
||||
for idx, company in enumerate(companies, 1):
|
||||
print(f"\n[{idx}/{total}] Analyzing {company}...")
|
||||
result = self._analyze_company_safe(company)
|
||||
results.append(result)
|
||||
|
||||
status = "✓" if result.success else "✗"
|
||||
print(f"[{idx}/{total}] {status} {company}")
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = total - successful
|
||||
|
||||
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
|
||||
|
||||
return BatchAnalysisResult(
|
||||
results=results,
|
||||
total_companies=total,
|
||||
successful=successful,
|
||||
failed=failed,
|
||||
)
|
||||
|
||||
+25
-1
@@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -12,3 +13,26 @@ class Patent:
|
||||
@dataclass
|
||||
class Patents:
|
||||
patents: list[Patent]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompanyAnalysisResult:
|
||||
"""Result of analyzing a single company's patent portfolio."""
|
||||
|
||||
company_name: str
|
||||
analysis: str
|
||||
patent_count: int
|
||||
success: bool
|
||||
error: str | None = None
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchAnalysisResult:
|
||||
"""Result of batch analyzing multiple companies."""
|
||||
|
||||
results: list[CompanyAnalysisResult]
|
||||
total_companies: int
|
||||
successful: int
|
||||
failed: int
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
Reference in New Issue
Block a user