From cd81218154a770ae061e7ac029b40587d99139a5 Mon Sep 17 00:00:00 2001 From: agent-company Date: Tue, 19 May 2026 15:27:46 +0000 Subject: [PATCH] Add LLM-based patent classification tagging by technology domain - Add classify_patent_tags() to LLMAnalyzer with canonical tag list (ai, semiconductors, materials, biotech, networking, other) - Add patent_tags TEXT[] column to patents table with GIN index - Run classification automatically in the analysis pipeline after patent processing; persist tags via update_patent_tags() - Include tags in CompanyAnalysisResult and API response models - Add ?tags= filter to GET /analyze/batch endpoint - Add GET /analytics/tags endpoint for tag distribution data - Add tag filter controls and distribution chart to Analytics page - Add 12 unit tests covering classification, DB storage, and caching Closes leeworks-agents/SPARC#1672 Co-Authored-By: Claude Opus 4.6 --- SPARC/analyzer.py | 32 +++- SPARC/api.py | 53 ++++++- SPARC/database.py | 91 +++++++++-- SPARC/llm.py | 62 +++++++- SPARC/types.py | 1 + frontend/src/api/client.ts | 10 ++ frontend/src/pages/Analytics.tsx | 96 ++++++++++- tests/test_patent_tags.py | 262 +++++++++++++++++++++++++++++++ 8 files changed, 590 insertions(+), 17 deletions(-) create mode 100644 tests/test_patent_tags.py diff --git a/SPARC/analyzer.py b/SPARC/analyzer.py index 1ebceaf..2751688 100644 --- a/SPARC/analyzer.py +++ b/SPARC/analyzer.py @@ -40,8 +40,9 @@ class CompanyAnalyzer: 1. Retrieve patents from SERP API 2. Download and parse each patent PDF 3. Minimize patent content (remove bloat) - 4. Analyze portfolio with LLM - 5. Return performance estimation + 4. Classify patent technology domain tags + 5. Analyze portfolio with LLM + 6. Return performance estimation Args: company_name: Name of the company to analyze @@ -97,6 +98,17 @@ class CompanyAnalyzer: if not processed_patents: return f"Failed to process any patents for {company_name}" + # Classify each patent's technology domain tags + logger.info("Classifying patent technology domains...") + for patent_data in processed_patents: + if "tags" not in patent_data or not patent_data["tags"]: + tags = self.llm_analyzer.classify_patent_tags( + patent_content=patent_data["content"], model=model + ) + patent_data["tags"] = tags + # Persist tags to the database + self.db.update_patent_tags(patent_data["patent_id"], tags) + logger.info("Analyzing portfolio with LLM...") # Analyze the full portfolio with LLM @@ -181,7 +193,10 @@ class CompanyAnalyzer: if db: cached = db.get_cached_patent(patent.patent_id) if cached and cached.get("minimized_content"): - return {"patent_id": patent.patent_id, "content": cached["minimized_content"]} + result = {"patent_id": patent.patent_id, "content": cached["minimized_content"]} + if cached.get("patent_tags"): + result["tags"] = cached["patent_tags"] + return result # Full processing: download, parse, minimize patent = SERP.save_patents(patent) @@ -217,11 +232,19 @@ class CompanyAnalyzer: # Delegate to analyze_company which handles SERP/patent caching analysis = self.analyze_company(company_name, model=model) - # Determine patent count from cached SERP query + # Determine patent count and aggregate tags from cached SERP query query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest() cached_ids = self.db.get_cached_serp_query(query_hash) patent_count = len(cached_ids) if cached_ids else 0 + # Collect unique tags across all patents for this company + all_tags: set[str] = set() + if cached_ids: + for pid in cached_ids: + cached_patent = self.db.get_cached_patent(pid) + if cached_patent and cached_patent.get("patent_tags"): + all_tags.update(cached_patent["patent_tags"]) + # Check if analysis indicates failure if analysis.startswith("No patents found") or analysis.startswith( "Failed to process" @@ -239,6 +262,7 @@ class CompanyAnalyzer: analysis=analysis, patent_count=patent_count, success=True, + tags=sorted(all_tags), ) except Exception as e: diff --git a/SPARC/api.py b/SPARC/api.py index 1b29d38..41cdd13 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -57,6 +57,7 @@ class CompanyAnalysisResponse(BaseModel): success: bool error: str | None = None model: str | None = None + tags: list[str] = [] timestamp: datetime @@ -188,6 +189,7 @@ def _convert_result(result: CompanyAnalysisResult) -> CompanyAnalysisResponse: success=result.success, error=result.error, model=result.model, + tags=result.tags, timestamp=result.timestamp, ) @@ -560,6 +562,38 @@ async def get_analytics( ) +@app.get("/analytics/tags", tags=["Analytics"]) +async def get_tag_distribution( + _: UserResponse = Depends(get_current_user), +): + """Get the distribution of technology domain tags across all patents. + + Returns: + List of tag counts and the canonical tag list + """ + from SPARC.llm import LLMAnalyzer + db = get_db_client() + + with db.get_conn() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT tag, COUNT(*) as count + FROM patents, UNNEST(patent_tags) AS tag + WHERE patent_tags IS NOT NULL + GROUP BY tag + ORDER BY count DESC + """ + ) + rows = cur.fetchall() + + by_tag = [{"tag": row[0], "count": row[1]} for row in rows] + return { + "by_tag": by_tag, + "canonical_tags": LLMAnalyzer.CANONICAL_TAGS, + } + + # ============== Model Selection Endpoints ============== # Supported models via OpenRouter @@ -969,6 +1003,10 @@ async def list_analysis_results( str | None, Query(description="Filter results by company name"), ] = None, + tags: Annotated[ + str | None, + Query(description="Comma-separated technology domain tags to filter by (e.g. 'ai,semiconductors')"), + ] = None, limit: Annotated[int, Query(ge=1, le=200)] = 50, cursor: Annotated[ str | None, @@ -986,14 +1024,27 @@ async def list_analysis_results( Args: company_name: Optional filter by company name + tags: Optional comma-separated tag filter (e.g. 'ai,semiconductors') limit: Maximum number of results to return (default 50, max 200) cursor: Opaque pagination cursor from a previous response Returns: Paginated list of analysis results """ + # Parse and validate tags + tag_list = None + if tags: + from SPARC.llm import LLMAnalyzer + tag_list = [t.strip().lower() for t in tags.split(",") if t.strip()] + invalid = [t for t in tag_list if t not in LLMAnalyzer.CANONICAL_TAGS] + if invalid: + raise HTTPException( + status_code=400, + detail=f"Invalid tags: {', '.join(invalid)}. Valid tags: {', '.join(LLMAnalyzer.CANONICAL_TAGS)}", + ) + db = _get_job_db() - rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor) + rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor, tags=tag_list) has_next = len(rows) > limit if has_next: diff --git a/SPARC/database.py b/SPARC/database.py index 0759a66..0480e0c 100644 --- a/SPARC/database.py +++ b/SPARC/database.py @@ -146,15 +146,35 @@ class DatabaseClient: pdf_link TEXT, raw_sections JSONB, minimized_content TEXT, + patent_tags TEXT[], created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) + # Add patent_tags column if it doesn't exist (for existing tables) + cursor.execute(""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'patents' AND column_name = 'patent_tags' + ) THEN + ALTER TABLE patents ADD COLUMN patent_tags TEXT[]; + END IF; + END $$; + """) + cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_patents_company ON patents(company_name) """) + # GIN index for efficient tag array queries + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_patents_tags + ON patents USING GIN(patent_tags) + """) + # Create SERP query cache table cursor.execute(""" CREATE TABLE IF NOT EXISTS serp_queries ( @@ -376,6 +396,7 @@ class DatabaseClient: company_name: Optional[str] = None, limit: int = 50, cursor: Optional[str] = None, + tags: Optional[List[str]] = None, ) -> List[Dict]: """List analysis results with cursor-based pagination. @@ -383,29 +404,40 @@ class DatabaseClient: company_name: Optional filter by company name. limit: Maximum number of records to return. cursor: Opaque cursor (``timestamp|id``) from a previous response. + tags: Optional list of technology domain tags to filter by. Returns: List of analysis dicts ordered by timestamp descending. """ - conditions: list[str] = ["is_cached = FALSE"] + conditions: list[str] = ["m.is_cached = FALSE"] params: list = [] + join_clause = "" if company_name: - conditions.append("LOWER(company_name) = LOWER(%s)") + conditions.append("LOWER(m.company_name) = LOWER(%s)") params.append(company_name) + if tags: + # Join with patents table to filter by tags + join_clause = ( + " INNER JOIN patents p ON LOWER(p.company_name) = LOWER(m.company_name)" + ) + conditions.append("p.patent_tags && %s") + params.append(tags) + if cursor: try: ts_str, cursor_id = cursor.rsplit("|", 1) - conditions.append("(timestamp, id) < (%s, %s)") + conditions.append("(m.timestamp, m.id) < (%s, %s)") params.extend([ts_str, int(cursor_id)]) except (ValueError, TypeError): pass # Ignore malformed cursors; return from start - query = "SELECT id, company_name, analysis_type, model, response, timestamp FROM llm_messages" + query = "SELECT DISTINCT m.id, m.company_name, m.analysis_type, m.model, m.response, m.timestamp FROM llm_messages m" + query += join_clause if conditions: query += " WHERE " + " AND ".join(conditions) - query += " ORDER BY timestamp DESC, id DESC LIMIT %s" + query += " ORDER BY m.timestamp DESC, m.id DESC LIMIT %s" params.append(limit) with self.get_conn() as conn: @@ -493,22 +525,63 @@ class DatabaseClient: pdf_link: str, raw_sections: Dict, minimized_content: str, + patent_tags: Optional[List[str]] = None, ) -> None: """Store a processed patent in the cache.""" with self.get_conn() as conn: with conn.cursor() as cursor: cursor.execute( """ - INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content) - VALUES (%s, %s, %s, %s, %s) + INSERT INTO patents (patent_id, company_name, pdf_link, raw_sections, minimized_content, patent_tags) + VALUES (%s, %s, %s, %s, %s, %s) ON CONFLICT (patent_id) DO UPDATE SET raw_sections = EXCLUDED.raw_sections, - minimized_content = EXCLUDED.minimized_content + minimized_content = EXCLUDED.minimized_content, + patent_tags = EXCLUDED.patent_tags """, - (patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content), + (patent_id, company_name, pdf_link, json.dumps(raw_sections), minimized_content, patent_tags), ) conn.commit() + def update_patent_tags(self, patent_id: str, tags: List[str]) -> None: + """Update the technology domain tags for a patent.""" + with self.get_conn() as conn: + with conn.cursor() as cursor: + cursor.execute( + "UPDATE patents SET patent_tags = %s WHERE patent_id = %s", + (tags, patent_id), + ) + conn.commit() + + def get_patents_by_tags( + self, + tags: List[str], + limit: int = 100, + offset: int = 0, + ) -> List[Dict]: + """Retrieve patents that match any of the given tags. + + Args: + tags: List of tag strings to filter by (OR logic) + limit: Max results + offset: Pagination offset + + Returns: + List of patent dicts + """ + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cursor: + cursor.execute( + """ + SELECT * FROM patents + WHERE patent_tags && %s + ORDER BY created_at DESC + LIMIT %s OFFSET %s + """, + (tags, limit, offset), + ) + return [dict(row) for row in cursor.fetchall()] + def get_cached_serp_query(self, query_hash: str) -> Optional[List[str]]: """Look up cached SERP query results. diff --git a/SPARC/llm.py b/SPARC/llm.py index 9214cee..2f7399b 100644 --- a/SPARC/llm.py +++ b/SPARC/llm.py @@ -1,7 +1,8 @@ """LLM integration for patent analysis using OpenRouter.""" +import json import logging -from typing import Dict +from typing import Dict, List from openai import OpenAI @@ -245,4 +246,61 @@ Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the co metadata={**metadata, "pending": True} ) return placeholder - + + # Canonical technology domain tags + CANONICAL_TAGS = ["ai", "semiconductors", "materials", "biotech", "networking", "other"] + + def classify_patent_tags(self, patent_content: str, model: str | None = None) -> List[str]: + """Classify a patent into one or more technology domain tags. + + Sends the patent abstract/claims to the LLM with a classification prompt + and returns a list of canonical tags. + + Args: + patent_content: Minimized patent text (abstract, claims, summary) + model: Optional model override + + Returns: + List of canonical tag strings from CANONICAL_TAGS + """ + prompt = f"""You are a patent classification system. Analyze the following patent content and assign one or more technology domain tags from this exact list: + +Tags: ai, semiconductors, materials, biotech, networking, other + +Rules: +- Return ONLY a JSON array of tag strings, e.g. ["ai", "semiconductors"] +- Use ONLY tags from the list above +- Assign "other" only if no other tag fits +- A patent can have multiple tags if it spans domains + +Patent Content: +{patent_content} + +Return ONLY the JSON array, nothing else.""" + + effective_model = model or self.model + + if self.test_mode: + logger.debug("TEST MODE - Classification prompt:\n%s", prompt) + return ["other"] + + if self.client: + try: + response = self.client.chat.completions.create( + model=effective_model, + max_tokens=128, + messages=[{"role": "user", "content": prompt}], + ) + raw = response.choices[0].message.content.strip() + tags = json.loads(raw) + # Validate and filter to canonical tags only + valid_tags = [t for t in tags if t in self.CANONICAL_TAGS] + return valid_tags if valid_tags else ["other"] + except (json.JSONDecodeError, AttributeError, TypeError) as e: + logger.warning("Failed to parse classification response: %s", e) + return ["other"] + except Exception as e: + logger.warning("Classification LLM call failed: %s", e) + return ["other"] + + return ["other"] diff --git a/SPARC/types.py b/SPARC/types.py index fd11073..198f2bb 100644 --- a/SPARC/types.py +++ b/SPARC/types.py @@ -25,6 +25,7 @@ class CompanyAnalysisResult: success: bool error: str | None = None model: str | None = None + tags: list[str] = field(default_factory=list) timestamp: datetime = field(default_factory=datetime.now) diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 09a4ae6..27b6e0e 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -199,8 +199,18 @@ export const analyticsApi = { const response = await api.get(`/analytics/trends?days=${days}`); return response.data; }, + + getTagDistribution: async (): Promise => { + const response = await api.get('/analytics/tags'); + return response.data; + }, }; +export interface TagDistribution { + by_tag: Array<{ tag: string; count: number }>; + canonical_tags: string[]; +} + // Admin API export const adminApi = { listUsers: async (limit = 100, offset = 0): Promise => { diff --git a/frontend/src/pages/Analytics.tsx b/frontend/src/pages/Analytics.tsx index b7c4604..0362cef 100644 --- a/frontend/src/pages/Analytics.tsx +++ b/frontend/src/pages/Analytics.tsx @@ -1,14 +1,24 @@ import { useState } from 'react'; import { useQuery } from '@tanstack/react-query'; import { analyticsApi } from '../api/client'; -import { AlertCircle, Database } from 'lucide-react'; +import { AlertCircle, Database, Tag } from 'lucide-react'; import { PieChart, Pie, Cell, BarChart, Bar, LineChart, Line, XAxis, YAxis, Tooltip, ResponsiveContainer, Legend } from 'recharts'; import { useChartTheme } from '../context/useChartTheme'; const COLORS = ['#6366f1', '#0ea5e9', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#ec4899', '#14b8a6']; +const TAG_COLORS: Record = { + ai: '#6366f1', + semiconductors: '#0ea5e9', + materials: '#10b981', + biotech: '#f59e0b', + networking: '#ec4899', + other: '#8b5cf6', +}; + export function AnalyticsPage() { const [days, setDays] = useState(30); + const [selectedTags, setSelectedTags] = useState([]); const chartTheme = useChartTheme(); const { data, isLoading, isError, refetch } = useQuery({ @@ -21,6 +31,17 @@ export function AnalyticsPage() { queryFn: () => analyticsApi.getTrends(days), }); + const tagQuery = useQuery({ + queryKey: ['analytics-tags'], + queryFn: () => analyticsApi.getTagDistribution(), + }); + + const toggleTag = (tag: string) => { + setSelectedTags((prev) => + prev.includes(tag) ? prev.filter((t) => t !== tag) : [...prev, tag] + ); + }; + if (isLoading) { return (
@@ -107,6 +128,13 @@ export function AnalyticsPage() { count: t.count, })); + const tagData = tagQuery.data?.by_tag?.map((t) => ({ + name: t.tag, + count: t.count, + })) || []; + + const canonicalTags = tagQuery.data?.canonical_tags || []; + return (
{/* Header */} @@ -131,6 +159,50 @@ export function AnalyticsPage() {
+ {/* Tag Filter Controls */} + {canonicalTags.length > 0 && ( +
+
+ + Filter by Technology Domain + {selectedTags.length > 0 && ( + + )} +
+
+ {canonicalTags.map((tag) => { + const isActive = selectedTags.includes(tag); + const color = TAG_COLORS[tag] || '#8b5cf6'; + const tagCount = tagQuery.data?.by_tag?.find((t) => t.tag === tag)?.count || 0; + return ( + + ); + })} +
+
+ )} + {/* Summary Metrics */}
@@ -187,6 +259,28 @@ export function AnalyticsPage() {
)} + + {/* Bar Chart - Technology Domain Tags */} + {tagData.length > 0 && ( +
+

Patents by Technology Domain

+ + + + + + + {tagData.map((entry, index) => ( + + ))} + + + +
+ )}
{/* Trend Charts */} diff --git a/tests/test_patent_tags.py b/tests/test_patent_tags.py new file mode 100644 index 0000000..1880002 --- /dev/null +++ b/tests/test_patent_tags.py @@ -0,0 +1,262 @@ +"""Tests for LLM-based patent classification tagging by technology domain.""" + +import json +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from SPARC.llm import LLMAnalyzer + + +class TestClassifyPatentTags: + """Test the classify_patent_tags method on LLMAnalyzer.""" + + @pytest.fixture(autouse=True) + def mock_database(self, mocker): + """Mock the database client for all tests.""" + mock_db_client = Mock() + mock_db_client.get_cached_response.return_value = None + mock_db_client.store_message.return_value = 1 + mocker.patch("SPARC.llm.DatabaseClient", return_value=mock_db_client) + return mock_db_client + + def test_classify_returns_valid_tags(self, mocker, mock_database): + """Test that classify_patent_tags returns valid canonical tags from LLM.""" + mock_openai = mocker.patch("SPARC.llm.OpenAI") + mock_client = Mock() + mock_openai.return_value = mock_client + + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content='["ai", "semiconductors"]'))] + mock_client.chat.completions.create.return_value = mock_response + + analyzer = LLMAnalyzer(api_key="test-key") + tags = analyzer.classify_patent_tags("A patent about neural network accelerator chips") + + assert tags == ["ai", "semiconductors"] + mock_client.chat.completions.create.assert_called_once() + + # Verify the prompt contains classification instructions + call_args = mock_client.chat.completions.create.call_args + prompt_text = call_args[1]["messages"][0]["content"] + assert "ai" in prompt_text + assert "semiconductors" in prompt_text + assert "JSON array" in prompt_text + + def test_classify_filters_invalid_tags(self, mocker, mock_database): + """Test that invalid tags from LLM response are filtered out.""" + mock_openai = mocker.patch("SPARC.llm.OpenAI") + mock_client = Mock() + mock_openai.return_value = mock_client + + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content='["ai", "quantum_computing", "biotech"]'))] + mock_client.chat.completions.create.return_value = mock_response + + analyzer = LLMAnalyzer(api_key="test-key") + tags = analyzer.classify_patent_tags("Some patent content") + + assert tags == ["ai", "biotech"] + assert "quantum_computing" not in tags + + def test_classify_returns_other_when_all_invalid(self, mocker, mock_database): + """Test fallback to 'other' when all LLM tags are invalid.""" + mock_openai = mocker.patch("SPARC.llm.OpenAI") + mock_client = Mock() + mock_openai.return_value = mock_client + + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content='["quantum", "robotics"]'))] + mock_client.chat.completions.create.return_value = mock_response + + analyzer = LLMAnalyzer(api_key="test-key") + tags = analyzer.classify_patent_tags("Some patent content") + + assert tags == ["other"] + + def test_classify_handles_malformed_json(self, mocker, mock_database): + """Test graceful handling of non-JSON LLM response.""" + mock_openai = mocker.patch("SPARC.llm.OpenAI") + mock_client = Mock() + mock_openai.return_value = mock_client + + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="ai, semiconductors"))] + mock_client.chat.completions.create.return_value = mock_response + + analyzer = LLMAnalyzer(api_key="test-key") + tags = analyzer.classify_patent_tags("Some patent content") + + assert tags == ["other"] + + def test_classify_handles_api_error(self, mocker, mock_database): + """Test graceful fallback when LLM API call fails.""" + mock_openai = mocker.patch("SPARC.llm.OpenAI") + mock_client = Mock() + mock_openai.return_value = mock_client + + mock_client.chat.completions.create.side_effect = Exception("API timeout") + + analyzer = LLMAnalyzer(api_key="test-key") + tags = analyzer.classify_patent_tags("Some patent content") + + assert tags == ["other"] + + def test_classify_test_mode(self, mocker, mock_database): + """Test that test mode returns 'other' without API call.""" + mocker.patch("SPARC.llm.config") + + analyzer = LLMAnalyzer(test_mode=True) + tags = analyzer.classify_patent_tags("Some patent content") + + assert tags == ["other"] + + def test_classify_no_api_client(self, mocker, mock_database): + """Test that without API client, classification returns 'other'.""" + mocker.patch("SPARC.llm.config") + + analyzer = LLMAnalyzer(use_cache=False) + tags = analyzer.classify_patent_tags("Some patent content") + + assert tags == ["other"] + + def test_canonical_tags_list(self): + """Test that the canonical tag list matches requirements.""" + expected = ["ai", "semiconductors", "materials", "biotech", "networking", "other"] + assert LLMAnalyzer.CANONICAL_TAGS == expected + + def test_classify_uses_model_override(self, mocker, mock_database): + """Test that model override is passed to the API call.""" + mock_openai = mocker.patch("SPARC.llm.OpenAI") + mock_client = Mock() + mock_openai.return_value = mock_client + + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content='["ai"]'))] + mock_client.chat.completions.create.return_value = mock_response + + analyzer = LLMAnalyzer(api_key="test-key") + analyzer.classify_patent_tags("content", model="openai/gpt-4o") + + call_args = mock_client.chat.completions.create.call_args + assert call_args[1]["model"] == "openai/gpt-4o" + + +class TestPatentTagsStorage: + """Test that tags are persisted to the database.""" + + @pytest.fixture(autouse=True) + def mock_db(self, mocker): + """Mock DatabaseClient for all tests.""" + mock_db_cls = mocker.patch("SPARC.analyzer.DatabaseClient") + mock_db_instance = MagicMock() + mock_db_instance.get_cached_patent.return_value = None + mock_db_instance.get_cached_serp_query.return_value = None + mock_db_cls.return_value = mock_db_instance + return mock_db_instance + + def test_tags_stored_after_classification(self, mocker, mock_db): + """Test that classify_patent_tags results are stored via update_patent_tags.""" + from SPARC.analyzer import CompanyAnalyzer + from SPARC.types import Patent, Patents + + mock_query = mocker.patch("SPARC.analyzer.SERP.query") + mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents") + mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf") + mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm") + mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer") + + patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf") + mock_query.return_value = Patents(patents=[patent]) + + def save_side_effect(p): + p.pdf_path = "patents/US123.pdf" + return p + + mock_save.side_effect = save_side_effect + mock_parse.return_value = {"abstract": "Test abstract"} + mock_minimize.return_value = "Minimized content" + + mock_llm_instance = Mock() + mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis result" + mock_llm_instance.classify_patent_tags.return_value = ["ai", "semiconductors"] + mock_llm_cls.return_value = mock_llm_instance + + analyzer = CompanyAnalyzer() + analyzer.analyze_company("TestCorp") + + # Verify classification was called + mock_llm_instance.classify_patent_tags.assert_called_once_with( + patent_content="Minimized content", model=None + ) + + # Verify tags were persisted + mock_db.update_patent_tags.assert_called_once_with("US123", ["ai", "semiconductors"]) + + def test_cached_tags_skip_classification(self, mocker, mock_db): + """Test that patents with cached tags skip re-classification.""" + from SPARC.analyzer import CompanyAnalyzer + from SPARC.types import Patent, Patents + + mocker.patch("SPARC.analyzer.SERP.query") + mocker.patch("SPARC.analyzer.SERP.save_patents") + mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer") + + # Simulate DB cache hit with tags + mock_db.get_cached_patent.return_value = { + "patent_id": "US123", + "minimized_content": "Cached content", + "patent_tags": ["biotech"], + } + mock_db.get_cached_serp_query.return_value = ["US123"] + + mock_llm_instance = Mock() + mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis" + mock_llm_cls.return_value = mock_llm_instance + + analyzer = CompanyAnalyzer() + analyzer.analyze_company("TestCorp") + + # Should NOT classify since tags already present from cache + mock_llm_instance.classify_patent_tags.assert_not_called() + mock_db.update_patent_tags.assert_not_called() + + def test_tags_in_analysis_result(self, mocker, mock_db): + """Test that tags appear in the CompanyAnalysisResult.""" + from SPARC.analyzer import CompanyAnalyzer + from SPARC.types import Patent, Patents + + mock_query = mocker.patch("SPARC.analyzer.SERP.query") + mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents") + mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf") + mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm") + mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer") + + patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf") + mock_query.return_value = Patents(patents=[patent]) + + def save_side_effect(p): + p.pdf_path = "patents/US123.pdf" + return p + + mock_save.side_effect = save_side_effect + mock_parse.return_value = {"abstract": "Test abstract"} + mock_minimize.return_value = "Minimized content" + + mock_llm_instance = Mock() + mock_llm_instance.analyze_patent_portfolio.return_value = "Strong innovation" + mock_llm_instance.classify_patent_tags.return_value = ["ai", "networking"] + mock_llm_cls.return_value = mock_llm_instance + + # After analysis, get_cached_patent returns the patent with tags + mock_db.get_cached_serp_query.side_effect = [None, ["US123"]] + mock_db.get_cached_patent.side_effect = [ + None, # First call during _process_single_patent + {"patent_id": "US123", "patent_tags": ["ai", "networking"]}, # Second call during _analyze_company_safe + ] + + analyzer = CompanyAnalyzer() + result = analyzer._analyze_company_safe("TestCorp") + + assert result.success is True + assert result.tags == ["ai", "networking"] -- 2.52.0