forked from 0xWheatyz/SPARC
Add LLM-based patent classification tagging by technology domain
- Add classify_patent_tags() to LLMAnalyzer with canonical tag list (ai, semiconductors, materials, biotech, networking, other) - Add patent_tags TEXT[] column to patents table with GIN index - Run classification automatically in the analysis pipeline after patent processing; persist tags via update_patent_tags() - Include tags in CompanyAnalysisResult and API response models - Add ?tags= filter to GET /analyze/batch endpoint - Add GET /analytics/tags endpoint for tag distribution data - Add tag filter controls and distribution chart to Analytics page - Add 12 unit tests covering classification, DB storage, and caching Closes leeworks-agents/SPARC#1672 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+28
-4
@@ -40,8 +40,9 @@ class CompanyAnalyzer:
|
||||
1. Retrieve patents from SERP API
|
||||
2. Download and parse each patent PDF
|
||||
3. Minimize patent content (remove bloat)
|
||||
4. Analyze portfolio with LLM
|
||||
5. Return performance estimation
|
||||
4. Classify patent technology domain tags
|
||||
5. Analyze portfolio with LLM
|
||||
6. Return performance estimation
|
||||
|
||||
Args:
|
||||
company_name: Name of the company to analyze
|
||||
@@ -97,6 +98,17 @@ class CompanyAnalyzer:
|
||||
if not processed_patents:
|
||||
return f"Failed to process any patents for {company_name}"
|
||||
|
||||
# Classify each patent's technology domain tags
|
||||
logger.info("Classifying patent technology domains...")
|
||||
for patent_data in processed_patents:
|
||||
if "tags" not in patent_data or not patent_data["tags"]:
|
||||
tags = self.llm_analyzer.classify_patent_tags(
|
||||
patent_content=patent_data["content"], model=model
|
||||
)
|
||||
patent_data["tags"] = tags
|
||||
# Persist tags to the database
|
||||
self.db.update_patent_tags(patent_data["patent_id"], tags)
|
||||
|
||||
logger.info("Analyzing portfolio with LLM...")
|
||||
|
||||
# Analyze the full portfolio with LLM
|
||||
@@ -181,7 +193,10 @@ class CompanyAnalyzer:
|
||||
if db:
|
||||
cached = db.get_cached_patent(patent.patent_id)
|
||||
if cached and cached.get("minimized_content"):
|
||||
return {"patent_id": patent.patent_id, "content": cached["minimized_content"]}
|
||||
result = {"patent_id": patent.patent_id, "content": cached["minimized_content"]}
|
||||
if cached.get("patent_tags"):
|
||||
result["tags"] = cached["patent_tags"]
|
||||
return result
|
||||
|
||||
# Full processing: download, parse, minimize
|
||||
patent = SERP.save_patents(patent)
|
||||
@@ -217,11 +232,19 @@ class CompanyAnalyzer:
|
||||
# Delegate to analyze_company which handles SERP/patent caching
|
||||
analysis = self.analyze_company(company_name, model=model)
|
||||
|
||||
# Determine patent count from cached SERP query
|
||||
# Determine patent count and aggregate tags from cached SERP query
|
||||
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
|
||||
cached_ids = self.db.get_cached_serp_query(query_hash)
|
||||
patent_count = len(cached_ids) if cached_ids else 0
|
||||
|
||||
# Collect unique tags across all patents for this company
|
||||
all_tags: set[str] = set()
|
||||
if cached_ids:
|
||||
for pid in cached_ids:
|
||||
cached_patent = self.db.get_cached_patent(pid)
|
||||
if cached_patent and cached_patent.get("patent_tags"):
|
||||
all_tags.update(cached_patent["patent_tags"])
|
||||
|
||||
# Check if analysis indicates failure
|
||||
if analysis.startswith("No patents found") or analysis.startswith(
|
||||
"Failed to process"
|
||||
@@ -239,6 +262,7 @@ class CompanyAnalyzer:
|
||||
analysis=analysis,
|
||||
patent_count=patent_count,
|
||||
success=True,
|
||||
tags=sorted(all_tags),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
+52
-1
@@ -57,6 +57,7 @@ class CompanyAnalysisResponse(BaseModel):
|
||||
success: bool
|
||||
error: str | None = None
|
||||
model: str | None = None
|
||||
tags: list[str] = []
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@@ -188,6 +189,7 @@ def _convert_result(result: CompanyAnalysisResult) -> CompanyAnalysisResponse:
|
||||
success=result.success,
|
||||
error=result.error,
|
||||
model=result.model,
|
||||
tags=result.tags,
|
||||
timestamp=result.timestamp,
|
||||
)
|
||||
|
||||
@@ -560,6 +562,38 @@ async def get_analytics(
|
||||
)
|
||||
|
||||
|
||||
@app.get("/analytics/tags", tags=["Analytics"])
|
||||
async def get_tag_distribution(
|
||||
_: UserResponse = Depends(get_current_user),
|
||||
):
|
||||
"""Get the distribution of technology domain tags across all patents.
|
||||
|
||||
Returns:
|
||||
List of tag counts and the canonical tag list
|
||||
"""
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
db = get_db_client()
|
||||
|
||||
with db.get_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT tag, COUNT(*) as count
|
||||
FROM patents, UNNEST(patent_tags) AS tag
|
||||
WHERE patent_tags IS NOT NULL
|
||||
GROUP BY tag
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
by_tag = [{"tag": row[0], "count": row[1]} for row in rows]
|
||||
return {
|
||||
"by_tag": by_tag,
|
||||
"canonical_tags": LLMAnalyzer.CANONICAL_TAGS,
|
||||
}
|
||||
|
||||
|
||||
# ============== Model Selection Endpoints ==============
|
||||
|
||||
# Supported models via OpenRouter
|
||||
@@ -969,6 +1003,10 @@ async def list_analysis_results(
|
||||
str | None,
|
||||
Query(description="Filter results by company name"),
|
||||
] = None,
|
||||
tags: Annotated[
|
||||
str | None,
|
||||
Query(description="Comma-separated technology domain tags to filter by (e.g. 'ai,semiconductors')"),
|
||||
] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
||||
cursor: Annotated[
|
||||
str | None,
|
||||
@@ -986,14 +1024,27 @@ async def list_analysis_results(
|
||||
|
||||
Args:
|
||||
company_name: Optional filter by company name
|
||||
tags: Optional comma-separated tag filter (e.g. 'ai,semiconductors')
|
||||
limit: Maximum number of results to return (default 50, max 200)
|
||||
cursor: Opaque pagination cursor from a previous response
|
||||
|
||||
Returns:
|
||||
Paginated list of analysis results
|
||||
"""
|
||||
# Parse and validate tags
|
||||
tag_list = None
|
||||
if tags:
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
tag_list = [t.strip().lower() for t in tags.split(",") if t.strip()]
|
||||
invalid = [t for t in tag_list if t not in LLMAnalyzer.CANONICAL_TAGS]
|
||||
if invalid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid tags: {', '.join(invalid)}. Valid tags: {', '.join(LLMAnalyzer.CANONICAL_TAGS)}",
|
||||
)
|
||||
|
||||
db = _get_job_db()
|
||||
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor)
|
||||
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor, tags=tag_list)
|
||||
|
||||
has_next = len(rows) > limit
|
||||
if has_next:
|
||||
|
||||
+82
-9
@@ -146,15 +146,35 @@ class DatabaseClient:
|
||||
pdf_link TEXT,
|
||||
raw_sections JSONB,
|
||||
minimized_content TEXT,
|
||||
patent_tags TEXT[],
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Add patent_tags column if it doesn't exist (for existing tables)
|
||||
cursor.execute("""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'patents' AND column_name = 'patent_tags'
|
||||
) THEN
|
||||
ALTER TABLE patents ADD COLUMN patent_tags TEXT[];
|
||||
END IF;
|
||||
END $$;
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_patents_company
|
||||
ON patents(company_name)
|
||||
""")
|
||||
|
||||
# GIN index for efficient tag array queries
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_patents_tags
|
||||
ON patents USING GIN(patent_tags)
|
||||
""")
|
||||
|
||||
# Create SERP query cache table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS serp_queries (
|
||||
@@ -376,6 +396,7 @@ class DatabaseClient:
|
||||
company_name: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
cursor: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> List[Dict]:
|
||||
"""List analysis results with cursor-based pagination.
|
||||
|
||||
@@ -383,29 +404,40 @@ class DatabaseClient:
|
||||
company_name: Optional filter by company name.
|
||||
limit: Maximum number of records to return.
|
||||
cursor: Opaque cursor (``timestamp|id``) from a previous response.
|
||||
tags: Optional list of technology domain tags to filter by.
|
||||
|
||||
Returns:
|
||||
List of analysis dicts ordered by timestamp descending.
|
||||
"""
|
||||
conditions: list[str] = ["is_cached = FALSE"]
|
||||
conditions: list[str] = ["m.is_cached = FALSE"]
|
||||
params: list = []
|
||||
join_clause = ""
|
||||
|
||||
if company_name:
|
||||
conditions.append("LOWER(company_name) = LOWER(%s)")
|
||||
conditions.append("LOWER(m.company_name) = LOWER(%s)")
|
||||
params.append(company_name)
|
||||
|
||||
if tags:
|
||||
# Join with patents table to filter by tags
|
||||
join_clause = (
|
||||
" INNER JOIN patents p ON LOWER(p.company_name) = LOWER(m.company_name)"
|
||||
)
|
||||
conditions.append("p.patent_tags && %s")
|
||||
params.append(tags)
|
||||
|
||||
if cursor:
|
||||
try:
|
||||
ts_str, cursor_id = cursor.rsplit("|", 1)
|
||||
conditions.append("(timestamp, id) < (%s, %s)")
|
||||
conditions.append("(m.timestamp, m.id) < (%s, %s)")
|
||||
params.extend([ts_str, int(cursor_id)])
|
||||
except (ValueError, TypeError):
|
||||
pass # Ignore malformed cursors; return from start
|
||||
|
||||
query = "SELECT id, company_name, analysis_type, model, response, timestamp FROM llm_messages"
|
||||
query = "SELECT DISTINCT m.id, m.company_name, m.analysis_type, m.model, m.response, m.timestamp FROM llm_messages m"
|
||||
query += join_clause
|
||||
if conditions:
|
||||
query += " WHERE " + " AND ".join(conditions)
|
||||
query += " ORDER BY timestamp DESC, id DESC LIMIT %s"
|
||||
query += " ORDER BY m.timestamp DESC, m.id DESC LIMIT %s"
|
||||
params.append(limit)
|
||||
|
||||
with self.get_conn() as conn:
|
||||
@@ -493,22 +525,63 @@ class DatabaseClient:
|
||||
pdf_link: str,
|
||||
raw_sections: Dict,
|
||||
minimized_content: str,
|
||||
patent_tags: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Store a processed patent in the cache."""
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content, patent_tags)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
ON CONFLICT (patent_id) DO UPDATE SET
|
||||
raw_sections = EXCLUDED.raw_sections,
|
||||
minimized_content = EXCLUDED.minimized_content
|
||||
minimized_content = EXCLUDED.minimized_content,
|
||||
patent_tags = EXCLUDED.patent_tags
|
||||
""",
|
||||
(patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content),
|
||||
(patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content, patent_tags),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def update_patent_tags(self, patent_id: str, tags: List[str]) -> None:
|
||||
"""Update the technology domain tags for a patent."""
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"UPDATE patents SET patent_tags = %s WHERE patent_id = %s",
|
||||
(tags, patent_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_patents_by_tags(
|
||||
self,
|
||||
tags: List[str],
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> List[Dict]:
|
||||
"""Retrieve patents that match any of the given tags.
|
||||
|
||||
Args:
|
||||
tags: List of tag strings to filter by (OR logic)
|
||||
limit: Max results
|
||||
offset: Pagination offset
|
||||
|
||||
Returns:
|
||||
List of patent dicts
|
||||
"""
|
||||
with self.get_conn() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM patents
|
||||
WHERE patent_tags && %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""",
|
||||
(tags, limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_cached_serp_query(self, query_hash: str) -> Optional[List[str]]:
|
||||
"""Look up cached SERP query results.
|
||||
|
||||
|
||||
+60
-2
@@ -1,7 +1,8 @@
|
||||
"""LLM integration for patent analysis using OpenRouter."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -245,4 +246,61 @@ Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the co
|
||||
metadata={**metadata, "pending": True}
|
||||
)
|
||||
return placeholder
|
||||
|
||||
|
||||
# Canonical technology domain tags
|
||||
CANONICAL_TAGS = ["ai", "semiconductors", "materials", "biotech", "networking", "other"]
|
||||
|
||||
def classify_patent_tags(self, patent_content: str, model: str | None = None) -> List[str]:
|
||||
"""Classify a patent into one or more technology domain tags.
|
||||
|
||||
Sends the patent abstract/claims to the LLM with a classification prompt
|
||||
and returns a list of canonical tags.
|
||||
|
||||
Args:
|
||||
patent_content: Minimized patent text (abstract, claims, summary)
|
||||
model: Optional model override
|
||||
|
||||
Returns:
|
||||
List of canonical tag strings from CANONICAL_TAGS
|
||||
"""
|
||||
prompt = f"""You are a patent classification system. Analyze the following patent content and assign one or more technology domain tags from this exact list:
|
||||
|
||||
Tags: ai, semiconductors, materials, biotech, networking, other
|
||||
|
||||
Rules:
|
||||
- Return ONLY a JSON array of tag strings, e.g. ["ai", "semiconductors"]
|
||||
- Use ONLY tags from the list above
|
||||
- Assign "other" only if no other tag fits
|
||||
- A patent can have multiple tags if it spans domains
|
||||
|
||||
Patent Content:
|
||||
{patent_content}
|
||||
|
||||
Return ONLY the JSON array, nothing else."""
|
||||
|
||||
effective_model = model or self.model
|
||||
|
||||
if self.test_mode:
|
||||
logger.debug("TEST MODE - Classification prompt:\n%s", prompt)
|
||||
return ["other"]
|
||||
|
||||
if self.client:
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=effective_model,
|
||||
max_tokens=128,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
raw = response.choices[0].message.content.strip()
|
||||
tags = json.loads(raw)
|
||||
# Validate and filter to canonical tags only
|
||||
valid_tags = [t for t in tags if t in self.CANONICAL_TAGS]
|
||||
return valid_tags if valid_tags else ["other"]
|
||||
except (json.JSONDecodeError, AttributeError, TypeError) as e:
|
||||
logger.warning("Failed to parse classification response: %s", e)
|
||||
return ["other"]
|
||||
except Exception as e:
|
||||
logger.warning("Classification LLM call failed: %s", e)
|
||||
return ["other"]
|
||||
|
||||
return ["other"]
|
||||
|
||||
@@ -25,6 +25,7 @@ class CompanyAnalysisResult:
|
||||
success: bool
|
||||
error: str | None = None
|
||||
model: str | None = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
|
||||
@@ -199,8 +199,18 @@ export const analyticsApi = {
|
||||
const response = await api.get<TrendData>(`/analytics/trends?days=${days}`);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
getTagDistribution: async (): Promise<TagDistribution> => {
|
||||
const response = await api.get<TagDistribution>('/analytics/tags');
|
||||
return response.data;
|
||||
},
|
||||
};
|
||||
|
||||
export interface TagDistribution {
|
||||
by_tag: Array<{ tag: string; count: number }>;
|
||||
canonical_tags: string[];
|
||||
}
|
||||
|
||||
// Admin API
|
||||
export const adminApi = {
|
||||
listUsers: async (limit = 100, offset = 0): Promise<User[]> => {
|
||||
|
||||
@@ -1,14 +1,24 @@
|
||||
import { useState } from 'react';
|
||||
import { useQuery } from '@tanstack/react-query';
|
||||
import { analyticsApi } from '../api/client';
|
||||
import { AlertCircle, Database } from 'lucide-react';
|
||||
import { AlertCircle, Database, Tag } from 'lucide-react';
|
||||
import { PieChart, Pie, Cell, BarChart, Bar, LineChart, Line, XAxis, YAxis, Tooltip, ResponsiveContainer, Legend } from 'recharts';
|
||||
import { useChartTheme } from '../context/useChartTheme';
|
||||
|
||||
const COLORS = ['#6366f1', '#0ea5e9', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#ec4899', '#14b8a6'];
|
||||
|
||||
const TAG_COLORS: Record<string, string> = {
|
||||
ai: '#6366f1',
|
||||
semiconductors: '#0ea5e9',
|
||||
materials: '#10b981',
|
||||
biotech: '#f59e0b',
|
||||
networking: '#ec4899',
|
||||
other: '#8b5cf6',
|
||||
};
|
||||
|
||||
export function AnalyticsPage() {
|
||||
const [days, setDays] = useState(30);
|
||||
const [selectedTags, setSelectedTags] = useState<string[]>([]);
|
||||
const chartTheme = useChartTheme();
|
||||
|
||||
const { data, isLoading, isError, refetch } = useQuery({
|
||||
@@ -21,6 +31,17 @@ export function AnalyticsPage() {
|
||||
queryFn: () => analyticsApi.getTrends(days),
|
||||
});
|
||||
|
||||
const tagQuery = useQuery({
|
||||
queryKey: ['analytics-tags'],
|
||||
queryFn: () => analyticsApi.getTagDistribution(),
|
||||
});
|
||||
|
||||
const toggleTag = (tag: string) => {
|
||||
setSelectedTags((prev) =>
|
||||
prev.includes(tag) ? prev.filter((t) => t !== tag) : [...prev, tag]
|
||||
);
|
||||
};
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
@@ -107,6 +128,13 @@ export function AnalyticsPage() {
|
||||
count: t.count,
|
||||
}));
|
||||
|
||||
const tagData = tagQuery.data?.by_tag?.map((t) => ({
|
||||
name: t.tag,
|
||||
count: t.count,
|
||||
})) || [];
|
||||
|
||||
const canonicalTags = tagQuery.data?.canonical_tags || [];
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{/* Header */}
|
||||
@@ -131,6 +159,50 @@ export function AnalyticsPage() {
|
||||
</select>
|
||||
</div>
|
||||
|
||||
{/* Tag Filter Controls */}
|
||||
{canonicalTags.length > 0 && (
|
||||
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-4">
|
||||
<div className="flex items-center gap-2 mb-3">
|
||||
<Tag size={16} className="text-primary" />
|
||||
<span className="text-sm font-semibold text-text-primary">Filter by Technology Domain</span>
|
||||
{selectedTags.length > 0 && (
|
||||
<button
|
||||
onClick={() => setSelectedTags([])}
|
||||
className="ml-auto text-xs text-text-secondary hover:text-primary transition-colors"
|
||||
>
|
||||
Clear all
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{canonicalTags.map((tag) => {
|
||||
const isActive = selectedTags.includes(tag);
|
||||
const color = TAG_COLORS[tag] || '#8b5cf6';
|
||||
const tagCount = tagQuery.data?.by_tag?.find((t) => t.tag === tag)?.count || 0;
|
||||
return (
|
||||
<button
|
||||
key={tag}
|
||||
onClick={() => toggleTag(tag)}
|
||||
className={`px-3 py-1.5 rounded-lg text-sm font-medium transition-all ${
|
||||
isActive
|
||||
? 'text-white shadow-md'
|
||||
: 'bg-bg-card/80 text-text-secondary border border-primary/20 hover:border-primary/40'
|
||||
}`}
|
||||
style={isActive ? { backgroundColor: color } : {}}
|
||||
>
|
||||
{tag}
|
||||
{tagCount > 0 && (
|
||||
<span className={`ml-1.5 text-xs ${isActive ? 'opacity-80' : 'opacity-60'}`}>
|
||||
({tagCount})
|
||||
</span>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Summary Metrics */}
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
||||
<MetricCard label="Total Analyses" value={data.total_messages} />
|
||||
@@ -187,6 +259,28 @@ export function AnalyticsPage() {
|
||||
</ResponsiveContainer>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Bar Chart - Technology Domain Tags */}
|
||||
{tagData.length > 0 && (
|
||||
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-6">
|
||||
<h3 className="text-lg font-semibold text-text-primary mb-4">Patents by Technology Domain</h3>
|
||||
<ResponsiveContainer width="100%" height={300}>
|
||||
<BarChart data={tagData}>
|
||||
<XAxis dataKey="name" stroke={chartTheme.axisStroke} fontSize={12} />
|
||||
<YAxis stroke={chartTheme.axisStroke} fontSize={12} />
|
||||
<Tooltip
|
||||
contentStyle={chartTheme.tooltipContentStyle}
|
||||
labelStyle={chartTheme.tooltipLabelStyle}
|
||||
/>
|
||||
<Bar dataKey="count" radius={[4, 4, 0, 0]}>
|
||||
{tagData.map((entry, index) => (
|
||||
<Cell key={`tag-cell-${index}`} fill={TAG_COLORS[entry.name] || COLORS[index % COLORS.length]} />
|
||||
))}
|
||||
</Bar>
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Trend Charts */}
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
"""Tests for LLM-based patent classification tagging by technology domain."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
|
||||
|
||||
class TestClassifyPatentTags:
|
||||
"""Test the classify_patent_tags method on LLMAnalyzer."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_database(self, mocker):
|
||||
"""Mock the database client for all tests."""
|
||||
mock_db_client = Mock()
|
||||
mock_db_client.get_cached_response.return_value = None
|
||||
mock_db_client.store_message.return_value = 1
|
||||
mocker.patch("SPARC.llm.DatabaseClient", return_value=mock_db_client)
|
||||
return mock_db_client
|
||||
|
||||
def test_classify_returns_valid_tags(self, mocker, mock_database):
|
||||
"""Test that classify_patent_tags returns valid canonical tags from LLM."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content='["ai", "semiconductors"]'))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
tags = analyzer.classify_patent_tags("A patent about neural network accelerator chips")
|
||||
|
||||
assert tags == ["ai", "semiconductors"]
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Verify the prompt contains classification instructions
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
prompt_text = call_args[1]["messages"][0]["content"]
|
||||
assert "ai" in prompt_text
|
||||
assert "semiconductors" in prompt_text
|
||||
assert "JSON array" in prompt_text
|
||||
|
||||
def test_classify_filters_invalid_tags(self, mocker, mock_database):
|
||||
"""Test that invalid tags from LLM response are filtered out."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content='["ai", "quantum_computing", "biotech"]'))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["ai", "biotech"]
|
||||
assert "quantum_computing" not in tags
|
||||
|
||||
def test_classify_returns_other_when_all_invalid(self, mocker, mock_database):
|
||||
"""Test fallback to 'other' when all LLM tags are invalid."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content='["quantum", "robotics"]'))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["other"]
|
||||
|
||||
def test_classify_handles_malformed_json(self, mocker, mock_database):
|
||||
"""Test graceful handling of non-JSON LLM response."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="ai, semiconductors"))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["other"]
|
||||
|
||||
def test_classify_handles_api_error(self, mocker, mock_database):
|
||||
"""Test graceful fallback when LLM API call fails."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_client.chat.completions.create.side_effect = Exception("API timeout")
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["other"]
|
||||
|
||||
def test_classify_test_mode(self, mocker, mock_database):
|
||||
"""Test that test mode returns 'other' without API call."""
|
||||
mocker.patch("SPARC.llm.config")
|
||||
|
||||
analyzer = LLMAnalyzer(test_mode=True)
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["other"]
|
||||
|
||||
def test_classify_no_api_client(self, mocker, mock_database):
|
||||
"""Test that without API client, classification returns 'other'."""
|
||||
mocker.patch("SPARC.llm.config")
|
||||
|
||||
analyzer = LLMAnalyzer(use_cache=False)
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["other"]
|
||||
|
||||
def test_canonical_tags_list(self):
|
||||
"""Test that the canonical tag list matches requirements."""
|
||||
expected = ["ai", "semiconductors", "materials", "biotech", "networking", "other"]
|
||||
assert LLMAnalyzer.CANONICAL_TAGS == expected
|
||||
|
||||
def test_classify_uses_model_override(self, mocker, mock_database):
|
||||
"""Test that model override is passed to the API call."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content='["ai"]'))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
analyzer.classify_patent_tags("content", model="openai/gpt-4o")
|
||||
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
assert call_args[1]["model"] == "openai/gpt-4o"
|
||||
|
||||
|
||||
class TestPatentTagsStorage:
|
||||
"""Test that tags are persisted to the database."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db(self, mocker):
|
||||
"""Mock DatabaseClient for all tests."""
|
||||
mock_db_cls = mocker.patch("SPARC.analyzer.DatabaseClient")
|
||||
mock_db_instance = MagicMock()
|
||||
mock_db_instance.get_cached_patent.return_value = None
|
||||
mock_db_instance.get_cached_serp_query.return_value = None
|
||||
mock_db_cls.return_value = mock_db_instance
|
||||
return mock_db_instance
|
||||
|
||||
def test_tags_stored_after_classification(self, mocker, mock_db):
|
||||
"""Test that classify_patent_tags results are stored via update_patent_tags."""
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.types import Patent, Patents
|
||||
|
||||
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
|
||||
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
|
||||
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
|
||||
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf")
|
||||
mock_query.return_value = Patents(patents=[patent])
|
||||
|
||||
def save_side_effect(p):
|
||||
p.pdf_path = "patents/US123.pdf"
|
||||
return p
|
||||
|
||||
mock_save.side_effect = save_side_effect
|
||||
mock_parse.return_value = {"abstract": "Test abstract"}
|
||||
mock_minimize.return_value = "Minimized content"
|
||||
|
||||
mock_llm_instance = Mock()
|
||||
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis result"
|
||||
mock_llm_instance.classify_patent_tags.return_value = ["ai", "semiconductors"]
|
||||
mock_llm_cls.return_value = mock_llm_instance
|
||||
|
||||
analyzer = CompanyAnalyzer()
|
||||
analyzer.analyze_company("TestCorp")
|
||||
|
||||
# Verify classification was called
|
||||
mock_llm_instance.classify_patent_tags.assert_called_once_with(
|
||||
patent_content="Minimized content", model=None
|
||||
)
|
||||
|
||||
# Verify tags were persisted
|
||||
mock_db.update_patent_tags.assert_called_once_with("US123", ["ai", "semiconductors"])
|
||||
|
||||
def test_cached_tags_skip_classification(self, mocker, mock_db):
|
||||
"""Test that patents with cached tags skip re-classification."""
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.types import Patent, Patents
|
||||
|
||||
mocker.patch("SPARC.analyzer.SERP.query")
|
||||
mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
# Simulate DB cache hit with tags
|
||||
mock_db.get_cached_patent.return_value = {
|
||||
"patent_id": "US123",
|
||||
"minimized_content": "Cached content",
|
||||
"patent_tags": ["biotech"],
|
||||
}
|
||||
mock_db.get_cached_serp_query.return_value = ["US123"]
|
||||
|
||||
mock_llm_instance = Mock()
|
||||
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis"
|
||||
mock_llm_cls.return_value = mock_llm_instance
|
||||
|
||||
analyzer = CompanyAnalyzer()
|
||||
analyzer.analyze_company("TestCorp")
|
||||
|
||||
# Should NOT classify since tags already present from cache
|
||||
mock_llm_instance.classify_patent_tags.assert_not_called()
|
||||
mock_db.update_patent_tags.assert_not_called()
|
||||
|
||||
def test_tags_in_analysis_result(self, mocker, mock_db):
|
||||
"""Test that tags appear in the CompanyAnalysisResult."""
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.types import Patent, Patents
|
||||
|
||||
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
|
||||
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
|
||||
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
|
||||
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf")
|
||||
mock_query.return_value = Patents(patents=[patent])
|
||||
|
||||
def save_side_effect(p):
|
||||
p.pdf_path = "patents/US123.pdf"
|
||||
return p
|
||||
|
||||
mock_save.side_effect = save_side_effect
|
||||
mock_parse.return_value = {"abstract": "Test abstract"}
|
||||
mock_minimize.return_value = "Minimized content"
|
||||
|
||||
mock_llm_instance = Mock()
|
||||
mock_llm_instance.analyze_patent_portfolio.return_value = "Strong innovation"
|
||||
mock_llm_instance.classify_patent_tags.return_value = ["ai", "networking"]
|
||||
mock_llm_cls.return_value = mock_llm_instance
|
||||
|
||||
# After analysis, get_cached_patent returns the patent with tags
|
||||
mock_db.get_cached_serp_query.side_effect = [None, ["US123"]]
|
||||
mock_db.get_cached_patent.side_effect = [
|
||||
None, # First call during _process_single_patent
|
||||
{"patent_id": "US123", "patent_tags": ["ai", "networking"]}, # Second call during _analyze_company_safe
|
||||
]
|
||||
|
||||
analyzer = CompanyAnalyzer()
|
||||
result = analyzer._analyze_company_safe("TestCorp")
|
||||
|
||||
assert result.success is True
|
||||
assert result.tags == ["ai", "networking"]
|
||||
Reference in New Issue
Block a user