Compare commits

..

12 Commits

Author SHA1 Message Date
agent-company b000146585 feat: configurable LLM model, SERP cache TTL, structured logging, fix patent_id type
- Make LLM model configurable via MODEL env var, default anthropic/claude-3.5-sonnet (#12)
- Expose SERP cache TTL as SERP_CACHE_TTL_HOURS env var, default 24 hours (#13)
- Fix Patent.patent_id type annotation from int to str in types.py (#14)
- Replace all print() calls with structured logging in analyzer.py and llm.py (#11)
- Add LOG_LEVEL config with basicConfig setup in config.py
- Add model and serp_cache_ttl_hours to config.py

Closes leeworks-agents/SPARC#11
Closes leeworks-agents/SPARC#12
Closes leeworks-agents/SPARC#13
Closes leeworks-agents/SPARC#14

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 06:03:25 +00:00
AI-Manager 35d105b14e Merge pull request 'feat(auth): add rate limiting to login and register endpoints' (#28) from feature/rate-limiting into main 2026-03-26 05:04:46 +00:00
AI-Manager 6fcf170d93 Merge pull request 'feat(jobs): persist async batch job state in PostgreSQL' (#34) from feature/persist-job-state into main 2026-03-26 05:04:26 +00:00
AI-Manager 5a42e216ba Merge pull request 'docs: patent PDF storage docs, FileNotFoundError, frontend lockfile' (#31) from feature/p2-docs-and-lockfile into main 2026-03-26 05:04:01 +00:00
AI-Manager 24ab341d9b Merge pull request 'test(auth): add comprehensive JWT authentication test suite' (#35) from feature/jwt-auth-tests into main 2026-03-26 05:03:29 +00:00
AI-Manager 878fedfbb8 Merge pull request 'feat(security): JWT startup guard, configurable CORS, externalize DB creds' (#27) from feature/p1-security-hardening into main 2026-03-26 05:03:16 +00:00
agent-company ae9f257dcb test(auth): add comprehensive JWT authentication test suite
Add 17 tests in tests/test_auth.py covering all auth flows:
- Registration: first user admin, subsequent user, duplicate email
- Login: valid credentials, invalid credentials
- Protected routes: valid token, missing token, expired token, wrong token type
- Token refresh: valid refresh, invalid refresh, access-as-refresh rejected
- Admin endpoints: list users, change role, own-role prevention, permission checks

All tests use mocked database (no live DB required).

Closes leeworks-agents/SPARC#10

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:24:12 +00:00
agent-company 96d5d27b17 feat(jobs): persist async batch job state in PostgreSQL
- Add jobs table to database schema (job_id, status, progress, result_json, etc.)
- Add DatabaseClient methods: create_job, update_job, get_job, list_jobs
- Add mark_stale_jobs_failed() called at startup to handle interrupted jobs
- Refactor _run_batch_job and job endpoints to read/write from PostgreSQL
- Remove in-memory _jobs dict; job state now survives API restarts
- Update init_database.py to list all tables in output

Closes leeworks-agents/SPARC#8

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:22:57 +00:00
agent-company 3dac88ec90 docs: document patent PDF storage, add FileNotFoundError, commit lockfile
- Add docstring to analyze_single_patent explaining the PDF prerequisite
- Raise FileNotFoundError with helpful message when PDF is missing
- Add patent PDF storage section to README with Docker volume mount example
- Commit frontend/package-lock.json for reproducible builds

Closes leeworks-agents/SPARC#15
Closes leeworks-agents/SPARC#17

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:17:09 +00:00
agent-company e2d750146c feat(auth): add rate limiting to login and register endpoints
- Add slowapi rate limiter: 10 req/min for /auth/login, 5 req/min for /auth/register
- Return HTTP 429 with Retry-After header when limit is exceeded
- Add slowapi to requirements.txt
- Add 4 passing tests for rate limit behavior

Closes leeworks-agents/SPARC#9

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:08:22 +00:00
agent-company 47cddcbeaf feat(security): add JWT startup guard, configurable CORS, and externalize DB credentials
- Add check_jwt_secret() that refuses default JWT secret when APP_ENV != development
- Make CORS origins configurable via CORS_ORIGINS env var (comma-separated)
- Replace hardcoded postgres credentials in docker-compose.yml with env var references
- Add APP_ENV and cors_origins to config.py
- Update .env.example with all required variables and documentation
- Add tests for JWT startup guard and CORS configuration

Closes leeworks-agents/SPARC#4
Closes leeworks-agents/SPARC#5
Closes leeworks-agents/SPARC#6

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:06:31 +00:00
AI-Manager 6105ba7793 Merge pull request 'chore: add ROADMAP.md for SPARC application development' (#3) from chore/add-roadmap into main 2026-03-26 02:47:54 +00:00
17 changed files with 5640 additions and 88 deletions
+30 -9
View File
@@ -1,21 +1,42 @@
# SPARC Configuration
# ---- Application Environment ----
# Set to "production" or "staging" in deployed environments.
# The API will refuse to start with the default JWT secret unless APP_ENV=development.
APP_ENV=development
# ---- API Keys ----
# SerpAPI key for patent search
API_KEY=your_serpapi_key_here
# OpenRouter API key for LLM analysis
OPENROUTER_API_KEY=your_openrouter_key_here
# Database configuration
# All messages are stored in the database for persistence and caching
DATABASE_URL=postgresql://postgres:postgres@localhost:5432/sparc
# ---- Database ----
# Cache configuration
# When USE_CACHE=true: check database for cached responses before making API calls
# When USE_CACHE=false: always make fresh API calls (still stores results in database)
# Default: true
USE_CACHE=true
# PostgreSQL credentials (used by docker-compose)
POSTGRES_USER=postgres
POSTGRES_PASSWORD=change-me-to-a-secure-password
POSTGRES_DB=sparc
# JWT Secret for authentication
# Full database URL (must match the credentials above)
DATABASE_URL=postgresql://postgres:change-me-to-a-secure-password@localhost:5432/sparc
# ---- Authentication ----
# JWT Secret for signing tokens
# IMPORTANT: Change this to a secure random string in production
JWT_SECRET=your-secure-jwt-secret-change-in-production
# ---- CORS ----
# Comma-separated list of allowed origins for CORS
# Defaults to http://localhost:3000,http://localhost:5173 when unset
# CORS_ORIGINS=https://sparc.example.com,https://app.example.com
# ---- Cache ----
# When USE_CACHE=true: check database for cached responses before making API calls
# When USE_CACHE=false: always make fresh API calls (still stores results in database)
USE_CACHE=true
+15
View File
@@ -54,6 +54,21 @@ docker-compose up -d
# - API Docs: http://localhost:8000/docs
```
#### Patent PDF Storage
The API stores downloaded patent PDFs in a `patents/` directory. In Docker,
this is mounted as a bind mount (`./patents:/app/patents`) so that PDFs persist
across container restarts.
If you deploy to a different environment, ensure the `patents/` directory is a
persistent volume. Without it, PDFs will be re-downloaded on every analysis.
```yaml
# docker-compose.yml excerpt
volumes:
- ./patents:/app/patents
```
### NixOS
```bash
+40 -21
View File
@@ -5,10 +5,13 @@ to provide company performance estimation based on patent portfolios.
"""
import hashlib
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable
from SPARC import config
logger = logging.getLogger(__name__)
from SPARC.database import DatabaseClient
from SPARC.serp_api import SERP
from SPARC.llm import LLMAnalyzer
@@ -52,13 +55,13 @@ class CompanyAnalyzer:
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
cached_ids = self.db.get_cached_serp_query(query_hash)
if cached_ids is not None:
print(f"Using cached SERP results for {company_name} ({len(cached_ids)} patents)")
logger.info("Using cached SERP results for %s (%d patents)", company_name, len(cached_ids))
patents = Patents(patents=[
Patent(patent_id=pid, pdf_link="")
for pid in cached_ids
])
else:
print(f"Retrieving patents for {company_name}...")
logger.info("Retrieving patents for %s...", company_name)
patents = SERP.query(company_name)
# Cache the SERP results
if patents.patents:
@@ -66,12 +69,13 @@ class CompanyAnalyzer:
company_name=company_name,
query_hash=query_hash,
patent_ids=[p.patent_id for p in patents.patents],
ttl_hours=config.serp_cache_ttl_hours,
)
if not patents.patents:
return f"No patents found for {company_name}"
print(f"Found {len(patents.patents)} patents. Processing...")
logger.info("Found %d patents. Processing...", len(patents.patents))
# Download, parse, and minimize patents in parallel
processed_patents = []
@@ -87,12 +91,12 @@ class CompanyAnalyzer:
if result:
processed_patents.append(result)
except Exception as e:
print(f"Warning: Failed to process {patent.patent_id}: {e}")
logger.warning("Failed to process %s: %s", patent.patent_id, e)
if not processed_patents:
return f"Failed to process any patents for {company_name}"
print(f"Analyzing portfolio with LLM...")
logger.info("Analyzing portfolio with LLM...")
# Analyze the full portfolio with LLM
analysis = self.llm_analyzer.analyze_patent_portfolio(
@@ -104,21 +108,34 @@ class CompanyAnalyzer:
def analyze_single_patent(self, patent_id: str, company_name: str) -> str:
"""Analyze a single patent by ID.
Useful for focused analysis of specific innovations.
Prerequisite:
The patent PDF must already exist at ``patents/{patent_id}.pdf``
before calling this method. PDFs are downloaded automatically when
using the batch analysis pipeline (``analyze_company`` or the
``/analyze/batch`` API endpoint). For standalone usage, download
the PDF manually or call ``SERP.save_patents()`` first.
Args:
patent_id: Publication ID of the patent
patent_id: Publication ID of the patent (e.g. "US-11234567-B2")
company_name: Name of the company (for context)
Returns:
Analysis of the specific patent's innovation quality
Raises:
FileNotFoundError: If the patent PDF is not found at the expected path.
"""
# Note: This simplified version assumes the patent PDF is already downloaded
# A more complete implementation would support direct patent ID lookup
print(f"Analyzing patent {patent_id} for {company_name}...")
import os
logger.info("Analyzing patent %s for %s...", patent_id, company_name)
patent_path = f"patents/{patent_id}.pdf"
if not os.path.exists(patent_path):
raise FileNotFoundError(
f"Patent PDF not found at '{patent_path}'. "
f"Download the PDF first using SERP.save_patents() or the batch analysis pipeline."
)
try:
sections = SERP.parse_patent_pdf(patent_path)
minimized_content = SERP.minimize_patent_for_llm(sections)
@@ -129,6 +146,8 @@ class CompanyAnalyzer:
return analysis
except FileNotFoundError:
raise
except Exception as e:
return f"Failed to analyze patent {patent_id}: {e}"
@@ -169,7 +188,7 @@ class CompanyAnalyzer:
return {"patent_id": patent.patent_id, "content": minimized_content}
except Exception as e:
print(f"Warning: Failed to process {patent.patent_id}: {e}")
logger.warning("Failed to process %s: %s", patent.patent_id, e)
return None
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult:
@@ -240,7 +259,7 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = []
total = len(companies)
print(f"Starting batch analysis of {total} companies...")
logger.info("Starting batch analysis of %d companies...", total)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_company = {
@@ -257,8 +276,8 @@ class CompanyAnalyzer:
result = future.result()
results.append(result)
status = "" if result.success else ""
print(f"[{completed}/{total}] {status} {company}")
status = "OK" if result.success else "FAIL"
logger.info("[%d/%d] %s %s", completed, total, status, company)
if progress_callback:
progress_callback(company, completed, total)
@@ -273,12 +292,12 @@ class CompanyAnalyzer:
error=str(e),
)
)
print(f"[{completed}/{total}] ✗ {company}: {e}")
logger.error("[%d/%d] FAIL %s: %s", completed, total, company, e)
successful = sum(1 for r in results if r.success)
failed = total - successful
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
return BatchAnalysisResult(
results=results,
@@ -304,20 +323,20 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = []
total = len(companies)
print(f"Starting sequential analysis of {total} companies...")
logger.info("Starting sequential analysis of %d companies...", total)
for idx, company in enumerate(companies, 1):
print(f"\n[{idx}/{total}] Analyzing {company}...")
logger.info("[%d/%d] Analyzing %s...", idx, total, company)
result = self._analyze_company_safe(company)
results.append(result)
status = "" if result.success else ""
print(f"[{idx}/{total}] {status} {company}")
status = "OK" if result.success else "FAIL"
logger.info("[%d/%d] %s %s", idx, total, status, company)
successful = sum(1 for r in results if r.success)
failed = total - successful
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
return BatchAnalysisResult(
results=results,
+100 -40
View File
@@ -7,15 +7,20 @@ from contextlib import asynccontextmanager
from datetime import datetime
from typing import Annotated, List
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, EmailStr, Field
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from SPARC import config
from SPARC.analyzer import CompanyAnalyzer
from SPARC.auth import (
TokenResponse,
UserResponse,
check_jwt_secret,
create_tokens,
decode_token,
get_current_admin,
@@ -114,8 +119,7 @@ class AnalyticsResponse(BaseModel):
period_days: int
# In-memory job storage (for demo; production would use Redis/DB)
_jobs: dict[str, JobStatus] = {}
# Job counter for generating unique IDs (the actual state is in PostgreSQL)
_job_counter = 0
@@ -148,9 +152,20 @@ _analyzer: CompanyAnalyzer | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize resources on startup."""
"""Initialize resources on startup, clean up on shutdown."""
global _analyzer
check_jwt_secret()
_analyzer = CompanyAnalyzer()
# Mark any jobs that were running/pending before the restart as failed
from SPARC.database import DatabaseClient
_db = DatabaseClient(config.database_url)
_db.connect()
_db.initialize_schema()
stale = _db.mark_stale_jobs_failed()
if stale:
import logging
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
_db.close()
yield
# Cleanup if needed
_analyzer = None
@@ -164,10 +179,26 @@ app = FastAPI(
root_path=config.root_path,
)
# Rate limiter (in-memory storage, suitable for single-instance deployments)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return 429 with Retry-After header when rate limit is exceeded."""
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded. Please try again later."},
headers={"Retry-After": str(retry_after)},
)
# Add CORS middleware for React frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:5173"],
allow_origins=config.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
@@ -178,7 +209,8 @@ app.add_middleware(
@app.post("/auth/register", response_model=UserResponse, tags=["Auth"])
async def register(request: RegisterRequest):
@limiter.limit("5/minute")
async def register(request: Request, body: RegisterRequest):
"""Register a new user.
The first registered user automatically becomes an admin.
@@ -190,8 +222,8 @@ async def register(request: RegisterRequest):
role = "admin" if user_count == 0 else "user"
user = db.create_user(
email=request.email,
password=request.password,
email=body.email,
password=body.password,
role=role,
)
@@ -210,11 +242,12 @@ async def register(request: RegisterRequest):
@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"])
async def login(request: LoginRequest):
@limiter.limit("10/minute")
async def login(request: Request, body: LoginRequest):
"""Authenticate user and return JWT tokens."""
db = get_db_client()
user = db.authenticate_user(request.email, request.password)
user = db.authenticate_user(body.email, body.password)
if not user:
raise HTTPException(
@@ -422,20 +455,52 @@ async def analyze_companies_batch(
return _convert_batch_result(result)
def _get_job_db() -> "DatabaseClient":
"""Get a DatabaseClient for job persistence."""
from SPARC.database import DatabaseClient
db = DatabaseClient(config.database_url)
return db
def _job_row_to_status(row: dict) -> JobStatus:
"""Convert a database job row to a JobStatus model."""
import json as _json
result = None
if row.get("result_json"):
result_data = row["result_json"]
if isinstance(result_data, str):
result_data = _json.loads(result_data)
result = BatchAnalysisResponse(**result_data)
return JobStatus(
job_id=row["job_id"],
status=row["status"],
progress=row["progress"],
total_companies=row["total_companies"],
completed_companies=row["completed_companies"],
result=result,
error=row.get("error"),
)
def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
"""Background task for batch analysis."""
global _jobs, _analyzer
import json as _json
global _analyzer
db = _get_job_db()
if not _analyzer:
_jobs[job_id].status = "failed"
_jobs[job_id].error = "Analyzer not initialized"
db.update_job(job_id, status="failed", error="Analyzer not initialized")
return
_jobs[job_id].status = "running"
db.update_job(job_id, status="running")
def progress_callback(company: str, completed: int, total: int):
_jobs[job_id].completed_companies = completed
_jobs[job_id].progress = int((completed / total) * 100)
db.update_job(
job_id,
completed_companies=completed,
progress=int((completed / total) * 100),
)
try:
result = _analyzer.analyze_companies(
@@ -443,12 +508,15 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
max_workers=max_workers,
progress_callback=progress_callback,
)
_jobs[job_id].status = "completed"
_jobs[job_id].progress = 100
_jobs[job_id].result = _convert_batch_result(result)
batch_response = _convert_batch_result(result)
db.update_job(
job_id,
status="completed",
progress=100,
result_json=_json.dumps(batch_response.model_dump(), default=str),
)
except Exception as e:
_jobs[job_id].status = "failed"
_jobs[job_id].error = str(e)
db.update_job(job_id, status="failed", error=str(e))
@app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"])
@@ -473,19 +541,14 @@ async def analyze_companies_async(
_job_counter += 1
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
_jobs[job_id] = JobStatus(
job_id=job_id,
status="pending",
progress=0,
total_companies=len(request.companies),
completed_companies=0,
)
db = _get_job_db()
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies))
background_tasks.add_task(
_run_batch_job, job_id, request.companies, request.max_workers
)
return _jobs[job_id]
return _job_row_to_status(job_row)
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
@@ -501,10 +564,13 @@ async def get_job_status(
Returns:
Current job status including progress and results when complete
"""
if job_id not in _jobs:
db = _get_job_db()
job_row = db.get_job(job_id)
if not job_row:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
return _jobs[job_id]
return _job_row_to_status(job_row)
@app.get("/jobs", response_model=list[JobStatus], tags=["Jobs"])
@@ -525,12 +591,6 @@ async def list_jobs(
Returns:
List of job statuses
"""
jobs = list(_jobs.values())
if status:
jobs = [j for j in jobs if j.status == status]
# Return most recent first
jobs.sort(key=lambda j: j.job_id, reverse=True)
return jobs[:limit]
db = _get_job_db()
job_rows = db.list_jobs(status=status, limit=limit)
return [_job_row_to_status(row) for row in job_rows]
+15 -1
View File
@@ -13,11 +13,25 @@ from SPARC import config
from SPARC.database import DatabaseClient
# JWT Configuration
JWT_SECRET = os.getenv("JWT_SECRET", "sparc-secret-key-change-in-production")
_DEFAULT_JWT_SECRET = "sparc-secret-key-change-in-production"
JWT_SECRET = os.getenv("JWT_SECRET", _DEFAULT_JWT_SECRET)
JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
def check_jwt_secret() -> None:
"""Refuse to start with the default JWT secret in non-development environments.
Raises:
RuntimeError: If JWT_SECRET is the default value and APP_ENV is not 'development'.
"""
if JWT_SECRET == _DEFAULT_JWT_SECRET and config.app_env != "development":
raise RuntimeError(
f"FATAL: JWT_SECRET is set to the default value and APP_ENV={config.app_env!r}. "
"Set a secure JWT_SECRET environment variable before running in non-development environments."
)
security = HTTPBearer()
+29 -1
View File
@@ -2,11 +2,20 @@
Loads environment variables from .env file for API keys and other secrets.
"""
from dotenv import load_dotenv
import logging
import os
from dotenv import load_dotenv
load_dotenv()
# Logging configuration
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=getattr(logging, log_level, logging.INFO),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
# SerpAPI key for patent search
api_key = os.getenv("API_KEY")
@@ -30,6 +39,25 @@ use_database = os.getenv("USE_DATABASE", "false").lower() in ("true", "1", "yes"
patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90"))
patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5"))
# LLM model to use via OpenRouter (e.g. "anthropic/claude-3.5-sonnet", "openai/gpt-4o")
model = os.getenv("MODEL", "anthropic/claude-3.5-sonnet")
# SERP cache TTL in hours (how long cached search results are considered fresh)
serp_cache_ttl_hours = int(os.getenv("SERP_CACHE_TTL_HOURS", "24"))
# Root path for running behind a reverse proxy (e.g., "/api" when served at /api/)
# This ensures OpenAPI docs work correctly when accessed via the proxy
root_path = os.getenv("ROOT_PATH", "")
# Application environment: "development", "staging", or "production"
# Used for safety checks (e.g., refusing default JWT secret in production)
app_env = os.getenv("APP_ENV", "development")
# CORS allowed origins (comma-separated)
# Defaults to localhost dev origins when unset
_cors_origins_raw = os.getenv("CORS_ORIGINS", "")
cors_origins: list[str] = (
[o.strip() for o in _cors_origins_raw.split(",") if o.strip()]
if _cors_origins_raw
else ["http://localhost:3000", "http://localhost:5173"]
)
+145
View File
@@ -171,6 +171,26 @@ class DatabaseClient:
ON serp_queries(query_hash)
""")
# Create jobs table for persisting async batch job state
cursor.execute("""
CREATE TABLE IF NOT EXISTS jobs (
job_id VARCHAR(128) PRIMARY KEY,
status VARCHAR(20) NOT NULL DEFAULT 'pending',
progress INTEGER NOT NULL DEFAULT 0,
total_companies INTEGER NOT NULL DEFAULT 0,
completed_companies INTEGER NOT NULL DEFAULT 0,
result_json JSONB,
error TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_jobs_status
ON jobs(status)
""")
self.conn.commit()
@staticmethod
@@ -462,6 +482,131 @@ class DatabaseClient:
)
conn.commit()
# Job Persistence Methods
def create_job(
self,
job_id: str,
total_companies: int,
) -> Dict:
"""Create a new job record.
Args:
job_id: Unique job identifier
total_companies: Number of companies in the batch
Returns:
Job dict
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies)
VALUES (%s, 'pending', 0, %s, 0)
RETURNING *
""",
(job_id, total_companies),
)
job = cursor.fetchone()
conn.commit()
return dict(job)
def update_job(
self,
job_id: str,
status: Optional[str] = None,
progress: Optional[int] = None,
completed_companies: Optional[int] = None,
result_json: Optional[str] = None,
error: Optional[str] = None,
) -> Optional[Dict]:
"""Update a job's state.
Only non-None fields are updated.
"""
updates = []
params = []
if status is not None:
updates.append("status = %s")
params.append(status)
if progress is not None:
updates.append("progress = %s")
params.append(progress)
if completed_companies is not None:
updates.append("completed_companies = %s")
params.append(completed_companies)
if result_json is not None:
updates.append("result_json = %s")
params.append(result_json)
if error is not None:
updates.append("error = %s")
params.append(error)
if not updates:
return self.get_job(job_id)
updates.append("updated_at = CURRENT_TIMESTAMP")
params.append(job_id)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
f"UPDATE jobs SET {', '.join(updates)} WHERE job_id = %s RETURNING *",
params,
)
job = cursor.fetchone()
conn.commit()
return dict(job) if job else None
def get_job(self, job_id: str) -> Optional[Dict]:
"""Get a job by ID."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,))
job = cursor.fetchone()
return dict(job) if job else None
def list_jobs(
self,
status: Optional[str] = None,
limit: int = 10,
) -> List[Dict]:
"""List jobs, optionally filtered by status."""
query = "SELECT * FROM jobs"
params: list = []
if status:
query += " WHERE status = %s"
params.append(status)
query += " ORDER BY created_at DESC LIMIT %s"
params.append(limit)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def mark_stale_jobs_failed(self) -> int:
"""Mark any jobs in 'running' or 'pending' state as 'failed'.
Called at startup to clean up jobs that were interrupted by a restart.
Returns:
Number of jobs marked as failed.
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
UPDATE jobs SET status = 'failed', error = 'Interrupted by server restart',
updated_at = CURRENT_TIMESTAMP
WHERE status IN ('running', 'pending')
"""
)
count = cursor.rowcount
conn.commit()
return count
# User Authentication Methods
@staticmethod
+9 -8
View File
@@ -1,9 +1,14 @@
"""LLM integration for patent analysis using OpenRouter."""
import logging
from typing import Dict
from openai import OpenAI
from SPARC import config
from SPARC.database import DatabaseClient
from typing import Dict
logger = logging.getLogger(__name__)
class LLMAnalyzer:
@@ -20,7 +25,7 @@ class LLMAnalyzer:
"""
self.test_mode = test_mode
self.use_cache = use_cache if use_cache is not None else config.use_cache
self.model = "anthropic/claude-3.5-sonnet"
self.model = config.model
# Always initialize database client for storage and caching
self.db_client = DatabaseClient(config.database_url)
@@ -59,11 +64,7 @@ Patent Content:
Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage."""
if self.test_mode:
print("=" * 80)
print("TEST MODE - Prompt that would be sent to LLM:")
print("=" * 80)
print(prompt)
print("=" * 80)
logger.debug("TEST MODE - Prompt that would be sent to LLM:\n%s", prompt)
return "[TEST MODE - No API call made]"
# Check cache first
@@ -165,7 +166,7 @@ Patent Portfolio:
Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook."""
if self.test_mode:
print(prompt)
logger.debug("TEST MODE - Portfolio prompt:\n%s", prompt)
return "[TEST MODE]"
metadata = {
+1 -1
View File
@@ -4,7 +4,7 @@ from datetime import datetime
@dataclass
class Patent:
patent_id: int
patent_id: str
pdf_link: str
pdf_path: str | None = None
summary: dict | None = None
+8 -6
View File
@@ -3,15 +3,15 @@ services:
image: postgres:16-alpine
container_name: sparc-postgres
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: sparc
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_DB: ${POSTGRES_DB}
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER}"]
interval: 5s
timeout: 5s
retries: 5
@@ -22,7 +22,7 @@ services:
container_name: sparc-init-db
command: python scripts/init_database.py
environment:
DATABASE_URL: postgresql://postgres:postgres@postgres:5432/sparc
DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/${POSTGRES_DB}
depends_on:
postgres:
condition: service_healthy
@@ -35,9 +35,11 @@ services:
environment:
API_KEY: ${API_KEY}
OPENROUTER_API_KEY: ${OPENROUTER_API_KEY}
DATABASE_URL: postgresql://postgres:postgres@postgres:5432/sparc
DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/${POSTGRES_DB}
USE_CACHE: "true"
JWT_SECRET: ${JWT_SECRET:-sparc-secret-key-change-in-production}
CORS_ORIGINS: ${CORS_ORIGINS:-}
APP_ENV: ${APP_ENV:-development}
ROOT_PATH: /api
ports:
- "8000:8000"
+4728
View File
File diff suppressed because it is too large Load Diff
+1
View File
@@ -14,3 +14,4 @@ numpy
pandas
bcrypt
PyJWT
slowapi
+3
View File
@@ -40,6 +40,9 @@ def main():
print("\nTables created:")
print(" - llm_messages: Stores all LLM prompts and responses")
print(" - users: Stores user accounts")
print(" - jobs: Stores async batch job state")
print(" - patents: Patent PDF cache")
print(" - serp_queries: SERP query result cache")
print("\nIndexes created:")
print(" - idx_messages_timestamp: For time-based queries")
print(" - idx_messages_company: For company-specific queries")
+1 -1
View File
@@ -5,7 +5,7 @@ from datetime import datetime
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
from SPARC.api import app, _analyzer, _jobs
from SPARC.api import app
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult
+302
View File
@@ -0,0 +1,302 @@
"""Tests for JWT authentication flow: register, login, protected routes, refresh, admin access."""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token, create_refresh_token
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db(monkeypatch):
"""Mock the database client used by auth endpoints.
Returns a MagicMock with all DB methods pre-configured.
"""
db = MagicMock()
# Default: no users exist
db.get_user_count.return_value = 0
db.get_user_by_id.return_value = None
db.get_user_by_email.return_value = None
db.authenticate_user.return_value = None
db.create_user.return_value = None
db.get_all_users.return_value = []
db.update_user_role.return_value = None
db.delete_user.return_value = False
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _make_admin_user():
return {
"id": 1,
"email": "admin@test.com",
"role": "admin",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
def _make_regular_user():
return {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
def _auth_header(user_dict):
"""Create an Authorization header with a valid access token for the given user."""
token = create_access_token(user_dict["id"], user_dict["email"], user_dict["role"])
return {"Authorization": f"Bearer {token}"}
class TestRegister:
"""POST /auth/register"""
def test_register_first_user_becomes_admin(self, client, mock_db):
"""First registered user should get admin role."""
mock_db.get_user_count.return_value = 0
mock_db.create_user.return_value = {
"id": 1,
"email": "admin@test.com",
"role": "admin",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
response = client.post(
"/auth/register",
json={"email": "admin@test.com", "password": "securepass123"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == "admin@test.com"
assert data["role"] == "admin"
mock_db.create_user.assert_called_once_with(
email="admin@test.com", password="securepass123", role="admin"
)
def test_register_subsequent_user_gets_user_role(self, client, mock_db):
"""Non-first user should get regular user role."""
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = _make_regular_user()
response = client.post(
"/auth/register",
json={"email": "user@test.com", "password": "securepass123"},
)
assert response.status_code == 200
data = response.json()
assert data["role"] == "user"
def test_register_duplicate_email_returns_400(self, client, mock_db):
"""Registering with an existing email should return 400."""
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = None # indicates duplicate
response = client.post(
"/auth/register",
json={"email": "existing@test.com", "password": "securepass123"},
)
assert response.status_code == 400
assert "already registered" in response.json()["detail"].lower()
class TestLogin:
"""POST /auth/login"""
def test_login_valid_credentials_returns_tokens(self, client, mock_db):
"""Valid credentials should return access and refresh tokens."""
user = _make_regular_user()
mock_db.authenticate_user.return_value = user
response = client.post(
"/auth/login",
json={"email": "user@test.com", "password": "correctpassword"},
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
def test_login_invalid_credentials_returns_401(self, client, mock_db):
"""Invalid credentials should return 401."""
mock_db.authenticate_user.return_value = None
response = client.post(
"/auth/login",
json={"email": "user@test.com", "password": "wrongpassword"},
)
assert response.status_code == 401
assert "invalid" in response.json()["detail"].lower()
class TestGetMe:
"""GET /auth/me"""
def test_valid_access_token_returns_user(self, client, mock_db):
"""A valid access token should return the user's data."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = user
response = client.get("/auth/me", headers=_auth_header(user))
assert response.status_code == 200
data = response.json()
assert data["email"] == "user@test.com"
assert data["id"] == 2
def test_missing_token_returns_401(self, client):
"""No token should return 401 (403 from HTTPBearer)."""
response = client.get("/auth/me")
assert response.status_code in (401, 403)
def test_expired_token_returns_401(self, client, mock_db):
"""An expired token should return 401."""
# Create a token that has already expired
from datetime import timedelta
import jwt as pyjwt
from SPARC.auth import JWT_ALGORITHM, JWT_SECRET
payload = {
"sub": "1",
"email": "user@test.com",
"role": "user",
"exp": datetime.now(timezone.utc) - timedelta(hours=1),
"type": "access",
}
expired_token = pyjwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
response = client.get(
"/auth/me", headers={"Authorization": f"Bearer {expired_token}"}
)
assert response.status_code == 401
def test_refresh_token_as_access_returns_401(self, client, mock_db):
"""Using a refresh token as an access token should return 401."""
user = _make_regular_user()
refresh_token = create_refresh_token(user["id"], user["email"], user["role"])
response = client.get(
"/auth/me", headers={"Authorization": f"Bearer {refresh_token}"}
)
assert response.status_code == 401
class TestRefreshToken:
"""POST /auth/refresh"""
def test_valid_refresh_token_returns_new_tokens(self, client, mock_db):
"""A valid refresh token should issue new access and refresh tokens."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = user
refresh = create_refresh_token(user["id"], user["email"], user["role"])
response = client.post(
"/auth/refresh", json={"refresh_token": refresh}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
def test_invalid_refresh_token_returns_401(self, client, mock_db):
"""An invalid refresh token should return 401."""
response = client.post(
"/auth/refresh", json={"refresh_token": "invalid-token-string"}
)
assert response.status_code == 401
def test_access_token_as_refresh_returns_401(self, client, mock_db):
"""Using an access token as a refresh token should return 401."""
user = _make_regular_user()
access = create_access_token(user["id"], user["email"], user["role"])
response = client.post(
"/auth/refresh", json={"refresh_token": access}
)
assert response.status_code == 401
class TestAdminUsers:
"""GET /admin/users and PATCH /admin/users/{id}/role"""
def test_admin_can_list_users(self, client, mock_db):
"""Admin token should allow listing users."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
mock_db.get_all_users.return_value = [admin, _make_regular_user()]
response = client.get("/admin/users", headers=_auth_header(admin))
assert response.status_code == 200
data = response.json()
assert len(data) == 2
def test_regular_user_cannot_list_users(self, client, mock_db):
"""Regular user token should be rejected with 403."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = user
response = client.get("/admin/users", headers=_auth_header(user))
assert response.status_code == 403
def test_no_token_cannot_list_users(self, client):
"""No token should be rejected."""
response = client.get("/admin/users")
assert response.status_code in (401, 403)
def test_admin_can_change_user_role(self, client, mock_db):
"""Admin should be able to change another user's role."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
mock_db.update_user_role.return_value = {
"id": 2,
"email": "user@test.com",
"role": "admin",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
response = client.patch(
"/admin/users/2/role",
json={"role": "admin"},
headers=_auth_header(admin),
)
assert response.status_code == 200
assert response.json()["role"] == "admin"
def test_admin_cannot_change_own_role(self, client, mock_db):
"""Admin should not be able to change their own role."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
response = client.patch(
"/admin/users/1/role",
json={"role": "user"},
headers=_auth_header(admin),
)
assert response.status_code == 400
assert "own role" in response.json()["detail"].lower()
+97
View File
@@ -0,0 +1,97 @@
"""Tests for rate limiting on auth endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from fastapi.testclient import TestClient
from SPARC.api import app
@pytest.fixture
def client():
"""Create test client with rate limiter enabled."""
return TestClient(app)
@pytest.fixture(autouse=True)
def reset_limiter():
"""Reset rate limiter storage between tests."""
from SPARC.api import limiter
limiter.reset()
yield
class TestRateLimiting:
"""Test rate limiting on login and register endpoints."""
@patch("SPARC.api.get_db_client")
def test_login_allows_requests_under_limit(self, mock_db_client, client):
"""Login endpoint allows requests under the rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Should allow at least a few requests
for _ in range(5):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
# 401 is expected (invalid credentials), not 429
assert response.status_code == 401
@patch("SPARC.api.get_db_client")
def test_login_rate_limited_after_threshold(self, mock_db_client, client):
"""Login endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Send more than the limit (10/minute)
statuses = []
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_register_rate_limited_after_threshold(self, mock_db_client, client):
"""Register endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = None # triggers 400 (email exists)
mock_db_client.return_value = mock_db
# Send more than the limit (5/minute)
statuses = []
for _ in range(10):
response = client.post(
"/auth/register",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_rate_limit_returns_retry_after_header(self, mock_db_client, client):
"""Rate limited responses include a Retry-After header."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Exhaust the limit
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
if response.status_code == 429:
assert "Retry-After" in response.headers
break
+116
View File
@@ -0,0 +1,116 @@
"""Tests for security hardening: JWT secret startup check, CORS config, credential handling."""
import os
from unittest.mock import patch
import pytest
class TestJWTSecretStartupCheck:
"""Test the startup guard that refuses default JWT secret in non-dev environments."""
def test_default_secret_in_production_raises(self):
"""Starting with default secret and APP_ENV=production must raise RuntimeError."""
with patch.dict(os.environ, {"APP_ENV": "production"}):
# Reload config to pick up the new APP_ENV
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
# Patch JWT_SECRET to the default
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
with pytest.raises(RuntimeError, match="FATAL.*JWT_SECRET"):
check_jwt_secret()
# Restore config
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
def test_default_secret_in_development_succeeds(self):
"""Starting with default secret and APP_ENV=development must not raise."""
with patch.dict(os.environ, {"APP_ENV": "development"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
# Should not raise
check_jwt_secret()
# Restore
importlib.reload(SPARC.config)
def test_custom_secret_in_production_succeeds(self):
"""Starting with a custom secret in production must not raise."""
with patch.dict(os.environ, {"APP_ENV": "production"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", "my-secure-random-secret-abc123"):
# Should not raise
check_jwt_secret()
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
def test_default_secret_unset_env_succeeds(self):
"""When APP_ENV is unset (defaults to development), default secret is allowed."""
with patch.dict(os.environ, {}, clear=False):
# Remove APP_ENV if present
env = os.environ.copy()
env.pop("APP_ENV", None)
with patch.dict(os.environ, env, clear=True):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
# Should not raise (defaults to development)
check_jwt_secret()
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
class TestCORSConfig:
"""Test that CORS origins are configurable via environment variable."""
def test_default_cors_origins(self):
"""When CORS_ORIGINS is unset, defaults to localhost origins."""
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == [
"http://localhost:3000",
"http://localhost:5173",
]
def test_custom_cors_origins(self):
"""Setting CORS_ORIGINS configures allowed origins."""
with patch.dict(os.environ, {"CORS_ORIGINS": "https://sparc.example.com,https://app.example.com"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == [
"https://sparc.example.com",
"https://app.example.com",
]
# Restore
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
importlib.reload(SPARC.config)
def test_single_cors_origin(self):
"""A single origin without comma works correctly."""
with patch.dict(os.environ, {"CORS_ORIGINS": "https://sparc.example.com"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == ["https://sparc.example.com"]
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
importlib.reload(SPARC.config)