forked from 0xWheatyz/SPARC
feat: add model picker to Analysis and Batch pages with full backend wiring
Thread the optional model parameter through the entire analysis pipeline: - analyzer.py: analyze_company, _analyze_company_safe, analyze_companies, and analyze_single_patent now accept and forward model override - api.py: single company endpoint accepts model query param; batch and async batch endpoints pass request.model through to the analyzer - client.ts: analyzeCompany, analyzeBatch, analyzeBatchAsync accept model; add listModels() to fetch available models from GET /models - Analysis.tsx: add model selector dropdown that loads from /models API - Batch.tsx: add model selector alongside the workers slider Users can now pick a specific LLM (GPT-4o, Claude 3.5, Gemini, etc.) per analysis request, or leave it on the server default. Closes leeworks-agents/SPARC#351 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+12
-7
@@ -33,7 +33,7 @@ class CompanyAnalyzer:
|
||||
self.db.connect()
|
||||
self.db.initialize_schema()
|
||||
|
||||
def analyze_company(self, company_name: str, patents: "Patents | None" = None) -> str:
|
||||
def analyze_company(self, company_name: str, patents: "Patents | None" = None, model: str | None = None) -> str:
|
||||
"""Analyze a company's performance based on their patent portfolio.
|
||||
|
||||
This is the main entry point that orchestrates the full pipeline:
|
||||
@@ -46,6 +46,7 @@ class CompanyAnalyzer:
|
||||
Args:
|
||||
company_name: Name of the company to analyze
|
||||
patents: Optional pre-fetched Patents result to avoid duplicate API calls
|
||||
model: Optional LLM model override (e.g. 'openai/gpt-4o')
|
||||
|
||||
Returns:
|
||||
Comprehensive analysis of company's innovation and performance outlook
|
||||
@@ -100,12 +101,12 @@ class CompanyAnalyzer:
|
||||
|
||||
# Analyze the full portfolio with LLM
|
||||
analysis = self.llm_analyzer.analyze_patent_portfolio(
|
||||
patents_data=processed_patents, company_name=company_name
|
||||
patents_data=processed_patents, company_name=company_name, model=model
|
||||
)
|
||||
|
||||
return analysis
|
||||
|
||||
def analyze_single_patent(self, patent_id: str, company_name: str) -> str:
|
||||
def analyze_single_patent(self, patent_id: str, company_name: str, model: str | None = None) -> str:
|
||||
"""Analyze a single patent by ID.
|
||||
|
||||
If the patent PDF is not already on disk, this method attempts to
|
||||
@@ -116,6 +117,7 @@ class CompanyAnalyzer:
|
||||
Args:
|
||||
patent_id: Publication ID of the patent (e.g. "US-11234567-B2")
|
||||
company_name: Name of the company (for context)
|
||||
model: Optional LLM model override (e.g. 'openai/gpt-4o')
|
||||
|
||||
Returns:
|
||||
Analysis of the specific patent's innovation quality
|
||||
@@ -151,7 +153,7 @@ class CompanyAnalyzer:
|
||||
minimized_content = SERP.minimize_patent_for_llm(sections)
|
||||
|
||||
analysis = self.llm_analyzer.analyze_patent_content(
|
||||
patent_content=minimized_content, company_name=company_name
|
||||
patent_content=minimized_content, company_name=company_name, model=model
|
||||
)
|
||||
|
||||
return analysis
|
||||
@@ -201,18 +203,19 @@ class CompanyAnalyzer:
|
||||
logger.warning("Failed to process %s: %s", patent.patent_id, e)
|
||||
return None
|
||||
|
||||
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult:
|
||||
def _analyze_company_safe(self, company_name: str, model: str | None = None) -> CompanyAnalysisResult:
|
||||
"""Internal wrapper that catches exceptions and returns structured result.
|
||||
|
||||
Args:
|
||||
company_name: Name of the company to analyze
|
||||
model: Optional LLM model override (e.g. 'openai/gpt-4o')
|
||||
|
||||
Returns:
|
||||
CompanyAnalysisResult with success/failure status
|
||||
"""
|
||||
try:
|
||||
# Delegate to analyze_company which handles SERP/patent caching
|
||||
analysis = self.analyze_company(company_name)
|
||||
analysis = self.analyze_company(company_name, model=model)
|
||||
|
||||
# Determine patent count from cached SERP query
|
||||
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
|
||||
@@ -252,6 +255,7 @@ class CompanyAnalyzer:
|
||||
companies: list[str],
|
||||
max_workers: int = 3,
|
||||
progress_callback: Callable[[str, int, int], None] | None = None,
|
||||
model: str | None = None,
|
||||
) -> BatchAnalysisResult:
|
||||
"""Analyze multiple companies' patent portfolios in batch.
|
||||
|
||||
@@ -262,6 +266,7 @@ class CompanyAnalyzer:
|
||||
companies: List of company names to analyze
|
||||
max_workers: Maximum concurrent analyses (default 3 to avoid rate limits)
|
||||
progress_callback: Optional callback(company_name, completed, total)
|
||||
model: Optional LLM model override (e.g. 'openai/gpt-4o')
|
||||
|
||||
Returns:
|
||||
BatchAnalysisResult containing all individual results and summary stats
|
||||
@@ -273,7 +278,7 @@ class CompanyAnalyzer:
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_company = {
|
||||
executor.submit(self._analyze_company_safe, company): company
|
||||
executor.submit(self._analyze_company_safe, company, model): company
|
||||
for company in companies
|
||||
}
|
||||
|
||||
|
||||
+7
-3
@@ -799,6 +799,7 @@ async def health_check():
|
||||
)
|
||||
async def analyze_company(
|
||||
company_name: str,
|
||||
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
|
||||
_: UserResponse = Depends(get_current_user),
|
||||
):
|
||||
"""Analyze a single company's patent portfolio.
|
||||
@@ -808,6 +809,7 @@ async def analyze_company(
|
||||
|
||||
Args:
|
||||
company_name: Name of the company to analyze (e.g., "nvidia", "intel")
|
||||
model: Optional LLM model override
|
||||
|
||||
Returns:
|
||||
Analysis results including patent count, AI insights, and success status
|
||||
@@ -815,7 +817,7 @@ async def analyze_company(
|
||||
if not _analyzer:
|
||||
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
||||
|
||||
result = _analyzer._analyze_company_safe(company_name)
|
||||
result = _analyzer._analyze_company_safe(company_name, model=model)
|
||||
return _convert_result(result)
|
||||
|
||||
|
||||
@@ -877,6 +879,7 @@ async def analyze_companies_batch(
|
||||
result = _analyzer.analyze_companies(
|
||||
companies=request.companies,
|
||||
max_workers=request.max_workers,
|
||||
model=request.model,
|
||||
)
|
||||
return _convert_batch_result(result)
|
||||
|
||||
@@ -908,7 +911,7 @@ def _job_row_to_status(row: dict) -> JobStatus:
|
||||
)
|
||||
|
||||
|
||||
def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
|
||||
def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: str | None = None):
|
||||
"""Background task for batch analysis."""
|
||||
import json as _json
|
||||
global _analyzer
|
||||
@@ -933,6 +936,7 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
|
||||
companies=companies,
|
||||
max_workers=max_workers,
|
||||
progress_callback=progress_callback,
|
||||
model=model,
|
||||
)
|
||||
batch_response = _convert_batch_result(result)
|
||||
db.update_job(
|
||||
@@ -988,7 +992,7 @@ async def analyze_companies_async(
|
||||
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies))
|
||||
|
||||
background_tasks.add_task(
|
||||
_run_batch_job, job_id, request.companies, request.max_workers
|
||||
_run_batch_job, job_id, request.companies, request.max_workers, request.model
|
||||
)
|
||||
|
||||
return _job_row_to_status(job_row)
|
||||
|
||||
Reference in New Issue
Block a user