diff --git a/test_database_mode.py b/test_database_mode.py index 7508ef2..172162e 100644 --- a/test_database_mode.py +++ b/test_database_mode.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 -"""Test script to verify database mode functionality. +"""Test script to verify database caching functionality. -This script tests the LLMAnalyzer in database mode without requiring +This script tests the LLMAnalyzer with database caching without requiring actual API keys or patent downloads. """ @@ -9,28 +9,29 @@ from SPARC.llm import LLMAnalyzer from SPARC.database import DatabaseClient from SPARC import config -def test_database_mode(): - """Test that database mode stores messages correctly.""" - print("Testing Database Mode") +def test_database_storage(): + """Test that messages are always stored in database.""" + print("Testing Database Storage & Caching") print("=" * 70) - # Initialize analyzer in database mode - print("\n1. Initializing LLMAnalyzer in database mode...") - analyzer = LLMAnalyzer(use_database=True) + # Initialize analyzer (database is always used) + print("\n1. Initializing LLMAnalyzer...") + analyzer = LLMAnalyzer(use_cache=True) - print(f" - use_database: {analyzer.use_database}") + print(f" - use_cache: {analyzer.use_cache}") print(f" - db_client: {analyzer.db_client is not None}") + print(f" - client (API): {analyzer.client is not None}") - # Test single patent analysis - print("\n2. Testing single patent analysis (database mode)...") + # Test single patent analysis (without API key, stores placeholder) + print("\n2. Testing single patent analysis (no API key)...") result = analyzer.analyze_patent_content( patent_content="Test patent content about semiconductor innovation", company_name="TestCorp" ) - print(f" Result: {result}") + print(f" Result: {result[:80]}...") # Test portfolio analysis - print("\n3. Testing portfolio analysis (database mode)...") + print("\n3. Testing portfolio analysis (no API key)...") test_patents = [ {"patent_id": "US001", "content": "First test patent"}, {"patent_id": "US002", "content": "Second test patent"}, @@ -39,7 +40,7 @@ def test_database_mode(): patents_data=test_patents, company_name="TestCorp" ) - print(f" Result: {result}") + print(f" Result: {result[:80]}...") # Verify messages were stored print("\n4. Verifying messages were stored...") @@ -48,7 +49,8 @@ def test_database_mode(): print(f" Found {len(messages)} stored messages") for msg in messages: - print(f" - ID: {msg['id']}, Type: {msg['analysis_type']}, Timestamp: {msg['timestamp']}") + cached_status = "CACHED" if msg.get('is_cached') else "NEW" + print(f" - ID: {msg['id']}, Type: {msg['analysis_type']}, Status: {cached_status}") # Get analytics print("\n5. Getting analytics...") @@ -58,18 +60,68 @@ def test_database_mode(): print(f" By type: {analytics['by_type']}") print("\n" + "=" * 70) - print("Database mode test completed successfully!") + print("Database storage test completed successfully!") -def test_api_mode(): - """Test that API mode initializes correctly.""" - print("\nTesting API Mode") +def test_caching(): + """Test that caching works correctly.""" + print("\nTesting Cache Functionality") print("=" * 70) - print("\n1. Initializing LLMAnalyzer in API mode...") - analyzer = LLMAnalyzer(use_database=False, test_mode=True) + db_client = DatabaseClient(config.database_url) + db_client.initialize_schema() + + # Store a fake cached response + print("\n1. Storing a test response in database...") + test_prompt = "Test prompt for caching" + test_response = "This is a cached response from previous API call" + + db_client.store_message( + prompt=test_prompt, + response=test_response, + company_name="CacheTest", + analysis_type="test", + model="test-model" + ) + + # Try to retrieve from cache + print("\n2. Testing cache retrieval...") + cached = db_client.get_cached_response( + prompt=test_prompt, + company_name="CacheTest", + analysis_type="test" + ) + + if cached: + print(f" Cache hit! Response: {cached['response']}") + else: + print(" Cache miss (unexpected)") + + # Test cache miss + print("\n3. Testing cache miss...") + cached = db_client.get_cached_response( + prompt="Different prompt", + company_name="CacheTest", + analysis_type="test" + ) + + if cached: + print(" Unexpected cache hit") + else: + print(" Cache miss as expected") + + print("\n" + "=" * 70) + print("Cache test completed successfully!") + +def test_test_mode(): + """Test that test mode works correctly.""" + print("\nTesting Test Mode") + print("=" * 70) + + print("\n1. Initializing LLMAnalyzer in test mode...") + analyzer = LLMAnalyzer(test_mode=True) - print(f" - use_database: {analyzer.use_database}") print(f" - test_mode: {analyzer.test_mode}") + print(f" - db_client: {analyzer.db_client is not None}") print("\n2. Testing single patent analysis (test mode)...") result = analyzer.analyze_patent_content( @@ -79,9 +131,11 @@ def test_api_mode(): print(f" Result: {result}") print("\n" + "=" * 70) - print("API mode test completed successfully!") + print("Test mode test completed successfully!") if __name__ == "__main__": - test_database_mode() + test_database_storage() print("\n") - test_api_mode() + test_caching() + print("\n") + test_test_mode() diff --git a/tests/test_llm.py b/tests/test_llm.py index 1cc255a..154fdac 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -1,13 +1,22 @@ """Tests for LLM analysis functionality.""" import pytest -from unittest.mock import Mock, MagicMock +from unittest.mock import Mock, MagicMock, patch from SPARC.llm import LLMAnalyzer class TestLLMAnalyzer: """Test LLM analyzer initialization and API interaction.""" + @pytest.fixture(autouse=True) + def mock_database(self, mocker): + """Mock the database client for all tests.""" + mock_db_client = Mock() + mock_db_client.get_cached_response.return_value = None # No cache hit by default + mock_db_client.store_message.return_value = 1 + mocker.patch("SPARC.llm.DatabaseClient", return_value=mock_db_client) + return mock_db_client + def test_analyzer_initialization_with_api_key(self, mocker): """Test that analyzer initializes with provided API key.""" mock_openai = mocker.patch("SPARC.llm.OpenAI") @@ -25,7 +34,7 @@ class TestLLMAnalyzer: mock_openai = mocker.patch("SPARC.llm.OpenAI") mock_config = mocker.patch("SPARC.llm.config") mock_config.openrouter_api_key = "config-key-456" - mock_config.use_database = False + mock_config.use_cache = True mock_config.database_url = "postgresql://localhost/test" analyzer = LLMAnalyzer() @@ -35,7 +44,7 @@ class TestLLMAnalyzer: base_url="https://openrouter.ai/api/v1" ) - def test_analyze_patent_content(self, mocker): + def test_analyze_patent_content(self, mocker, mock_database): """Test single patent content analysis.""" mock_openai = mocker.patch("SPARC.llm.OpenAI") mock_client = Mock() @@ -44,9 +53,10 @@ class TestLLMAnalyzer: # Mock the API response mock_response = Mock() mock_response.choices = [Mock(message=Mock(content="Innovative GPU architecture."))] + mock_response.usage = Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150) mock_client.chat.completions.create.return_value = mock_response - analyzer = LLMAnalyzer(api_key="test-key") + analyzer = LLMAnalyzer(api_key="test-key", use_cache=False) result = analyzer.analyze_patent_content( patent_content="ABSTRACT: GPU with new cache design...", company_name="NVIDIA", @@ -61,7 +71,32 @@ class TestLLMAnalyzer: assert "NVIDIA" in prompt_text assert "GPU with new cache design" in prompt_text - def test_analyze_patent_portfolio(self, mocker): + # Verify message was stored in database + mock_database.store_message.assert_called_once() + + def test_analyze_patent_content_cache_hit(self, mocker, mock_database): + """Test that cached responses are returned without API call.""" + mock_openai = mocker.patch("SPARC.llm.OpenAI") + mock_client = Mock() + mock_openai.return_value = mock_client + + # Set up cache hit + mock_database.get_cached_response.return_value = { + "id": 1, + "response": "Cached analysis result" + } + + analyzer = LLMAnalyzer(api_key="test-key", use_cache=True) + result = analyzer.analyze_patent_content( + patent_content="ABSTRACT: GPU with new cache design...", + company_name="NVIDIA", + ) + + assert result == "Cached analysis result" + # API should NOT be called on cache hit + mock_client.chat.completions.create.assert_not_called() + + def test_analyze_patent_portfolio(self, mocker, mock_database): """Test portfolio analysis with multiple patents.""" mock_openai = mocker.patch("SPARC.llm.OpenAI") mock_client = Mock() @@ -72,9 +107,10 @@ class TestLLMAnalyzer: mock_response.choices = [ Mock(message=Mock(content="Strong portfolio in AI and graphics.")) ] + mock_response.usage = Mock(prompt_tokens=200, completion_tokens=100, total_tokens=300) mock_client.chat.completions.create.return_value = mock_response - analyzer = LLMAnalyzer(api_key="test-key") + analyzer = LLMAnalyzer(api_key="test-key", use_cache=False) patents_data = [ {"patent_id": "US123", "content": "AI acceleration patent"}, {"patent_id": "US456", "content": "Graphics rendering patent"}, @@ -95,7 +131,7 @@ class TestLLMAnalyzer: assert "AI acceleration patent" in prompt_text assert "Graphics rendering patent" in prompt_text - def test_analyze_patent_portfolio_with_correct_token_limit(self, mocker): + def test_analyze_patent_portfolio_with_correct_token_limit(self, mocker, mock_database): """Test that portfolio analysis uses higher token limit.""" mock_openai = mocker.patch("SPARC.llm.OpenAI") mock_client = Mock() @@ -103,9 +139,10 @@ class TestLLMAnalyzer: mock_response = Mock() mock_response.choices = [Mock(message=Mock(content="Analysis result."))] + mock_response.usage = Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150) mock_client.chat.completions.create.return_value = mock_response - analyzer = LLMAnalyzer(api_key="test-key") + analyzer = LLMAnalyzer(api_key="test-key", use_cache=False) patents_data = [{"patent_id": "US123", "content": "Test content"}] analyzer.analyze_patent_portfolio(patents_data, "TestCo") @@ -114,7 +151,7 @@ class TestLLMAnalyzer: # Portfolio analysis should use 2048 tokens assert call_args[1]["max_tokens"] == 2048 - def test_analyze_single_patent_with_correct_token_limit(self, mocker): + def test_analyze_single_patent_with_correct_token_limit(self, mocker, mock_database): """Test that single patent analysis uses lower token limit.""" mock_openai = mocker.patch("SPARC.llm.OpenAI") mock_client = Mock() @@ -122,11 +159,30 @@ class TestLLMAnalyzer: mock_response = Mock() mock_response.choices = [Mock(message=Mock(content="Analysis result."))] + mock_response.usage = Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150) mock_client.chat.completions.create.return_value = mock_response - analyzer = LLMAnalyzer(api_key="test-key") + analyzer = LLMAnalyzer(api_key="test-key", use_cache=False) analyzer.analyze_patent_content("Test content", "TestCo") call_args = mock_client.chat.completions.create.call_args # Single patent should use 1024 tokens assert call_args[1]["max_tokens"] == 1024 + + def test_database_always_initialized(self, mocker, mock_database): + """Test that database client is always initialized.""" + mock_openai = mocker.patch("SPARC.llm.OpenAI") + + analyzer = LLMAnalyzer(api_key="test-key") + + assert analyzer.db_client is not None + + def test_no_api_key_stores_placeholder(self, mocker, mock_database): + """Test that without API key, a placeholder is stored.""" + mocker.patch("SPARC.llm.config") + + analyzer = LLMAnalyzer(use_cache=False) + result = analyzer.analyze_patent_content("Test content", "TestCo") + + assert "[NO API]" in result + mock_database.store_message.assert_called_once()