forked from 0xWheatyz/SPARC
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b000146585 | |||
| 35d105b14e | |||
| 6fcf170d93 | |||
| 5a42e216ba | |||
| 24ab341d9b | |||
| 878fedfbb8 | |||
| ae9f257dcb | |||
| 3dac88ec90 | |||
| e2d750146c | |||
| 47cddcbeaf |
+30
-9
@@ -1,21 +1,42 @@
|
||||
# SPARC Configuration
|
||||
|
||||
# ---- Application Environment ----
|
||||
# Set to "production" or "staging" in deployed environments.
|
||||
# The API will refuse to start with the default JWT secret unless APP_ENV=development.
|
||||
APP_ENV=development
|
||||
|
||||
# ---- API Keys ----
|
||||
|
||||
# SerpAPI key for patent search
|
||||
API_KEY=your_serpapi_key_here
|
||||
|
||||
# OpenRouter API key for LLM analysis
|
||||
OPENROUTER_API_KEY=your_openrouter_key_here
|
||||
|
||||
# Database configuration
|
||||
# All messages are stored in the database for persistence and caching
|
||||
DATABASE_URL=postgresql://postgres:postgres@localhost:5432/sparc
|
||||
# ---- Database ----
|
||||
|
||||
# Cache configuration
|
||||
# When USE_CACHE=true: check database for cached responses before making API calls
|
||||
# When USE_CACHE=false: always make fresh API calls (still stores results in database)
|
||||
# Default: true
|
||||
USE_CACHE=true
|
||||
# PostgreSQL credentials (used by docker-compose)
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=change-me-to-a-secure-password
|
||||
POSTGRES_DB=sparc
|
||||
|
||||
# JWT Secret for authentication
|
||||
# Full database URL (must match the credentials above)
|
||||
DATABASE_URL=postgresql://postgres:change-me-to-a-secure-password@localhost:5432/sparc
|
||||
|
||||
# ---- Authentication ----
|
||||
|
||||
# JWT Secret for signing tokens
|
||||
# IMPORTANT: Change this to a secure random string in production
|
||||
JWT_SECRET=your-secure-jwt-secret-change-in-production
|
||||
|
||||
# ---- CORS ----
|
||||
|
||||
# Comma-separated list of allowed origins for CORS
|
||||
# Defaults to http://localhost:3000,http://localhost:5173 when unset
|
||||
# CORS_ORIGINS=https://sparc.example.com,https://app.example.com
|
||||
|
||||
# ---- Cache ----
|
||||
|
||||
# When USE_CACHE=true: check database for cached responses before making API calls
|
||||
# When USE_CACHE=false: always make fresh API calls (still stores results in database)
|
||||
USE_CACHE=true
|
||||
|
||||
@@ -54,6 +54,21 @@ docker-compose up -d
|
||||
# - API Docs: http://localhost:8000/docs
|
||||
```
|
||||
|
||||
#### Patent PDF Storage
|
||||
|
||||
The API stores downloaded patent PDFs in a `patents/` directory. In Docker,
|
||||
this is mounted as a bind mount (`./patents:/app/patents`) so that PDFs persist
|
||||
across container restarts.
|
||||
|
||||
If you deploy to a different environment, ensure the `patents/` directory is a
|
||||
persistent volume. Without it, PDFs will be re-downloaded on every analysis.
|
||||
|
||||
```yaml
|
||||
# docker-compose.yml excerpt
|
||||
volumes:
|
||||
- ./patents:/app/patents
|
||||
```
|
||||
|
||||
### NixOS
|
||||
|
||||
```bash
|
||||
|
||||
+40
-21
@@ -5,10 +5,13 @@ to provide company performance estimation based on patent portfolios.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Callable
|
||||
|
||||
from SPARC import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from SPARC.database import DatabaseClient
|
||||
from SPARC.serp_api import SERP
|
||||
from SPARC.llm import LLMAnalyzer
|
||||
@@ -52,13 +55,13 @@ class CompanyAnalyzer:
|
||||
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
|
||||
cached_ids = self.db.get_cached_serp_query(query_hash)
|
||||
if cached_ids is not None:
|
||||
print(f"Using cached SERP results for {company_name} ({len(cached_ids)} patents)")
|
||||
logger.info("Using cached SERP results for %s (%d patents)", company_name, len(cached_ids))
|
||||
patents = Patents(patents=[
|
||||
Patent(patent_id=pid, pdf_link="")
|
||||
for pid in cached_ids
|
||||
])
|
||||
else:
|
||||
print(f"Retrieving patents for {company_name}...")
|
||||
logger.info("Retrieving patents for %s...", company_name)
|
||||
patents = SERP.query(company_name)
|
||||
# Cache the SERP results
|
||||
if patents.patents:
|
||||
@@ -66,12 +69,13 @@ class CompanyAnalyzer:
|
||||
company_name=company_name,
|
||||
query_hash=query_hash,
|
||||
patent_ids=[p.patent_id for p in patents.patents],
|
||||
ttl_hours=config.serp_cache_ttl_hours,
|
||||
)
|
||||
|
||||
if not patents.patents:
|
||||
return f"No patents found for {company_name}"
|
||||
|
||||
print(f"Found {len(patents.patents)} patents. Processing...")
|
||||
logger.info("Found %d patents. Processing...", len(patents.patents))
|
||||
|
||||
# Download, parse, and minimize patents in parallel
|
||||
processed_patents = []
|
||||
@@ -87,12 +91,12 @@ class CompanyAnalyzer:
|
||||
if result:
|
||||
processed_patents.append(result)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to process {patent.patent_id}: {e}")
|
||||
logger.warning("Failed to process %s: %s", patent.patent_id, e)
|
||||
|
||||
if not processed_patents:
|
||||
return f"Failed to process any patents for {company_name}"
|
||||
|
||||
print(f"Analyzing portfolio with LLM...")
|
||||
logger.info("Analyzing portfolio with LLM...")
|
||||
|
||||
# Analyze the full portfolio with LLM
|
||||
analysis = self.llm_analyzer.analyze_patent_portfolio(
|
||||
@@ -104,21 +108,34 @@ class CompanyAnalyzer:
|
||||
def analyze_single_patent(self, patent_id: str, company_name: str) -> str:
|
||||
"""Analyze a single patent by ID.
|
||||
|
||||
Useful for focused analysis of specific innovations.
|
||||
Prerequisite:
|
||||
The patent PDF must already exist at ``patents/{patent_id}.pdf``
|
||||
before calling this method. PDFs are downloaded automatically when
|
||||
using the batch analysis pipeline (``analyze_company`` or the
|
||||
``/analyze/batch`` API endpoint). For standalone usage, download
|
||||
the PDF manually or call ``SERP.save_patents()`` first.
|
||||
|
||||
Args:
|
||||
patent_id: Publication ID of the patent
|
||||
patent_id: Publication ID of the patent (e.g. "US-11234567-B2")
|
||||
company_name: Name of the company (for context)
|
||||
|
||||
Returns:
|
||||
Analysis of the specific patent's innovation quality
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the patent PDF is not found at the expected path.
|
||||
"""
|
||||
# Note: This simplified version assumes the patent PDF is already downloaded
|
||||
# A more complete implementation would support direct patent ID lookup
|
||||
print(f"Analyzing patent {patent_id} for {company_name}...")
|
||||
import os
|
||||
logger.info("Analyzing patent %s for %s...", patent_id, company_name)
|
||||
|
||||
patent_path = f"patents/{patent_id}.pdf"
|
||||
|
||||
if not os.path.exists(patent_path):
|
||||
raise FileNotFoundError(
|
||||
f"Patent PDF not found at '{patent_path}'. "
|
||||
f"Download the PDF first using SERP.save_patents() or the batch analysis pipeline."
|
||||
)
|
||||
|
||||
try:
|
||||
sections = SERP.parse_patent_pdf(patent_path)
|
||||
minimized_content = SERP.minimize_patent_for_llm(sections)
|
||||
@@ -129,6 +146,8 @@ class CompanyAnalyzer:
|
||||
|
||||
return analysis
|
||||
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
return f"Failed to analyze patent {patent_id}: {e}"
|
||||
|
||||
@@ -169,7 +188,7 @@ class CompanyAnalyzer:
|
||||
|
||||
return {"patent_id": patent.patent_id, "content": minimized_content}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to process {patent.patent_id}: {e}")
|
||||
logger.warning("Failed to process %s: %s", patent.patent_id, e)
|
||||
return None
|
||||
|
||||
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult:
|
||||
@@ -240,7 +259,7 @@ class CompanyAnalyzer:
|
||||
results: list[CompanyAnalysisResult] = []
|
||||
total = len(companies)
|
||||
|
||||
print(f"Starting batch analysis of {total} companies...")
|
||||
logger.info("Starting batch analysis of %d companies...", total)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_company = {
|
||||
@@ -257,8 +276,8 @@ class CompanyAnalyzer:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
|
||||
status = "✓" if result.success else "✗"
|
||||
print(f"[{completed}/{total}] {status} {company}")
|
||||
status = "OK" if result.success else "FAIL"
|
||||
logger.info("[%d/%d] %s %s", completed, total, status, company)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(company, completed, total)
|
||||
@@ -273,12 +292,12 @@ class CompanyAnalyzer:
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
print(f"[{completed}/{total}] ✗ {company}: {e}")
|
||||
logger.error("[%d/%d] FAIL %s: %s", completed, total, company, e)
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = total - successful
|
||||
|
||||
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
|
||||
logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
|
||||
|
||||
return BatchAnalysisResult(
|
||||
results=results,
|
||||
@@ -304,20 +323,20 @@ class CompanyAnalyzer:
|
||||
results: list[CompanyAnalysisResult] = []
|
||||
total = len(companies)
|
||||
|
||||
print(f"Starting sequential analysis of {total} companies...")
|
||||
logger.info("Starting sequential analysis of %d companies...", total)
|
||||
|
||||
for idx, company in enumerate(companies, 1):
|
||||
print(f"\n[{idx}/{total}] Analyzing {company}...")
|
||||
logger.info("[%d/%d] Analyzing %s...", idx, total, company)
|
||||
result = self._analyze_company_safe(company)
|
||||
results.append(result)
|
||||
|
||||
status = "✓" if result.success else "✗"
|
||||
print(f"[{idx}/{total}] {status} {company}")
|
||||
status = "OK" if result.success else "FAIL"
|
||||
logger.info("[%d/%d] %s %s", idx, total, status, company)
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = total - successful
|
||||
|
||||
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
|
||||
logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
|
||||
|
||||
return BatchAnalysisResult(
|
||||
results=results,
|
||||
|
||||
+31
-7
@@ -7,15 +7,20 @@ from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Annotated, List
|
||||
|
||||
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query
|
||||
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from slowapi import Limiter
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from SPARC import config
|
||||
from SPARC.analyzer import CompanyAnalyzer
|
||||
from SPARC.auth import (
|
||||
TokenResponse,
|
||||
UserResponse,
|
||||
check_jwt_secret,
|
||||
create_tokens,
|
||||
decode_token,
|
||||
get_current_admin,
|
||||
@@ -149,6 +154,7 @@ _analyzer: CompanyAnalyzer | None = None
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize resources on startup, clean up on shutdown."""
|
||||
global _analyzer
|
||||
check_jwt_secret()
|
||||
_analyzer = CompanyAnalyzer()
|
||||
# Mark any jobs that were running/pending before the restart as failed
|
||||
from SPARC.database import DatabaseClient
|
||||
@@ -173,10 +179,26 @@ app = FastAPI(
|
||||
root_path=config.root_path,
|
||||
)
|
||||
|
||||
# Rate limiter (in-memory storage, suitable for single-instance deployments)
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app.state.limiter = limiter
|
||||
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
|
||||
"""Return 429 with Retry-After header when rate limit is exceeded."""
|
||||
retry_after = getattr(exc, "retry_after", 60)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded. Please try again later."},
|
||||
headers={"Retry-After": str(retry_after)},
|
||||
)
|
||||
|
||||
|
||||
# Add CORS middleware for React frontend
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000", "http://localhost:5173"],
|
||||
allow_origins=config.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
@@ -187,7 +209,8 @@ app.add_middleware(
|
||||
|
||||
|
||||
@app.post("/auth/register", response_model=UserResponse, tags=["Auth"])
|
||||
async def register(request: RegisterRequest):
|
||||
@limiter.limit("5/minute")
|
||||
async def register(request: Request, body: RegisterRequest):
|
||||
"""Register a new user.
|
||||
|
||||
The first registered user automatically becomes an admin.
|
||||
@@ -199,8 +222,8 @@ async def register(request: RegisterRequest):
|
||||
role = "admin" if user_count == 0 else "user"
|
||||
|
||||
user = db.create_user(
|
||||
email=request.email,
|
||||
password=request.password,
|
||||
email=body.email,
|
||||
password=body.password,
|
||||
role=role,
|
||||
)
|
||||
|
||||
@@ -219,11 +242,12 @@ async def register(request: RegisterRequest):
|
||||
|
||||
|
||||
@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"])
|
||||
async def login(request: LoginRequest):
|
||||
@limiter.limit("10/minute")
|
||||
async def login(request: Request, body: LoginRequest):
|
||||
"""Authenticate user and return JWT tokens."""
|
||||
db = get_db_client()
|
||||
|
||||
user = db.authenticate_user(request.email, request.password)
|
||||
user = db.authenticate_user(body.email, body.password)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
|
||||
+15
-1
@@ -13,11 +13,25 @@ from SPARC import config
|
||||
from SPARC.database import DatabaseClient
|
||||
|
||||
# JWT Configuration
|
||||
JWT_SECRET = os.getenv("JWT_SECRET", "sparc-secret-key-change-in-production")
|
||||
_DEFAULT_JWT_SECRET = "sparc-secret-key-change-in-production"
|
||||
JWT_SECRET = os.getenv("JWT_SECRET", _DEFAULT_JWT_SECRET)
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
||||
|
||||
|
||||
def check_jwt_secret() -> None:
|
||||
"""Refuse to start with the default JWT secret in non-development environments.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If JWT_SECRET is the default value and APP_ENV is not 'development'.
|
||||
"""
|
||||
if JWT_SECRET == _DEFAULT_JWT_SECRET and config.app_env != "development":
|
||||
raise RuntimeError(
|
||||
f"FATAL: JWT_SECRET is set to the default value and APP_ENV={config.app_env!r}. "
|
||||
"Set a secure JWT_SECRET environment variable before running in non-development environments."
|
||||
)
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
|
||||
+29
-1
@@ -2,11 +2,20 @@
|
||||
|
||||
Loads environment variables from .env file for API keys and other secrets.
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Logging configuration
|
||||
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level, logging.INFO),
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
# SerpAPI key for patent search
|
||||
api_key = os.getenv("API_KEY")
|
||||
|
||||
@@ -30,6 +39,25 @@ use_database = os.getenv("USE_DATABASE", "false").lower() in ("true", "1", "yes"
|
||||
patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90"))
|
||||
patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5"))
|
||||
|
||||
# LLM model to use via OpenRouter (e.g. "anthropic/claude-3.5-sonnet", "openai/gpt-4o")
|
||||
model = os.getenv("MODEL", "anthropic/claude-3.5-sonnet")
|
||||
|
||||
# SERP cache TTL in hours (how long cached search results are considered fresh)
|
||||
serp_cache_ttl_hours = int(os.getenv("SERP_CACHE_TTL_HOURS", "24"))
|
||||
|
||||
# Root path for running behind a reverse proxy (e.g., "/api" when served at /api/)
|
||||
# This ensures OpenAPI docs work correctly when accessed via the proxy
|
||||
root_path = os.getenv("ROOT_PATH", "")
|
||||
|
||||
# Application environment: "development", "staging", or "production"
|
||||
# Used for safety checks (e.g., refusing default JWT secret in production)
|
||||
app_env = os.getenv("APP_ENV", "development")
|
||||
|
||||
# CORS allowed origins (comma-separated)
|
||||
# Defaults to localhost dev origins when unset
|
||||
_cors_origins_raw = os.getenv("CORS_ORIGINS", "")
|
||||
cors_origins: list[str] = (
|
||||
[o.strip() for o in _cors_origins_raw.split(",") if o.strip()]
|
||||
if _cors_origins_raw
|
||||
else ["http://localhost:3000", "http://localhost:5173"]
|
||||
)
|
||||
|
||||
+9
-8
@@ -1,9 +1,14 @@
|
||||
"""LLM integration for patent analysis using OpenRouter."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from SPARC import config
|
||||
from SPARC.database import DatabaseClient
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMAnalyzer:
|
||||
@@ -20,7 +25,7 @@ class LLMAnalyzer:
|
||||
"""
|
||||
self.test_mode = test_mode
|
||||
self.use_cache = use_cache if use_cache is not None else config.use_cache
|
||||
self.model = "anthropic/claude-3.5-sonnet"
|
||||
self.model = config.model
|
||||
|
||||
# Always initialize database client for storage and caching
|
||||
self.db_client = DatabaseClient(config.database_url)
|
||||
@@ -59,11 +64,7 @@ Patent Content:
|
||||
Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage."""
|
||||
|
||||
if self.test_mode:
|
||||
print("=" * 80)
|
||||
print("TEST MODE - Prompt that would be sent to LLM:")
|
||||
print("=" * 80)
|
||||
print(prompt)
|
||||
print("=" * 80)
|
||||
logger.debug("TEST MODE - Prompt that would be sent to LLM:\n%s", prompt)
|
||||
return "[TEST MODE - No API call made]"
|
||||
|
||||
# Check cache first
|
||||
@@ -165,7 +166,7 @@ Patent Portfolio:
|
||||
Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook."""
|
||||
|
||||
if self.test_mode:
|
||||
print(prompt)
|
||||
logger.debug("TEST MODE - Portfolio prompt:\n%s", prompt)
|
||||
return "[TEST MODE]"
|
||||
|
||||
metadata = {
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
|
||||
@dataclass
|
||||
class Patent:
|
||||
patent_id: int
|
||||
patent_id: str
|
||||
pdf_link: str
|
||||
pdf_path: str | None = None
|
||||
summary: dict | None = None
|
||||
|
||||
+8
-6
@@ -3,15 +3,15 @@ services:
|
||||
image: postgres:16-alpine
|
||||
container_name: sparc-postgres
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: sparc
|
||||
POSTGRES_USER: ${POSTGRES_USER}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
POSTGRES_DB: ${POSTGRES_DB}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER}"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
@@ -22,7 +22,7 @@ services:
|
||||
container_name: sparc-init-db
|
||||
command: python scripts/init_database.py
|
||||
environment:
|
||||
DATABASE_URL: postgresql://postgres:postgres@postgres:5432/sparc
|
||||
DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/${POSTGRES_DB}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
@@ -35,9 +35,11 @@ services:
|
||||
environment:
|
||||
API_KEY: ${API_KEY}
|
||||
OPENROUTER_API_KEY: ${OPENROUTER_API_KEY}
|
||||
DATABASE_URL: postgresql://postgres:postgres@postgres:5432/sparc
|
||||
DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/${POSTGRES_DB}
|
||||
USE_CACHE: "true"
|
||||
JWT_SECRET: ${JWT_SECRET:-sparc-secret-key-change-in-production}
|
||||
CORS_ORIGINS: ${CORS_ORIGINS:-}
|
||||
APP_ENV: ${APP_ENV:-development}
|
||||
ROOT_PATH: /api
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
Generated
+4728
File diff suppressed because it is too large
Load Diff
@@ -14,3 +14,4 @@ numpy
|
||||
pandas
|
||||
bcrypt
|
||||
PyJWT
|
||||
slowapi
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
"""Tests for JWT authentication flow: register, login, protected routes, refresh, admin access."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from SPARC.api import app
|
||||
from SPARC.auth import create_access_token, create_refresh_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db(monkeypatch):
|
||||
"""Mock the database client used by auth endpoints.
|
||||
|
||||
Returns a MagicMock with all DB methods pre-configured.
|
||||
"""
|
||||
db = MagicMock()
|
||||
|
||||
# Default: no users exist
|
||||
db.get_user_count.return_value = 0
|
||||
db.get_user_by_id.return_value = None
|
||||
db.get_user_by_email.return_value = None
|
||||
db.authenticate_user.return_value = None
|
||||
db.create_user.return_value = None
|
||||
db.get_all_users.return_value = []
|
||||
db.update_user_role.return_value = None
|
||||
db.delete_user.return_value = False
|
||||
|
||||
with patch("SPARC.api.get_db_client", return_value=db), \
|
||||
patch("SPARC.auth.get_db_client", return_value=db):
|
||||
yield db
|
||||
|
||||
|
||||
def _make_admin_user():
|
||||
return {
|
||||
"id": 1,
|
||||
"email": "admin@test.com",
|
||||
"role": "admin",
|
||||
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
|
||||
|
||||
def _make_regular_user():
|
||||
return {
|
||||
"id": 2,
|
||||
"email": "user@test.com",
|
||||
"role": "user",
|
||||
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
|
||||
|
||||
def _auth_header(user_dict):
|
||||
"""Create an Authorization header with a valid access token for the given user."""
|
||||
token = create_access_token(user_dict["id"], user_dict["email"], user_dict["role"])
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
class TestRegister:
|
||||
"""POST /auth/register"""
|
||||
|
||||
def test_register_first_user_becomes_admin(self, client, mock_db):
|
||||
"""First registered user should get admin role."""
|
||||
mock_db.get_user_count.return_value = 0
|
||||
mock_db.create_user.return_value = {
|
||||
"id": 1,
|
||||
"email": "admin@test.com",
|
||||
"role": "admin",
|
||||
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/auth/register",
|
||||
json={"email": "admin@test.com", "password": "securepass123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "admin@test.com"
|
||||
assert data["role"] == "admin"
|
||||
mock_db.create_user.assert_called_once_with(
|
||||
email="admin@test.com", password="securepass123", role="admin"
|
||||
)
|
||||
|
||||
def test_register_subsequent_user_gets_user_role(self, client, mock_db):
|
||||
"""Non-first user should get regular user role."""
|
||||
mock_db.get_user_count.return_value = 1
|
||||
mock_db.create_user.return_value = _make_regular_user()
|
||||
|
||||
response = client.post(
|
||||
"/auth/register",
|
||||
json={"email": "user@test.com", "password": "securepass123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["role"] == "user"
|
||||
|
||||
def test_register_duplicate_email_returns_400(self, client, mock_db):
|
||||
"""Registering with an existing email should return 400."""
|
||||
mock_db.get_user_count.return_value = 1
|
||||
mock_db.create_user.return_value = None # indicates duplicate
|
||||
|
||||
response = client.post(
|
||||
"/auth/register",
|
||||
json={"email": "existing@test.com", "password": "securepass123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "already registered" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestLogin:
|
||||
"""POST /auth/login"""
|
||||
|
||||
def test_login_valid_credentials_returns_tokens(self, client, mock_db):
|
||||
"""Valid credentials should return access and refresh tokens."""
|
||||
user = _make_regular_user()
|
||||
mock_db.authenticate_user.return_value = user
|
||||
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
json={"email": "user@test.com", "password": "correctpassword"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
def test_login_invalid_credentials_returns_401(self, client, mock_db):
|
||||
"""Invalid credentials should return 401."""
|
||||
mock_db.authenticate_user.return_value = None
|
||||
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
json={"email": "user@test.com", "password": "wrongpassword"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "invalid" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestGetMe:
|
||||
"""GET /auth/me"""
|
||||
|
||||
def test_valid_access_token_returns_user(self, client, mock_db):
|
||||
"""A valid access token should return the user's data."""
|
||||
user = _make_regular_user()
|
||||
mock_db.get_user_by_id.return_value = user
|
||||
|
||||
response = client.get("/auth/me", headers=_auth_header(user))
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "user@test.com"
|
||||
assert data["id"] == 2
|
||||
|
||||
def test_missing_token_returns_401(self, client):
|
||||
"""No token should return 401 (403 from HTTPBearer)."""
|
||||
response = client.get("/auth/me")
|
||||
assert response.status_code in (401, 403)
|
||||
|
||||
def test_expired_token_returns_401(self, client, mock_db):
|
||||
"""An expired token should return 401."""
|
||||
# Create a token that has already expired
|
||||
from datetime import timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
from SPARC.auth import JWT_ALGORITHM, JWT_SECRET
|
||||
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"email": "user@test.com",
|
||||
"role": "user",
|
||||
"exp": datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
"type": "access",
|
||||
}
|
||||
expired_token = pyjwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
||||
|
||||
response = client.get(
|
||||
"/auth/me", headers={"Authorization": f"Bearer {expired_token}"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_refresh_token_as_access_returns_401(self, client, mock_db):
|
||||
"""Using a refresh token as an access token should return 401."""
|
||||
user = _make_regular_user()
|
||||
refresh_token = create_refresh_token(user["id"], user["email"], user["role"])
|
||||
|
||||
response = client.get(
|
||||
"/auth/me", headers={"Authorization": f"Bearer {refresh_token}"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
class TestRefreshToken:
|
||||
"""POST /auth/refresh"""
|
||||
|
||||
def test_valid_refresh_token_returns_new_tokens(self, client, mock_db):
|
||||
"""A valid refresh token should issue new access and refresh tokens."""
|
||||
user = _make_regular_user()
|
||||
mock_db.get_user_by_id.return_value = user
|
||||
refresh = create_refresh_token(user["id"], user["email"], user["role"])
|
||||
|
||||
response = client.post(
|
||||
"/auth/refresh", json={"refresh_token": refresh}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
|
||||
def test_invalid_refresh_token_returns_401(self, client, mock_db):
|
||||
"""An invalid refresh token should return 401."""
|
||||
response = client.post(
|
||||
"/auth/refresh", json={"refresh_token": "invalid-token-string"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_access_token_as_refresh_returns_401(self, client, mock_db):
|
||||
"""Using an access token as a refresh token should return 401."""
|
||||
user = _make_regular_user()
|
||||
access = create_access_token(user["id"], user["email"], user["role"])
|
||||
|
||||
response = client.post(
|
||||
"/auth/refresh", json={"refresh_token": access}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
class TestAdminUsers:
|
||||
"""GET /admin/users and PATCH /admin/users/{id}/role"""
|
||||
|
||||
def test_admin_can_list_users(self, client, mock_db):
|
||||
"""Admin token should allow listing users."""
|
||||
admin = _make_admin_user()
|
||||
mock_db.get_user_by_id.return_value = admin
|
||||
mock_db.get_all_users.return_value = [admin, _make_regular_user()]
|
||||
|
||||
response = client.get("/admin/users", headers=_auth_header(admin))
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
|
||||
def test_regular_user_cannot_list_users(self, client, mock_db):
|
||||
"""Regular user token should be rejected with 403."""
|
||||
user = _make_regular_user()
|
||||
mock_db.get_user_by_id.return_value = user
|
||||
|
||||
response = client.get("/admin/users", headers=_auth_header(user))
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_no_token_cannot_list_users(self, client):
|
||||
"""No token should be rejected."""
|
||||
response = client.get("/admin/users")
|
||||
assert response.status_code in (401, 403)
|
||||
|
||||
def test_admin_can_change_user_role(self, client, mock_db):
|
||||
"""Admin should be able to change another user's role."""
|
||||
admin = _make_admin_user()
|
||||
mock_db.get_user_by_id.return_value = admin
|
||||
mock_db.update_user_role.return_value = {
|
||||
"id": 2,
|
||||
"email": "user@test.com",
|
||||
"role": "admin",
|
||||
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
|
||||
response = client.patch(
|
||||
"/admin/users/2/role",
|
||||
json={"role": "admin"},
|
||||
headers=_auth_header(admin),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["role"] == "admin"
|
||||
|
||||
def test_admin_cannot_change_own_role(self, client, mock_db):
|
||||
"""Admin should not be able to change their own role."""
|
||||
admin = _make_admin_user()
|
||||
mock_db.get_user_by_id.return_value = admin
|
||||
|
||||
response = client.patch(
|
||||
"/admin/users/1/role",
|
||||
json={"role": "user"},
|
||||
headers=_auth_header(admin),
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "own role" in response.json()["detail"].lower()
|
||||
@@ -0,0 +1,97 @@
|
||||
"""Tests for rate limiting on auth endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from SPARC.api import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create test client with rate limiter enabled."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_limiter():
|
||||
"""Reset rate limiter storage between tests."""
|
||||
from SPARC.api import limiter
|
||||
limiter.reset()
|
||||
yield
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
"""Test rate limiting on login and register endpoints."""
|
||||
|
||||
@patch("SPARC.api.get_db_client")
|
||||
def test_login_allows_requests_under_limit(self, mock_db_client, client):
|
||||
"""Login endpoint allows requests under the rate limit."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.authenticate_user.return_value = None
|
||||
mock_db_client.return_value = mock_db
|
||||
|
||||
# Should allow at least a few requests
|
||||
for _ in range(5):
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
json={"email": "test@example.com", "password": "password123"},
|
||||
)
|
||||
# 401 is expected (invalid credentials), not 429
|
||||
assert response.status_code == 401
|
||||
|
||||
@patch("SPARC.api.get_db_client")
|
||||
def test_login_rate_limited_after_threshold(self, mock_db_client, client):
|
||||
"""Login endpoint returns 429 after exceeding rate limit."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.authenticate_user.return_value = None
|
||||
mock_db_client.return_value = mock_db
|
||||
|
||||
# Send more than the limit (10/minute)
|
||||
statuses = []
|
||||
for _ in range(15):
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
json={"email": "test@example.com", "password": "password123"},
|
||||
)
|
||||
statuses.append(response.status_code)
|
||||
|
||||
# At least one should be 429
|
||||
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
|
||||
|
||||
@patch("SPARC.api.get_db_client")
|
||||
def test_register_rate_limited_after_threshold(self, mock_db_client, client):
|
||||
"""Register endpoint returns 429 after exceeding rate limit."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_user_count.return_value = 1
|
||||
mock_db.create_user.return_value = None # triggers 400 (email exists)
|
||||
mock_db_client.return_value = mock_db
|
||||
|
||||
# Send more than the limit (5/minute)
|
||||
statuses = []
|
||||
for _ in range(10):
|
||||
response = client.post(
|
||||
"/auth/register",
|
||||
json={"email": "test@example.com", "password": "password123"},
|
||||
)
|
||||
statuses.append(response.status_code)
|
||||
|
||||
# At least one should be 429
|
||||
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
|
||||
|
||||
@patch("SPARC.api.get_db_client")
|
||||
def test_rate_limit_returns_retry_after_header(self, mock_db_client, client):
|
||||
"""Rate limited responses include a Retry-After header."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.authenticate_user.return_value = None
|
||||
mock_db_client.return_value = mock_db
|
||||
|
||||
# Exhaust the limit
|
||||
for _ in range(15):
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
json={"email": "test@example.com", "password": "password123"},
|
||||
)
|
||||
if response.status_code == 429:
|
||||
assert "Retry-After" in response.headers
|
||||
break
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Tests for security hardening: JWT secret startup check, CORS config, credential handling."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestJWTSecretStartupCheck:
|
||||
"""Test the startup guard that refuses default JWT secret in non-dev environments."""
|
||||
|
||||
def test_default_secret_in_production_raises(self):
|
||||
"""Starting with default secret and APP_ENV=production must raise RuntimeError."""
|
||||
with patch.dict(os.environ, {"APP_ENV": "production"}):
|
||||
# Reload config to pick up the new APP_ENV
|
||||
import importlib
|
||||
import SPARC.config
|
||||
importlib.reload(SPARC.config)
|
||||
|
||||
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
|
||||
# Patch JWT_SECRET to the default
|
||||
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
|
||||
with pytest.raises(RuntimeError, match="FATAL.*JWT_SECRET"):
|
||||
check_jwt_secret()
|
||||
|
||||
# Restore config
|
||||
with patch.dict(os.environ, {"APP_ENV": "development"}):
|
||||
importlib.reload(SPARC.config)
|
||||
|
||||
def test_default_secret_in_development_succeeds(self):
|
||||
"""Starting with default secret and APP_ENV=development must not raise."""
|
||||
with patch.dict(os.environ, {"APP_ENV": "development"}):
|
||||
import importlib
|
||||
import SPARC.config
|
||||
importlib.reload(SPARC.config)
|
||||
|
||||
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
|
||||
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
|
||||
# Should not raise
|
||||
check_jwt_secret()
|
||||
|
||||
# Restore
|
||||
importlib.reload(SPARC.config)
|
||||
|
||||
def test_custom_secret_in_production_succeeds(self):
|
||||
"""Starting with a custom secret in production must not raise."""
|
||||
with patch.dict(os.environ, {"APP_ENV": "production"}):
|
||||
import importlib
|
||||
import SPARC.config
|
||||
importlib.reload(SPARC.config)
|
||||
|
||||
from SPARC.auth import check_jwt_secret
|
||||
with patch("SPARC.auth.JWT_SECRET", "my-secure-random-secret-abc123"):
|
||||
# Should not raise
|
||||
check_jwt_secret()
|
||||
|
||||
with patch.dict(os.environ, {"APP_ENV": "development"}):
|
||||
importlib.reload(SPARC.config)
|
||||
|
||||
def test_default_secret_unset_env_succeeds(self):
|
||||
"""When APP_ENV is unset (defaults to development), default secret is allowed."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
# Remove APP_ENV if present
|
||||
env = os.environ.copy()
|
||||
env.pop("APP_ENV", None)
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
import importlib
|
||||
import SPARC.config
|
||||
importlib.reload(SPARC.config)
|
||||
|
||||
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
|
||||
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
|
||||
# Should not raise (defaults to development)
|
||||
check_jwt_secret()
|
||||
|
||||
with patch.dict(os.environ, {"APP_ENV": "development"}):
|
||||
importlib.reload(SPARC.config)
|
||||
|
||||
|
||||
class TestCORSConfig:
|
||||
"""Test that CORS origins are configurable via environment variable."""
|
||||
|
||||
def test_default_cors_origins(self):
|
||||
"""When CORS_ORIGINS is unset, defaults to localhost origins."""
|
||||
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
|
||||
import importlib
|
||||
import SPARC.config
|
||||
importlib.reload(SPARC.config)
|
||||
assert SPARC.config.cors_origins == [
|
||||
"http://localhost:3000",
|
||||
"http://localhost:5173",
|
||||
]
|
||||
|
||||
def test_custom_cors_origins(self):
|
||||
"""Setting CORS_ORIGINS configures allowed origins."""
|
||||
with patch.dict(os.environ, {"CORS_ORIGINS": "https://sparc.example.com,https://app.example.com"}):
|
||||
import importlib
|
||||
import SPARC.config
|
||||
importlib.reload(SPARC.config)
|
||||
assert SPARC.config.cors_origins == [
|
||||
"https://sparc.example.com",
|
||||
"https://app.example.com",
|
||||
]
|
||||
# Restore
|
||||
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
|
||||
importlib.reload(SPARC.config)
|
||||
|
||||
def test_single_cors_origin(self):
|
||||
"""A single origin without comma works correctly."""
|
||||
with patch.dict(os.environ, {"CORS_ORIGINS": "https://sparc.example.com"}):
|
||||
import importlib
|
||||
import SPARC.config
|
||||
importlib.reload(SPARC.config)
|
||||
assert SPARC.config.cors_origins == ["https://sparc.example.com"]
|
||||
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
|
||||
importlib.reload(SPARC.config)
|
||||
Reference in New Issue
Block a user