forked from 0xWheatyz/SPARC
Add LLM-based patent classification tagging by technology domain #1692
+28
-4
@@ -40,8 +40,9 @@ class CompanyAnalyzer:
|
|||||||
1. Retrieve patents from SERP API
|
1. Retrieve patents from SERP API
|
||||||
2. Download and parse each patent PDF
|
2. Download and parse each patent PDF
|
||||||
3. Minimize patent content (remove bloat)
|
3. Minimize patent content (remove bloat)
|
||||||
4. Analyze portfolio with LLM
|
4. Classify patent technology domain tags
|
||||||
5. Return performance estimation
|
5. Analyze portfolio with LLM
|
||||||
|
6. Return performance estimation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
company_name: Name of the company to analyze
|
company_name: Name of the company to analyze
|
||||||
@@ -97,6 +98,17 @@ class CompanyAnalyzer:
|
|||||||
if not processed_patents:
|
if not processed_patents:
|
||||||
return f"Failed to process any patents for {company_name}"
|
return f"Failed to process any patents for {company_name}"
|
||||||
|
|
||||||
|
# Classify each patent's technology domain tags
|
||||||
|
logger.info("Classifying patent technology domains...")
|
||||||
|
for patent_data in processed_patents:
|
||||||
|
if "tags" not in patent_data or not patent_data["tags"]:
|
||||||
|
tags = self.llm_analyzer.classify_patent_tags(
|
||||||
|
patent_content=patent_data["content"], model=model
|
||||||
|
)
|
||||||
|
patent_data["tags"] = tags
|
||||||
|
# Persist tags to the database
|
||||||
|
self.db.update_patent_tags(patent_data["patent_id"], tags)
|
||||||
|
|
||||||
logger.info("Analyzing portfolio with LLM...")
|
logger.info("Analyzing portfolio with LLM...")
|
||||||
|
|
||||||
# Analyze the full portfolio with LLM
|
# Analyze the full portfolio with LLM
|
||||||
@@ -181,7 +193,10 @@ class CompanyAnalyzer:
|
|||||||
if db:
|
if db:
|
||||||
cached = db.get_cached_patent(patent.patent_id)
|
cached = db.get_cached_patent(patent.patent_id)
|
||||||
if cached and cached.get("minimized_content"):
|
if cached and cached.get("minimized_content"):
|
||||||
return {"patent_id": patent.patent_id, "content": cached["minimized_content"]}
|
result = {"patent_id": patent.patent_id, "content": cached["minimized_content"]}
|
||||||
|
if cached.get("patent_tags"):
|
||||||
|
result["tags"] = cached["patent_tags"]
|
||||||
|
return result
|
||||||
|
|
||||||
# Full processing: download, parse, minimize
|
# Full processing: download, parse, minimize
|
||||||
patent = SERP.save_patents(patent)
|
patent = SERP.save_patents(patent)
|
||||||
@@ -217,11 +232,19 @@ class CompanyAnalyzer:
|
|||||||
# Delegate to analyze_company which handles SERP/patent caching
|
# Delegate to analyze_company which handles SERP/patent caching
|
||||||
analysis = self.analyze_company(company_name, model=model)
|
analysis = self.analyze_company(company_name, model=model)
|
||||||
|
|
||||||
# Determine patent count from cached SERP query
|
# Determine patent count and aggregate tags from cached SERP query
|
||||||
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
|
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
|
||||||
cached_ids = self.db.get_cached_serp_query(query_hash)
|
cached_ids = self.db.get_cached_serp_query(query_hash)
|
||||||
patent_count = len(cached_ids) if cached_ids else 0
|
patent_count = len(cached_ids) if cached_ids else 0
|
||||||
|
|
||||||
|
# Collect unique tags across all patents for this company
|
||||||
|
all_tags: set[str] = set()
|
||||||
|
if cached_ids:
|
||||||
|
for pid in cached_ids:
|
||||||
|
cached_patent = self.db.get_cached_patent(pid)
|
||||||
|
if cached_patent and cached_patent.get("patent_tags"):
|
||||||
|
all_tags.update(cached_patent["patent_tags"])
|
||||||
|
|
||||||
# Check if analysis indicates failure
|
# Check if analysis indicates failure
|
||||||
if analysis.startswith("No patents found") or analysis.startswith(
|
if analysis.startswith("No patents found") or analysis.startswith(
|
||||||
"Failed to process"
|
"Failed to process"
|
||||||
@@ -239,6 +262,7 @@ class CompanyAnalyzer:
|
|||||||
analysis=analysis,
|
analysis=analysis,
|
||||||
patent_count=patent_count,
|
patent_count=patent_count,
|
||||||
success=True,
|
success=True,
|
||||||
|
tags=sorted(all_tags),
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
+52
-1
@@ -57,6 +57,7 @@ class CompanyAnalysisResponse(BaseModel):
|
|||||||
success: bool
|
success: bool
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
model: str | None = None
|
model: str | None = None
|
||||||
|
tags: list[str] = []
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
|
|
||||||
|
|
||||||
@@ -188,6 +189,7 @@ def _convert_result(result: CompanyAnalysisResult) -> CompanyAnalysisResponse:
|
|||||||
success=result.success,
|
success=result.success,
|
||||||
error=result.error,
|
error=result.error,
|
||||||
model=result.model,
|
model=result.model,
|
||||||
|
tags=result.tags,
|
||||||
timestamp=result.timestamp,
|
timestamp=result.timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -560,6 +562,38 @@ async def get_analytics(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/analytics/tags", tags=["Analytics"])
|
||||||
|
async def get_tag_distribution(
|
||||||
|
_: UserResponse = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Get the distribution of technology domain tags across all patents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tag counts and the canonical tag list
|
||||||
|
"""
|
||||||
|
from SPARC.llm import LLMAnalyzer
|
||||||
|
db = get_db_client()
|
||||||
|
|
||||||
|
with db.get_conn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT tag, COUNT(*) as count
|
||||||
|
FROM patents, UNNEST(patent_tags) AS tag
|
||||||
|
WHERE patent_tags IS NOT NULL
|
||||||
|
GROUP BY tag
|
||||||
|
ORDER BY count DESC
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
|
||||||
|
by_tag = [{"tag": row[0], "count": row[1]} for row in rows]
|
||||||
|
return {
|
||||||
|
"by_tag": by_tag,
|
||||||
|
"canonical_tags": LLMAnalyzer.CANONICAL_TAGS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ============== Model Selection Endpoints ==============
|
# ============== Model Selection Endpoints ==============
|
||||||
|
|
||||||
# Supported models via OpenRouter
|
# Supported models via OpenRouter
|
||||||
@@ -969,6 +1003,10 @@ async def list_analysis_results(
|
|||||||
str | None,
|
str | None,
|
||||||
Query(description="Filter results by company name"),
|
Query(description="Filter results by company name"),
|
||||||
] = None,
|
] = None,
|
||||||
|
tags: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Comma-separated technology domain tags to filter by (e.g. 'ai,semiconductors')"),
|
||||||
|
] = None,
|
||||||
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
||||||
cursor: Annotated[
|
cursor: Annotated[
|
||||||
str | None,
|
str | None,
|
||||||
@@ -986,14 +1024,27 @@ async def list_analysis_results(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
company_name: Optional filter by company name
|
company_name: Optional filter by company name
|
||||||
|
tags: Optional comma-separated tag filter (e.g. 'ai,semiconductors')
|
||||||
limit: Maximum number of results to return (default 50, max 200)
|
limit: Maximum number of results to return (default 50, max 200)
|
||||||
cursor: Opaque pagination cursor from a previous response
|
cursor: Opaque pagination cursor from a previous response
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Paginated list of analysis results
|
Paginated list of analysis results
|
||||||
"""
|
"""
|
||||||
|
# Parse and validate tags
|
||||||
|
tag_list = None
|
||||||
|
if tags:
|
||||||
|
from SPARC.llm import LLMAnalyzer
|
||||||
|
tag_list = [t.strip().lower() for t in tags.split(",") if t.strip()]
|
||||||
|
invalid = [t for t in tag_list if t not in LLMAnalyzer.CANONICAL_TAGS]
|
||||||
|
if invalid:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid tags: {', '.join(invalid)}. Valid tags: {', '.join(LLMAnalyzer.CANONICAL_TAGS)}",
|
||||||
|
)
|
||||||
|
|
||||||
db = _get_job_db()
|
db = _get_job_db()
|
||||||
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor)
|
rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor, tags=tag_list)
|
||||||
|
|
||||||
has_next = len(rows) > limit
|
has_next = len(rows) > limit
|
||||||
if has_next:
|
if has_next:
|
||||||
|
|||||||
+82
-9
@@ -146,15 +146,35 @@ class DatabaseClient:
|
|||||||
pdf_link TEXT,
|
pdf_link TEXT,
|
||||||
raw_sections JSONB,
|
raw_sections JSONB,
|
||||||
minimized_content TEXT,
|
minimized_content TEXT,
|
||||||
|
patent_tags TEXT[],
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# Add patent_tags column if it doesn't exist (for existing tables)
|
||||||
|
cursor.execute("""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_name = 'patents' AND column_name = 'patent_tags'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE patents ADD COLUMN patent_tags TEXT[];
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE INDEX IF NOT EXISTS idx_patents_company
|
CREATE INDEX IF NOT EXISTS idx_patents_company
|
||||||
ON patents(company_name)
|
ON patents(company_name)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# GIN index for efficient tag array queries
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_patents_tags
|
||||||
|
ON patents USING GIN(patent_tags)
|
||||||
|
""")
|
||||||
|
|
||||||
# Create SERP query cache table
|
# Create SERP query cache table
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS serp_queries (
|
CREATE TABLE IF NOT EXISTS serp_queries (
|
||||||
@@ -376,6 +396,7 @@ class DatabaseClient:
|
|||||||
company_name: Optional[str] = None,
|
company_name: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
cursor: Optional[str] = None,
|
cursor: Optional[str] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""List analysis results with cursor-based pagination.
|
"""List analysis results with cursor-based pagination.
|
||||||
|
|
||||||
@@ -383,29 +404,40 @@ class DatabaseClient:
|
|||||||
company_name: Optional filter by company name.
|
company_name: Optional filter by company name.
|
||||||
limit: Maximum number of records to return.
|
limit: Maximum number of records to return.
|
||||||
cursor: Opaque cursor (``timestamp|id``) from a previous response.
|
cursor: Opaque cursor (``timestamp|id``) from a previous response.
|
||||||
|
tags: Optional list of technology domain tags to filter by.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of analysis dicts ordered by timestamp descending.
|
List of analysis dicts ordered by timestamp descending.
|
||||||
"""
|
"""
|
||||||
conditions: list[str] = ["is_cached = FALSE"]
|
conditions: list[str] = ["m.is_cached = FALSE"]
|
||||||
params: list = []
|
params: list = []
|
||||||
|
join_clause = ""
|
||||||
|
|
||||||
if company_name:
|
if company_name:
|
||||||
conditions.append("LOWER(company_name) = LOWER(%s)")
|
conditions.append("LOWER(m.company_name) = LOWER(%s)")
|
||||||
params.append(company_name)
|
params.append(company_name)
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
# Join with patents table to filter by tags
|
||||||
|
join_clause = (
|
||||||
|
" INNER JOIN patents p ON LOWER(p.company_name) = LOWER(m.company_name)"
|
||||||
|
)
|
||||||
|
conditions.append("p.patent_tags && %s")
|
||||||
|
params.append(tags)
|
||||||
|
|
||||||
if cursor:
|
if cursor:
|
||||||
try:
|
try:
|
||||||
ts_str, cursor_id = cursor.rsplit("|", 1)
|
ts_str, cursor_id = cursor.rsplit("|", 1)
|
||||||
conditions.append("(timestamp, id) < (%s, %s)")
|
conditions.append("(m.timestamp, m.id) < (%s, %s)")
|
||||||
params.extend([ts_str, int(cursor_id)])
|
params.extend([ts_str, int(cursor_id)])
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
pass # Ignore malformed cursors; return from start
|
pass # Ignore malformed cursors; return from start
|
||||||
|
|
||||||
query = "SELECT id, company_name, analysis_type, model, response, timestamp FROM llm_messages"
|
query = "SELECT DISTINCT m.id, m.company_name, m.analysis_type, m.model, m.response, m.timestamp FROM llm_messages m"
|
||||||
|
query += join_clause
|
||||||
if conditions:
|
if conditions:
|
||||||
query += " WHERE " + " AND ".join(conditions)
|
query += " WHERE " + " AND ".join(conditions)
|
||||||
query += " ORDER BY timestamp DESC, id DESC LIMIT %s"
|
query += " ORDER BY m.timestamp DESC, m.id DESC LIMIT %s"
|
||||||
params.append(limit)
|
params.append(limit)
|
||||||
|
|
||||||
with self.get_conn() as conn:
|
with self.get_conn() as conn:
|
||||||
@@ -493,22 +525,63 @@ class DatabaseClient:
|
|||||||
pdf_link: str,
|
pdf_link: str,
|
||||||
raw_sections: Dict,
|
raw_sections: Dict,
|
||||||
minimized_content: str,
|
minimized_content: str,
|
||||||
|
patent_tags: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store a processed patent in the cache."""
|
"""Store a processed patent in the cache."""
|
||||||
with self.get_conn() as conn:
|
with self.get_conn() as conn:
|
||||||
with conn.cursor() as cursor:
|
with conn.cursor() as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content)
|
INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content, patent_tags)
|
||||||
VALUES (%s, %s, %s, %s, %s)
|
VALUES (%s, %s, %s, %s, %s, %s)
|
||||||
ON CONFLICT (patent_id) DO UPDATE SET
|
ON CONFLICT (patent_id) DO UPDATE SET
|
||||||
raw_sections = EXCLUDED.raw_sections,
|
raw_sections = EXCLUDED.raw_sections,
|
||||||
minimized_content = EXCLUDED.minimized_content
|
minimized_content = EXCLUDED.minimized_content,
|
||||||
|
patent_tags = EXCLUDED.patent_tags
|
||||||
""",
|
""",
|
||||||
(patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content),
|
(patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content, patent_tags),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
def update_patent_tags(self, patent_id: str, tags: List[str]) -> None:
|
||||||
|
"""Update the technology domain tags for a patent."""
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"UPDATE patents SET patent_tags = %s WHERE patent_id = %s",
|
||||||
|
(tags, patent_id),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_patents_by_tags(
|
||||||
|
self,
|
||||||
|
tags: List[str],
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Retrieve patents that match any of the given tags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tags: List of tag strings to filter by (OR logic)
|
||||||
|
limit: Max results
|
||||||
|
offset: Pagination offset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of patent dicts
|
||||||
|
"""
|
||||||
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM patents
|
||||||
|
WHERE patent_tags && %s
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT %s OFFSET %s
|
||||||
|
""",
|
||||||
|
(tags, limit, offset),
|
||||||
|
)
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
def get_cached_serp_query(self, query_hash: str) -> Optional[List[str]]:
|
def get_cached_serp_query(self, query_hash: str) -> Optional[List[str]]:
|
||||||
"""Look up cached SERP query results.
|
"""Look up cached SERP query results.
|
||||||
|
|
||||||
|
|||||||
+59
-1
@@ -1,7 +1,8 @@
|
|||||||
"""LLM integration for patent analysis using OpenRouter."""
|
"""LLM integration for patent analysis using OpenRouter."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
@@ -246,3 +247,60 @@ Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the co
|
|||||||
)
|
)
|
||||||
return placeholder
|
return placeholder
|
||||||
|
|
||||||
|
# Canonical technology domain tags
|
||||||
|
CANONICAL_TAGS = ["ai", "semiconductors", "materials", "biotech", "networking", "other"]
|
||||||
|
|
||||||
|
def classify_patent_tags(self, patent_content: str, model: str | None = None) -> List[str]:
|
||||||
|
"""Classify a patent into one or more technology domain tags.
|
||||||
|
|
||||||
|
Sends the patent abstract/claims to the LLM with a classification prompt
|
||||||
|
and returns a list of canonical tags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patent_content: Minimized patent text (abstract, claims, summary)
|
||||||
|
model: Optional model override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of canonical tag strings from CANONICAL_TAGS
|
||||||
|
"""
|
||||||
|
prompt = f"""You are a patent classification system. Analyze the following patent content and assign one or more technology domain tags from this exact list:
|
||||||
|
|
||||||
|
Tags: ai, semiconductors, materials, biotech, networking, other
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Return ONLY a JSON array of tag strings, e.g. ["ai", "semiconductors"]
|
||||||
|
- Use ONLY tags from the list above
|
||||||
|
- Assign "other" only if no other tag fits
|
||||||
|
- A patent can have multiple tags if it spans domains
|
||||||
|
|
||||||
|
Patent Content:
|
||||||
|
{patent_content}
|
||||||
|
|
||||||
|
Return ONLY the JSON array, nothing else."""
|
||||||
|
|
||||||
|
effective_model = model or self.model
|
||||||
|
|
||||||
|
if self.test_mode:
|
||||||
|
logger.debug("TEST MODE - Classification prompt:\n%s", prompt)
|
||||||
|
return ["other"]
|
||||||
|
|
||||||
|
if self.client:
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=effective_model,
|
||||||
|
max_tokens=128,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
)
|
||||||
|
raw = response.choices[0].message.content.strip()
|
||||||
|
tags = json.loads(raw)
|
||||||
|
# Validate and filter to canonical tags only
|
||||||
|
valid_tags = [t for t in tags if t in self.CANONICAL_TAGS]
|
||||||
|
return valid_tags if valid_tags else ["other"]
|
||||||
|
except (json.JSONDecodeError, AttributeError, TypeError) as e:
|
||||||
|
logger.warning("Failed to parse classification response: %s", e)
|
||||||
|
return ["other"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Classification LLM call failed: %s", e)
|
||||||
|
return ["other"]
|
||||||
|
|
||||||
|
return ["other"]
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ class CompanyAnalysisResult:
|
|||||||
success: bool
|
success: bool
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
model: str | None = None
|
model: str | None = None
|
||||||
|
tags: list[str] = field(default_factory=list)
|
||||||
timestamp: datetime = field(default_factory=datetime.now)
|
timestamp: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -199,8 +199,18 @@ export const analyticsApi = {
|
|||||||
const response = await api.get<TrendData>(`/analytics/trends?days=${days}`);
|
const response = await api.get<TrendData>(`/analytics/trends?days=${days}`);
|
||||||
return response.data;
|
return response.data;
|
||||||
},
|
},
|
||||||
|
|
||||||
|
getTagDistribution: async (): Promise<TagDistribution> => {
|
||||||
|
const response = await api.get<TagDistribution>('/analytics/tags');
|
||||||
|
return response.data;
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export interface TagDistribution {
|
||||||
|
by_tag: Array<{ tag: string; count: number }>;
|
||||||
|
canonical_tags: string[];
|
||||||
|
}
|
||||||
|
|
||||||
// Admin API
|
// Admin API
|
||||||
export const adminApi = {
|
export const adminApi = {
|
||||||
listUsers: async (limit = 100, offset = 0): Promise<User[]> => {
|
listUsers: async (limit = 100, offset = 0): Promise<User[]> => {
|
||||||
|
|||||||
@@ -1,14 +1,24 @@
|
|||||||
import { useState } from 'react';
|
import { useState } from 'react';
|
||||||
import { useQuery } from '@tanstack/react-query';
|
import { useQuery } from '@tanstack/react-query';
|
||||||
import { analyticsApi } from '../api/client';
|
import { analyticsApi } from '../api/client';
|
||||||
import { AlertCircle, Database } from 'lucide-react';
|
import { AlertCircle, Database, Tag } from 'lucide-react';
|
||||||
import { PieChart, Pie, Cell, BarChart, Bar, LineChart, Line, XAxis, YAxis, Tooltip, ResponsiveContainer, Legend } from 'recharts';
|
import { PieChart, Pie, Cell, BarChart, Bar, LineChart, Line, XAxis, YAxis, Tooltip, ResponsiveContainer, Legend } from 'recharts';
|
||||||
import { useChartTheme } from '../context/useChartTheme';
|
import { useChartTheme } from '../context/useChartTheme';
|
||||||
|
|
||||||
const COLORS = ['#6366f1', '#0ea5e9', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#ec4899', '#14b8a6'];
|
const COLORS = ['#6366f1', '#0ea5e9', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#ec4899', '#14b8a6'];
|
||||||
|
|
||||||
|
const TAG_COLORS: Record<string, string> = {
|
||||||
|
ai: '#6366f1',
|
||||||
|
semiconductors: '#0ea5e9',
|
||||||
|
materials: '#10b981',
|
||||||
|
biotech: '#f59e0b',
|
||||||
|
networking: '#ec4899',
|
||||||
|
other: '#8b5cf6',
|
||||||
|
};
|
||||||
|
|
||||||
export function AnalyticsPage() {
|
export function AnalyticsPage() {
|
||||||
const [days, setDays] = useState(30);
|
const [days, setDays] = useState(30);
|
||||||
|
const [selectedTags, setSelectedTags] = useState<string[]>([]);
|
||||||
const chartTheme = useChartTheme();
|
const chartTheme = useChartTheme();
|
||||||
|
|
||||||
const { data, isLoading, isError, refetch } = useQuery({
|
const { data, isLoading, isError, refetch } = useQuery({
|
||||||
@@ -21,6 +31,17 @@ export function AnalyticsPage() {
|
|||||||
queryFn: () => analyticsApi.getTrends(days),
|
queryFn: () => analyticsApi.getTrends(days),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const tagQuery = useQuery({
|
||||||
|
queryKey: ['analytics-tags'],
|
||||||
|
queryFn: () => analyticsApi.getTagDistribution(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const toggleTag = (tag: string) => {
|
||||||
|
setSelectedTags((prev) =>
|
||||||
|
prev.includes(tag) ? prev.filter((t) => t !== tag) : [...prev, tag]
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
if (isLoading) {
|
if (isLoading) {
|
||||||
return (
|
return (
|
||||||
<div className="space-y-6">
|
<div className="space-y-6">
|
||||||
@@ -107,6 +128,13 @@ export function AnalyticsPage() {
|
|||||||
count: t.count,
|
count: t.count,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
const tagData = tagQuery.data?.by_tag?.map((t) => ({
|
||||||
|
name: t.tag,
|
||||||
|
count: t.count,
|
||||||
|
})) || [];
|
||||||
|
|
||||||
|
const canonicalTags = tagQuery.data?.canonical_tags || [];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-6">
|
<div className="space-y-6">
|
||||||
{/* Header */}
|
{/* Header */}
|
||||||
@@ -131,6 +159,50 @@ export function AnalyticsPage() {
|
|||||||
</select>
|
</select>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* Tag Filter Controls */}
|
||||||
|
{canonicalTags.length > 0 && (
|
||||||
|
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-4">
|
||||||
|
<div className="flex items-center gap-2 mb-3">
|
||||||
|
<Tag size={16} className="text-primary" />
|
||||||
|
<span className="text-sm font-semibold text-text-primary">Filter by Technology Domain</span>
|
||||||
|
{selectedTags.length > 0 && (
|
||||||
|
<button
|
||||||
|
onClick={() => setSelectedTags([])}
|
||||||
|
className="ml-auto text-xs text-text-secondary hover:text-primary transition-colors"
|
||||||
|
>
|
||||||
|
Clear all
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="flex flex-wrap gap-2">
|
||||||
|
{canonicalTags.map((tag) => {
|
||||||
|
const isActive = selectedTags.includes(tag);
|
||||||
|
const color = TAG_COLORS[tag] || '#8b5cf6';
|
||||||
|
const tagCount = tagQuery.data?.by_tag?.find((t) => t.tag === tag)?.count || 0;
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
key={tag}
|
||||||
|
onClick={() => toggleTag(tag)}
|
||||||
|
className={`px-3 py-1.5 rounded-lg text-sm font-medium transition-all ${
|
||||||
|
isActive
|
||||||
|
? 'text-white shadow-md'
|
||||||
|
: 'bg-bg-card/80 text-text-secondary border border-primary/20 hover:border-primary/40'
|
||||||
|
}`}
|
||||||
|
style={isActive ? { backgroundColor: color } : {}}
|
||||||
|
>
|
||||||
|
{tag}
|
||||||
|
{tagCount > 0 && (
|
||||||
|
<span className={`ml-1.5 text-xs ${isActive ? 'opacity-80' : 'opacity-60'}`}>
|
||||||
|
({tagCount})
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Summary Metrics */}
|
{/* Summary Metrics */}
|
||||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
||||||
<MetricCard label="Total Analyses" value={data.total_messages} />
|
<MetricCard label="Total Analyses" value={data.total_messages} />
|
||||||
@@ -187,6 +259,28 @@ export function AnalyticsPage() {
|
|||||||
</ResponsiveContainer>
|
</ResponsiveContainer>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* Bar Chart - Technology Domain Tags */}
|
||||||
|
{tagData.length > 0 && (
|
||||||
|
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-6">
|
||||||
|
<h3 className="text-lg font-semibold text-text-primary mb-4">Patents by Technology Domain</h3>
|
||||||
|
<ResponsiveContainer width="100%" height={300}>
|
||||||
|
<BarChart data={tagData}>
|
||||||
|
<XAxis dataKey="name" stroke={chartTheme.axisStroke} fontSize={12} />
|
||||||
|
<YAxis stroke={chartTheme.axisStroke} fontSize={12} />
|
||||||
|
<Tooltip
|
||||||
|
contentStyle={chartTheme.tooltipContentStyle}
|
||||||
|
labelStyle={chartTheme.tooltipLabelStyle}
|
||||||
|
/>
|
||||||
|
<Bar dataKey="count" radius={[4, 4, 0, 0]}>
|
||||||
|
{tagData.map((entry, index) => (
|
||||||
|
<Cell key={`tag-cell-${index}`} fill={TAG_COLORS[entry.name] || COLORS[index % COLORS.length]} />
|
||||||
|
))}
|
||||||
|
</Bar>
|
||||||
|
</BarChart>
|
||||||
|
</ResponsiveContainer>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Trend Charts */}
|
{/* Trend Charts */}
|
||||||
|
|||||||
@@ -0,0 +1,262 @@
|
|||||||
|
"""Tests for LLM-based patent classification tagging by technology domain."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from SPARC.llm import LLMAnalyzer
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassifyPatentTags:
|
||||||
|
"""Test the classify_patent_tags method on LLMAnalyzer."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_database(self, mocker):
|
||||||
|
"""Mock the database client for all tests."""
|
||||||
|
mock_db_client = Mock()
|
||||||
|
mock_db_client.get_cached_response.return_value = None
|
||||||
|
mock_db_client.store_message.return_value = 1
|
||||||
|
mocker.patch("SPARC.llm.DatabaseClient", return_value=mock_db_client)
|
||||||
|
return mock_db_client
|
||||||
|
|
||||||
|
def test_classify_returns_valid_tags(self, mocker, mock_database):
|
||||||
|
"""Test that classify_patent_tags returns valid canonical tags from LLM."""
|
||||||
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_openai.return_value = mock_client
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock(message=Mock(content='["ai", "semiconductors"]'))]
|
||||||
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
analyzer = LLMAnalyzer(api_key="test-key")
|
||||||
|
tags = analyzer.classify_patent_tags("A patent about neural network accelerator chips")
|
||||||
|
|
||||||
|
assert tags == ["ai", "semiconductors"]
|
||||||
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
|
|
||||||
|
# Verify the prompt contains classification instructions
|
||||||
|
call_args = mock_client.chat.completions.create.call_args
|
||||||
|
prompt_text = call_args[1]["messages"][0]["content"]
|
||||||
|
assert "ai" in prompt_text
|
||||||
|
assert "semiconductors" in prompt_text
|
||||||
|
assert "JSON array" in prompt_text
|
||||||
|
|
||||||
|
def test_classify_filters_invalid_tags(self, mocker, mock_database):
|
||||||
|
"""Test that invalid tags from LLM response are filtered out."""
|
||||||
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_openai.return_value = mock_client
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock(message=Mock(content='["ai", "quantum_computing", "biotech"]'))]
|
||||||
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
analyzer = LLMAnalyzer(api_key="test-key")
|
||||||
|
tags = analyzer.classify_patent_tags("Some patent content")
|
||||||
|
|
||||||
|
assert tags == ["ai", "biotech"]
|
||||||
|
assert "quantum_computing" not in tags
|
||||||
|
|
||||||
|
def test_classify_returns_other_when_all_invalid(self, mocker, mock_database):
|
||||||
|
"""Test fallback to 'other' when all LLM tags are invalid."""
|
||||||
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_openai.return_value = mock_client
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock(message=Mock(content='["quantum", "robotics"]'))]
|
||||||
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
analyzer = LLMAnalyzer(api_key="test-key")
|
||||||
|
tags = analyzer.classify_patent_tags("Some patent content")
|
||||||
|
|
||||||
|
assert tags == ["other"]
|
||||||
|
|
||||||
|
def test_classify_handles_malformed_json(self, mocker, mock_database):
|
||||||
|
"""Test graceful handling of non-JSON LLM response."""
|
||||||
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_openai.return_value = mock_client
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock(message=Mock(content="ai, semiconductors"))]
|
||||||
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
analyzer = LLMAnalyzer(api_key="test-key")
|
||||||
|
tags = analyzer.classify_patent_tags("Some patent content")
|
||||||
|
|
||||||
|
assert tags == ["other"]
|
||||||
|
|
||||||
|
def test_classify_handles_api_error(self, mocker, mock_database):
|
||||||
|
"""Test graceful fallback when LLM API call fails."""
|
||||||
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_openai.return_value = mock_client
|
||||||
|
|
||||||
|
mock_client.chat.completions.create.side_effect = Exception("API timeout")
|
||||||
|
|
||||||
|
analyzer = LLMAnalyzer(api_key="test-key")
|
||||||
|
tags = analyzer.classify_patent_tags("Some patent content")
|
||||||
|
|
||||||
|
assert tags == ["other"]
|
||||||
|
|
||||||
|
def test_classify_test_mode(self, mocker, mock_database):
|
||||||
|
"""Test that test mode returns 'other' without API call."""
|
||||||
|
mocker.patch("SPARC.llm.config")
|
||||||
|
|
||||||
|
analyzer = LLMAnalyzer(test_mode=True)
|
||||||
|
tags = analyzer.classify_patent_tags("Some patent content")
|
||||||
|
|
||||||
|
assert tags == ["other"]
|
||||||
|
|
||||||
|
def test_classify_no_api_client(self, mocker, mock_database):
|
||||||
|
"""Test that without API client, classification returns 'other'."""
|
||||||
|
mocker.patch("SPARC.llm.config")
|
||||||
|
|
||||||
|
analyzer = LLMAnalyzer(use_cache=False)
|
||||||
|
tags = analyzer.classify_patent_tags("Some patent content")
|
||||||
|
|
||||||
|
assert tags == ["other"]
|
||||||
|
|
||||||
|
def test_canonical_tags_list(self):
|
||||||
|
"""Test that the canonical tag list matches requirements."""
|
||||||
|
expected = ["ai", "semiconductors", "materials", "biotech", "networking", "other"]
|
||||||
|
assert LLMAnalyzer.CANONICAL_TAGS == expected
|
||||||
|
|
||||||
|
def test_classify_uses_model_override(self, mocker, mock_database):
|
||||||
|
"""Test that model override is passed to the API call."""
|
||||||
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_openai.return_value = mock_client
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock(message=Mock(content='["ai"]'))]
|
||||||
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
analyzer = LLMAnalyzer(api_key="test-key")
|
||||||
|
analyzer.classify_patent_tags("content", model="openai/gpt-4o")
|
||||||
|
|
||||||
|
call_args = mock_client.chat.completions.create.call_args
|
||||||
|
assert call_args[1]["model"] == "openai/gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPatentTagsStorage:
|
||||||
|
"""Test that tags are persisted to the database."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_db(self, mocker):
|
||||||
|
"""Mock DatabaseClient for all tests."""
|
||||||
|
mock_db_cls = mocker.patch("SPARC.analyzer.DatabaseClient")
|
||||||
|
mock_db_instance = MagicMock()
|
||||||
|
mock_db_instance.get_cached_patent.return_value = None
|
||||||
|
mock_db_instance.get_cached_serp_query.return_value = None
|
||||||
|
mock_db_cls.return_value = mock_db_instance
|
||||||
|
return mock_db_instance
|
||||||
|
|
||||||
|
def test_tags_stored_after_classification(self, mocker, mock_db):
|
||||||
|
"""Test that classify_patent_tags results are stored via update_patent_tags."""
|
||||||
|
from SPARC.analyzer import CompanyAnalyzer
|
||||||
|
from SPARC.types import Patent, Patents
|
||||||
|
|
||||||
|
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
|
||||||
|
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||||
|
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
|
||||||
|
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
|
||||||
|
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||||
|
|
||||||
|
patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf")
|
||||||
|
mock_query.return_value = Patents(patents=[patent])
|
||||||
|
|
||||||
|
def save_side_effect(p):
|
||||||
|
p.pdf_path = "patents/US123.pdf"
|
||||||
|
return p
|
||||||
|
|
||||||
|
mock_save.side_effect = save_side_effect
|
||||||
|
mock_parse.return_value = {"abstract": "Test abstract"}
|
||||||
|
mock_minimize.return_value = "Minimized content"
|
||||||
|
|
||||||
|
mock_llm_instance = Mock()
|
||||||
|
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis result"
|
||||||
|
mock_llm_instance.classify_patent_tags.return_value = ["ai", "semiconductors"]
|
||||||
|
mock_llm_cls.return_value = mock_llm_instance
|
||||||
|
|
||||||
|
analyzer = CompanyAnalyzer()
|
||||||
|
analyzer.analyze_company("TestCorp")
|
||||||
|
|
||||||
|
# Verify classification was called
|
||||||
|
mock_llm_instance.classify_patent_tags.assert_called_once_with(
|
||||||
|
patent_content="Minimized content", model=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify tags were persisted
|
||||||
|
mock_db.update_patent_tags.assert_called_once_with("US123", ["ai", "semiconductors"])
|
||||||
|
|
||||||
|
def test_cached_tags_skip_classification(self, mocker, mock_db):
|
||||||
|
"""Test that patents with cached tags skip re-classification."""
|
||||||
|
from SPARC.analyzer import CompanyAnalyzer
|
||||||
|
from SPARC.types import Patent, Patents
|
||||||
|
|
||||||
|
mocker.patch("SPARC.analyzer.SERP.query")
|
||||||
|
mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||||
|
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||||
|
|
||||||
|
# Simulate DB cache hit with tags
|
||||||
|
mock_db.get_cached_patent.return_value = {
|
||||||
|
"patent_id": "US123",
|
||||||
|
"minimized_content": "Cached content",
|
||||||
|
"patent_tags": ["biotech"],
|
||||||
|
}
|
||||||
|
mock_db.get_cached_serp_query.return_value = ["US123"]
|
||||||
|
|
||||||
|
mock_llm_instance = Mock()
|
||||||
|
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis"
|
||||||
|
mock_llm_cls.return_value = mock_llm_instance
|
||||||
|
|
||||||
|
analyzer = CompanyAnalyzer()
|
||||||
|
analyzer.analyze_company("TestCorp")
|
||||||
|
|
||||||
|
# Should NOT classify since tags already present from cache
|
||||||
|
mock_llm_instance.classify_patent_tags.assert_not_called()
|
||||||
|
mock_db.update_patent_tags.assert_not_called()
|
||||||
|
|
||||||
|
def test_tags_in_analysis_result(self, mocker, mock_db):
|
||||||
|
"""Test that tags appear in the CompanyAnalysisResult."""
|
||||||
|
from SPARC.analyzer import CompanyAnalyzer
|
||||||
|
from SPARC.types import Patent, Patents
|
||||||
|
|
||||||
|
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
|
||||||
|
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||||
|
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
|
||||||
|
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
|
||||||
|
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||||
|
|
||||||
|
patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf")
|
||||||
|
mock_query.return_value = Patents(patents=[patent])
|
||||||
|
|
||||||
|
def save_side_effect(p):
|
||||||
|
p.pdf_path = "patents/US123.pdf"
|
||||||
|
return p
|
||||||
|
|
||||||
|
mock_save.side_effect = save_side_effect
|
||||||
|
mock_parse.return_value = {"abstract": "Test abstract"}
|
||||||
|
mock_minimize.return_value = "Minimized content"
|
||||||
|
|
||||||
|
mock_llm_instance = Mock()
|
||||||
|
mock_llm_instance.analyze_patent_portfolio.return_value = "Strong innovation"
|
||||||
|
mock_llm_instance.classify_patent_tags.return_value = ["ai", "networking"]
|
||||||
|
mock_llm_cls.return_value = mock_llm_instance
|
||||||
|
|
||||||
|
# After analysis, get_cached_patent returns the patent with tags
|
||||||
|
mock_db.get_cached_serp_query.side_effect = [None, ["US123"]]
|
||||||
|
mock_db.get_cached_patent.side_effect = [
|
||||||
|
None, # First call during _process_single_patent
|
||||||
|
{"patent_id": "US123", "patent_tags": ["ai", "networking"]}, # Second call during _analyze_company_safe
|
||||||
|
]
|
||||||
|
|
||||||
|
analyzer = CompanyAnalyzer()
|
||||||
|
result = analyzer._analyze_company_safe("TestCorp")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.tags == ["ai", "networking"]
|
||||||
Reference in New Issue
Block a user