Add LLM-based patent classification tagging by technology domain

- Add classify_patent_tags() to LLMAnalyzer with canonical tag list
  (ai, semiconductors, materials, biotech, networking, other)
- Add patent_tags TEXT[] column to patents table with GIN index
- Run classification automatically in the analysis pipeline after
  patent processing; persist tags via update_patent_tags()
- Include tags in CompanyAnalysisResult and API response models
- Add ?tags= filter to GET /analyze/batch endpoint
- Add GET /analytics/tags endpoint for tag distribution data
- Add tag filter controls and distribution chart to Analytics page
- Add 12 unit tests covering classification, DB storage, and caching

Closes leeworks-agents/SPARC#1672

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
agent-company
2026-05-19 15:27:46 +00:00
parent 313800215c
commit cd81218154
8 changed files with 590 additions and 17 deletions
+262
View File
@@ -0,0 +1,262 @@
"""Tests for LLM-based patent classification tagging by technology domain."""
import json
from unittest.mock import MagicMock, Mock, patch
import pytest
from SPARC.llm import LLMAnalyzer
class TestClassifyPatentTags:
"""Test the classify_patent_tags method on LLMAnalyzer."""
@pytest.fixture(autouse=True)
def mock_database(self, mocker):
"""Mock the database client for all tests."""
mock_db_client = Mock()
mock_db_client.get_cached_response.return_value = None
mock_db_client.store_message.return_value = 1
mocker.patch("SPARC.llm.DatabaseClient", return_value=mock_db_client)
return mock_db_client
def test_classify_returns_valid_tags(self, mocker, mock_database):
"""Test that classify_patent_tags returns valid canonical tags from LLM."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content='["ai", "semiconductors"]'))]
mock_client.chat.completions.create.return_value = mock_response
analyzer = LLMAnalyzer(api_key="test-key")
tags = analyzer.classify_patent_tags("A patent about neural network accelerator chips")
assert tags == ["ai", "semiconductors"]
mock_client.chat.completions.create.assert_called_once()
# Verify the prompt contains classification instructions
call_args = mock_client.chat.completions.create.call_args
prompt_text = call_args[1]["messages"][0]["content"]
assert "ai" in prompt_text
assert "semiconductors" in prompt_text
assert "JSON array" in prompt_text
def test_classify_filters_invalid_tags(self, mocker, mock_database):
"""Test that invalid tags from LLM response are filtered out."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content='["ai", "quantum_computing", "biotech"]'))]
mock_client.chat.completions.create.return_value = mock_response
analyzer = LLMAnalyzer(api_key="test-key")
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["ai", "biotech"]
assert "quantum_computing" not in tags
def test_classify_returns_other_when_all_invalid(self, mocker, mock_database):
"""Test fallback to 'other' when all LLM tags are invalid."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content='["quantum", "robotics"]'))]
mock_client.chat.completions.create.return_value = mock_response
analyzer = LLMAnalyzer(api_key="test-key")
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["other"]
def test_classify_handles_malformed_json(self, mocker, mock_database):
"""Test graceful handling of non-JSON LLM response."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="ai, semiconductors"))]
mock_client.chat.completions.create.return_value = mock_response
analyzer = LLMAnalyzer(api_key="test-key")
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["other"]
def test_classify_handles_api_error(self, mocker, mock_database):
"""Test graceful fallback when LLM API call fails."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_client.chat.completions.create.side_effect = Exception("API timeout")
analyzer = LLMAnalyzer(api_key="test-key")
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["other"]
def test_classify_test_mode(self, mocker, mock_database):
"""Test that test mode returns 'other' without API call."""
mocker.patch("SPARC.llm.config")
analyzer = LLMAnalyzer(test_mode=True)
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["other"]
def test_classify_no_api_client(self, mocker, mock_database):
"""Test that without API client, classification returns 'other'."""
mocker.patch("SPARC.llm.config")
analyzer = LLMAnalyzer(use_cache=False)
tags = analyzer.classify_patent_tags("Some patent content")
assert tags == ["other"]
def test_canonical_tags_list(self):
"""Test that the canonical tag list matches requirements."""
expected = ["ai", "semiconductors", "materials", "biotech", "networking", "other"]
assert LLMAnalyzer.CANONICAL_TAGS == expected
def test_classify_uses_model_override(self, mocker, mock_database):
"""Test that model override is passed to the API call."""
mock_openai = mocker.patch("SPARC.llm.OpenAI")
mock_client = Mock()
mock_openai.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content='["ai"]'))]
mock_client.chat.completions.create.return_value = mock_response
analyzer = LLMAnalyzer(api_key="test-key")
analyzer.classify_patent_tags("content", model="openai/gpt-4o")
call_args = mock_client.chat.completions.create.call_args
assert call_args[1]["model"] == "openai/gpt-4o"
class TestPatentTagsStorage:
"""Test that tags are persisted to the database."""
@pytest.fixture(autouse=True)
def mock_db(self, mocker):
"""Mock DatabaseClient for all tests."""
mock_db_cls = mocker.patch("SPARC.analyzer.DatabaseClient")
mock_db_instance = MagicMock()
mock_db_instance.get_cached_patent.return_value = None
mock_db_instance.get_cached_serp_query.return_value = None
mock_db_cls.return_value = mock_db_instance
return mock_db_instance
def test_tags_stored_after_classification(self, mocker, mock_db):
"""Test that classify_patent_tags results are stored via update_patent_tags."""
from SPARC.analyzer import CompanyAnalyzer
from SPARC.types import Patent, Patents
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf")
mock_query.return_value = Patents(patents=[patent])
def save_side_effect(p):
p.pdf_path = "patents/US123.pdf"
return p
mock_save.side_effect = save_side_effect
mock_parse.return_value = {"abstract": "Test abstract"}
mock_minimize.return_value = "Minimized content"
mock_llm_instance = Mock()
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis result"
mock_llm_instance.classify_patent_tags.return_value = ["ai", "semiconductors"]
mock_llm_cls.return_value = mock_llm_instance
analyzer = CompanyAnalyzer()
analyzer.analyze_company("TestCorp")
# Verify classification was called
mock_llm_instance.classify_patent_tags.assert_called_once_with(
patent_content="Minimized content", model=None
)
# Verify tags were persisted
mock_db.update_patent_tags.assert_called_once_with("US123", ["ai", "semiconductors"])
def test_cached_tags_skip_classification(self, mocker, mock_db):
"""Test that patents with cached tags skip re-classification."""
from SPARC.analyzer import CompanyAnalyzer
from SPARC.types import Patent, Patents
mocker.patch("SPARC.analyzer.SERP.query")
mocker.patch("SPARC.analyzer.SERP.save_patents")
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
# Simulate DB cache hit with tags
mock_db.get_cached_patent.return_value = {
"patent_id": "US123",
"minimized_content": "Cached content",
"patent_tags": ["biotech"],
}
mock_db.get_cached_serp_query.return_value = ["US123"]
mock_llm_instance = Mock()
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis"
mock_llm_cls.return_value = mock_llm_instance
analyzer = CompanyAnalyzer()
analyzer.analyze_company("TestCorp")
# Should NOT classify since tags already present from cache
mock_llm_instance.classify_patent_tags.assert_not_called()
mock_db.update_patent_tags.assert_not_called()
def test_tags_in_analysis_result(self, mocker, mock_db):
"""Test that tags appear in the CompanyAnalysisResult."""
from SPARC.analyzer import CompanyAnalyzer
from SPARC.types import Patent, Patents
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf")
mock_query.return_value = Patents(patents=[patent])
def save_side_effect(p):
p.pdf_path = "patents/US123.pdf"
return p
mock_save.side_effect = save_side_effect
mock_parse.return_value = {"abstract": "Test abstract"}
mock_minimize.return_value = "Minimized content"
mock_llm_instance = Mock()
mock_llm_instance.analyze_patent_portfolio.return_value = "Strong innovation"
mock_llm_instance.classify_patent_tags.return_value = ["ai", "networking"]
mock_llm_cls.return_value = mock_llm_instance
# After analysis, get_cached_patent returns the patent with tags
mock_db.get_cached_serp_query.side_effect = [None, ["US123"]]
mock_db.get_cached_patent.side_effect = [
None, # First call during _process_single_patent
{"patent_id": "US123", "patent_tags": ["ai", "networking"]}, # Second call during _analyze_company_safe
]
analyzer = CompanyAnalyzer()
result = analyzer._analyze_company_safe("TestCorp")
assert result.success is True
assert result.tags == ["ai", "networking"]