forked from 0xWheatyz/SPARC
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 417b7ab31e |
+12
-7
@@ -2,14 +2,17 @@
|
|||||||
|
|
||||||
Uses APScheduler to periodically re-analyze tracked companies and
|
Uses APScheduler to periodically re-analyze tracked companies and
|
||||||
detect significant changes in patent counts.
|
detect significant changes in patent counts.
|
||||||
|
|
||||||
|
The scheduler reuses the application-level pooled DatabaseClient
|
||||||
|
(from ``SPARC.auth``) instead of creating its own connection, which
|
||||||
|
avoids exhausting the database connection pool under load.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from SPARC import config
|
|
||||||
from SPARC.analyzer import CompanyAnalyzer
|
from SPARC.analyzer import CompanyAnalyzer
|
||||||
from SPARC.database import DatabaseClient
|
from SPARC.auth import get_db_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -21,10 +24,13 @@ CHANGE_THRESHOLD_PERCENT = int(os.getenv("CHANGE_THRESHOLD_PERCENT", "20"))
|
|||||||
|
|
||||||
|
|
||||||
def run_scheduled_analysis() -> None:
|
def run_scheduled_analysis() -> None:
|
||||||
"""Re-analyze all tracked companies and check for significant changes."""
|
"""Re-analyze all tracked companies and check for significant changes.
|
||||||
db = DatabaseClient(config.database_url)
|
|
||||||
db.connect()
|
Uses the shared pooled DatabaseClient from ``SPARC.auth.get_db_client()``
|
||||||
db.initialize_schema()
|
rather than creating a disposable connection, so the scheduler participates
|
||||||
|
in the same connection pool as the rest of the application.
|
||||||
|
"""
|
||||||
|
db = get_db_client()
|
||||||
|
|
||||||
tracked = db.list_tracked_companies()
|
tracked = db.list_tracked_companies()
|
||||||
if not tracked:
|
if not tracked:
|
||||||
@@ -74,7 +80,6 @@ def run_scheduled_analysis() -> None:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error analyzing tracked company %s: %s", name, e)
|
logger.error("Error analyzing tracked company %s: %s", name, e)
|
||||||
|
|
||||||
db.close()
|
|
||||||
logger.info("Scheduled analysis complete")
|
logger.info("Scheduled analysis complete")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,224 +0,0 @@
|
|||||||
"""Tests for export endpoints: CSV and PDF export of analysis results.
|
|
||||||
|
|
||||||
Covers issue #1655:
|
|
||||||
- GET /export/{company_name} (CSV export)
|
|
||||||
- GET /export/{company_name}/pdf (PDF export)
|
|
||||||
|
|
||||||
All tests mock the database layer and use JWT auth fixtures from test_auth patterns.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
from SPARC.api import app
|
|
||||||
from SPARC.auth import create_access_token
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client():
|
|
||||||
"""Create test client."""
|
|
||||||
return TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def mock_db():
|
|
||||||
"""Mock the database client used by export and auth endpoints."""
|
|
||||||
db = MagicMock()
|
|
||||||
|
|
||||||
# Default: user exists for auth
|
|
||||||
db.get_user_by_id.return_value = {
|
|
||||||
"id": 1,
|
|
||||||
"email": "user@test.com",
|
|
||||||
"role": "user",
|
|
||||||
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Mock get_conn for export queries
|
|
||||||
mock_cursor = MagicMock()
|
|
||||||
mock_conn = MagicMock()
|
|
||||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
|
||||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
|
||||||
db.get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn)
|
|
||||||
db.get_conn.return_value.__exit__ = MagicMock(return_value=False)
|
|
||||||
db._mock_cursor = mock_cursor
|
|
||||||
|
|
||||||
with patch("SPARC.api.get_db_client", return_value=db), \
|
|
||||||
patch("SPARC.auth.get_db_client", return_value=db):
|
|
||||||
yield db
|
|
||||||
|
|
||||||
|
|
||||||
def _auth_header():
|
|
||||||
"""Create an Authorization header with a valid access token."""
|
|
||||||
token = create_access_token(1, "user@test.com", "user")
|
|
||||||
return {"Authorization": f"Bearer {token}"}
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_rows():
|
|
||||||
"""Return sample llm_messages rows as tuples (matching cursor.fetchall format)."""
|
|
||||||
return [
|
|
||||||
(
|
|
||||||
"NVIDIA",
|
|
||||||
"company_analysis",
|
|
||||||
"anthropic/claude-3.5-sonnet",
|
|
||||||
"Strong AI patent portfolio with focus on GPU architectures.",
|
|
||||||
datetime(2025, 6, 15, 10, 30, 0),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"NVIDIA",
|
|
||||||
"patent_analysis",
|
|
||||||
"openai/gpt-4o",
|
|
||||||
"Patent US-12345678-B2 covers novel tensor core design.",
|
|
||||||
datetime(2025, 6, 14, 9, 0, 0),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TestCSVExport:
|
|
||||||
"""GET /export/{company_name} -- CSV export."""
|
|
||||||
|
|
||||||
def test_csv_export_success(self, client, mock_db):
|
|
||||||
"""Valid company with results returns a CSV file."""
|
|
||||||
mock_db._mock_cursor.fetchall.return_value = _sample_rows()
|
|
||||||
|
|
||||||
response = client.get("/export/NVIDIA", headers=_auth_header())
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.headers["content-type"].startswith("text/csv")
|
|
||||||
assert "attachment" in response.headers.get("content-disposition", "")
|
|
||||||
assert "sparc_nvidia_export.csv" in response.headers["content-disposition"]
|
|
||||||
|
|
||||||
# Verify CSV content (CSV uses \r\n line endings)
|
|
||||||
lines = response.text.strip().split("\n")
|
|
||||||
assert len(lines) == 3 # header + 2 data rows
|
|
||||||
assert lines[0].strip() == "company_name,analysis_type,model,analysis,timestamp"
|
|
||||||
assert "NVIDIA" in lines[1]
|
|
||||||
assert "company_analysis" in lines[1]
|
|
||||||
|
|
||||||
def test_csv_export_no_results_returns_404(self, client, mock_db):
|
|
||||||
"""Unknown company returns 404."""
|
|
||||||
mock_db._mock_cursor.fetchall.return_value = []
|
|
||||||
|
|
||||||
response = client.get("/export/nonexistent", headers=_auth_header())
|
|
||||||
|
|
||||||
assert response.status_code == 404
|
|
||||||
assert "No analysis results found" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_csv_export_unauthenticated_returns_401(self, client):
|
|
||||||
"""Request without token returns 401."""
|
|
||||||
response = client.get("/export/NVIDIA")
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
def test_csv_export_invalid_token_returns_401(self, client):
|
|
||||||
"""Request with invalid token returns 401."""
|
|
||||||
response = client.get(
|
|
||||||
"/export/NVIDIA",
|
|
||||||
headers={"Authorization": "Bearer invalid.token.here"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
def test_csv_export_filename_sanitization(self, client, mock_db):
|
|
||||||
"""Company names with spaces get sanitized in the filename."""
|
|
||||||
mock_db._mock_cursor.fetchall.return_value = [
|
|
||||||
(
|
|
||||||
"Tesla Motors",
|
|
||||||
"company_analysis",
|
|
||||||
"anthropic/claude-3.5-sonnet",
|
|
||||||
"EV patent portfolio analysis.",
|
|
||||||
datetime(2025, 6, 15, 10, 0, 0),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
response = client.get("/export/Tesla Motors", headers=_auth_header())
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert "tesla_motors" in response.headers["content-disposition"]
|
|
||||||
|
|
||||||
def test_csv_export_single_row(self, client, mock_db):
|
|
||||||
"""Single analysis result produces valid CSV with one data row."""
|
|
||||||
mock_db._mock_cursor.fetchall.return_value = [_sample_rows()[0]]
|
|
||||||
|
|
||||||
response = client.get("/export/NVIDIA", headers=_auth_header())
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
lines = response.text.strip().split("\n")
|
|
||||||
assert len(lines) == 2 # header + 1 data row
|
|
||||||
|
|
||||||
|
|
||||||
class TestPDFExport:
|
|
||||||
"""GET /export/{company_name}/pdf -- PDF report export."""
|
|
||||||
|
|
||||||
def test_pdf_export_success(self, client, mock_db):
|
|
||||||
"""Valid company with results returns a PDF file."""
|
|
||||||
mock_db._mock_cursor.fetchall.return_value = _sample_rows()
|
|
||||||
|
|
||||||
response = client.get("/export/NVIDIA/pdf", headers=_auth_header())
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.headers["content-type"] == "application/pdf"
|
|
||||||
assert "attachment" in response.headers.get("content-disposition", "")
|
|
||||||
# PDF files start with %PDF
|
|
||||||
assert response.content[:4] == b"%PDF"
|
|
||||||
|
|
||||||
def test_pdf_export_no_results_returns_404(self, client, mock_db):
|
|
||||||
"""Unknown company returns 404."""
|
|
||||||
mock_db._mock_cursor.fetchall.return_value = []
|
|
||||||
|
|
||||||
response = client.get("/export/nonexistent/pdf", headers=_auth_header())
|
|
||||||
|
|
||||||
assert response.status_code == 404
|
|
||||||
assert "No analysis results found" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_pdf_export_unauthenticated_returns_401(self, client):
|
|
||||||
"""Request without token returns 401."""
|
|
||||||
response = client.get("/export/NVIDIA/pdf")
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
def test_pdf_export_invalid_token_returns_401(self, client):
|
|
||||||
"""Request with invalid token returns 401."""
|
|
||||||
response = client.get(
|
|
||||||
"/export/NVIDIA/pdf",
|
|
||||||
headers={"Authorization": "Bearer invalid.token.here"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
def test_pdf_export_filename_contains_date(self, client, mock_db):
|
|
||||||
"""PDF filename includes the analysis date."""
|
|
||||||
mock_db._mock_cursor.fetchall.return_value = _sample_rows()
|
|
||||||
|
|
||||||
response = client.get("/export/NVIDIA/pdf", headers=_auth_header())
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
disposition = response.headers["content-disposition"]
|
|
||||||
assert "nvidia-analysis-" in disposition
|
|
||||||
assert ".pdf" in disposition
|
|
||||||
|
|
||||||
def test_pdf_export_special_chars_in_response(self, client, mock_db):
|
|
||||||
"""Analysis text with XML-special chars (<, >, &) does not break PDF generation."""
|
|
||||||
rows = [
|
|
||||||
(
|
|
||||||
"TestCo",
|
|
||||||
"company_analysis",
|
|
||||||
"anthropic/claude-3.5-sonnet",
|
|
||||||
"Revenue > $1B & growth <20% for Q4. Test <html> escaping.",
|
|
||||||
datetime(2025, 6, 15, 10, 0, 0),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
mock_db._mock_cursor.fetchall.return_value = rows
|
|
||||||
|
|
||||||
response = client.get("/export/TestCo/pdf", headers=_auth_header())
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.content[:4] == b"%PDF"
|
|
||||||
|
|
||||||
def test_pdf_export_multiple_analyses(self, client, mock_db):
|
|
||||||
"""Multiple analysis records produce a valid PDF with content."""
|
|
||||||
mock_db._mock_cursor.fetchall.return_value = _sample_rows()
|
|
||||||
|
|
||||||
response = client.get("/export/NVIDIA/pdf", headers=_auth_header())
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
# PDF should have reasonable size (more than just headers)
|
|
||||||
assert len(response.content) > 500
|
|
||||||
Reference in New Issue
Block a user