Add LLM-based patent classification tagging by technology domain #1692

Open
AI-Manager wants to merge 1 commits from feature/patent-classification-tags into main
8 changed files with 590 additions and 17 deletions
+28 -4
View File
@@ -40,8 +40,9 @@ class CompanyAnalyzer:
1. Retrieve patents from SERP API 1. Retrieve patents from SERP API
2. Download and parse each patent PDF 2. Download and parse each patent PDF
3. Minimize patent content (remove bloat) 3. Minimize patent content (remove bloat)
4. Analyze portfolio with LLM 4. Classify patent technology domain tags
5. Return performance estimation 5. Analyze portfolio with LLM
6. Return performance estimation
Args: Args:
company_name: Name of the company to analyze company_name: Name of the company to analyze
@@ -97,6 +98,17 @@ class CompanyAnalyzer:
if not processed_patents: if not processed_patents:
return f"Failed to process any patents for {company_name}" return f"Failed to process any patents for {company_name}"
# Classify each patent's technology domain tags
logger.info("Classifying patent technology domains...")
for patent_data in processed_patents:
if "tags" not in patent_data or not patent_data["tags"]:
tags = self.llm_analyzer.classify_patent_tags(
patent_content=patent_data["content"], model=model
)
patent_data["tags"] = tags
# Persist tags to the database
self.db.update_patent_tags(patent_data["patent_id"], tags)
logger.info("Analyzing portfolio with LLM...") logger.info("Analyzing portfolio with LLM...")
# Analyze the full portfolio with LLM # Analyze the full portfolio with LLM
@@ -181,7 +193,10 @@ class CompanyAnalyzer:
if db: if db:
cached = db.get_cached_patent(patent.patent_id) cached = db.get_cached_patent(patent.patent_id)
if cached and cached.get("minimized_content"): if cached and cached.get("minimized_content"):
return {"patent_id": patent.patent_id, "content": cached["minimized_content"]} result = {"patent_id": patent.patent_id, "content": cached["minimized_content"]}
if cached.get("patent_tags"):
result["tags"] = cached["patent_tags"]
return result
# Full processing: download, parse, minimize # Full processing: download, parse, minimize
patent = SERP.save_patents(patent) patent = SERP.save_patents(patent)
@@ -217,11 +232,19 @@ class CompanyAnalyzer:
# Delegate to analyze_company which handles SERP/patent caching # Delegate to analyze_company which handles SERP/patent caching
analysis = self.analyze_company(company_name, model=model) analysis = self.analyze_company(company_name, model=model)
# Determine patent count from cached SERP query # Determine patent count and aggregate tags from cached SERP query
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest() query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
cached_ids = self.db.get_cached_serp_query(query_hash) cached_ids = self.db.get_cached_serp_query(query_hash)
patent_count = len(cached_ids) if cached_ids else 0 patent_count = len(cached_ids) if cached_ids else 0
# Collect unique tags across all patents for this company
all_tags: set[str] = set()
if cached_ids:
for pid in cached_ids:
cached_patent = self.db.get_cached_patent(pid)
if cached_patent and cached_patent.get("patent_tags"):
all_tags.update(cached_patent["patent_tags"])
# Check if analysis indicates failure # Check if analysis indicates failure
if analysis.startswith("No patents found") or analysis.startswith( if analysis.startswith("No patents found") or analysis.startswith(
"Failed to process" "Failed to process"
@@ -239,6 +262,7 @@ class CompanyAnalyzer:
analysis=analysis, analysis=analysis,
patent_count=patent_count, patent_count=patent_count,
success=True, success=True,
tags=sorted(all_tags),
) )
except Exception as e: except Exception as e:
+52 -1
View File
@@ -57,6 +57,7 @@ class CompanyAnalysisResponse(BaseModel):
success: bool success: bool
error: str | None = None error: str | None = None
model: str | None = None model: str | None = None
tags: list[str] = []
timestamp: datetime timestamp: datetime
@@ -188,6 +189,7 @@ def _convert_result(result: CompanyAnalysisResult) -> CompanyAnalysisResponse:
success=result.success, success=result.success,
error=result.error, error=result.error,
model=result.model, model=result.model,
tags=result.tags,
timestamp=result.timestamp, timestamp=result.timestamp,
) )
@@ -560,6 +562,38 @@ async def get_analytics(
) )
@app.get("/analytics/tags", tags=["Analytics"])
async def get_tag_distribution(
_: UserResponse = Depends(get_current_user),
):
"""Get the distribution of technology domain tags across all patents.
Returns:
List of tag counts and the canonical tag list
"""
from SPARC.llm import LLMAnalyzer
db = get_db_client()
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT tag, COUNT(*) as count
FROM patents, UNNEST(patent_tags) AS tag
WHERE patent_tags IS NOT NULL
GROUP BY tag
ORDER BY count DESC
"""
)
rows = cur.fetchall()
by_tag = [{"tag": row[0], "count": row[1]} for row in rows]
return {
"by_tag": by_tag,
"canonical_tags": LLMAnalyzer.CANONICAL_TAGS,
}
# ============== Model Selection Endpoints ============== # ============== Model Selection Endpoints ==============
# Supported models via OpenRouter # Supported models via OpenRouter
@@ -969,6 +1003,10 @@ async def list_analysis_results(
str | None, str | None,
Query(description="Filter results by company name"), Query(description="Filter results by company name"),
] = None, ] = None,
tags: Annotated[
str | None,
Query(description="Comma-separated technology domain tags to filter by (e.g. 'ai,semiconductors')"),
] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50, limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[ cursor: Annotated[
str | None, str | None,
@@ -986,14 +1024,27 @@ async def list_analysis_results(
Args: Args:
company_name: Optional filter by company name company_name: Optional filter by company name
tags: Optional comma-separated tag filter (e.g. 'ai,semiconductors')
limit: Maximum number of results to return (default 50, max 200) limit: Maximum number of results to return (default 50, max 200)
cursor: Opaque pagination cursor from a previous response cursor: Opaque pagination cursor from a previous response
Returns: Returns:
Paginated list of analysis results Paginated list of analysis results
""" """
# Parse and validate tags
tag_list = None
if tags:
from SPARC.llm import LLMAnalyzer
tag_list = [t.strip().lower() for t in tags.split(",") if t.strip()]
invalid = [t for t in tag_list if t not in LLMAnalyzer.CANONICAL_TAGS]
if invalid:
raise HTTPException(
status_code=400,
detail=f"Invalid tags: {', '.join(invalid)}. Valid tags: {', '.join(LLMAnalyzer.CANONICAL_TAGS)}",
)
db = _get_job_db() db = _get_job_db()
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor) rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor, tags=tag_list)
has_next = len(rows) > limit has_next = len(rows) > limit
if has_next: if has_next:
+82 -9
View File
@@ -146,15 +146,35 @@ class DatabaseClient:
pdf_link TEXT, pdf_link TEXT,
raw_sections JSONB, raw_sections JSONB,
minimized_content TEXT, minimized_content TEXT,
patent_tags TEXT[],
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) )
""") """)
# Add patent_tags column if it doesn't exist (for existing tables)
cursor.execute("""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'patents' AND column_name = 'patent_tags'
) THEN
ALTER TABLE patents ADD COLUMN patent_tags TEXT[];
END IF;
END $$;
""")
cursor.execute(""" cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_patents_company CREATE INDEX IF NOT EXISTS idx_patents_company
ON patents(company_name) ON patents(company_name)
""") """)
# GIN index for efficient tag array queries
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_patents_tags
ON patents USING GIN(patent_tags)
""")
# Create SERP query cache table # Create SERP query cache table
cursor.execute(""" cursor.execute("""
CREATE TABLE IF NOT EXISTS serp_queries ( CREATE TABLE IF NOT EXISTS serp_queries (
@@ -376,6 +396,7 @@ class DatabaseClient:
company_name: Optional[str] = None, company_name: Optional[str] = None,
limit: int = 50, limit: int = 50,
cursor: Optional[str] = None, cursor: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> List[Dict]: ) -> List[Dict]:
"""List analysis results with cursor-based pagination. """List analysis results with cursor-based pagination.
@@ -383,29 +404,40 @@ class DatabaseClient:
company_name: Optional filter by company name. company_name: Optional filter by company name.
limit: Maximum number of records to return. limit: Maximum number of records to return.
cursor: Opaque cursor (``timestamp|id``) from a previous response. cursor: Opaque cursor (``timestamp|id``) from a previous response.
tags: Optional list of technology domain tags to filter by.
Returns: Returns:
List of analysis dicts ordered by timestamp descending. List of analysis dicts ordered by timestamp descending.
""" """
conditions: list[str] = ["is_cached = FALSE"] conditions: list[str] = ["m.is_cached = FALSE"]
params: list = [] params: list = []
join_clause = ""
if company_name: if company_name:
conditions.append("LOWER(company_name) = LOWER(%s)") conditions.append("LOWER(m.company_name) = LOWER(%s)")
params.append(company_name) params.append(company_name)
if tags:
# Join with patents table to filter by tags
join_clause = (
" INNER JOIN patents p ON LOWER(p.company_name) = LOWER(m.company_name)"
)
conditions.append("p.patent_tags && %s")
params.append(tags)
if cursor: if cursor:
try: try:
ts_str, cursor_id = cursor.rsplit("|", 1) ts_str, cursor_id = cursor.rsplit("|", 1)
conditions.append("(timestamp, id) < (%s, %s)") conditions.append("(m.timestamp, m.id) < (%s, %s)")
params.extend([ts_str, int(cursor_id)]) params.extend([ts_str, int(cursor_id)])
except (ValueError, TypeError): except (ValueError, TypeError):
pass # Ignore malformed cursors; return from start pass # Ignore malformed cursors; return from start
query = "SELECT id, company_name, analysis_type, model, response, timestamp FROM llm_messages" query = "SELECT DISTINCT m.id, m.company_name, m.analysis_type, m.model, m.response, m.timestamp FROM llm_messages m"
query += join_clause
if conditions: if conditions:
query += " WHERE " + " AND ".join(conditions) query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY timestamp DESC, id DESC LIMIT %s" query += " ORDER BY m.timestamp DESC, m.id DESC LIMIT %s"
params.append(limit) params.append(limit)
with self.get_conn() as conn: with self.get_conn() as conn:
@@ -493,22 +525,63 @@ class DatabaseClient:
pdf_link: str, pdf_link: str,
raw_sections: Dict, raw_sections: Dict,
minimized_content: str, minimized_content: str,
patent_tags: Optional[List[str]] = None,
) -> None: ) -> None:
"""Store a processed patent in the cache.""" """Store a processed patent in the cache."""
with self.get_conn() as conn: with self.get_conn() as conn:
with conn.cursor() as cursor: with conn.cursor() as cursor:
cursor.execute( cursor.execute(
""" """
INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content) INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content, patent_tags)
VALUES (%s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s, %s)
ON CONFLICT (patent_id) DO UPDATE SET ON CONFLICT (patent_id) DO UPDATE SET
raw_sections = EXCLUDED.raw_sections, raw_sections = EXCLUDED.raw_sections,
minimized_content = EXCLUDED.minimized_content minimized_content = EXCLUDED.minimized_content,
patent_tags = EXCLUDED.patent_tags
""", """,
(patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content), (patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content, patent_tags),
) )
conn.commit() conn.commit()
def update_patent_tags(self, patent_id: str, tags: List[str]) -> None:
"""Update the technology domain tags for a patent."""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"UPDATE patents SET patent_tags = %s WHERE patent_id = %s",
(tags, patent_id),
)
conn.commit()
def get_patents_by_tags(
self,
tags: List[str],
limit: int = 100,
offset: int = 0,
) -> List[Dict]:
"""Retrieve patents that match any of the given tags.
Args:
tags: List of tag strings to filter by (OR logic)
limit: Max results
offset: Pagination offset
Returns:
List of patent dicts
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT * FROM patents
WHERE patent_tags && %s
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""",
(tags, limit, offset),
)
return [dict(row) for row in cursor.fetchall()]
def get_cached_serp_query(self, query_hash: str) -> Optional[List[str]]: def get_cached_serp_query(self, query_hash: str) -> Optional[List[str]]:
"""Look up cached SERP query results. """Look up cached SERP query results.
+59 -1
View File
@@ -1,7 +1,8 @@
"""LLM integration for patent analysis using OpenRouter.""" """LLM integration for patent analysis using OpenRouter."""
import json
import logging import logging
from typing import Dict from typing import Dict, List
from openai import OpenAI from openai import OpenAI
@@ -246,3 +247,60 @@ Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the co
) )
return placeholder return placeholder
# Canonical technology domain tags
CANONICAL_TAGS = ["ai", "semiconductors", "materials", "biotech", "networking", "other"]
def classify_patent_tags(self, patent_content: str, model: str | None = None) -> List[str]:
"""Classify a patent into one or more technology domain tags.
Sends the patent abstract/claims to the LLM with a classification prompt
and returns a list of canonical tags.
Args:
patent_content: Minimized patent text (abstract, claims, summary)
model: Optional model override
Returns:
List of canonical tag strings from CANONICAL_TAGS
"""
prompt = f"""You are a patent classification system. Analyze the following patent content and assign one or more technology domain tags from this exact list:
Tags: ai, semiconductors, materials, biotech, networking, other
Rules:
- Return ONLY a JSON array of tag strings, e.g. ["ai", "semiconductors"]
- Use ONLY tags from the list above
- Assign "other" only if no other tag fits
- A patent can have multiple tags if it spans domains
Patent Content:
{patent_content}
Return ONLY the JSON array, nothing else."""
effective_model = model or self.model
if self.test_mode:
logger.debug("TEST MODE - Classification prompt:\n%s", prompt)
return ["other"]
if self.client:
try:
response = self.client.chat.completions.create(
model=effective_model,
max_tokens=128,
messages=[{"role": "user", "content": prompt}],
)
raw = response.choices[0].message.content.strip()
tags = json.loads(raw)
# Validate and filter to canonical tags only
valid_tags = [t for t in tags if t in self.CANONICAL_TAGS]
return valid_tags if valid_tags else ["other"]
except (json.JSONDecodeError, AttributeError, TypeError) as e:
logger.warning("Failed to parse classification response: %s", e)
return ["other"]
except Exception as e:
logger.warning("Classification LLM call failed: %s", e)
return ["other"]
return ["other"]
+1
View File
@@ -25,6 +25,7 @@ class CompanyAnalysisResult:
success: bool success: bool
error: str | None = None error: str | None = None
model: str | None = None model: str | None = None
tags: list[str] = field(default_factory=list)
timestamp: datetime = field(default_factory=datetime.now) timestamp: datetime = field(default_factory=datetime.now)
+10
View File
@@ -199,8 +199,18 @@ export const analyticsApi = {
const response = await api.get<TrendData>(`/analytics/trends?days=${days}`); const response = await api.get<TrendData>(`/analytics/trends?days=${days}`);
return response.data; return response.data;
}, },
getTagDistribution: async (): Promise<TagDistribution> => {
const response = await api.get<TagDistribution>('/analytics/tags');
return response.data;
},
}; };
export interface TagDistribution {
by_tag: Array<{ tag: string; count: number }>;
canonical_tags: string[];
}
// Admin API // Admin API
export const adminApi = { export const adminApi = {
listUsers: async (limit = 100, offset = 0): Promise<User[]> => { listUsers: async (limit = 100, offset = 0): Promise<User[]> => {
+95 -1
View File
@@ -1,14 +1,24 @@
import { useState } from 'react'; import { useState } from 'react';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
import { analyticsApi } from '../api/client'; import { analyticsApi } from '../api/client';
import { AlertCircle, Database } from 'lucide-react'; import { AlertCircle, Database, Tag } from 'lucide-react';
import { PieChart, Pie, Cell, BarChart, Bar, LineChart, Line, XAxis, YAxis, Tooltip, ResponsiveContainer, Legend } from 'recharts'; import { PieChart, Pie, Cell, BarChart, Bar, LineChart, Line, XAxis, YAxis, Tooltip, ResponsiveContainer, Legend } from 'recharts';
import { useChartTheme } from '../context/useChartTheme'; import { useChartTheme } from '../context/useChartTheme';
const COLORS = ['#6366f1', '#0ea5e9', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#ec4899', '#14b8a6']; const COLORS = ['#6366f1', '#0ea5e9', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#ec4899', '#14b8a6'];
const TAG_COLORS: Record<string, string> = {
ai: '#6366f1',
semiconductors: '#0ea5e9',
materials: '#10b981',
biotech: '#f59e0b',
networking: '#ec4899',
other: '#8b5cf6',
};
export function AnalyticsPage() { export function AnalyticsPage() {
const [days, setDays] = useState(30); const [days, setDays] = useState(30);
const [selectedTags, setSelectedTags] = useState<string[]>([]);
const chartTheme = useChartTheme(); const chartTheme = useChartTheme();
const { data, isLoading, isError, refetch } = useQuery({ const { data, isLoading, isError, refetch } = useQuery({
@@ -21,6 +31,17 @@ export function AnalyticsPage() {
queryFn: () => analyticsApi.getTrends(days), queryFn: () => analyticsApi.getTrends(days),
}); });
const tagQuery = useQuery({
queryKey: ['analytics-tags'],
queryFn: () => analyticsApi.getTagDistribution(),
});
const toggleTag = (tag: string) => {
setSelectedTags((prev) =>
prev.includes(tag) ? prev.filter((t) => t !== tag) : [...prev, tag]
);
};
if (isLoading) { if (isLoading) {
return ( return (
<div className="space-y-6"> <div className="space-y-6">
@@ -107,6 +128,13 @@ export function AnalyticsPage() {
count: t.count, count: t.count,
})); }));
const tagData = tagQuery.data?.by_tag?.map((t) => ({
name: t.tag,
count: t.count,
})) || [];
const canonicalTags = tagQuery.data?.canonical_tags || [];
return ( return (
<div className="space-y-6"> <div className="space-y-6">
{/* Header */} {/* Header */}
@@ -131,6 +159,50 @@ export function AnalyticsPage() {
</select> </select>
</div> </div>
{/* Tag Filter Controls */}
{canonicalTags.length > 0 && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-4">
<div className="flex items-center gap-2 mb-3">
<Tag size={16} className="text-primary" />
<span className="text-sm font-semibold text-text-primary">Filter by Technology Domain</span>
{selectedTags.length > 0 && (
<button
onClick={() => setSelectedTags([])}
className="ml-auto text-xs text-text-secondary hover:text-primary transition-colors"
>
Clear all
</button>
)}
</div>
<div className="flex flex-wrap gap-2">
{canonicalTags.map((tag) => {
const isActive = selectedTags.includes(tag);
const color = TAG_COLORS[tag] || '#8b5cf6';
const tagCount = tagQuery.data?.by_tag?.find((t) => t.tag === tag)?.count || 0;
return (
<button
key={tag}
onClick={() => toggleTag(tag)}
className={`px-3 py-1.5 rounded-lg text-sm font-medium transition-all ${
isActive
? 'text-white shadow-md'
: 'bg-bg-card/80 text-text-secondary border border-primary/20 hover:border-primary/40'
}`}
style={isActive ? { backgroundColor: color } : {}}
>
{tag}
{tagCount > 0 && (
<span className={`ml-1.5 text-xs ${isActive ? 'opacity-80' : 'opacity-60'}`}>
({tagCount})
</span>
)}
</button>
);
})}
</div>
</div>
)}
{/* Summary Metrics */} {/* Summary Metrics */}
<div className="grid grid-cols-1 md:grid-cols-3 gap-4"> <div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<MetricCard label="Total Analyses" value={data.total_messages} /> <MetricCard label="Total Analyses" value={data.total_messages} />
@@ -187,6 +259,28 @@ export function AnalyticsPage() {
</ResponsiveContainer> </ResponsiveContainer>
</div> </div>
)} )}
{/* Bar Chart - Technology Domain Tags */}
{tagData.length > 0 && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-6">
<h3 className="text-lg font-semibold text-text-primary mb-4">Patents by Technology Domain</h3>
<ResponsiveContainer width="100%" height={300}>
<BarChart data={tagData}>
<XAxis dataKey="name" stroke={chartTheme.axisStroke} fontSize={12} />
<YAxis stroke={chartTheme.axisStroke} fontSize={12} />
<Tooltip
contentStyle={chartTheme.tooltipContentStyle}
labelStyle={chartTheme.tooltipLabelStyle}
/>
<Bar dataKey="count" radius={[4, 4, 0, 0]}>
{tagData.map((entry, index) => (
<Cell key={`tag-cell-${index}`} fill={TAG_COLORS[entry.name] || COLORS[index % COLORS.length]} />
))}
</Bar>
</BarChart>
</ResponsiveContainer>
</div>
)}
</div> </div>
{/* Trend Charts */} {/* Trend Charts */}
+262
View File
@@ -0,0 +1,262 @@
"""Tests for LLM-based patent classification tagging by technology domain."""
import json
from unittest.mock import MagicMock, Mock, patch
import pytest
from SPARC.llm import LLMAnalyzer
class TestClassifyPatentTags:
"""Test the classify_patent_tags method on LLMAnalyzer."""
@pytest.fixture(autouse=True)
def mock_database(self, mocker):
"""Mock the database client for all tests."""
mock_db_client = Mock()
mock_db_client.get_cached_response.return_value = None
mock_db_client.store_message.return_value = 1
mocker.patch("SPARC.llm.DatabaseClient", return_value=mock_db_client)
return mock_db_client
def test_classify_returns_valid_tags(self, mocker, mock_database):
"""Test that classify_patent_tags returns valid canonical tags from LLM."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content='["ai", "semiconductors"]'))]
mock_client.chat.completions.create.return_value = mock_response
analyzer = LLMAnalyzer(api_key="test-key")
tags = analyzer.classify_patent_tags("A patent about neural network accelerator chips")
assert tags == ["ai", "semiconductors"]
mock_client.chat.completions.create.assert_called_once()
# Verify the prompt contains classification instructions
call_args = mock_client.chat.completions.create.call_args
prompt_text = call_args[1]["messages"][0]["content"]
assert "ai" in prompt_text
assert "semiconductors" in prompt_text
assert "JSON array" in prompt_text
def test_classify_filters_invalid_tags(self, mocker, mock_database):
"""Test that invalid tags from LLM response are filtered out."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content='["ai", "quantum_computing", "biotech"]'))]
mock_client.chat.completions.create.return_value = mock_response
analyzer = LLMAnalyzer(api_key="test-key")
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["ai", "biotech"]
assert "quantum_computing" not in tags
def test_classify_returns_other_when_all_invalid(self, mocker, mock_database):
"""Test fallback to 'other' when all LLM tags are invalid."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content='["quantum", "robotics"]'))]
mock_client.chat.completions.create.return_value = mock_response
analyzer = LLMAnalyzer(api_key="test-key")
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["other"]
def test_classify_handles_malformed_json(self, mocker, mock_database):
"""Test graceful handling of non-JSON LLM response."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="ai, semiconductors"))]
mock_client.chat.completions.create.return_value = mock_response
analyzer = LLMAnalyzer(api_key="test-key")
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["other"]
def test_classify_handles_api_error(self, mocker, mock_database):
"""Test graceful fallback when LLM API call fails."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_client.chat.completions.create.side_effect = Exception("API timeout")
analyzer = LLMAnalyzer(api_key="test-key")
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["other"]
def test_classify_test_mode(self, mocker, mock_database):
"""Test that test mode returns 'other' without API call."""
mocker.patch("SPARC.llm.config")
analyzer = LLMAnalyzer(test_mode=True)
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["other"]
def test_classify_no_api_client(self, mocker, mock_database):
"""Test that without API client, classification returns 'other'."""
mocker.patch("SPARC.llm.config")
analyzer = LLMAnalyzer(use_cache=False)
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["other"]
def test_canonical_tags_list(self):
"""Test that the canonical tag list matches requirements."""
expected = ["ai", "semiconductors", "materials", "biotech", "networking", "other"]
assert LLMAnalyzer.CANONICAL_TAGS == expected
def test_classify_uses_model_override(self, mocker, mock_database):
"""Test that model override is passed to the API call."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content='["ai"]'))]
mock_client.chat.completions.create.return_value = mock_response
analyzer = LLMAnalyzer(api_key="test-key")
analyzer.classify_patent_tags("content", model="openai/gpt-4o")
call_args = mock_client.chat.completions.create.call_args
assert call_args[1]["model"] == "openai/gpt-4o"
class TestPatentTagsStorage:
"""Test that tags are persisted to the database."""
@pytest.fixture(autouse=True)
def mock_db(self, mocker):
"""Mock DatabaseClient for all tests."""
mock_db_cls = mocker.patch("SPARC.analyzer.DatabaseClient")
mock_db_instance = MagicMock()
mock_db_instance.get_cached_patent.return_value = None
mock_db_instance.get_cached_serp_query.return_value = None
mock_db_cls.return_value = mock_db_instance
return mock_db_instance
def test_tags_stored_after_classification(self, mocker, mock_db):
"""Test that classify_patent_tags results are stored via update_patent_tags."""
from SPARC.analyzer import CompanyAnalyzer
from SPARC.types import Patent, Patents
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf")
mock_query.return_value = Patents(patents=[patent])
def save_side_effect(p):
p.pdf_path = "patents/US123.pdf"
return p
mock_save.side_effect = save_side_effect
mock_parse.return_value = {"abstract": "Test abstract"}
mock_minimize.return_value = "Minimized content"
mock_llm_instance = Mock()
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis result"
mock_llm_instance.classify_patent_tags.return_value = ["ai", "semiconductors"]
mock_llm_cls.return_value = mock_llm_instance
analyzer = CompanyAnalyzer()
analyzer.analyze_company("TestCorp")
# Verify classification was called
mock_llm_instance.classify_patent_tags.assert_called_once_with(
patent_content="Minimized content", model=None
)
# Verify tags were persisted
mock_db.update_patent_tags.assert_called_once_with("US123", ["ai", "semiconductors"])
def test_cached_tags_skip_classification(self, mocker, mock_db):
"""Test that patents with cached tags skip re-classification."""
from SPARC.analyzer import CompanyAnalyzer
from SPARC.types import Patent, Patents
mocker.patch("SPARC.analyzer.SERP.query")
mocker.patch("SPARC.analyzer.SERP.save_patents")
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
# Simulate DB cache hit with tags
mock_db.get_cached_patent.return_value = {
"patent_id": "US123",
"minimized_content": "Cached content",
"patent_tags": ["biotech"],
}
mock_db.get_cached_serp_query.return_value = ["US123"]
mock_llm_instance = Mock()
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis"
mock_llm_cls.return_value = mock_llm_instance
analyzer = CompanyAnalyzer()
analyzer.analyze_company("TestCorp")
# Should NOT classify since tags already present from cache
mock_llm_instance.classify_patent_tags.assert_not_called()
mock_db.update_patent_tags.assert_not_called()
def test_tags_in_analysis_result(self, mocker, mock_db):
"""Test that tags appear in the CompanyAnalysisResult."""
from SPARC.analyzer import CompanyAnalyzer
from SPARC.types import Patent, Patents
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf")
mock_query.return_value = Patents(patents=[patent])
def save_side_effect(p):
p.pdf_path = "patents/US123.pdf"
return p
mock_save.side_effect = save_side_effect
mock_parse.return_value = {"abstract": "Test abstract"}
mock_minimize.return_value = "Minimized content"
mock_llm_instance = Mock()
mock_llm_instance.analyze_patent_portfolio.return_value = "Strong innovation"
mock_llm_instance.classify_patent_tags.return_value = ["ai", "networking"]
mock_llm_cls.return_value = mock_llm_instance
# After analysis, get_cached_patent returns the patent with tags
mock_db.get_cached_serp_query.side_effect = [None, ["US123"]]
mock_db.get_cached_patent.side_effect = [
None, # First call during _process_single_patent
{"patent_id": "US123", "patent_tags": ["ai", "networking"]}, # Second call during _analyze_company_safe
]
analyzer = CompanyAnalyzer()
result = analyzer._analyze_company_safe("TestCorp")
assert result.success is True
assert result.tags == ["ai", "networking"]