forked from 0xWheatyz/SPARC
Add LLM-based patent classification tagging by technology domain
- Add classify_patent_tags() to LLMAnalyzer with canonical tag list (ai, semiconductors, materials, biotech, networking, other) - Add patent_tags TEXT[] column to patents table with GIN index - Run classification automatically in the analysis pipeline after patent processing; persist tags via update_patent_tags() - Include tags in CompanyAnalysisResult and API response models - Add ?tags= filter to GET /analyze/batch endpoint - Add GET /analytics/tags endpoint for tag distribution data - Add tag filter controls and distribution chart to Analytics page - Add 12 unit tests covering classification, DB storage, and caching Closes leeworks-agents/SPARC#1672 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,262 @@
|
||||
"""Tests for LLM-based patent classification tagging by technology domain."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
|
||||
|
||||
class TestClassifyPatentTags:
|
||||
"""Test the classify_patent_tags method on LLMAnalyzer."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_database(self, mocker):
|
||||
"""Mock the database client for all tests."""
|
||||
mock_db_client = Mock()
|
||||
mock_db_client.get_cached_response.return_value = None
|
||||
mock_db_client.store_message.return_value = 1
|
||||
mocker.patch("SPARC.llm.DatabaseClient", return_value=mock_db_client)
|
||||
return mock_db_client
|
||||
|
||||
def test_classify_returns_valid_tags(self, mocker, mock_database):
|
||||
"""Test that classify_patent_tags returns valid canonical tags from LLM."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content='["ai", "semiconductors"]'))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
tags = analyzer.classify_patent_tags("A patent about neural network accelerator chips")
|
||||
|
||||
assert tags == ["ai", "semiconductors"]
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Verify the prompt contains classification instructions
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
prompt_text = call_args[1]["messages"][0]["content"]
|
||||
assert "ai" in prompt_text
|
||||
assert "semiconductors" in prompt_text
|
||||
assert "JSON array" in prompt_text
|
||||
|
||||
def test_classify_filters_invalid_tags(self, mocker, mock_database):
|
||||
"""Test that invalid tags from LLM response are filtered out."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content='["ai", "quantum_computing", "biotech"]'))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["ai", "biotech"]
|
||||
assert "quantum_computing" not in tags
|
||||
|
||||
def test_classify_returns_other_when_all_invalid(self, mocker, mock_database):
|
||||
"""Test fallback to 'other' when all LLM tags are invalid."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content='["quantum", "robotics"]'))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["other"]
|
||||
|
||||
def test_classify_handles_malformed_json(self, mocker, mock_database):
|
||||
"""Test graceful handling of non-JSON LLM response."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="ai, semiconductors"))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["other"]
|
||||
|
||||
def test_classify_handles_api_error(self, mocker, mock_database):
|
||||
"""Test graceful fallback when LLM API call fails."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_client.chat.completions.create.side_effect = Exception("API timeout")
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["other"]
|
||||
|
||||
def test_classify_test_mode(self, mocker, mock_database):
|
||||
"""Test that test mode returns 'other' without API call."""
|
||||
mocker.patch("SPARC.llm.config")
|
||||
|
||||
analyzer = LLMAnalyzer(test_mode=True)
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["other"]
|
||||
|
||||
def test_classify_no_api_client(self, mocker, mock_database):
|
||||
"""Test that without API client, classification returns 'other'."""
|
||||
mocker.patch("SPARC.llm.config")
|
||||
|
||||
analyzer = LLMAnalyzer(use_cache=False)
|
||||
tags = analyzer.classify_patent_tags("Some patent content")
|
||||
|
||||
assert tags == ["other"]
|
||||
|
||||
def test_canonical_tags_list(self):
|
||||
"""Test that the canonical tag list matches requirements."""
|
||||
expected = ["ai", "semiconductors", "materials", "biotech", "networking", "other"]
|
||||
assert LLMAnalyzer.CANONICAL_TAGS == expected
|
||||
|
||||
def test_classify_uses_model_override(self, mocker, mock_database):
|
||||
"""Test that model override is passed to the API call."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content='["ai"]'))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
analyzer.classify_patent_tags("content", model="openai/gpt-4o")
|
||||
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
assert call_args[1]["model"] == "openai/gpt-4o"
|
||||
|
||||
|
||||
class TestPatentTagsStorage:
|
||||
"""Test that tags are persisted to the database."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db(self, mocker):
|
||||
"""Mock DatabaseClient for all tests."""
|
||||
mock_db_cls = mocker.patch("SPARC.analyzer.DatabaseClient")
|
||||
mock_db_instance = MagicMock()
|
||||
mock_db_instance.get_cached_patent.return_value = None
|
||||
mock_db_instance.get_cached_serp_query.return_value = None
|
||||
mock_db_cls.return_value = mock_db_instance
|
||||
return mock_db_instance
|
||||
|
||||
def test_tags_stored_after_classification(self, mocker, mock_db):
|
||||
"""Test that classify_patent_tags results are stored via update_patent_tags."""
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.types import Patent, Patents
|
||||
|
||||
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
|
||||
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
|
||||
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
|
||||
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf")
|
||||
mock_query.return_value = Patents(patents=[patent])
|
||||
|
||||
def save_side_effect(p):
|
||||
p.pdf_path = "patents/US123.pdf"
|
||||
return p
|
||||
|
||||
mock_save.side_effect = save_side_effect
|
||||
mock_parse.return_value = {"abstract": "Test abstract"}
|
||||
mock_minimize.return_value = "Minimized content"
|
||||
|
||||
mock_llm_instance = Mock()
|
||||
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis result"
|
||||
mock_llm_instance.classify_patent_tags.return_value = ["ai", "semiconductors"]
|
||||
mock_llm_cls.return_value = mock_llm_instance
|
||||
|
||||
analyzer = CompanyAnalyzer()
|
||||
analyzer.analyze_company("TestCorp")
|
||||
|
||||
# Verify classification was called
|
||||
mock_llm_instance.classify_patent_tags.assert_called_once_with(
|
||||
patent_content="Minimized content", model=None
|
||||
)
|
||||
|
||||
# Verify tags were persisted
|
||||
mock_db.update_patent_tags.assert_called_once_with("US123", ["ai", "semiconductors"])
|
||||
|
||||
def test_cached_tags_skip_classification(self, mocker, mock_db):
|
||||
"""Test that patents with cached tags skip re-classification."""
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.types import Patent, Patents
|
||||
|
||||
mocker.patch("SPARC.analyzer.SERP.query")
|
||||
mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
# Simulate DB cache hit with tags
|
||||
mock_db.get_cached_patent.return_value = {
|
||||
"patent_id": "US123",
|
||||
"minimized_content": "Cached content",
|
||||
"patent_tags": ["biotech"],
|
||||
}
|
||||
mock_db.get_cached_serp_query.return_value = ["US123"]
|
||||
|
||||
mock_llm_instance = Mock()
|
||||
mock_llm_instance.analyze_patent_portfolio.return_value = "Analysis"
|
||||
mock_llm_cls.return_value = mock_llm_instance
|
||||
|
||||
analyzer = CompanyAnalyzer()
|
||||
analyzer.analyze_company("TestCorp")
|
||||
|
||||
# Should NOT classify since tags already present from cache
|
||||
mock_llm_instance.classify_patent_tags.assert_not_called()
|
||||
mock_db.update_patent_tags.assert_not_called()
|
||||
|
||||
def test_tags_in_analysis_result(self, mocker, mock_db):
|
||||
"""Test that tags appear in the CompanyAnalysisResult."""
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.types import Patent, Patents
|
||||
|
||||
mock_query = mocker.patch("SPARC.analyzer.SERP.query")
|
||||
mock_save = mocker.patch("SPARC.analyzer.SERP.save_patents")
|
||||
mock_parse = mocker.patch("SPARC.analyzer.SERP.parse_patent_pdf")
|
||||
mock_minimize = mocker.patch("SPARC.analyzer.SERP.minimize_patent_for_llm")
|
||||
mock_llm_cls = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||
|
||||
patent = Patent(patent_id="US123", pdf_link="http://example.com/test.pdf")
|
||||
mock_query.return_value = Patents(patents=[patent])
|
||||
|
||||
def save_side_effect(p):
|
||||
p.pdf_path = "patents/US123.pdf"
|
||||
return p
|
||||
|
||||
mock_save.side_effect = save_side_effect
|
||||
mock_parse.return_value = {"abstract": "Test abstract"}
|
||||
mock_minimize.return_value = "Minimized content"
|
||||
|
||||
mock_llm_instance = Mock()
|
||||
mock_llm_instance.analyze_patent_portfolio.return_value = "Strong innovation"
|
||||
mock_llm_instance.classify_patent_tags.return_value = ["ai", "networking"]
|
||||
mock_llm_cls.return_value = mock_llm_instance
|
||||
|
||||
# After analysis, get_cached_patent returns the patent with tags
|
||||
mock_db.get_cached_serp_query.side_effect = [None, ["US123"]]
|
||||
mock_db.get_cached_patent.side_effect = [
|
||||
None, # First call during _process_single_patent
|
||||
{"patent_id": "US123", "patent_tags": ["ai", "networking"]}, # Second call during _analyze_company_safe
|
||||
]
|
||||
|
||||
analyzer = CompanyAnalyzer()
|
||||
result = analyzer._analyze_company_safe("TestCorp")
|
||||
|
||||
assert result.success is True
|
||||
assert result.tags == ["ai", "networking"]
|
||||
Reference in New Issue
Block a user