diff --git a/SPARC/analyzer.py b/SPARC/analyzer.py index c55803b..31ad7f1 100644 --- a/SPARC/analyzer.py +++ b/SPARC/analyzer.py @@ -33,7 +33,7 @@ class CompanyAnalyzer: self.db.connect() self.db.initialize_schema() - def analyze_company(self, company_name: str, patents: "Patents | None" = None) -> str: + def analyze_company(self, company_name: str, patents: "Patents | None" = None, model: str | None = None) -> str: """Analyze a company's performance based on their patent portfolio. This is the main entry point that orchestrates the full pipeline: @@ -46,6 +46,7 @@ class CompanyAnalyzer: Args: company_name: Name of the company to analyze patents: Optional pre-fetched Patents result to avoid duplicate API calls + model: Optional LLM model override (e.g. 'openai/gpt-4o') Returns: Comprehensive analysis of company's innovation and performance outlook @@ -100,12 +101,12 @@ class CompanyAnalyzer: # Analyze the full portfolio with LLM analysis = self.llm_analyzer.analyze_patent_portfolio( - patents_data=processed_patents, company_name=company_name + patents_data=processed_patents, company_name=company_name, model=model ) return analysis - def analyze_single_patent(self, patent_id: str, company_name: str) -> str: + def analyze_single_patent(self, patent_id: str, company_name: str, model: str | None = None) -> str: """Analyze a single patent by ID. If the patent PDF is not already on disk, this method attempts to @@ -116,6 +117,7 @@ class CompanyAnalyzer: Args: patent_id: Publication ID of the patent (e.g. "US-11234567-B2") company_name: Name of the company (for context) + model: Optional LLM model override (e.g. 'openai/gpt-4o') Returns: Analysis of the specific patent's innovation quality @@ -151,7 +153,7 @@ class CompanyAnalyzer: minimized_content = SERP.minimize_patent_for_llm(sections) analysis = self.llm_analyzer.analyze_patent_content( - patent_content=minimized_content, company_name=company_name + patent_content=minimized_content, company_name=company_name, model=model ) return analysis @@ -201,18 +203,19 @@ class CompanyAnalyzer: logger.warning("Failed to process %s: %s", patent.patent_id, e) return None - def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult: + def _analyze_company_safe(self, company_name: str, model: str | None = None) -> CompanyAnalysisResult: """Internal wrapper that catches exceptions and returns structured result. Args: company_name: Name of the company to analyze + model: Optional LLM model override (e.g. 'openai/gpt-4o') Returns: CompanyAnalysisResult with success/failure status """ try: # Delegate to analyze_company which handles SERP/patent caching - analysis = self.analyze_company(company_name) + analysis = self.analyze_company(company_name, model=model) # Determine patent count from cached SERP query query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest() @@ -252,6 +255,7 @@ class CompanyAnalyzer: companies: list[str], max_workers: int = 3, progress_callback: Callable[[str, int, int], None] | None = None, + model: str | None = None, ) -> BatchAnalysisResult: """Analyze multiple companies' patent portfolios in batch. @@ -262,6 +266,7 @@ class CompanyAnalyzer: companies: List of company names to analyze max_workers: Maximum concurrent analyses (default 3 to avoid rate limits) progress_callback: Optional callback(company_name, completed, total) + model: Optional LLM model override (e.g. 'openai/gpt-4o') Returns: BatchAnalysisResult containing all individual results and summary stats @@ -273,7 +278,7 @@ class CompanyAnalyzer: with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_company = { - executor.submit(self._analyze_company_safe, company): company + executor.submit(self._analyze_company_safe, company, model): company for company in companies } diff --git a/SPARC/api.py b/SPARC/api.py index dbcc01e..6fbbcdb 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -799,6 +799,7 @@ async def health_check(): ) async def analyze_company( company_name: str, + model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."), _: UserResponse = Depends(get_current_user), ): """Analyze a single company's patent portfolio. @@ -808,6 +809,7 @@ async def analyze_company( Args: company_name: Name of the company to analyze (e.g., "nvidia", "intel") + model: Optional LLM model override Returns: Analysis results including patent count, AI insights, and success status @@ -815,7 +817,7 @@ async def analyze_company( if not _analyzer: raise HTTPException(status_code=503, detail="Analyzer not initialized") - result = _analyzer._analyze_company_safe(company_name) + result = _analyzer._analyze_company_safe(company_name, model=model) return _convert_result(result) @@ -877,6 +879,7 @@ async def analyze_companies_batch( result = _analyzer.analyze_companies( companies=request.companies, max_workers=request.max_workers, + model=request.model, ) return _convert_batch_result(result) @@ -908,7 +911,7 @@ def _job_row_to_status(row: dict) -> JobStatus: ) -def _run_batch_job(job_id: str, companies: list[str], max_workers: int): +def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: str | None = None): """Background task for batch analysis.""" import json as _json global _analyzer @@ -933,6 +936,7 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int): companies=companies, max_workers=max_workers, progress_callback=progress_callback, + model=model, ) batch_response = _convert_batch_result(result) db.update_job( @@ -988,7 +992,7 @@ async def analyze_companies_async( job_row = db.create_job(job_id=job_id, total_companies=len(request.companies)) background_tasks.add_task( - _run_batch_job, job_id, request.companies, request.max_workers + _run_batch_job, job_id, request.companies, request.max_workers, request.model ) return _job_row_to_status(job_row) diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 7dd76ff..09a4ae6 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -89,29 +89,53 @@ export const authApi = { }, }; +// Model types +export interface ModelInfo { + id: string; + name: string; + provider: string; +} + +export interface ModelsResponse { + models: ModelInfo[]; + default: string; +} + // Analysis API export const analysisApi = { - analyzeCompany: async (companyName: string): Promise => { - const response = await api.get(`/analyze/${encodeURIComponent(companyName)}`); + analyzeCompany: async (companyName: string, model?: string): Promise => { + const params = new URLSearchParams(); + if (model) params.append('model', model); + const qs = params.toString(); + const response = await api.get( + `/analyze/${encodeURIComponent(companyName)}${qs ? `?${qs}` : ''}` + ); return response.data; }, - analyzeBatch: async (companies: string[], maxWorkers = 3): Promise => { + analyzeBatch: async (companies: string[], maxWorkers = 3, model?: string): Promise => { const response = await api.post('/analyze/batch', { companies, max_workers: maxWorkers, + ...(model ? { model } : {}), }); return response.data; }, - analyzeBatchAsync: async (companies: string[], maxWorkers = 3): Promise => { + analyzeBatchAsync: async (companies: string[], maxWorkers = 3, model?: string): Promise => { const response = await api.post('/analyze/batch/async', { companies, max_workers: maxWorkers, + ...(model ? { model } : {}), }); return response.data; }, + listModels: async (): Promise => { + const response = await api.get('/models'); + return response.data; + }, + getJobStatus: async (jobId: string): Promise => { const response = await api.get(`/jobs/${jobId}`); return response.data; diff --git a/frontend/src/pages/Analysis.tsx b/frontend/src/pages/Analysis.tsx index 1ded981..7ec67f7 100644 --- a/frontend/src/pages/Analysis.tsx +++ b/frontend/src/pages/Analysis.tsx @@ -1,15 +1,21 @@ import { useState } from 'react'; -import { useMutation } from '@tanstack/react-query'; +import { useMutation, useQuery } from '@tanstack/react-query'; import { analysisApi, exportApi } from '../api/client'; -import { Search, CheckCircle, AlertCircle, Clock, FileText, Download } from 'lucide-react'; +import { Search, CheckCircle, AlertCircle, Clock, FileText, Download, ChevronDown } from 'lucide-react'; import type { CompanyAnalysis } from '../types'; export function Analysis() { const [companyName, setCompanyName] = useState(''); + const [selectedModel, setSelectedModel] = useState(''); const [result, setResult] = useState(null); + const modelsQuery = useQuery({ + queryKey: ['models'], + queryFn: () => analysisApi.listModels(), + }); + const mutation = useMutation({ - mutationFn: (name: string) => analysisApi.analyzeCompany(name), + mutationFn: (name: string) => analysisApi.analyzeCompany(name, selectedModel || undefined), onSuccess: (data) => setResult(data), }); @@ -33,31 +39,57 @@ export function Analysis() { {/* Search Form */} -
-
- - setCompanyName(e.target.value)} - placeholder="Enter company name (e.g., nvidia, intel, amd)" - className="w-full bg-bg-card/80 border border-primary/30 rounded-xl pl-12 pr-4 py-3 text-text-primary placeholder-text-secondary/50 focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all" - /> + +
+
+ + setCompanyName(e.target.value)} + placeholder="Enter company name (e.g., nvidia, intel, amd)" + className="w-full bg-bg-card/80 border border-primary/30 rounded-xl pl-12 pr-4 py-3 text-text-primary placeholder-text-secondary/50 focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all" + /> +
+ +
+ + {/* Model Selector */} +
+ +
+ + +
- {/* Error */} diff --git a/frontend/src/pages/Batch.tsx b/frontend/src/pages/Batch.tsx index 6620597..4c53bb0 100644 --- a/frontend/src/pages/Batch.tsx +++ b/frontend/src/pages/Batch.tsx @@ -1,5 +1,5 @@ import { useState } from 'react'; -import { useMutation } from '@tanstack/react-query'; +import { useMutation, useQuery } from '@tanstack/react-query'; import { analysisApi } from '../api/client'; import { Rocket, CheckCircle, AlertCircle, ChevronDown, ChevronUp } from 'lucide-react'; import { BarChart, Bar, XAxis, YAxis, Tooltip, ResponsiveContainer, Cell } from 'recharts'; @@ -8,12 +8,18 @@ import type { BatchAnalysisResult } from '../types'; export function Batch() { const [companiesInput, setCompaniesInput] = useState(''); const [maxWorkers, setMaxWorkers] = useState(3); + const [selectedModel, setSelectedModel] = useState(''); const [result, setResult] = useState(null); const [expandedItems, setExpandedItems] = useState>(new Set()); + const modelsQuery = useQuery({ + queryKey: ['models'], + queryFn: () => analysisApi.listModels(), + }); + const mutation = useMutation({ mutationFn: ({ companies, workers }: { companies: string[]; workers: number }) => - analysisApi.analyzeBatch(companies, workers), + analysisApi.analyzeBatch(companies, workers, selectedModel || undefined), onSuccess: (data) => setResult(data), }); @@ -85,6 +91,29 @@ export function Batch() {
{maxWorkers}
+
+ +
+ + +
+
+