Compare commits

..

1 Commits

Author SHA1 Message Date
agent-company 47cddcbeaf feat(security): add JWT startup guard, configurable CORS, and externalize DB credentials
- Add check_jwt_secret() that refuses default JWT secret when APP_ENV != development
- Make CORS origins configurable via CORS_ORIGINS env var (comma-separated)
- Replace hardcoded postgres credentials in docker-compose.yml with env var references
- Add APP_ENV and cors_origins to config.py
- Update .env.example with all required variables and documentation
- Add tests for JWT startup guard and CORS configuration

Closes leeworks-agents/SPARC#4
Closes leeworks-agents/SPARC#5
Closes leeworks-agents/SPARC#6

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:06:31 +00:00
9 changed files with 212 additions and 64 deletions
+30 -9
View File
@@ -1,21 +1,42 @@
# SPARC Configuration # SPARC Configuration
# ---- Application Environment ----
# Set to "production" or "staging" in deployed environments.
# The API will refuse to start with the default JWT secret unless APP_ENV=development.
APP_ENV=development
# ---- API Keys ----
# SerpAPI key for patent search # SerpAPI key for patent search
API_KEY=your_serpapi_key_here API_KEY=your_serpapi_key_here
# OpenRouter API key for LLM analysis # OpenRouter API key for LLM analysis
OPENROUTER_API_KEY=your_openrouter_key_here OPENROUTER_API_KEY=your_openrouter_key_here
# Database configuration # ---- Database ----
# All messages are stored in the database for persistence and caching
DATABASE_URL=postgresql://postgres:postgres@localhost:5432/sparc
# Cache configuration # PostgreSQL credentials (used by docker-compose)
# When USE_CACHE=true: check database for cached responses before making API calls POSTGRES_USER=postgres
# When USE_CACHE=false: always make fresh API calls (still stores results in database) POSTGRES_PASSWORD=change-me-to-a-secure-password
# Default: true POSTGRES_DB=sparc
USE_CACHE=true
# JWT Secret for authentication # Full database URL (must match the credentials above)
DATABASE_URL=postgresql://postgres:change-me-to-a-secure-password@localhost:5432/sparc
# ---- Authentication ----
# JWT Secret for signing tokens
# IMPORTANT: Change this to a secure random string in production # IMPORTANT: Change this to a secure random string in production
JWT_SECRET=your-secure-jwt-secret-change-in-production JWT_SECRET=your-secure-jwt-secret-change-in-production
# ---- CORS ----
# Comma-separated list of allowed origins for CORS
# Defaults to http://localhost:3000,http://localhost:5173 when unset
# CORS_ORIGINS=https://sparc.example.com,https://app.example.com
# ---- Cache ----
# When USE_CACHE=true: check database for cached responses before making API calls
# When USE_CACHE=false: always make fresh API calls (still stores results in database)
USE_CACHE=true
+17 -21
View File
@@ -5,13 +5,10 @@ to provide company performance estimation based on patent portfolios.
""" """
import hashlib import hashlib
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable from typing import Callable
from SPARC import config from SPARC import config
logger = logging.getLogger(__name__)
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
from SPARC.serp_api import SERP from SPARC.serp_api import SERP
from SPARC.llm import LLMAnalyzer from SPARC.llm import LLMAnalyzer
@@ -55,13 +52,13 @@ class CompanyAnalyzer:
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest() query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
cached_ids = self.db.get_cached_serp_query(query_hash) cached_ids = self.db.get_cached_serp_query(query_hash)
if cached_ids is not None: if cached_ids is not None:
logger.info("Using cached SERP results for %s (%d patents)", company_name, len(cached_ids)) print(f"Using cached SERP results for {company_name} ({len(cached_ids)} patents)")
patents = Patents(patents=[ patents = Patents(patents=[
Patent(patent_id=pid, pdf_link="") Patent(patent_id=pid, pdf_link="")
for pid in cached_ids for pid in cached_ids
]) ])
else: else:
logger.info("Retrieving patents for %s...", company_name) print(f"Retrieving patents for {company_name}...")
patents = SERP.query(company_name) patents = SERP.query(company_name)
# Cache the SERP results # Cache the SERP results
if patents.patents: if patents.patents:
@@ -69,13 +66,12 @@ class CompanyAnalyzer:
company_name=company_name, company_name=company_name,
query_hash=query_hash, query_hash=query_hash,
patent_ids=[p.patent_id for p in patents.patents], patent_ids=[p.patent_id for p in patents.patents],
ttl_hours=config.serp_cache_ttl_hours,
) )
if not patents.patents: if not patents.patents:
return f"No patents found for {company_name}" return f"No patents found for {company_name}"
logger.info("Found %d patents. Processing...", len(patents.patents)) print(f"Found {len(patents.patents)} patents. Processing...")
# Download, parse, and minimize patents in parallel # Download, parse, and minimize patents in parallel
processed_patents = [] processed_patents = []
@@ -91,12 +87,12 @@ class CompanyAnalyzer:
if result: if result:
processed_patents.append(result) processed_patents.append(result)
except Exception as e: except Exception as e:
logger.warning("Failed to process %s: %s", patent.patent_id, e) print(f"Warning: Failed to process {patent.patent_id}: {e}")
if not processed_patents: if not processed_patents:
return f"Failed to process any patents for {company_name}" return f"Failed to process any patents for {company_name}"
logger.info("Analyzing portfolio with LLM...") print(f"Analyzing portfolio with LLM...")
# Analyze the full portfolio with LLM # Analyze the full portfolio with LLM
analysis = self.llm_analyzer.analyze_patent_portfolio( analysis = self.llm_analyzer.analyze_patent_portfolio(
@@ -119,7 +115,7 @@ class CompanyAnalyzer:
""" """
# Note: This simplified version assumes the patent PDF is already downloaded # Note: This simplified version assumes the patent PDF is already downloaded
# A more complete implementation would support direct patent ID lookup # A more complete implementation would support direct patent ID lookup
logger.info("Analyzing patent %s for %s...", patent_id, company_name) print(f"Analyzing patent {patent_id} for {company_name}...")
patent_path = f"patents/{patent_id}.pdf" patent_path = f"patents/{patent_id}.pdf"
@@ -173,7 +169,7 @@ class CompanyAnalyzer:
return {"patent_id": patent.patent_id, "content": minimized_content} return {"patent_id": patent.patent_id, "content": minimized_content}
except Exception as e: except Exception as e:
logger.warning("Failed to process %s: %s", patent.patent_id, e) print(f"Warning: Failed to process {patent.patent_id}: {e}")
return None return None
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult: def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult:
@@ -244,7 +240,7 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = [] results: list[CompanyAnalysisResult] = []
total = len(companies) total = len(companies)
logger.info("Starting batch analysis of %d companies...", total) print(f"Starting batch analysis of {total} companies...")
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_company = { future_to_company = {
@@ -261,8 +257,8 @@ class CompanyAnalyzer:
result = future.result() result = future.result()
results.append(result) results.append(result)
status = "OK" if result.success else "FAIL" status = "" if result.success else ""
logger.info("[%d/%d] %s %s", completed, total, status, company) print(f"[{completed}/{total}] {status} {company}")
if progress_callback: if progress_callback:
progress_callback(company, completed, total) progress_callback(company, completed, total)
@@ -277,12 +273,12 @@ class CompanyAnalyzer:
error=str(e), error=str(e),
) )
) )
logger.error("[%d/%d] FAIL %s: %s", completed, total, company, e) print(f"[{completed}/{total}] ✗ {company}: {e}")
successful = sum(1 for r in results if r.success) successful = sum(1 for r in results if r.success)
failed = total - successful failed = total - successful
logger.info("Batch complete: %d succeeded, %d failed", successful, failed) print(f"\nBatch complete: {successful} succeeded, {failed} failed")
return BatchAnalysisResult( return BatchAnalysisResult(
results=results, results=results,
@@ -308,20 +304,20 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = [] results: list[CompanyAnalysisResult] = []
total = len(companies) total = len(companies)
logger.info("Starting sequential analysis of %d companies...", total) print(f"Starting sequential analysis of {total} companies...")
for idx, company in enumerate(companies, 1): for idx, company in enumerate(companies, 1):
logger.info("[%d/%d] Analyzing %s...", idx, total, company) print(f"\n[{idx}/{total}] Analyzing {company}...")
result = self._analyze_company_safe(company) result = self._analyze_company_safe(company)
results.append(result) results.append(result)
status = "OK" if result.success else "FAIL" status = "" if result.success else ""
logger.info("[%d/%d] %s %s", idx, total, status, company) print(f"[{idx}/{total}] {status} {company}")
successful = sum(1 for r in results if r.success) successful = sum(1 for r in results if r.success)
failed = total - successful failed = total - successful
logger.info("Batch complete: %d succeeded, %d failed", successful, failed) print(f"\nBatch complete: {successful} succeeded, {failed} failed")
return BatchAnalysisResult( return BatchAnalysisResult(
results=results, results=results,
+3 -1
View File
@@ -16,6 +16,7 @@ from SPARC.analyzer import CompanyAnalyzer
from SPARC.auth import ( from SPARC.auth import (
TokenResponse, TokenResponse,
UserResponse, UserResponse,
check_jwt_secret,
create_tokens, create_tokens,
decode_token, decode_token,
get_current_admin, get_current_admin,
@@ -150,6 +151,7 @@ _analyzer: CompanyAnalyzer | None = None
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Initialize resources on startup.""" """Initialize resources on startup."""
global _analyzer global _analyzer
check_jwt_secret()
_analyzer = CompanyAnalyzer() _analyzer = CompanyAnalyzer()
yield yield
# Cleanup if needed # Cleanup if needed
@@ -167,7 +169,7 @@ app = FastAPI(
# Add CORS middleware for React frontend # Add CORS middleware for React frontend
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:5173"], allow_origins=config.cors_origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
+15 -1
View File
@@ -13,11 +13,25 @@ from SPARC import config
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
# JWT Configuration # JWT Configuration
JWT_SECRET = os.getenv("JWT_SECRET", "sparc-secret-key-change-in-production") _DEFAULT_JWT_SECRET = "sparc-secret-key-change-in-production"
JWT_SECRET = os.getenv("JWT_SECRET", _DEFAULT_JWT_SECRET)
JWT_ALGORITHM = "HS256" JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7 REFRESH_TOKEN_EXPIRE_DAYS = 7
def check_jwt_secret() -> None:
"""Refuse to start with the default JWT secret in non-development environments.
Raises:
RuntimeError: If JWT_SECRET is the default value and APP_ENV is not 'development'.
"""
if JWT_SECRET == _DEFAULT_JWT_SECRET and config.app_env != "development":
raise RuntimeError(
f"FATAL: JWT_SECRET is set to the default value and APP_ENV={config.app_env!r}. "
"Set a secure JWT_SECRET environment variable before running in non-development environments."
)
security = HTTPBearer() security = HTTPBearer()
+14 -16
View File
@@ -2,20 +2,11 @@
Loads environment variables from .env file for API keys and other secrets. Loads environment variables from .env file for API keys and other secrets.
""" """
import logging from dotenv import load_dotenv
import os import os
from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Logging configuration
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=getattr(logging, log_level, logging.INFO),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
# SerpAPI key for patent search # SerpAPI key for patent search
api_key = os.getenv("API_KEY") api_key = os.getenv("API_KEY")
@@ -39,12 +30,19 @@ use_database = os.getenv("USE_DATABASE", "false").lower() in ("true", "1", "yes"
patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90")) patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90"))
patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5")) patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5"))
# LLM model to use via OpenRouter (e.g. "anthropic/claude-3.5-sonnet", "openai/gpt-4o")
model = os.getenv("MODEL", "anthropic/claude-3.5-sonnet")
# SERP cache TTL in hours (how long cached search results are considered fresh)
serp_cache_ttl_hours = int(os.getenv("SERP_CACHE_TTL_HOURS", "24"))
# Root path for running behind a reverse proxy (e.g., "/api" when served at /api/) # Root path for running behind a reverse proxy (e.g., "/api" when served at /api/)
# This ensures OpenAPI docs work correctly when accessed via the proxy # This ensures OpenAPI docs work correctly when accessed via the proxy
root_path = os.getenv("ROOT_PATH", "") root_path = os.getenv("ROOT_PATH", "")
# Application environment: "development", "staging", or "production"
# Used for safety checks (e.g., refusing default JWT secret in production)
app_env = os.getenv("APP_ENV", "development")
# CORS allowed origins (comma-separated)
# Defaults to localhost dev origins when unset
_cors_origins_raw = os.getenv("CORS_ORIGINS", "")
cors_origins: list[str] = (
[o.strip() for o in _cors_origins_raw.split(",") if o.strip()]
if _cors_origins_raw
else ["http://localhost:3000", "http://localhost:5173"]
)
+8 -9
View File
@@ -1,14 +1,9 @@
"""LLM integration for patent analysis using OpenRouter.""" """LLM integration for patent analysis using OpenRouter."""
import logging
from typing import Dict
from openai import OpenAI from openai import OpenAI
from SPARC import config from SPARC import config
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
from typing import Dict
logger = logging.getLogger(__name__)
class LLMAnalyzer: class LLMAnalyzer:
@@ -25,7 +20,7 @@ class LLMAnalyzer:
""" """
self.test_mode = test_mode self.test_mode = test_mode
self.use_cache = use_cache if use_cache is not None else config.use_cache self.use_cache = use_cache if use_cache is not None else config.use_cache
self.model = config.model self.model = "anthropic/claude-3.5-sonnet"
# Always initialize database client for storage and caching # Always initialize database client for storage and caching
self.db_client = DatabaseClient(config.database_url) self.db_client = DatabaseClient(config.database_url)
@@ -64,7 +59,11 @@ Patent Content:
Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage.""" Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage."""
if self.test_mode: if self.test_mode:
logger.debug("TEST MODE - Prompt that would be sent to LLM:\n%s", prompt) print("=" * 80)
print("TEST MODE - Prompt that would be sent to LLM:")
print("=" * 80)
print(prompt)
print("=" * 80)
return "[TEST MODE - No API call made]" return "[TEST MODE - No API call made]"
# Check cache first # Check cache first
@@ -166,7 +165,7 @@ Patent Portfolio:
Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook.""" Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook."""
if self.test_mode: if self.test_mode:
logger.debug("TEST MODE - Portfolio prompt:\n%s", prompt) print(prompt)
return "[TEST MODE]" return "[TEST MODE]"
metadata = { metadata = {
+1 -1
View File
@@ -4,7 +4,7 @@ from datetime import datetime
@dataclass @dataclass
class Patent: class Patent:
patent_id: str patent_id: int
pdf_link: str pdf_link: str
pdf_path: str | None = None pdf_path: str | None = None
summary: dict | None = None summary: dict | None = None
+8 -6
View File
@@ -3,15 +3,15 @@ services:
image: postgres:16-alpine image: postgres:16-alpine
container_name: sparc-postgres container_name: sparc-postgres
environment: environment:
POSTGRES_USER: postgres POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_DB: sparc POSTGRES_DB: ${POSTGRES_DB}
ports: ports:
- "5432:5432" - "5432:5432"
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"] test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER}"]
interval: 5s interval: 5s
timeout: 5s timeout: 5s
retries: 5 retries: 5
@@ -22,7 +22,7 @@ services:
container_name: sparc-init-db container_name: sparc-init-db
command: python scripts/init_database.py command: python scripts/init_database.py
environment: environment:
DATABASE_URL: postgresql://postgres:postgres@postgres:5432/sparc DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/${POSTGRES_DB}
depends_on: depends_on:
postgres: postgres:
condition: service_healthy condition: service_healthy
@@ -35,9 +35,11 @@ services:
environment: environment:
API_KEY: ${API_KEY} API_KEY: ${API_KEY}
OPENROUTER_API_KEY: ${OPENROUTER_API_KEY} OPENROUTER_API_KEY: ${OPENROUTER_API_KEY}
DATABASE_URL: postgresql://postgres:postgres@postgres:5432/sparc DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/${POSTGRES_DB}
USE_CACHE: "true" USE_CACHE: "true"
JWT_SECRET: ${JWT_SECRET:-sparc-secret-key-change-in-production} JWT_SECRET: ${JWT_SECRET:-sparc-secret-key-change-in-production}
CORS_ORIGINS: ${CORS_ORIGINS:-}
APP_ENV: ${APP_ENV:-development}
ROOT_PATH: /api ROOT_PATH: /api
ports: ports:
- "8000:8000" - "8000:8000"
+116
View File
@@ -0,0 +1,116 @@
"""Tests for security hardening: JWT secret startup check, CORS config, credential handling."""
import os
from unittest.mock import patch
import pytest
class TestJWTSecretStartupCheck:
"""Test the startup guard that refuses default JWT secret in non-dev environments."""
def test_default_secret_in_production_raises(self):
"""Starting with default secret and APP_ENV=production must raise RuntimeError."""
with patch.dict(os.environ, {"APP_ENV": "production"}):
# Reload config to pick up the new APP_ENV
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
# Patch JWT_SECRET to the default
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
with pytest.raises(RuntimeError, match="FATAL.*JWT_SECRET"):
check_jwt_secret()
# Restore config
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
def test_default_secret_in_development_succeeds(self):
"""Starting with default secret and APP_ENV=development must not raise."""
with patch.dict(os.environ, {"APP_ENV": "development"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
# Should not raise
check_jwt_secret()
# Restore
importlib.reload(SPARC.config)
def test_custom_secret_in_production_succeeds(self):
"""Starting with a custom secret in production must not raise."""
with patch.dict(os.environ, {"APP_ENV": "production"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", "my-secure-random-secret-abc123"):
# Should not raise
check_jwt_secret()
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
def test_default_secret_unset_env_succeeds(self):
"""When APP_ENV is unset (defaults to development), default secret is allowed."""
with patch.dict(os.environ, {}, clear=False):
# Remove APP_ENV if present
env = os.environ.copy()
env.pop("APP_ENV", None)
with patch.dict(os.environ, env, clear=True):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
# Should not raise (defaults to development)
check_jwt_secret()
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
class TestCORSConfig:
"""Test that CORS origins are configurable via environment variable."""
def test_default_cors_origins(self):
"""When CORS_ORIGINS is unset, defaults to localhost origins."""
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == [
"http://localhost:3000",
"http://localhost:5173",
]
def test_custom_cors_origins(self):
"""Setting CORS_ORIGINS configures allowed origins."""
with patch.dict(os.environ, {"CORS_ORIGINS": "https://sparc.example.com,https://app.example.com"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == [
"https://sparc.example.com",
"https://app.example.com",
]
# Restore
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
importlib.reload(SPARC.config)
def test_single_cors_origin(self):
"""A single origin without comma works correctly."""
with patch.dict(os.environ, {"CORS_ORIGINS": "https://sparc.example.com"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == ["https://sparc.example.com"]
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
importlib.reload(SPARC.config)