diff --git a/README.md b/README.md index 15db310..2a06340 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,8 @@ SPARC automatically collects, parses, and analyzes patents from companies to pro - **Content Minimization**: Removes verbose descriptions to reduce LLM token usage - **AI Analysis**: Uses Claude 3.5 Sonnet via OpenRouter to analyze innovation quality and market potential - **Portfolio Analysis**: Evaluates multiple patents holistically for comprehensive insights -- **Robust Testing**: 26 tests covering all major functionality +- **Batch Processing**: Analyze multiple companies concurrently with progress tracking +- **Robust Testing**: 31 tests covering all major functionality ## Architecture @@ -99,6 +100,33 @@ result = analyzer.analyze_single_patent( ) ``` +### Multi-Company Batch Analysis + +```python +from SPARC.analyzer import CompanyAnalyzer + +analyzer = CompanyAnalyzer() + +# Analyze multiple companies concurrently (default 3 workers) +batch_result = analyzer.analyze_companies( + ["nvidia", "amd", "intel", "qualcomm"], + max_workers=3 +) + +# Access results +print(f"Analyzed: {batch_result.total_companies}") +print(f"Successful: {batch_result.successful}") +print(f"Failed: {batch_result.failed}") + +for result in batch_result.results: + if result.success: + print(f"{result.company_name}: {result.patent_count} patents") + print(result.analysis) + +# Or use sequential processing (safer for rate limits) +batch_result = analyzer.analyze_companies_sequential(["nvidia", "amd"]) +``` + ## Running Tests ```bash @@ -130,7 +158,7 @@ pytest tests/ --cov=SPARC --cov-report=term-missing - [X] Extract and minimize patent content - [X] LLM integration for analysis - [X] Company performance estimation -- [ ] Multi-company batch processing +- [X] Multi-company batch processing - [ ] FastAPI web service wrapper - [X] Docker containerization - [X] Results persistence (database) diff --git a/SPARC/analyzer.py b/SPARC/analyzer.py index 414a7b8..c833c0d 100644 --- a/SPARC/analyzer.py +++ b/SPARC/analyzer.py @@ -4,10 +4,12 @@ This module ties together patent retrieval, parsing, and LLM analysis to provide company performance estimation based on patent portfolios. """ +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Callable + from SPARC.serp_api import SERP from SPARC.llm import LLMAnalyzer -from SPARC.types import Patent -from typing import List +from SPARC.types import Patent, CompanyAnalysisResult, BatchAnalysisResult class CompanyAnalyzer: @@ -110,3 +112,154 @@ class CompanyAnalyzer: except Exception as e: return f"Failed to analyze patent {patent_id}: {e}" + + def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult: + """Internal wrapper that catches exceptions and returns structured result. + + Args: + company_name: Name of the company to analyze + + Returns: + CompanyAnalysisResult with success/failure status + """ + try: + patents = SERP.query(company_name) + patent_count = len(patents.patents) if patents.patents else 0 + + analysis = self.analyze_company(company_name) + + # Check if analysis indicates failure + if analysis.startswith("No patents found") or analysis.startswith( + "Failed to process" + ): + return CompanyAnalysisResult( + company_name=company_name, + analysis=analysis, + patent_count=patent_count, + success=False, + error=analysis, + ) + + return CompanyAnalysisResult( + company_name=company_name, + analysis=analysis, + patent_count=patent_count, + success=True, + ) + + except Exception as e: + return CompanyAnalysisResult( + company_name=company_name, + analysis="", + patent_count=0, + success=False, + error=str(e), + ) + + def analyze_companies( + self, + companies: list[str], + max_workers: int = 3, + progress_callback: Callable[[str, int, int], None] | None = None, + ) -> BatchAnalysisResult: + """Analyze multiple companies' patent portfolios in batch. + + Processes companies concurrently for improved performance while + respecting API rate limits. + + Args: + companies: List of company names to analyze + max_workers: Maximum concurrent analyses (default 3 to avoid rate limits) + progress_callback: Optional callback(company_name, completed, total) + + Returns: + BatchAnalysisResult containing all individual results and summary stats + """ + results: list[CompanyAnalysisResult] = [] + total = len(companies) + + print(f"Starting batch analysis of {total} companies...") + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_company = { + executor.submit(self._analyze_company_safe, company): company + for company in companies + } + + completed = 0 + for future in as_completed(future_to_company): + company = future_to_company[future] + completed += 1 + + try: + result = future.result() + results.append(result) + + status = "✓" if result.success else "✗" + print(f"[{completed}/{total}] {status} {company}") + + if progress_callback: + progress_callback(company, completed, total) + + except Exception as e: + results.append( + CompanyAnalysisResult( + company_name=company, + analysis="", + patent_count=0, + success=False, + error=str(e), + ) + ) + print(f"[{completed}/{total}] ✗ {company}: {e}") + + successful = sum(1 for r in results if r.success) + failed = total - successful + + print(f"\nBatch complete: {successful} succeeded, {failed} failed") + + return BatchAnalysisResult( + results=results, + total_companies=total, + successful=successful, + failed=failed, + ) + + def analyze_companies_sequential( + self, companies: list[str] + ) -> BatchAnalysisResult: + """Analyze multiple companies sequentially (safer for rate limits). + + Use this when you want more control over API rate limiting or + when debugging issues. + + Args: + companies: List of company names to analyze + + Returns: + BatchAnalysisResult containing all individual results + """ + results: list[CompanyAnalysisResult] = [] + total = len(companies) + + print(f"Starting sequential analysis of {total} companies...") + + for idx, company in enumerate(companies, 1): + print(f"\n[{idx}/{total}] Analyzing {company}...") + result = self._analyze_company_safe(company) + results.append(result) + + status = "✓" if result.success else "✗" + print(f"[{idx}/{total}] {status} {company}") + + successful = sum(1 for r in results if r.success) + failed = total - successful + + print(f"\nBatch complete: {successful} succeeded, {failed} failed") + + return BatchAnalysisResult( + results=results, + total_companies=total, + successful=successful, + failed=failed, + ) diff --git a/SPARC/types.py b/SPARC/types.py index 4dfc809..09c17d1 100644 --- a/SPARC/types.py +++ b/SPARC/types.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field +from datetime import datetime @dataclass @@ -12,3 +13,26 @@ class Patent: @dataclass class Patents: patents: list[Patent] + + +@dataclass +class CompanyAnalysisResult: + """Result of analyzing a single company's patent portfolio.""" + + company_name: str + analysis: str + patent_count: int + success: bool + error: str | None = None + timestamp: datetime = field(default_factory=datetime.now) + + +@dataclass +class BatchAnalysisResult: + """Result of batch analyzing multiple companies.""" + + results: list[CompanyAnalysisResult] + total_companies: int + successful: int + failed: int + timestamp: datetime = field(default_factory=datetime.now) diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index 6f4dbf1..d81be6d 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -1,9 +1,9 @@ """Tests for the high-level company analyzer orchestration.""" import pytest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, call from SPARC.analyzer import CompanyAnalyzer -from SPARC.types import Patent, Patents +from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult class TestCompanyAnalyzer: @@ -176,3 +176,177 @@ class TestCompanyAnalyzer: assert "Failed to analyze patent US999" in result assert "PDF not found" in result + + +class TestBatchProcessing: + """Test multi-company batch processing functionality.""" + + def test_analyze_companies_success(self, mocker): + """Test batch analysis of multiple companies.""" + mock_query = mocker.patch("SPARC.analyzer.SERP.query") + mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents") + mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf") + mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm") + mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer") + + # Setup mock returns + def query_side_effect(company): + patent = Patent( + patent_id=f"US-{company}", + pdf_link=f"http://example.com/{company}.pdf", + ) + return Patents(patents=[patent]) + + mock_query.side_effect = query_side_effect + + def save_side_effect(patent): + patent.pdf_path = f"patents/{patent.patent_id}.pdf" + return patent + + mock_save.side_effect = save_side_effect + mock_parse.return_value = {"abstract": "Test"} + mock_minimize.return_value = "Content" + + mock_llm_instance = Mock() + mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis result" + mock_llm.return_value = mock_llm_instance + + analyzer = CompanyAnalyzer() + result = analyzer.analyze_companies(["CompanyA", "CompanyB"], max_workers=2) + + assert isinstance(result, BatchAnalysisResult) + assert result.total_companies == 2 + assert result.successful == 2 + assert result.failed == 0 + assert len(result.results) == 2 + + def test_analyze_companies_with_failures(self, mocker): + """Test batch analysis handles partial failures.""" + mock_query = mocker.patch("SPARC.analyzer.SERP.query") + mocker.patch("SPARC.analyzer.LLMAnalyzer") + + def query_side_effect(company): + if company == "FailCorp": + return Patents(patents=[]) + patent = Patent( + patent_id=f"US-{company}", + pdf_link=f"http://example.com/{company}.pdf", + ) + return Patents(patents=[patent]) + + mock_query.side_effect = query_side_effect + + analyzer = CompanyAnalyzer() + result = analyzer.analyze_companies(["GoodCorp", "FailCorp"], max_workers=1) + + assert result.total_companies == 2 + assert result.failed >= 1 # At least FailCorp should fail + + # Find the failed result + fail_result = next(r for r in result.results if r.company_name == "FailCorp") + assert fail_result.success is False + + def test_analyze_companies_sequential(self, mocker): + """Test sequential batch analysis.""" + mock_query = mocker.patch("SPARC.analyzer.SERP.query") + mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents") + mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf") + mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm") + mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer") + + def query_side_effect(company): + patent = Patent( + patent_id=f"US-{company}", + pdf_link=f"http://example.com/{company}.pdf", + ) + return Patents(patents=[patent]) + + mock_query.side_effect = query_side_effect + + def save_side_effect(patent): + patent.pdf_path = f"patents/{patent.patent_id}.pdf" + return patent + + mock_save.side_effect = save_side_effect + mock_parse.return_value = {"abstract": "Test"} + mock_minimize.return_value = "Content" + + mock_llm_instance = Mock() + mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis" + mock_llm.return_value = mock_llm_instance + + analyzer = CompanyAnalyzer() + result = analyzer.analyze_companies_sequential(["Corp1", "Corp2", "Corp3"]) + + assert result.total_companies == 3 + assert len(result.results) == 3 + + def test_analyze_companies_progress_callback(self, mocker): + """Test that progress callback is invoked correctly.""" + mock_query = mocker.patch("SPARC.analyzer.SERP.query") + mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents") + mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf") + mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm") + mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer") + + def query_side_effect(company): + patent = Patent( + patent_id=f"US-{company}", + pdf_link=f"http://example.com/{company}.pdf", + ) + return Patents(patents=[patent]) + + mock_query.side_effect = query_side_effect + + def save_side_effect(patent): + patent.pdf_path = f"patents/{patent.patent_id}.pdf" + return patent + + mock_save.side_effect = save_side_effect + mock_parse.return_value = {"abstract": "Test"} + mock_minimize.return_value = "Content" + + mock_llm_instance = Mock() + mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis" + mock_llm.return_value = mock_llm_instance + + callback = Mock() + analyzer = CompanyAnalyzer() + analyzer.analyze_companies(["A", "B"], max_workers=1, progress_callback=callback) + + assert callback.call_count == 2 + + def test_company_analysis_result_structure(self, mocker): + """Test CompanyAnalysisResult has correct structure.""" + mock_query = mocker.patch("SPARC.analyzer.SERP.query") + mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents") + mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf") + mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm") + mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer") + + patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf") + mock_query.return_value = Patents(patents=[patent]) + + def save_side_effect(p): + p.pdf_path = "patents/US123.pdf" + return p + + mock_save.side_effect = save_side_effect + mock_parse.return_value = {"abstract": "Test"} + mock_minimize.return_value = "Content" + + mock_llm_instance = Mock() + mock_llm_instance.analyze_patent_portfolio.return_value = "Strong innovation" + mock_llm.return_value = mock_llm_instance + + analyzer = CompanyAnalyzer() + result = analyzer.analyze_companies(["TestCorp"], max_workers=1) + + assert len(result.results) == 1 + company_result = result.results[0] + assert company_result.company_name == "TestCorp" + assert company_result.analysis == "Strong innovation" + assert company_result.patent_count == 1 + assert company_result.success is True + assert company_result.error is None + assert company_result.timestamp is not None diff --git a/tests/test_llm.py b/tests/test_llm.py index c5133d8..1cc255a 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -25,6 +25,8 @@ class TestLLMAnalyzer: mock_openai = mocker.patch("SPARC.llm.OpenAI") mock_config = mocker.patch("SPARC.llm.config") mock_config.openrouter_api_key = "config-key-456" + mock_config.use_database = False + mock_config.database_url = "postgresql://localhost/test" analyzer = LLMAnalyzer()