Compare commits

..

3 Commits

Author SHA1 Message Date
agent-company 6aa71eb17e merge: resolve Batch.tsx conflict between model picker and job history
Combine both useQuery hooks (modelsQuery for model selector, jobsQuery for
job history) and pass selectedModel to analyzeBatch while also triggering
jobsQuery.refetch() on successful submission.
2026-03-27 16:44:47 +00:00
AI-Manager fb52d08387 Merge pull request 'feat: add loading skeletons and error states to Batch page' (#352) from feature/343-batch-loading-states into main 2026-03-27 16:43:40 +00:00
agent-company 223d5f7e5d feat: add model picker to Analysis and Batch pages with full backend wiring
Thread the optional model parameter through the entire analysis pipeline:
- analyzer.py: analyze_company, _analyze_company_safe, analyze_companies,
  and analyze_single_patent now accept and forward model override
- api.py: single company endpoint accepts model query param; batch and
  async batch endpoints pass request.model through to the analyzer
- client.ts: analyzeCompany, analyzeBatch, analyzeBatchAsync accept model;
  add listModels() to fetch available models from GET /models
- Analysis.tsx: add model selector dropdown that loads from /models API
- Batch.tsx: add model selector alongside the workers slider

Users can now pick a specific LLM (GPT-4o, Claude 3.5, Gemini, etc.)
per analysis request, or leave it on the server default.

Closes leeworks-agents/SPARC#351

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 16:13:00 +00:00
5 changed files with 136 additions and 42 deletions
+12 -7
View File
@@ -33,7 +33,7 @@ class CompanyAnalyzer:
self.db.connect() self.db.connect()
self.db.initialize_schema() self.db.initialize_schema()
def analyze_company(self, company_name: str, patents: "Patents | None" = None) -> str: def analyze_company(self, company_name: str, patents: "Patents | None" = None, model: str | None = None) -> str:
"""Analyze a company's performance based on their patent portfolio. """Analyze a company's performance based on their patent portfolio.
This is the main entry point that orchestrates the full pipeline: This is the main entry point that orchestrates the full pipeline:
@@ -46,6 +46,7 @@ class CompanyAnalyzer:
Args: Args:
company_name: Name of the company to analyze company_name: Name of the company to analyze
patents: Optional pre-fetched Patents result to avoid duplicate API calls patents: Optional pre-fetched Patents result to avoid duplicate API calls
model: Optional LLM model override (e.g. 'openai/gpt-4o')
Returns: Returns:
Comprehensive analysis of company's innovation and performance outlook Comprehensive analysis of company's innovation and performance outlook
@@ -100,12 +101,12 @@ class CompanyAnalyzer:
# Analyze the full portfolio with LLM # Analyze the full portfolio with LLM
analysis = self.llm_analyzer.analyze_patent_portfolio( analysis = self.llm_analyzer.analyze_patent_portfolio(
patents_data=processed_patents, company_name=company_name patents_data=processed_patents, company_name=company_name, model=model
) )
return analysis return analysis
def analyze_single_patent(self, patent_id: str, company_name: str) -> str: def analyze_single_patent(self, patent_id: str, company_name: str, model: str | None = None) -> str:
"""Analyze a single patent by ID. """Analyze a single patent by ID.
If the patent PDF is not already on disk, this method attempts to If the patent PDF is not already on disk, this method attempts to
@@ -116,6 +117,7 @@ class CompanyAnalyzer:
Args: Args:
patent_id: Publication ID of the patent (e.g. "US-11234567-B2") patent_id: Publication ID of the patent (e.g. "US-11234567-B2")
company_name: Name of the company (for context) company_name: Name of the company (for context)
model: Optional LLM model override (e.g. 'openai/gpt-4o')
Returns: Returns:
Analysis of the specific patent's innovation quality Analysis of the specific patent's innovation quality
@@ -151,7 +153,7 @@ class CompanyAnalyzer:
minimized_content = SERP.minimize_patent_for_llm(sections) minimized_content = SERP.minimize_patent_for_llm(sections)
analysis = self.llm_analyzer.analyze_patent_content( analysis = self.llm_analyzer.analyze_patent_content(
patent_content=minimized_content, company_name=company_name patent_content=minimized_content, company_name=company_name, model=model
) )
return analysis return analysis
@@ -201,18 +203,19 @@ class CompanyAnalyzer:
logger.warning("Failed to process %s: %s", patent.patent_id, e) logger.warning("Failed to process %s: %s", patent.patent_id, e)
return None return None
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult: def _analyze_company_safe(self, company_name: str, model: str | None = None) -> CompanyAnalysisResult:
"""Internal wrapper that catches exceptions and returns structured result. """Internal wrapper that catches exceptions and returns structured result.
Args: Args:
company_name: Name of the company to analyze company_name: Name of the company to analyze
model: Optional LLM model override (e.g. 'openai/gpt-4o')
Returns: Returns:
CompanyAnalysisResult with success/failure status CompanyAnalysisResult with success/failure status
""" """
try: try:
# Delegate to analyze_company which handles SERP/patent caching # Delegate to analyze_company which handles SERP/patent caching
analysis = self.analyze_company(company_name) analysis = self.analyze_company(company_name, model=model)
# Determine patent count from cached SERP query # Determine patent count from cached SERP query
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest() query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
@@ -252,6 +255,7 @@ class CompanyAnalyzer:
companies: list[str], companies: list[str],
max_workers: int = 3, max_workers: int = 3,
progress_callback: Callable[[str, int, int], None] | None = None, progress_callback: Callable[[str, int, int], None] | None = None,
model: str | None = None,
) -> BatchAnalysisResult: ) -> BatchAnalysisResult:
"""Analyze multiple companies' patent portfolios in batch. """Analyze multiple companies' patent portfolios in batch.
@@ -262,6 +266,7 @@ class CompanyAnalyzer:
companies: List of company names to analyze companies: List of company names to analyze
max_workers: Maximum concurrent analyses (default 3 to avoid rate limits) max_workers: Maximum concurrent analyses (default 3 to avoid rate limits)
progress_callback: Optional callback(company_name, completed, total) progress_callback: Optional callback(company_name, completed, total)
model: Optional LLM model override (e.g. 'openai/gpt-4o')
Returns: Returns:
BatchAnalysisResult containing all individual results and summary stats BatchAnalysisResult containing all individual results and summary stats
@@ -273,7 +278,7 @@ class CompanyAnalyzer:
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_company = { future_to_company = {
executor.submit(self._analyze_company_safe, company): company executor.submit(self._analyze_company_safe, company, model): company
for company in companies for company in companies
} }
+7 -3
View File
@@ -799,6 +799,7 @@ async def health_check():
) )
async def analyze_company( async def analyze_company(
company_name: str, company_name: str,
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
"""Analyze a single company's patent portfolio. """Analyze a single company's patent portfolio.
@@ -808,6 +809,7 @@ async def analyze_company(
Args: Args:
company_name: Name of the company to analyze (e.g., "nvidia", "intel") company_name: Name of the company to analyze (e.g., "nvidia", "intel")
model: Optional LLM model override
Returns: Returns:
Analysis results including patent count, AI insights, and success status Analysis results including patent count, AI insights, and success status
@@ -815,7 +817,7 @@ async def analyze_company(
if not _analyzer: if not _analyzer:
raise HTTPException(status_code=503, detail="Analyzer not initialized") raise HTTPException(status_code=503, detail="Analyzer not initialized")
result = _analyzer._analyze_company_safe(company_name) result = _analyzer._analyze_company_safe(company_name, model=model)
return _convert_result(result) return _convert_result(result)
@@ -877,6 +879,7 @@ async def analyze_companies_batch(
result = _analyzer.analyze_companies( result = _analyzer.analyze_companies(
companies=request.companies, companies=request.companies,
max_workers=request.max_workers, max_workers=request.max_workers,
model=request.model,
) )
return _convert_batch_result(result) return _convert_batch_result(result)
@@ -908,7 +911,7 @@ def _job_row_to_status(row: dict) -> JobStatus:
) )
def _run_batch_job(job_id: str, companies: list[str], max_workers: int): def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: str | None = None):
"""Background task for batch analysis.""" """Background task for batch analysis."""
import json as _json import json as _json
global _analyzer global _analyzer
@@ -933,6 +936,7 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
companies=companies, companies=companies,
max_workers=max_workers, max_workers=max_workers,
progress_callback=progress_callback, progress_callback=progress_callback,
model=model,
) )
batch_response = _convert_batch_result(result) batch_response = _convert_batch_result(result)
db.update_job( db.update_job(
@@ -988,7 +992,7 @@ async def analyze_companies_async(
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies)) job_row = db.create_job(job_id=job_id, total_companies=len(request.companies))
background_tasks.add_task( background_tasks.add_task(
_run_batch_job, job_id, request.companies, request.max_workers _run_batch_job, job_id, request.companies, request.max_workers, request.model
) )
return _job_row_to_status(job_row) return _job_row_to_status(job_row)
+28 -4
View File
@@ -89,29 +89,53 @@ export const authApi = {
}, },
}; };
// Model types
export interface ModelInfo {
id: string;
name: string;
provider: string;
}
export interface ModelsResponse {
models: ModelInfo[];
default: string;
}
// Analysis API // Analysis API
export const analysisApi = { export const analysisApi = {
analyzeCompany: async (companyName: string): Promise<CompanyAnalysis> => { analyzeCompany: async (companyName: string, model?: string): Promise<CompanyAnalysis> => {
const response = await api.get<CompanyAnalysis>(`/analyze/${encodeURIComponent(companyName)}`); const params = new URLSearchParams();
if (model) params.append('model', model);
const qs = params.toString();
const response = await api.get<CompanyAnalysis>(
`/analyze/${encodeURIComponent(companyName)}${qs ? `?${qs}` : ''}`
);
return response.data; return response.data;
}, },
analyzeBatch: async (companies: string[], maxWorkers = 3): Promise<BatchAnalysisResult> => { analyzeBatch: async (companies: string[], maxWorkers = 3, model?: string): Promise<BatchAnalysisResult> => {
const response = await api.post<BatchAnalysisResult>('/analyze/batch', { const response = await api.post<BatchAnalysisResult>('/analyze/batch', {
companies, companies,
max_workers: maxWorkers, max_workers: maxWorkers,
...(model ? { model } : {}),
}); });
return response.data; return response.data;
}, },
analyzeBatchAsync: async (companies: string[], maxWorkers = 3): Promise<JobStatus> => { analyzeBatchAsync: async (companies: string[], maxWorkers = 3, model?: string): Promise<JobStatus> => {
const response = await api.post<JobStatus>('/analyze/batch/async', { const response = await api.post<JobStatus>('/analyze/batch/async', {
companies, companies,
max_workers: maxWorkers, max_workers: maxWorkers,
...(model ? { model } : {}),
}); });
return response.data; return response.data;
}, },
listModels: async (): Promise<ModelsResponse> => {
const response = await api.get<ModelsResponse>('/models');
return response.data;
},
getJobStatus: async (jobId: string): Promise<JobStatus> => { getJobStatus: async (jobId: string): Promise<JobStatus> => {
const response = await api.get<JobStatus>(`/jobs/${jobId}`); const response = await api.get<JobStatus>(`/jobs/${jobId}`);
return response.data; return response.data;
+59 -27
View File
@@ -1,15 +1,21 @@
import { useState } from 'react'; import { useState } from 'react';
import { useMutation } from '@tanstack/react-query'; import { useMutation, useQuery } from '@tanstack/react-query';
import { analysisApi, exportApi } from '../api/client'; import { analysisApi, exportApi } from '../api/client';
import { Search, CheckCircle, AlertCircle, Clock, FileText, Download } from 'lucide-react'; import { Search, CheckCircle, AlertCircle, Clock, FileText, Download, ChevronDown } from 'lucide-react';
import type { CompanyAnalysis } from '../types'; import type { CompanyAnalysis } from '../types';
export function Analysis() { export function Analysis() {
const [companyName, setCompanyName] = useState(''); const [companyName, setCompanyName] = useState('');
const [selectedModel, setSelectedModel] = useState('');
const [result, setResult] = useState<CompanyAnalysis | null>(null); const [result, setResult] = useState<CompanyAnalysis | null>(null);
const modelsQuery = useQuery({
queryKey: ['models'],
queryFn: () => analysisApi.listModels(),
});
const mutation = useMutation({ const mutation = useMutation({
mutationFn: (name: string) => analysisApi.analyzeCompany(name), mutationFn: (name: string) => analysisApi.analyzeCompany(name, selectedModel || undefined),
onSuccess: (data) => setResult(data), onSuccess: (data) => setResult(data),
}); });
@@ -33,31 +39,57 @@ export function Analysis() {
</div> </div>
{/* Search Form */} {/* Search Form */}
<form onSubmit={handleSubmit} className="flex gap-4"> <form onSubmit={handleSubmit} className="space-y-4">
<div className="flex-1 relative"> <div className="flex gap-4">
<Search className="absolute left-4 top-1/2 -translate-y-1/2 text-text-secondary" size={18} /> <div className="flex-1 relative">
<input <Search className="absolute left-4 top-1/2 -translate-y-1/2 text-text-secondary" size={18} />
type="text" <input
value={companyName} type="text"
onChange={(e) => setCompanyName(e.target.value)} value={companyName}
placeholder="Enter company name (e.g., nvidia, intel, amd)" onChange={(e) => setCompanyName(e.target.value)}
className="w-full bg-bg-card/80 border border-primary/30 rounded-xl pl-12 pr-4 py-3 text-text-primary placeholder-text-secondary/50 focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all" placeholder="Enter company name (e.g., nvidia, intel, amd)"
/> className="w-full bg-bg-card/80 border border-primary/30 rounded-xl pl-12 pr-4 py-3 text-text-primary placeholder-text-secondary/50 focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all"
/>
</div>
<button
type="submit"
disabled={mutation.isPending || !companyName.trim()}
className="bg-gradient-to-r from-primary to-primary-dark text-white font-semibold py-3 px-6 rounded-xl hover:shadow-lg hover:shadow-primary/30 transition-all disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2"
>
{mutation.isPending ? (
<div className="animate-spin rounded-full h-5 w-5 border-t-2 border-b-2 border-white"></div>
) : (
<>
<Search size={18} />
Analyze
</>
)}
</button>
</div>
{/* Model Selector */}
<div className="flex items-center gap-3">
<label className="text-sm font-medium text-text-secondary whitespace-nowrap">
LLM Model
</label>
<div className="relative flex-1 max-w-xs">
<select
value={selectedModel}
onChange={(e) => setSelectedModel(e.target.value)}
className="w-full appearance-none bg-bg-card/80 border border-primary/30 rounded-lg pl-3 pr-8 py-2 text-sm text-text-primary focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all cursor-pointer"
>
<option value="">
{modelsQuery.data ? `Default (${modelsQuery.data.default})` : 'Default'}
</option>
{modelsQuery.data?.models.map((m) => (
<option key={m.id} value={m.id}>
{m.name} ({m.provider})
</option>
))}
</select>
<ChevronDown className="absolute right-2 top-1/2 -translate-y-1/2 text-text-secondary pointer-events-none" size={16} />
</div>
</div> </div>
<button
type="submit"
disabled={mutation.isPending || !companyName.trim()}
className="bg-gradient-to-r from-primary to-primary-dark text-white font-semibold py-3 px-6 rounded-xl hover:shadow-lg hover:shadow-primary/30 transition-all disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2"
>
{mutation.isPending ? (
<div className="animate-spin rounded-full h-5 w-5 border-t-2 border-b-2 border-white"></div>
) : (
<>
<Search size={18} />
Analyze
</>
)}
</button>
</form> </form>
{/* Error */} {/* Error */}
+30 -1
View File
@@ -8,9 +8,15 @@ import type { BatchAnalysisResult } from '../types';
export function Batch() { export function Batch() {
const [companiesInput, setCompaniesInput] = useState(''); const [companiesInput, setCompaniesInput] = useState('');
const [maxWorkers, setMaxWorkers] = useState(3); const [maxWorkers, setMaxWorkers] = useState(3);
const [selectedModel, setSelectedModel] = useState('');
const [result, setResult] = useState<BatchAnalysisResult | null>(null); const [result, setResult] = useState<BatchAnalysisResult | null>(null);
const [expandedItems, setExpandedItems] = useState<Set<string>>(new Set()); const [expandedItems, setExpandedItems] = useState<Set<string>>(new Set());
const modelsQuery = useQuery({
queryKey: ['models'],
queryFn: () => analysisApi.listModels(),
});
const jobsQuery = useQuery({ const jobsQuery = useQuery({
queryKey: ['jobs'], queryKey: ['jobs'],
queryFn: () => analysisApi.listJobs(undefined, 20), queryFn: () => analysisApi.listJobs(undefined, 20),
@@ -18,7 +24,7 @@ export function Batch() {
const mutation = useMutation({ const mutation = useMutation({
mutationFn: ({ companies, workers }: { companies: string[]; workers: number }) => mutationFn: ({ companies, workers }: { companies: string[]; workers: number }) =>
analysisApi.analyzeBatch(companies, workers), analysisApi.analyzeBatch(companies, workers, selectedModel || undefined),
onSuccess: (data) => { onSuccess: (data) => {
setResult(data); setResult(data);
jobsQuery.refetch(); jobsQuery.refetch();
@@ -93,6 +99,29 @@ export function Batch() {
<div className="text-center text-text-primary font-semibold">{maxWorkers}</div> <div className="text-center text-text-primary font-semibold">{maxWorkers}</div>
</div> </div>
<div>
<label className="block text-sm font-medium text-text-secondary mb-2">
LLM Model
</label>
<div className="relative">
<select
value={selectedModel}
onChange={(e) => setSelectedModel(e.target.value)}
className="w-full appearance-none bg-bg-card/80 border border-primary/30 rounded-lg pl-3 pr-8 py-2 text-sm text-text-primary focus:outline-none focus:border-primary focus:ring-2 focus:ring-primary/20 transition-all cursor-pointer"
>
<option value="">
{modelsQuery.data ? `Default (${modelsQuery.data.default})` : 'Default'}
</option>
{modelsQuery.data?.models.map((m) => (
<option key={m.id} value={m.id}>
{m.name} ({m.provider})
</option>
))}
</select>
<ChevronDown className="absolute right-2 top-1/2 -translate-y-1/2 text-text-secondary pointer-events-none" size={16} />
</div>
</div>
<button <button
type="submit" type="submit"
disabled={mutation.isPending || !companiesInput.trim()} disabled={mutation.isPending || !companiesInput.trim()}