Compare commits

..

1 Commits

Author SHA1 Message Date
agent-company c4341ca8dc feat: configurable LLM model, SERP cache TTL, structured logging, fix patent_id type
- Make LLM model configurable via MODEL env var, default anthropic/claude-3.5-sonnet (#12)
- Expose SERP cache TTL as SERP_CACHE_TTL_HOURS env var, default 24 hours (#13)
- Fix Patent.patent_id type annotation from int to str in types.py (#14)
- Replace all print() calls with structured logging in analyzer.py and llm.py (#11)
- Add LOG_LEVEL config with basicConfig setup in config.py
- Add model and serp_cache_ttl_hours to config.py

Closes leeworks-agents/SPARC#11
Closes leeworks-agents/SPARC#12
Closes leeworks-agents/SPARC#13
Closes leeworks-agents/SPARC#14

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:12:00 +00:00
7 changed files with 53 additions and 153 deletions
+21 -17
View File
@@ -5,10 +5,13 @@ to provide company performance estimation based on patent portfolios.
""" """
import hashlib import hashlib
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable from typing import Callable
from SPARC import config from SPARC import config
logger = logging.getLogger(__name__)
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
from SPARC.serp_api import SERP from SPARC.serp_api import SERP
from SPARC.llm import LLMAnalyzer from SPARC.llm import LLMAnalyzer
@@ -52,13 +55,13 @@ class CompanyAnalyzer:
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest() query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
cached_ids = self.db.get_cached_serp_query(query_hash) cached_ids = self.db.get_cached_serp_query(query_hash)
if cached_ids is not None: if cached_ids is not None:
print(f"Using cached SERP results for {company_name} ({len(cached_ids)} patents)") logger.info("Using cached SERP results for %s (%d patents)", company_name, len(cached_ids))
patents = Patents(patents=[ patents = Patents(patents=[
Patent(patent_id=pid, pdf_link="") Patent(patent_id=pid, pdf_link="")
for pid in cached_ids for pid in cached_ids
]) ])
else: else:
print(f"Retrieving patents for {company_name}...") logger.info("Retrieving patents for %s...", company_name)
patents = SERP.query(company_name) patents = SERP.query(company_name)
# Cache the SERP results # Cache the SERP results
if patents.patents: if patents.patents:
@@ -66,12 +69,13 @@ class CompanyAnalyzer:
company_name=company_name, company_name=company_name,
query_hash=query_hash, query_hash=query_hash,
patent_ids=[p.patent_id for p in patents.patents], patent_ids=[p.patent_id for p in patents.patents],
ttl_hours=config.serp_cache_ttl_hours,
) )
if not patents.patents: if not patents.patents:
return f"No patents found for {company_name}" return f"No patents found for {company_name}"
print(f"Found {len(patents.patents)} patents. Processing...") logger.info("Found %d patents. Processing...", len(patents.patents))
# Download, parse, and minimize patents in parallel # Download, parse, and minimize patents in parallel
processed_patents = [] processed_patents = []
@@ -87,12 +91,12 @@ class CompanyAnalyzer:
if result: if result:
processed_patents.append(result) processed_patents.append(result)
except Exception as e: except Exception as e:
print(f"Warning: Failed to process {patent.patent_id}: {e}") logger.warning("Failed to process %s: %s", patent.patent_id, e)
if not processed_patents: if not processed_patents:
return f"Failed to process any patents for {company_name}" return f"Failed to process any patents for {company_name}"
print(f"Analyzing portfolio with LLM...") logger.info("Analyzing portfolio with LLM...")
# Analyze the full portfolio with LLM # Analyze the full portfolio with LLM
analysis = self.llm_analyzer.analyze_patent_portfolio( analysis = self.llm_analyzer.analyze_patent_portfolio(
@@ -115,7 +119,7 @@ class CompanyAnalyzer:
""" """
# Note: This simplified version assumes the patent PDF is already downloaded # Note: This simplified version assumes the patent PDF is already downloaded
# A more complete implementation would support direct patent ID lookup # A more complete implementation would support direct patent ID lookup
print(f"Analyzing patent {patent_id} for {company_name}...") logger.info("Analyzing patent %s for %s...", patent_id, company_name)
patent_path = f"patents/{patent_id}.pdf" patent_path = f"patents/{patent_id}.pdf"
@@ -169,7 +173,7 @@ class CompanyAnalyzer:
return {"patent_id": patent.patent_id, "content": minimized_content} return {"patent_id": patent.patent_id, "content": minimized_content}
except Exception as e: except Exception as e:
print(f"Warning: Failed to process {patent.patent_id}: {e}") logger.warning("Failed to process %s: %s", patent.patent_id, e)
return None return None
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult: def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult:
@@ -240,7 +244,7 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = [] results: list[CompanyAnalysisResult] = []
total = len(companies) total = len(companies)
print(f"Starting batch analysis of {total} companies...") logger.info("Starting batch analysis of %d companies...", total)
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_company = { future_to_company = {
@@ -257,8 +261,8 @@ class CompanyAnalyzer:
result = future.result() result = future.result()
results.append(result) results.append(result)
status = "" if result.success else "" status = "OK" if result.success else "FAIL"
print(f"[{completed}/{total}] {status} {company}") logger.info("[%d/%d] %s %s", completed, total, status, company)
if progress_callback: if progress_callback:
progress_callback(company, completed, total) progress_callback(company, completed, total)
@@ -273,12 +277,12 @@ class CompanyAnalyzer:
error=str(e), error=str(e),
) )
) )
print(f"[{completed}/{total}] ✗ {company}: {e}") logger.error("[%d/%d] FAIL %s: %s", completed, total, company, e)
successful = sum(1 for r in results if r.success) successful = sum(1 for r in results if r.success)
failed = total - successful failed = total - successful
print(f"\nBatch complete: {successful} succeeded, {failed} failed") logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
return BatchAnalysisResult( return BatchAnalysisResult(
results=results, results=results,
@@ -304,20 +308,20 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = [] results: list[CompanyAnalysisResult] = []
total = len(companies) total = len(companies)
print(f"Starting sequential analysis of {total} companies...") logger.info("Starting sequential analysis of %d companies...", total)
for idx, company in enumerate(companies, 1): for idx, company in enumerate(companies, 1):
print(f"\n[{idx}/{total}] Analyzing {company}...") logger.info("[%d/%d] Analyzing %s...", idx, total, company)
result = self._analyze_company_safe(company) result = self._analyze_company_safe(company)
results.append(result) results.append(result)
status = "" if result.success else "" status = "OK" if result.success else "FAIL"
print(f"[{idx}/{total}] {status} {company}") logger.info("[%d/%d] %s %s", idx, total, status, company)
successful = sum(1 for r in results if r.success) successful = sum(1 for r in results if r.success)
failed = total - successful failed = total - successful
print(f"\nBatch complete: {successful} succeeded, {failed} failed") logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
return BatchAnalysisResult( return BatchAnalysisResult(
results=results, results=results,
+6 -28
View File
@@ -7,13 +7,9 @@ from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from typing import Annotated, List from typing import Annotated, List
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from SPARC import config from SPARC import config
from SPARC.analyzer import CompanyAnalyzer from SPARC.analyzer import CompanyAnalyzer
@@ -168,22 +164,6 @@ app = FastAPI(
root_path=config.root_path, root_path=config.root_path,
) )
# Rate limiter (in-memory storage, suitable for single-instance deployments)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return 429 with Retry-After header when rate limit is exceeded."""
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded. Please try again later."},
headers={"Retry-After": str(retry_after)},
)
# Add CORS middleware for React frontend # Add CORS middleware for React frontend
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@@ -198,8 +178,7 @@ app.add_middleware(
@app.post("/auth/register", response_model=UserResponse, tags=["Auth"]) @app.post("/auth/register", response_model=UserResponse, tags=["Auth"])
@limiter.limit("5/minute") async def register(request: RegisterRequest):
async def register(request: Request, body: RegisterRequest):
"""Register a new user. """Register a new user.
The first registered user automatically becomes an admin. The first registered user automatically becomes an admin.
@@ -211,8 +190,8 @@ async def register(request: Request, body: RegisterRequest):
role = "admin" if user_count == 0 else "user" role = "admin" if user_count == 0 else "user"
user = db.create_user( user = db.create_user(
email=body.email, email=request.email,
password=body.password, password=request.password,
role=role, role=role,
) )
@@ -231,12 +210,11 @@ async def register(request: Request, body: RegisterRequest):
@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"]) @app.post("/auth/login", response_model=TokenResponse, tags=["Auth"])
@limiter.limit("10/minute") async def login(request: LoginRequest):
async def login(request: Request, body: LoginRequest):
"""Authenticate user and return JWT tokens.""" """Authenticate user and return JWT tokens."""
db = get_db_client() db = get_db_client()
user = db.authenticate_user(body.email, body.password) user = db.authenticate_user(request.email, request.password)
if not user: if not user:
raise HTTPException( raise HTTPException(
+16 -1
View File
@@ -2,11 +2,20 @@
Loads environment variables from .env file for API keys and other secrets. Loads environment variables from .env file for API keys and other secrets.
""" """
from dotenv import load_dotenv import logging
import os import os
from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Logging configuration
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=getattr(logging, log_level, logging.INFO),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
# SerpAPI key for patent search # SerpAPI key for patent search
api_key = os.getenv("API_KEY") api_key = os.getenv("API_KEY")
@@ -30,6 +39,12 @@ use_database = os.getenv("USE_DATABASE", "false").lower() in ("true", "1", "yes"
patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90")) patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90"))
patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5")) patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5"))
# LLM model to use via OpenRouter (e.g. "anthropic/claude-3.5-sonnet", "openai/gpt-4o")
model = os.getenv("MODEL", "anthropic/claude-3.5-sonnet")
# SERP cache TTL in hours (how long cached search results are considered fresh)
serp_cache_ttl_hours = int(os.getenv("SERP_CACHE_TTL_HOURS", "24"))
# Root path for running behind a reverse proxy (e.g., "/api" when served at /api/) # Root path for running behind a reverse proxy (e.g., "/api" when served at /api/)
# This ensures OpenAPI docs work correctly when accessed via the proxy # This ensures OpenAPI docs work correctly when accessed via the proxy
root_path = os.getenv("ROOT_PATH", "") root_path = os.getenv("ROOT_PATH", "")
+9 -8
View File
@@ -1,9 +1,14 @@
"""LLM integration for patent analysis using OpenRouter.""" """LLM integration for patent analysis using OpenRouter."""
import logging
from typing import Dict
from openai import OpenAI from openai import OpenAI
from SPARC import config from SPARC import config
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
from typing import Dict
logger = logging.getLogger(__name__)
class LLMAnalyzer: class LLMAnalyzer:
@@ -20,7 +25,7 @@ class LLMAnalyzer:
""" """
self.test_mode = test_mode self.test_mode = test_mode
self.use_cache = use_cache if use_cache is not None else config.use_cache self.use_cache = use_cache if use_cache is not None else config.use_cache
self.model = "anthropic/claude-3.5-sonnet" self.model = config.model
# Always initialize database client for storage and caching # Always initialize database client for storage and caching
self.db_client = DatabaseClient(config.database_url) self.db_client = DatabaseClient(config.database_url)
@@ -59,11 +64,7 @@ Patent Content:
Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage.""" Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage."""
if self.test_mode: if self.test_mode:
print("=" * 80) logger.debug("TEST MODE - Prompt that would be sent to LLM:\n%s", prompt)
print("TEST MODE - Prompt that would be sent to LLM:")
print("=" * 80)
print(prompt)
print("=" * 80)
return "[TEST MODE - No API call made]" return "[TEST MODE - No API call made]"
# Check cache first # Check cache first
@@ -165,7 +166,7 @@ Patent Portfolio:
Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook.""" Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook."""
if self.test_mode: if self.test_mode:
print(prompt) logger.debug("TEST MODE - Portfolio prompt:\n%s", prompt)
return "[TEST MODE]" return "[TEST MODE]"
metadata = { metadata = {
+1 -1
View File
@@ -4,7 +4,7 @@ from datetime import datetime
@dataclass @dataclass
class Patent: class Patent:
patent_id: int patent_id: str
pdf_link: str pdf_link: str
pdf_path: str | None = None pdf_path: str | None = None
summary: dict | None = None summary: dict | None = None
-1
View File
@@ -14,4 +14,3 @@ numpy
pandas pandas
bcrypt bcrypt
PyJWT PyJWT
slowapi
-97
View File
@@ -1,97 +0,0 @@
"""Tests for rate limiting on auth endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from fastapi.testclient import TestClient
from SPARC.api import app
@pytest.fixture
def client():
"""Create test client with rate limiter enabled."""
return TestClient(app)
@pytest.fixture(autouse=True)
def reset_limiter():
"""Reset rate limiter storage between tests."""
from SPARC.api import limiter
limiter.reset()
yield
class TestRateLimiting:
"""Test rate limiting on login and register endpoints."""
@patch("SPARC.api.get_db_client")
def test_login_allows_requests_under_limit(self, mock_db_client, client):
"""Login endpoint allows requests under the rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Should allow at least a few requests
for _ in range(5):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
# 401 is expected (invalid credentials), not 429
assert response.status_code == 401
@patch("SPARC.api.get_db_client")
def test_login_rate_limited_after_threshold(self, mock_db_client, client):
"""Login endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Send more than the limit (10/minute)
statuses = []
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_register_rate_limited_after_threshold(self, mock_db_client, client):
"""Register endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = None # triggers 400 (email exists)
mock_db_client.return_value = mock_db
# Send more than the limit (5/minute)
statuses = []
for _ in range(10):
response = client.post(
"/auth/register",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_rate_limit_returns_retry_after_header(self, mock_db_client, client):
"""Rate limited responses include a Retry-After header."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Exhaust the limit
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
if response.status_code == 429:
assert "Retry-After" in response.headers
break