forked from 0xWheatyz/SPARC
test: update tests for cache mode terminology
Rename database mode tests to cache mode to reflect new architecture: - Replace USE_DATABASE with USE_CACHE references - Update test assertions for cache behavior - Maintain backward compatibility testing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
+79
-25
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test script to verify database mode functionality.
|
||||
"""Test script to verify database caching functionality.
|
||||
|
||||
This script tests the LLMAnalyzer in database mode without requiring
|
||||
This script tests the LLMAnalyzer with database caching without requiring
|
||||
actual API keys or patent downloads.
|
||||
"""
|
||||
|
||||
@@ -9,28 +9,29 @@ from SPARC.llm import LLMAnalyzer
|
||||
from SPARC.database import DatabaseClient
|
||||
from SPARC import config
|
||||
|
||||
def test_database_mode():
|
||||
"""Test that database mode stores messages correctly."""
|
||||
print("Testing Database Mode")
|
||||
def test_database_storage():
|
||||
"""Test that messages are always stored in database."""
|
||||
print("Testing Database Storage & Caching")
|
||||
print("=" * 70)
|
||||
|
||||
# Initialize analyzer in database mode
|
||||
print("\n1. Initializing LLMAnalyzer in database mode...")
|
||||
analyzer = LLMAnalyzer(use_database=True)
|
||||
# Initialize analyzer (database is always used)
|
||||
print("\n1. Initializing LLMAnalyzer...")
|
||||
analyzer = LLMAnalyzer(use_cache=True)
|
||||
|
||||
print(f" - use_database: {analyzer.use_database}")
|
||||
print(f" - use_cache: {analyzer.use_cache}")
|
||||
print(f" - db_client: {analyzer.db_client is not None}")
|
||||
print(f" - client (API): {analyzer.client is not None}")
|
||||
|
||||
# Test single patent analysis
|
||||
print("\n2. Testing single patent analysis (database mode)...")
|
||||
# Test single patent analysis (without API key, stores placeholder)
|
||||
print("\n2. Testing single patent analysis (no API key)...")
|
||||
result = analyzer.analyze_patent_content(
|
||||
patent_content="Test patent content about semiconductor innovation",
|
||||
company_name="TestCorp"
|
||||
)
|
||||
print(f" Result: {result}")
|
||||
print(f" Result: {result[:80]}...")
|
||||
|
||||
# Test portfolio analysis
|
||||
print("\n3. Testing portfolio analysis (database mode)...")
|
||||
print("\n3. Testing portfolio analysis (no API key)...")
|
||||
test_patents = [
|
||||
{"patent_id": "US001", "content": "First test patent"},
|
||||
{"patent_id": "US002", "content": "Second test patent"},
|
||||
@@ -39,7 +40,7 @@ def test_database_mode():
|
||||
patents_data=test_patents,
|
||||
company_name="TestCorp"
|
||||
)
|
||||
print(f" Result: {result}")
|
||||
print(f" Result: {result[:80]}...")
|
||||
|
||||
# Verify messages were stored
|
||||
print("\n4. Verifying messages were stored...")
|
||||
@@ -48,7 +49,8 @@ def test_database_mode():
|
||||
print(f" Found {len(messages)} stored messages")
|
||||
|
||||
for msg in messages:
|
||||
print(f" - ID: {msg['id']}, Type: {msg['analysis_type']}, Timestamp: {msg['timestamp']}")
|
||||
cached_status = "CACHED" if msg.get('is_cached') else "NEW"
|
||||
print(f" - ID: {msg['id']}, Type: {msg['analysis_type']}, Status: {cached_status}")
|
||||
|
||||
# Get analytics
|
||||
print("\n5. Getting analytics...")
|
||||
@@ -58,18 +60,68 @@ def test_database_mode():
|
||||
print(f" By type: {analytics['by_type']}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Database mode test completed successfully!")
|
||||
print("Database storage test completed successfully!")
|
||||
|
||||
def test_api_mode():
|
||||
"""Test that API mode initializes correctly."""
|
||||
print("\nTesting API Mode")
|
||||
def test_caching():
|
||||
"""Test that caching works correctly."""
|
||||
print("\nTesting Cache Functionality")
|
||||
print("=" * 70)
|
||||
|
||||
print("\n1. Initializing LLMAnalyzer in API mode...")
|
||||
analyzer = LLMAnalyzer(use_database=False, test_mode=True)
|
||||
db_client = DatabaseClient(config.database_url)
|
||||
db_client.initialize_schema()
|
||||
|
||||
# Store a fake cached response
|
||||
print("\n1. Storing a test response in database...")
|
||||
test_prompt = "Test prompt for caching"
|
||||
test_response = "This is a cached response from previous API call"
|
||||
|
||||
db_client.store_message(
|
||||
prompt=test_prompt,
|
||||
response=test_response,
|
||||
company_name="CacheTest",
|
||||
analysis_type="test",
|
||||
model="test-model"
|
||||
)
|
||||
|
||||
# Try to retrieve from cache
|
||||
print("\n2. Testing cache retrieval...")
|
||||
cached = db_client.get_cached_response(
|
||||
prompt=test_prompt,
|
||||
company_name="CacheTest",
|
||||
analysis_type="test"
|
||||
)
|
||||
|
||||
if cached:
|
||||
print(f" Cache hit! Response: {cached['response']}")
|
||||
else:
|
||||
print(" Cache miss (unexpected)")
|
||||
|
||||
# Test cache miss
|
||||
print("\n3. Testing cache miss...")
|
||||
cached = db_client.get_cached_response(
|
||||
prompt="Different prompt",
|
||||
company_name="CacheTest",
|
||||
analysis_type="test"
|
||||
)
|
||||
|
||||
if cached:
|
||||
print(" Unexpected cache hit")
|
||||
else:
|
||||
print(" Cache miss as expected")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Cache test completed successfully!")
|
||||
|
||||
def test_test_mode():
|
||||
"""Test that test mode works correctly."""
|
||||
print("\nTesting Test Mode")
|
||||
print("=" * 70)
|
||||
|
||||
print("\n1. Initializing LLMAnalyzer in test mode...")
|
||||
analyzer = LLMAnalyzer(test_mode=True)
|
||||
|
||||
print(f" - use_database: {analyzer.use_database}")
|
||||
print(f" - test_mode: {analyzer.test_mode}")
|
||||
print(f" - db_client: {analyzer.db_client is not None}")
|
||||
|
||||
print("\n2. Testing single patent analysis (test mode)...")
|
||||
result = analyzer.analyze_patent_content(
|
||||
@@ -79,9 +131,11 @@ def test_api_mode():
|
||||
print(f" Result: {result}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("API mode test completed successfully!")
|
||||
print("Test mode test completed successfully!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_database_mode()
|
||||
test_database_storage()
|
||||
print("\n")
|
||||
test_api_mode()
|
||||
test_caching()
|
||||
print("\n")
|
||||
test_test_mode()
|
||||
|
||||
+66
-10
@@ -1,13 +1,22 @@
|
||||
"""Tests for LLM analysis functionality."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
|
||||
|
||||
class TestLLMAnalyzer:
|
||||
"""Test LLM analyzer initialization and API interaction."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_database(self, mocker):
|
||||
"""Mock the database client for all tests."""
|
||||
mock_db_client = Mock()
|
||||
mock_db_client.get_cached_response.return_value = None # No cache hit by default
|
||||
mock_db_client.store_message.return_value = 1
|
||||
mocker.patch("SPARC.llm.DatabaseClient", return_value=mock_db_client)
|
||||
return mock_db_client
|
||||
|
||||
def test_analyzer_initialization_with_api_key(self, mocker):
|
||||
"""Test that analyzer initializes with provided API key."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
@@ -25,7 +34,7 @@ class TestLLMAnalyzer:
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_config = mocker.patch("SPARC.llm.config")
|
||||
mock_config.openrouter_api_key = "config-key-456"
|
||||
mock_config.use_database = False
|
||||
mock_config.use_cache = True
|
||||
mock_config.database_url = "postgresql://localhost/test"
|
||||
|
||||
analyzer = LLMAnalyzer()
|
||||
@@ -35,7 +44,7 @@ class TestLLMAnalyzer:
|
||||
base_url="https://openrouter.ai/api/v1"
|
||||
)
|
||||
|
||||
def test_analyze_patent_content(self, mocker):
|
||||
def test_analyze_patent_content(self, mocker, mock_database):
|
||||
"""Test single patent content analysis."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
@@ -44,9 +53,10 @@ class TestLLMAnalyzer:
|
||||
# Mock the API response
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="Innovative GPU architecture."))]
|
||||
mock_response.usage = Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150)
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
analyzer = LLMAnalyzer(api_key="test-key", use_cache=False)
|
||||
result = analyzer.analyze_patent_content(
|
||||
patent_content="ABSTRACT: GPU with new cache design...",
|
||||
company_name="NVIDIA",
|
||||
@@ -61,7 +71,32 @@ class TestLLMAnalyzer:
|
||||
assert "NVIDIA" in prompt_text
|
||||
assert "GPU with new cache design" in prompt_text
|
||||
|
||||
def test_analyze_patent_portfolio(self, mocker):
|
||||
# Verify message was stored in database
|
||||
mock_database.store_message.assert_called_once()
|
||||
|
||||
def test_analyze_patent_content_cache_hit(self, mocker, mock_database):
|
||||
"""Test that cached responses are returned without API call."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
# Set up cache hit
|
||||
mock_database.get_cached_response.return_value = {
|
||||
"id": 1,
|
||||
"response": "Cached analysis result"
|
||||
}
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key", use_cache=True)
|
||||
result = analyzer.analyze_patent_content(
|
||||
patent_content="ABSTRACT: GPU with new cache design...",
|
||||
company_name="NVIDIA",
|
||||
)
|
||||
|
||||
assert result == "Cached analysis result"
|
||||
# API should NOT be called on cache hit
|
||||
mock_client.chat.completions.create.assert_not_called()
|
||||
|
||||
def test_analyze_patent_portfolio(self, mocker, mock_database):
|
||||
"""Test portfolio analysis with multiple patents."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
@@ -72,9 +107,10 @@ class TestLLMAnalyzer:
|
||||
mock_response.choices = [
|
||||
Mock(message=Mock(content="Strong portfolio in AI and graphics."))
|
||||
]
|
||||
mock_response.usage = Mock(prompt_tokens=200, completion_tokens=100, total_tokens=300)
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
analyzer = LLMAnalyzer(api_key="test-key", use_cache=False)
|
||||
patents_data = [
|
||||
{"patent_id": "US123", "content": "AI acceleration patent"},
|
||||
{"patent_id": "US456", "content": "Graphics rendering patent"},
|
||||
@@ -95,7 +131,7 @@ class TestLLMAnalyzer:
|
||||
assert "AI acceleration patent" in prompt_text
|
||||
assert "Graphics rendering patent" in prompt_text
|
||||
|
||||
def test_analyze_patent_portfolio_with_correct_token_limit(self, mocker):
|
||||
def test_analyze_patent_portfolio_with_correct_token_limit(self, mocker, mock_database):
|
||||
"""Test that portfolio analysis uses higher token limit."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
@@ -103,9 +139,10 @@ class TestLLMAnalyzer:
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="Analysis result."))]
|
||||
mock_response.usage = Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150)
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
analyzer = LLMAnalyzer(api_key="test-key", use_cache=False)
|
||||
patents_data = [{"patent_id": "US123", "content": "Test content"}]
|
||||
|
||||
analyzer.analyze_patent_portfolio(patents_data, "TestCo")
|
||||
@@ -114,7 +151,7 @@ class TestLLMAnalyzer:
|
||||
# Portfolio analysis should use 2048 tokens
|
||||
assert call_args[1]["max_tokens"] == 2048
|
||||
|
||||
def test_analyze_single_patent_with_correct_token_limit(self, mocker):
|
||||
def test_analyze_single_patent_with_correct_token_limit(self, mocker, mock_database):
|
||||
"""Test that single patent analysis uses lower token limit."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
mock_client = Mock()
|
||||
@@ -122,11 +159,30 @@ class TestLLMAnalyzer:
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="Analysis result."))]
|
||||
mock_response.usage = Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150)
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
analyzer = LLMAnalyzer(api_key="test-key", use_cache=False)
|
||||
analyzer.analyze_patent_content("Test content", "TestCo")
|
||||
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
# Single patent should use 1024 tokens
|
||||
assert call_args[1]["max_tokens"] == 1024
|
||||
|
||||
def test_database_always_initialized(self, mocker, mock_database):
|
||||
"""Test that database client is always initialized."""
|
||||
mock_openai = mocker.patch("SPARC.llm.OpenAI")
|
||||
|
||||
analyzer = LLMAnalyzer(api_key="test-key")
|
||||
|
||||
assert analyzer.db_client is not None
|
||||
|
||||
def test_no_api_key_stores_placeholder(self, mocker, mock_database):
|
||||
"""Test that without API key, a placeholder is stored."""
|
||||
mocker.patch("SPARC.llm.config")
|
||||
|
||||
analyzer = LLMAnalyzer(use_cache=False)
|
||||
result = analyzer.analyze_patent_content("Test content", "TestCo")
|
||||
|
||||
assert "[NO API]" in result
|
||||
mock_database.store_message.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user