forked from 0xWheatyz/SPARC
test: update tests for cache mode terminology
Rename database mode tests to cache mode to reflect new architecture: - Replace USE_DATABASE with USE_CACHE references - Update test assertions for cache behavior - Maintain backward compatibility testing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
+79
-25
@@ -1,7 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Test script to verify database mode functionality.
|
"""Test script to verify database caching functionality.
|
||||||
|
|
||||||
This script tests the LLMAnalyzer in database mode without requiring
|
This script tests the LLMAnalyzer with database caching without requiring
|
||||||
actual API keys or patent downloads.
|
actual API keys or patent downloads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -9,28 +9,29 @@ from SPARC.llm import LLMAnalyzer
|
|||||||
from SPARC.database import DatabaseClient
|
from SPARC.database import DatabaseClient
|
||||||
from SPARC import config
|
from SPARC import config
|
||||||
|
|
||||||
def test_database_mode():
|
def test_database_storage():
|
||||||
"""Test that database mode stores messages correctly."""
|
"""Test that messages are always stored in database."""
|
||||||
print("Testing Database Mode")
|
print("Testing Database Storage & Caching")
|
||||||
print("=" * 70)
|
print("=" * 70)
|
||||||
|
|
||||||
# Initialize analyzer in database mode
|
# Initialize analyzer (database is always used)
|
||||||
print("\n1. Initializing LLMAnalyzer in database mode...")
|
print("\n1. Initializing LLMAnalyzer...")
|
||||||
analyzer = LLMAnalyzer(use_database=True)
|
analyzer = LLMAnalyzer(use_cache=True)
|
||||||
|
|
||||||
print(f" - use_database: {analyzer.use_database}")
|
print(f" - use_cache: {analyzer.use_cache}")
|
||||||
print(f" - db_client: {analyzer.db_client is not None}")
|
print(f" - db_client: {analyzer.db_client is not None}")
|
||||||
|
print(f" - client (API): {analyzer.client is not None}")
|
||||||
|
|
||||||
# Test single patent analysis
|
# Test single patent analysis (without API key, stores placeholder)
|
||||||
print("\n2. Testing single patent analysis (database mode)...")
|
print("\n2. Testing single patent analysis (no API key)...")
|
||||||
result = analyzer.analyze_patent_content(
|
result = analyzer.analyze_patent_content(
|
||||||
patent_content="Test patent content about semiconductor innovation",
|
patent_content="Test patent content about semiconductor innovation",
|
||||||
company_name="TestCorp"
|
company_name="TestCorp"
|
||||||
)
|
)
|
||||||
print(f" Result: {result}")
|
print(f" Result: {result[:80]}...")
|
||||||
|
|
||||||
# Test portfolio analysis
|
# Test portfolio analysis
|
||||||
print("\n3. Testing portfolio analysis (database mode)...")
|
print("\n3. Testing portfolio analysis (no API key)...")
|
||||||
test_patents = [
|
test_patents = [
|
||||||
{"patent_id": "US001", "content": "First test patent"},
|
{"patent_id": "US001", "content": "First test patent"},
|
||||||
{"patent_id": "US002", "content": "Second test patent"},
|
{"patent_id": "US002", "content": "Second test patent"},
|
||||||
@@ -39,7 +40,7 @@ def test_database_mode():
|
|||||||
patents_data=test_patents,
|
patents_data=test_patents,
|
||||||
company_name="TestCorp"
|
company_name="TestCorp"
|
||||||
)
|
)
|
||||||
print(f" Result: {result}")
|
print(f" Result: {result[:80]}...")
|
||||||
|
|
||||||
# Verify messages were stored
|
# Verify messages were stored
|
||||||
print("\n4. Verifying messages were stored...")
|
print("\n4. Verifying messages were stored...")
|
||||||
@@ -48,7 +49,8 @@ def test_database_mode():
|
|||||||
print(f" Found {len(messages)} stored messages")
|
print(f" Found {len(messages)} stored messages")
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
print(f" - ID: {msg['id']}, Type: {msg['analysis_type']}, Timestamp: {msg['timestamp']}")
|
cached_status = "CACHED" if msg.get('is_cached') else "NEW"
|
||||||
|
print(f" - ID: {msg['id']}, Type: {msg['analysis_type']}, Status: {cached_status}")
|
||||||
|
|
||||||
# Get analytics
|
# Get analytics
|
||||||
print("\n5. Getting analytics...")
|
print("\n5. Getting analytics...")
|
||||||
@@ -58,18 +60,68 @@ def test_database_mode():
|
|||||||
print(f" By type: {analytics['by_type']}")
|
print(f" By type: {analytics['by_type']}")
|
||||||
|
|
||||||
print("\n" + "=" * 70)
|
print("\n" + "=" * 70)
|
||||||
print("Database mode test completed successfully!")
|
print("Database storage test completed successfully!")
|
||||||
|
|
||||||
def test_api_mode():
|
def test_caching():
|
||||||
"""Test that API mode initializes correctly."""
|
"""Test that caching works correctly."""
|
||||||
print("\nTesting API Mode")
|
print("\nTesting Cache Functionality")
|
||||||
print("=" * 70)
|
print("=" * 70)
|
||||||
|
|
||||||
print("\n1. Initializing LLMAnalyzer in API mode...")
|
db_client = DatabaseClient(config.database_url)
|
||||||
analyzer = LLMAnalyzer(use_database=False, test_mode=True)
|
db_client.initialize_schema()
|
||||||
|
|
||||||
|
# Store a fake cached response
|
||||||
|
print("\n1. Storing a test response in database...")
|
||||||
|
test_prompt = "Test prompt for caching"
|
||||||
|
test_response = "This is a cached response from previous API call"
|
||||||
|
|
||||||
|
db_client.store_message(
|
||||||
|
prompt=test_prompt,
|
||||||
|
response=test_response,
|
||||||
|
company_name="CacheTest",
|
||||||
|
analysis_type="test",
|
||||||
|
model="test-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to retrieve from cache
|
||||||
|
print("\n2. Testing cache retrieval...")
|
||||||
|
cached = db_client.get_cached_response(
|
||||||
|
prompt=test_prompt,
|
||||||
|
company_name="CacheTest",
|
||||||
|
analysis_type="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cached:
|
||||||
|
print(f" Cache hit! Response: {cached['response']}")
|
||||||
|
else:
|
||||||
|
print(" Cache miss (unexpected)")
|
||||||
|
|
||||||
|
# Test cache miss
|
||||||
|
print("\n3. Testing cache miss...")
|
||||||
|
cached = db_client.get_cached_response(
|
||||||
|
prompt="Different prompt",
|
||||||
|
company_name="CacheTest",
|
||||||
|
analysis_type="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cached:
|
||||||
|
print(" Unexpected cache hit")
|
||||||
|
else:
|
||||||
|
print(" Cache miss as expected")
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("Cache test completed successfully!")
|
||||||
|
|
||||||
|
def test_test_mode():
|
||||||
|
"""Test that test mode works correctly."""
|
||||||
|
print("\nTesting Test Mode")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
print("\n1. Initializing LLMAnalyzer in test mode...")
|
||||||
|
analyzer = LLMAnalyzer(test_mode=True)
|
||||||
|
|
||||||
print(f" - use_database: {analyzer.use_database}")
|
|
||||||
print(f" - test_mode: {analyzer.test_mode}")
|
print(f" - test_mode: {analyzer.test_mode}")
|
||||||
|
print(f" - db_client: {analyzer.db_client is not None}")
|
||||||
|
|
||||||
print("\n2. Testing single patent analysis (test mode)...")
|
print("\n2. Testing single patent analysis (test mode)...")
|
||||||
result = analyzer.analyze_patent_content(
|
result = analyzer.analyze_patent_content(
|
||||||
@@ -79,9 +131,11 @@ def test_api_mode():
|
|||||||
print(f" Result: {result}")
|
print(f" Result: {result}")
|
||||||
|
|
||||||
print("\n" + "=" * 70)
|
print("\n" + "=" * 70)
|
||||||
print("API mode test completed successfully!")
|
print("Test mode test completed successfully!")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_database_mode()
|
test_database_storage()
|
||||||
print("\n")
|
print("\n")
|
||||||
test_api_mode()
|
test_caching()
|
||||||
|
print("\n")
|
||||||
|
test_test_mode()
|
||||||
|
|||||||
+66
-10
@@ -1,13 +1,22 @@
|
|||||||
"""Tests for LLM analysis functionality."""
|
"""Tests for LLM analysis functionality."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, MagicMock
|
from unittest.mock import Mock, MagicMock, patch
|
||||||
from SPARC.llm import LLMAnalyzer
|
from SPARC.llm import LLMAnalyzer
|
||||||
|
|
||||||
|
|
||||||
class TestLLMAnalyzer:
|
class TestLLMAnalyzer:
|
||||||
"""Test LLM analyzer initialization and API interaction."""
|
"""Test LLM analyzer initialization and API interaction."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_database(self, mocker):
|
||||||
|
"""Mock the database client for all tests."""
|
||||||
|
mock_db_client = Mock()
|
||||||
|
mock_db_client.get_cached_response.return_value = None # No cache hit by default
|
||||||
|
mock_db_client.store_message.return_value = 1
|
||||||
|
mocker.patch("SPARC.llm.DatabaseClient", return_value=mock_db_client)
|
||||||
|
return mock_db_client
|
||||||
|
|
||||||
def test_analyzer_initialization_with_api_key(self, mocker):
|
def test_analyzer_initialization_with_api_key(self, mocker):
|
||||||
"""Test that analyzer initializes with provided API key."""
|
"""Test that analyzer initializes with provided API key."""
|
||||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
@@ -25,7 +34,7 @@ class TestLLMAnalyzer:
|
|||||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
mock_config = mocker.patch("SPARC.llm.config")
|
mock_config = mocker.patch("SPARC.llm.config")
|
||||||
mock_config.openrouter_api_key = "config-key-456"
|
mock_config.openrouter_api_key = "config-key-456"
|
||||||
mock_config.use_database = False
|
mock_config.use_cache = True
|
||||||
mock_config.database_url = "postgresql://localhost/test"
|
mock_config.database_url = "postgresql://localhost/test"
|
||||||
|
|
||||||
analyzer = LLMAnalyzer()
|
analyzer = LLMAnalyzer()
|
||||||
@@ -35,7 +44,7 @@ class TestLLMAnalyzer:
|
|||||||
base_url="https://openrouter.ai/api/v1"
|
base_url="https://openrouter.ai/api/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_analyze_patent_content(self, mocker):
|
def test_analyze_patent_content(self, mocker, mock_database):
|
||||||
"""Test single patent content analysis."""
|
"""Test single patent content analysis."""
|
||||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
@@ -44,9 +53,10 @@ class TestLLMAnalyzer:
|
|||||||
# Mock the API response
|
# Mock the API response
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [Mock(message=Mock(content="Innovative GPU architecture."))]
|
mock_response.choices = [Mock(message=Mock(content="Innovative GPU architecture."))]
|
||||||
|
mock_response.usage = Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150)
|
||||||
mock_client.chat.completions.create.return_value = mock_response
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
analyzer = LLMAnalyzer(api_key="test-key")
|
analyzer = LLMAnalyzer(api_key="test-key", use_cache=False)
|
||||||
result = analyzer.analyze_patent_content(
|
result = analyzer.analyze_patent_content(
|
||||||
patent_content="ABSTRACT: GPU with new cache design...",
|
patent_content="ABSTRACT: GPU with new cache design...",
|
||||||
company_name="NVIDIA",
|
company_name="NVIDIA",
|
||||||
@@ -61,7 +71,32 @@ class TestLLMAnalyzer:
|
|||||||
assert "NVIDIA" in prompt_text
|
assert "NVIDIA" in prompt_text
|
||||||
assert "GPU with new cache design" in prompt_text
|
assert "GPU with new cache design" in prompt_text
|
||||||
|
|
||||||
def test_analyze_patent_portfolio(self, mocker):
|
# Verify message was stored in database
|
||||||
|
mock_database.store_message.assert_called_once()
|
||||||
|
|
||||||
|
def test_analyze_patent_content_cache_hit(self, mocker, mock_database):
|
||||||
|
"""Test that cached responses are returned without API call."""
|
||||||
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_openai.return_value = mock_client
|
||||||
|
|
||||||
|
# Set up cache hit
|
||||||
|
mock_database.get_cached_response.return_value = {
|
||||||
|
"id": 1,
|
||||||
|
"response": "Cached analysis result"
|
||||||
|
}
|
||||||
|
|
||||||
|
analyzer = LLMAnalyzer(api_key="test-key", use_cache=True)
|
||||||
|
result = analyzer.analyze_patent_content(
|
||||||
|
patent_content="ABSTRACT: GPU with new cache design...",
|
||||||
|
company_name="NVIDIA",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "Cached analysis result"
|
||||||
|
# API should NOT be called on cache hit
|
||||||
|
mock_client.chat.completions.create.assert_not_called()
|
||||||
|
|
||||||
|
def test_analyze_patent_portfolio(self, mocker, mock_database):
|
||||||
"""Test portfolio analysis with multiple patents."""
|
"""Test portfolio analysis with multiple patents."""
|
||||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
@@ -72,9 +107,10 @@ class TestLLMAnalyzer:
|
|||||||
mock_response.choices = [
|
mock_response.choices = [
|
||||||
Mock(message=Mock(content="Strong portfolio in AI and graphics."))
|
Mock(message=Mock(content="Strong portfolio in AI and graphics."))
|
||||||
]
|
]
|
||||||
|
mock_response.usage = Mock(prompt_tokens=200, completion_tokens=100, total_tokens=300)
|
||||||
mock_client.chat.completions.create.return_value = mock_response
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
analyzer = LLMAnalyzer(api_key="test-key")
|
analyzer = LLMAnalyzer(api_key="test-key", use_cache=False)
|
||||||
patents_data = [
|
patents_data = [
|
||||||
{"patent_id": "US123", "content": "AI acceleration patent"},
|
{"patent_id": "US123", "content": "AI acceleration patent"},
|
||||||
{"patent_id": "US456", "content": "Graphics rendering patent"},
|
{"patent_id": "US456", "content": "Graphics rendering patent"},
|
||||||
@@ -95,7 +131,7 @@ class TestLLMAnalyzer:
|
|||||||
assert "AI acceleration patent" in prompt_text
|
assert "AI acceleration patent" in prompt_text
|
||||||
assert "Graphics rendering patent" in prompt_text
|
assert "Graphics rendering patent" in prompt_text
|
||||||
|
|
||||||
def test_analyze_patent_portfolio_with_correct_token_limit(self, mocker):
|
def test_analyze_patent_portfolio_with_correct_token_limit(self, mocker, mock_database):
|
||||||
"""Test that portfolio analysis uses higher token limit."""
|
"""Test that portfolio analysis uses higher token limit."""
|
||||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
@@ -103,9 +139,10 @@ class TestLLMAnalyzer:
|
|||||||
|
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [Mock(message=Mock(content="Analysis result."))]
|
mock_response.choices = [Mock(message=Mock(content="Analysis result."))]
|
||||||
|
mock_response.usage = Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150)
|
||||||
mock_client.chat.completions.create.return_value = mock_response
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
analyzer = LLMAnalyzer(api_key="test-key")
|
analyzer = LLMAnalyzer(api_key="test-key", use_cache=False)
|
||||||
patents_data = [{"patent_id": "US123", "content": "Test content"}]
|
patents_data = [{"patent_id": "US123", "content": "Test content"}]
|
||||||
|
|
||||||
analyzer.analyze_patent_portfolio(patents_data, "TestCo")
|
analyzer.analyze_patent_portfolio(patents_data, "TestCo")
|
||||||
@@ -114,7 +151,7 @@ class TestLLMAnalyzer:
|
|||||||
# Portfolio analysis should use 2048 tokens
|
# Portfolio analysis should use 2048 tokens
|
||||||
assert call_args[1]["max_tokens"] == 2048
|
assert call_args[1]["max_tokens"] == 2048
|
||||||
|
|
||||||
def test_analyze_single_patent_with_correct_token_limit(self, mocker):
|
def test_analyze_single_patent_with_correct_token_limit(self, mocker, mock_database):
|
||||||
"""Test that single patent analysis uses lower token limit."""
|
"""Test that single patent analysis uses lower token limit."""
|
||||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
@@ -122,11 +159,30 @@ class TestLLMAnalyzer:
|
|||||||
|
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [Mock(message=Mock(content="Analysis result."))]
|
mock_response.choices = [Mock(message=Mock(content="Analysis result."))]
|
||||||
|
mock_response.usage = Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150)
|
||||||
mock_client.chat.completions.create.return_value = mock_response
|
mock_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
analyzer = LLMAnalyzer(api_key="test-key")
|
analyzer = LLMAnalyzer(api_key="test-key", use_cache=False)
|
||||||
analyzer.analyze_patent_content("Test content", "TestCo")
|
analyzer.analyze_patent_content("Test content", "TestCo")
|
||||||
|
|
||||||
call_args = mock_client.chat.completions.create.call_args
|
call_args = mock_client.chat.completions.create.call_args
|
||||||
# Single patent should use 1024 tokens
|
# Single patent should use 1024 tokens
|
||||||
assert call_args[1]["max_tokens"] == 1024
|
assert call_args[1]["max_tokens"] == 1024
|
||||||
|
|
||||||
|
def test_database_always_initialized(self, mocker, mock_database):
|
||||||
|
"""Test that database client is always initialized."""
|
||||||
|
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||||
|
|
||||||
|
analyzer = LLMAnalyzer(api_key="test-key")
|
||||||
|
|
||||||
|
assert analyzer.db_client is not None
|
||||||
|
|
||||||
|
def test_no_api_key_stores_placeholder(self, mocker, mock_database):
|
||||||
|
"""Test that without API key, a placeholder is stored."""
|
||||||
|
mocker.patch("SPARC.llm.config")
|
||||||
|
|
||||||
|
analyzer = LLMAnalyzer(use_cache=False)
|
||||||
|
result = analyzer.analyze_patent_content("Test content", "TestCo")
|
||||||
|
|
||||||
|
assert "[NO API]" in result
|
||||||
|
mock_database.store_message.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user