Compare commits

..

18 Commits

Author SHA1 Message Date
agent-company 2e6b8c7445 feat: add webhook notification support for job completion and alerts
Send HTTP POST notifications to configured webhook URLs when batch
jobs complete or when scheduled analysis detects significant changes.

- Add SPARC/webhooks.py with retry logic (3 attempts, exponential backoff)
- Support generic HTTP POST and Slack-compatible text payloads
- Integrate into batch job completion handler in api.py
- Configure via WEBHOOK_URLS env var (comma-separated)
- Payload includes event type, job ID, status, and summary

Closes leeworks-agents/SPARC#23

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:32:07 +00:00
AI-Manager 55c131cb32 Merge pull request 'ci: add pytest and ruff linting to CI workflow' (#32) from feature/ci-testing-linting into main 2026-03-26 07:04:31 +00:00
agent-company fbb72fe2a5 ci: add pytest and ruff linting to CI, fix all lint errors
- Add test job to build.yaml that runs pytest and ruff before building images
- Add standalone test.yaml workflow for PRs
- Add ruff.toml with E/F/I rules configured
- Fix all ruff lint errors: sort imports, remove unused imports, fix re-exports
- Build jobs now depend on test job passing (needs: test)

Closes leeworks-agents/SPARC#18
Closes leeworks-agents/SPARC#19

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 07:04:00 +00:00
AI-Manager e484baaf5f Merge pull request 'feat: configurable LLM model, SERP cache TTL, structured logging, fix type' (#29) from feature/p2-config-improvements into main 2026-03-26 07:03:08 +00:00
AI-Manager 069f1c343c Merge pull request 'refactor(db): shared pooled DatabaseClient singleton' (#30) from feature/db-client-pooling into main 2026-03-26 07:02:46 +00:00
agent-company d366443b38 refactor(db): use shared pooled DatabaseClient singleton instead of per-call instances
- Replace get_db_client() creating new DatabaseClient on every call with a
  module-level singleton initialized once at startup via init_db_client()
- Add init_db_client() and close_db_client() lifecycle functions called
  from FastAPI lifespan handler
- Migrate all DatabaseClient methods from legacy self.connect()/self.conn
  to pooled self.get_conn() context manager for thread-safe connection reuse
- Pool is properly torn down on application shutdown

Closes leeworks-agents/SPARC#7

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 06:03:56 +00:00
agent-company b000146585 feat: configurable LLM model, SERP cache TTL, structured logging, fix patent_id type
- Make LLM model configurable via MODEL env var, default anthropic/claude-3.5-sonnet (#12)
- Expose SERP cache TTL as SERP_CACHE_TTL_HOURS env var, default 24 hours (#13)
- Fix Patent.patent_id type annotation from int to str in types.py (#14)
- Replace all print() calls with structured logging in analyzer.py and llm.py (#11)
- Add LOG_LEVEL config with basicConfig setup in config.py
- Add model and serp_cache_ttl_hours to config.py

Closes leeworks-agents/SPARC#11
Closes leeworks-agents/SPARC#12
Closes leeworks-agents/SPARC#13
Closes leeworks-agents/SPARC#14

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 06:03:25 +00:00
AI-Manager 35d105b14e Merge pull request 'feat(auth): add rate limiting to login and register endpoints' (#28) from feature/rate-limiting into main 2026-03-26 05:04:46 +00:00
AI-Manager 6fcf170d93 Merge pull request 'feat(jobs): persist async batch job state in PostgreSQL' (#34) from feature/persist-job-state into main 2026-03-26 05:04:26 +00:00
AI-Manager 5a42e216ba Merge pull request 'docs: patent PDF storage docs, FileNotFoundError, frontend lockfile' (#31) from feature/p2-docs-and-lockfile into main 2026-03-26 05:04:01 +00:00
AI-Manager 24ab341d9b Merge pull request 'test(auth): add comprehensive JWT authentication test suite' (#35) from feature/jwt-auth-tests into main 2026-03-26 05:03:29 +00:00
AI-Manager 878fedfbb8 Merge pull request 'feat(security): JWT startup guard, configurable CORS, externalize DB creds' (#27) from feature/p1-security-hardening into main 2026-03-26 05:03:16 +00:00
agent-company ae9f257dcb test(auth): add comprehensive JWT authentication test suite
Add 17 tests in tests/test_auth.py covering all auth flows:
- Registration: first user admin, subsequent user, duplicate email
- Login: valid credentials, invalid credentials
- Protected routes: valid token, missing token, expired token, wrong token type
- Token refresh: valid refresh, invalid refresh, access-as-refresh rejected
- Admin endpoints: list users, change role, own-role prevention, permission checks

All tests use mocked database (no live DB required).

Closes leeworks-agents/SPARC#10

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:24:12 +00:00
agent-company 96d5d27b17 feat(jobs): persist async batch job state in PostgreSQL
- Add jobs table to database schema (job_id, status, progress, result_json, etc.)
- Add DatabaseClient methods: create_job, update_job, get_job, list_jobs
- Add mark_stale_jobs_failed() called at startup to handle interrupted jobs
- Refactor _run_batch_job and job endpoints to read/write from PostgreSQL
- Remove in-memory _jobs dict; job state now survives API restarts
- Update init_database.py to list all tables in output

Closes leeworks-agents/SPARC#8

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:22:57 +00:00
agent-company 3dac88ec90 docs: document patent PDF storage, add FileNotFoundError, commit lockfile
- Add docstring to analyze_single_patent explaining the PDF prerequisite
- Raise FileNotFoundError with helpful message when PDF is missing
- Add patent PDF storage section to README with Docker volume mount example
- Commit frontend/package-lock.json for reproducible builds

Closes leeworks-agents/SPARC#15
Closes leeworks-agents/SPARC#17

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:17:09 +00:00
agent-company e2d750146c feat(auth): add rate limiting to login and register endpoints
- Add slowapi rate limiter: 10 req/min for /auth/login, 5 req/min for /auth/register
- Return HTTP 429 with Retry-After header when limit is exceeded
- Add slowapi to requirements.txt
- Add 4 passing tests for rate limit behavior

Closes leeworks-agents/SPARC#9

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:08:22 +00:00
agent-company 47cddcbeaf feat(security): add JWT startup guard, configurable CORS, and externalize DB credentials
- Add check_jwt_secret() that refuses default JWT secret when APP_ENV != development
- Make CORS origins configurable via CORS_ORIGINS env var (comma-separated)
- Replace hardcoded postgres credentials in docker-compose.yml with env var references
- Add APP_ENV and cors_origins to config.py
- Update .env.example with all required variables and documentation
- Add tests for JWT startup guard and CORS configuration

Closes leeworks-agents/SPARC#4
Closes leeworks-agents/SPARC#5
Closes leeworks-agents/SPARC#6

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 04:06:31 +00:00
AI-Manager 6105ba7793 Merge pull request 'chore: add ROADMAP.md for SPARC application development' (#3) from chore/add-roadmap into main 2026-03-26 02:47:54 +00:00
26 changed files with 6113 additions and 283 deletions
+36 -9
View File
@@ -1,21 +1,48 @@
# SPARC Configuration
# ---- Application Environment ----
# Set to "production" or "staging" in deployed environments.
# The API will refuse to start with the default JWT secret unless APP_ENV=development.
APP_ENV=development
# ---- API Keys ----
# SerpAPI key for patent search
API_KEY=your_serpapi_key_here
# OpenRouter API key for LLM analysis
OPENROUTER_API_KEY=your_openrouter_key_here
# Database configuration
# All messages are stored in the database for persistence and caching
DATABASE_URL=postgresql://postgres:postgres@localhost:5432/sparc
# ---- Database ----
# Cache configuration
# When USE_CACHE=true: check database for cached responses before making API calls
# When USE_CACHE=false: always make fresh API calls (still stores results in database)
# Default: true
USE_CACHE=true
# PostgreSQL credentials (used by docker-compose)
POSTGRES_USER=postgres
POSTGRES_PASSWORD=change-me-to-a-secure-password
POSTGRES_DB=sparc
# JWT Secret for authentication
# Full database URL (must match the credentials above)
DATABASE_URL=postgresql://postgres:change-me-to-a-secure-password@localhost:5432/sparc
# ---- Authentication ----
# JWT Secret for signing tokens
# IMPORTANT: Change this to a secure random string in production
JWT_SECRET=your-secure-jwt-secret-change-in-production
# ---- CORS ----
# Comma-separated list of allowed origins for CORS
# Defaults to http://localhost:3000,http://localhost:5173 when unset
# CORS_ORIGINS=https://sparc.example.com,https://app.example.com
# ---- Cache ----
# When USE_CACHE=true: check database for cached responses before making API calls
# When USE_CACHE=false: always make fresh API calls (still stores results in database)
USE_CACHE=true
# ---- Webhooks ----
# Comma-separated list of webhook URLs for job completion and alert notifications
# Supports generic HTTP POST and Slack/Discord incoming webhooks
# WEBHOOK_URLS=https://hooks.slack.com/services/XXX,https://example.com/webhook
+37
View File
@@ -9,7 +9,43 @@ on:
workflow_dispatch:
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Install system dependencies
shell: sh
run: |
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
- name: Checkout code
shell: sh
run: |
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
git checkout ${{ gitea.sha }}
- name: Install Python dependencies
shell: sh
run: |
pip3 install --break-system-packages -r requirements.txt ruff
- name: Run ruff linter
shell: sh
run: |
ruff check SPARC/ tests/
- name: Run pytest
shell: sh
env:
DATABASE_URL: "sqlite://"
API_KEY: "test-key"
OPENROUTER_API_KEY: "test-key"
JWT_SECRET: "test-secret-for-ci"
APP_ENV: "development"
run: |
python3 -m pytest tests/ -v --tb=short -x
build-api:
needs: test
runs-on: ubuntu-latest
steps:
- name: Install dependencies
@@ -81,6 +117,7 @@ jobs:
echo "API image available at ${{ steps.tags.outputs.IMAGE_TAG }}"
build-frontend:
needs: test
runs-on: ubuntu-latest
steps:
- name: Install dependencies
+46
View File
@@ -0,0 +1,46 @@
name: Test and Lint
on:
push:
branches:
- main
pull_request:
branches:
- main
workflow_dispatch:
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Install system dependencies
shell: sh
run: |
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
- name: Checkout code
shell: sh
run: |
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
git checkout ${{ gitea.sha }}
- name: Install Python dependencies
shell: sh
run: |
pip3 install --break-system-packages -r requirements.txt ruff
- name: Run ruff linter
shell: sh
run: |
ruff check SPARC/ tests/
- name: Run pytest
shell: sh
env:
DATABASE_URL: "sqlite://"
API_KEY: "test-key"
OPENROUTER_API_KEY: "test-key"
JWT_SECRET: "test-secret-for-ci"
APP_ENV: "development"
run: |
python3 -m pytest tests/ -v --tb=short -x
+15
View File
@@ -54,6 +54,21 @@ docker-compose up -d
# - API Docs: http://localhost:8000/docs
```
#### Patent PDF Storage
The API stores downloaded patent PDFs in a `patents/` directory. In Docker,
this is mounted as a bind mount (`./patents:/app/patents`) so that PDFs persist
across container restarts.
If you deploy to a different environment, ensure the `patents/` directory is a
persistent volume. Without it, PDFs will be re-downloaded on every analysis.
```yaml
# docker-compose.yml excerpt
volumes:
- ./patents:/app/patents
```
### NixOS
```bash
+3 -2
View File
@@ -1,3 +1,4 @@
from .types import Patents, Patent
from .types import Patent as Patent
from .types import Patents as Patents
all = ["Patents", "Patent"]
__all__ = ["Patents", "Patent"]
+42 -23
View File
@@ -5,14 +5,17 @@ to provide company performance estimation based on patent portfolios.
"""
import hashlib
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable
from SPARC import config
logger = logging.getLogger(__name__)
from SPARC.database import DatabaseClient
from SPARC.serp_api import SERP
from SPARC.llm import LLMAnalyzer
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult
from SPARC.serp_api import SERP
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult, Patent, Patents
class CompanyAnalyzer:
@@ -52,13 +55,13 @@ class CompanyAnalyzer:
query_hash = hashlib.sha256(company_name.lower().encode()).hexdigest()
cached_ids = self.db.get_cached_serp_query(query_hash)
if cached_ids is not None:
print(f"Using cached SERP results for {company_name} ({len(cached_ids)} patents)")
logger.info("Using cached SERP results for %s (%d patents)", company_name, len(cached_ids))
patents = Patents(patents=[
Patent(patent_id=pid, pdf_link="")
for pid in cached_ids
])
else:
print(f"Retrieving patents for {company_name}...")
logger.info("Retrieving patents for %s...", company_name)
patents = SERP.query(company_name)
# Cache the SERP results
if patents.patents:
@@ -66,12 +69,13 @@ class CompanyAnalyzer:
company_name=company_name,
query_hash=query_hash,
patent_ids=[p.patent_id for p in patents.patents],
ttl_hours=config.serp_cache_ttl_hours,
)
if not patents.patents:
return f"No patents found for {company_name}"
print(f"Found {len(patents.patents)} patents. Processing...")
logger.info("Found %d patents. Processing...", len(patents.patents))
# Download, parse, and minimize patents in parallel
processed_patents = []
@@ -87,12 +91,12 @@ class CompanyAnalyzer:
if result:
processed_patents.append(result)
except Exception as e:
print(f"Warning: Failed to process {patent.patent_id}: {e}")
logger.warning("Failed to process %s: %s", patent.patent_id, e)
if not processed_patents:
return f"Failed to process any patents for {company_name}"
print(f"Analyzing portfolio with LLM...")
logger.info("Analyzing portfolio with LLM...")
# Analyze the full portfolio with LLM
analysis = self.llm_analyzer.analyze_patent_portfolio(
@@ -104,21 +108,34 @@ class CompanyAnalyzer:
def analyze_single_patent(self, patent_id: str, company_name: str) -> str:
"""Analyze a single patent by ID.
Useful for focused analysis of specific innovations.
Prerequisite:
The patent PDF must already exist at ``patents/{patent_id}.pdf``
before calling this method. PDFs are downloaded automatically when
using the batch analysis pipeline (``analyze_company`` or the
``/analyze/batch`` API endpoint). For standalone usage, download
the PDF manually or call ``SERP.save_patents()`` first.
Args:
patent_id: Publication ID of the patent
patent_id: Publication ID of the patent (e.g. "US-11234567-B2")
company_name: Name of the company (for context)
Returns:
Analysis of the specific patent's innovation quality
Raises:
FileNotFoundError: If the patent PDF is not found at the expected path.
"""
# Note: This simplified version assumes the patent PDF is already downloaded
# A more complete implementation would support direct patent ID lookup
print(f"Analyzing patent {patent_id} for {company_name}...")
import os
logger.info("Analyzing patent %s for %s...", patent_id, company_name)
patent_path = f"patents/{patent_id}.pdf"
if not os.path.exists(patent_path):
raise FileNotFoundError(
f"Patent PDF not found at '{patent_path}'. "
f"Download the PDF first using SERP.save_patents() or the batch analysis pipeline."
)
try:
sections = SERP.parse_patent_pdf(patent_path)
minimized_content = SERP.minimize_patent_for_llm(sections)
@@ -129,6 +146,8 @@ class CompanyAnalyzer:
return analysis
except FileNotFoundError:
raise
except Exception as e:
return f"Failed to analyze patent {patent_id}: {e}"
@@ -169,7 +188,7 @@ class CompanyAnalyzer:
return {"patent_id": patent.patent_id, "content": minimized_content}
except Exception as e:
print(f"Warning: Failed to process {patent.patent_id}: {e}")
logger.warning("Failed to process %s: %s", patent.patent_id, e)
return None
def _analyze_company_safe(self, company_name: str) -> CompanyAnalysisResult:
@@ -240,7 +259,7 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = []
total = len(companies)
print(f"Starting batch analysis of {total} companies...")
logger.info("Starting batch analysis of %d companies...", total)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_company = {
@@ -257,8 +276,8 @@ class CompanyAnalyzer:
result = future.result()
results.append(result)
status = "" if result.success else ""
print(f"[{completed}/{total}] {status} {company}")
status = "OK" if result.success else "FAIL"
logger.info("[%d/%d] %s %s", completed, total, status, company)
if progress_callback:
progress_callback(company, completed, total)
@@ -273,12 +292,12 @@ class CompanyAnalyzer:
error=str(e),
)
)
print(f"[{completed}/{total}] ✗ {company}: {e}")
logger.error("[%d/%d] FAIL %s: %s", completed, total, company, e)
successful = sum(1 for r in results if r.success)
failed = total - successful
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
return BatchAnalysisResult(
results=results,
@@ -304,20 +323,20 @@ class CompanyAnalyzer:
results: list[CompanyAnalysisResult] = []
total = len(companies)
print(f"Starting sequential analysis of {total} companies...")
logger.info("Starting sequential analysis of %d companies...", total)
for idx, company in enumerate(companies, 1):
print(f"\n[{idx}/{total}] Analyzing {company}...")
logger.info("[%d/%d] Analyzing %s...", idx, total, company)
result = self._analyze_company_safe(company)
results.append(result)
status = "" if result.success else ""
print(f"[{idx}/{total}] {status} {company}")
status = "OK" if result.success else "FAIL"
logger.info("[%d/%d] %s %s", idx, total, status, company)
successful = sum(1 for r in results if r.success)
failed = total - successful
print(f"\nBatch complete: {successful} succeeded, {failed} failed")
logger.info("Batch complete: %d succeeded, %d failed", successful, failed)
return BatchAnalysisResult(
results=results,
+122 -41
View File
@@ -7,20 +7,27 @@ from contextlib import asynccontextmanager
from datetime import datetime
from typing import Annotated, List
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, EmailStr, Field
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from SPARC import config
from SPARC.analyzer import CompanyAnalyzer
from SPARC.auth import (
TokenResponse,
UserResponse,
check_jwt_secret,
close_db_client,
create_tokens,
decode_token,
get_current_admin,
get_current_user,
get_db_client,
init_db_client,
)
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@@ -114,8 +121,7 @@ class AnalyticsResponse(BaseModel):
period_days: int
# In-memory job storage (for demo; production would use Redis/DB)
_jobs: dict[str, JobStatus] = {}
# Job counter for generating unique IDs (the actual state is in PostgreSQL)
_job_counter = 0
@@ -148,12 +154,25 @@ _analyzer: CompanyAnalyzer | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize resources on startup."""
"""Initialize resources on startup, clean up on shutdown."""
global _analyzer
check_jwt_secret()
init_db_client()
_analyzer = CompanyAnalyzer()
# Mark any jobs that were running/pending before the restart as failed
from SPARC.database import DatabaseClient
_db = DatabaseClient(config.database_url)
_db.connect()
_db.initialize_schema()
stale = _db.mark_stale_jobs_failed()
if stale:
import logging
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
_db.close()
yield
# Cleanup if needed
# Cleanup
_analyzer = None
close_db_client()
app = FastAPI(
@@ -164,10 +183,26 @@ app = FastAPI(
root_path=config.root_path,
)
# Rate limiter (in-memory storage, suitable for single-instance deployments)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return 429 with Retry-After header when rate limit is exceeded."""
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded. Please try again later."},
headers={"Retry-After": str(retry_after)},
)
# Add CORS middleware for React frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:5173"],
allow_origins=config.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
@@ -178,7 +213,8 @@ app.add_middleware(
@app.post("/auth/register", response_model=UserResponse, tags=["Auth"])
async def register(request: RegisterRequest):
@limiter.limit("5/minute")
async def register(request: Request, body: RegisterRequest):
"""Register a new user.
The first registered user automatically becomes an admin.
@@ -190,8 +226,8 @@ async def register(request: RegisterRequest):
role = "admin" if user_count == 0 else "user"
user = db.create_user(
email=request.email,
password=request.password,
email=body.email,
password=body.password,
role=role,
)
@@ -210,11 +246,12 @@ async def register(request: RegisterRequest):
@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"])
async def login(request: LoginRequest):
@limiter.limit("10/minute")
async def login(request: Request, body: LoginRequest):
"""Authenticate user and return JWT tokens."""
db = get_db_client()
user = db.authenticate_user(request.email, request.password)
user = db.authenticate_user(body.email, body.password)
if not user:
raise HTTPException(
@@ -422,20 +459,52 @@ async def analyze_companies_batch(
return _convert_batch_result(result)
def _get_job_db() -> "DatabaseClient":
"""Get a DatabaseClient for job persistence."""
from SPARC.database import DatabaseClient
db = DatabaseClient(config.database_url)
return db
def _job_row_to_status(row: dict) -> JobStatus:
"""Convert a database job row to a JobStatus model."""
import json as _json
result = None
if row.get("result_json"):
result_data = row["result_json"]
if isinstance(result_data, str):
result_data = _json.loads(result_data)
result = BatchAnalysisResponse(**result_data)
return JobStatus(
job_id=row["job_id"],
status=row["status"],
progress=row["progress"],
total_companies=row["total_companies"],
completed_companies=row["completed_companies"],
result=result,
error=row.get("error"),
)
def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
"""Background task for batch analysis."""
global _jobs, _analyzer
import json as _json
global _analyzer
db = _get_job_db()
if not _analyzer:
_jobs[job_id].status = "failed"
_jobs[job_id].error = "Analyzer not initialized"
db.update_job(job_id, status="failed", error="Analyzer not initialized")
return
_jobs[job_id].status = "running"
db.update_job(job_id, status="running")
def progress_callback(company: str, completed: int, total: int):
_jobs[job_id].completed_companies = completed
_jobs[job_id].progress = int((completed / total) * 100)
db.update_job(
job_id,
completed_companies=completed,
progress=int((completed / total) * 100),
)
try:
result = _analyzer.analyze_companies(
@@ -443,12 +512,32 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int):
max_workers=max_workers,
progress_callback=progress_callback,
)
_jobs[job_id].status = "completed"
_jobs[job_id].progress = 100
_jobs[job_id].result = _convert_batch_result(result)
batch_response = _convert_batch_result(result)
db.update_job(
job_id,
status="completed",
progress=100,
result_json=_json.dumps(batch_response.model_dump(), default=str),
)
# Fire webhook notification
from SPARC.webhooks import notify_job_completed
notify_job_completed(
job_id=job_id,
status="completed",
total_companies=result.total_companies,
successful=result.successful,
failed=result.failed,
)
except Exception as e:
_jobs[job_id].status = "failed"
_jobs[job_id].error = str(e)
db.update_job(job_id, status="failed", error=str(e))
from SPARC.webhooks import notify_job_completed
notify_job_completed(
job_id=job_id,
status="failed",
total_companies=len(companies),
successful=0,
failed=len(companies),
)
@app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"])
@@ -473,19 +562,14 @@ async def analyze_companies_async(
_job_counter += 1
job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
_jobs[job_id] = JobStatus(
job_id=job_id,
status="pending",
progress=0,
total_companies=len(request.companies),
completed_companies=0,
)
db = _get_job_db()
job_row = db.create_job(job_id=job_id, total_companies=len(request.companies))
background_tasks.add_task(
_run_batch_job, job_id, request.companies, request.max_workers
)
return _jobs[job_id]
return _job_row_to_status(job_row)
@app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"])
@@ -501,10 +585,13 @@ async def get_job_status(
Returns:
Current job status including progress and results when complete
"""
if job_id not in _jobs:
db = _get_job_db()
job_row = db.get_job(job_id)
if not job_row:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
return _jobs[job_id]
return _job_row_to_status(job_row)
@app.get("/jobs", response_model=list[JobStatus], tags=["Jobs"])
@@ -525,12 +612,6 @@ async def list_jobs(
Returns:
List of job statuses
"""
jobs = list(_jobs.values())
if status:
jobs = [j for j in jobs if j.status == status]
# Return most recent first
jobs.sort(key=lambda j: j.job_id, reverse=True)
return jobs[:limit]
db = _get_job_db()
job_rows = db.list_jobs(status=status, limit=limit)
return [_job_row_to_status(row) for row in job_rows]
+44 -5
View File
@@ -13,11 +13,25 @@ from SPARC import config
from SPARC.database import DatabaseClient
# JWT Configuration
JWT_SECRET = os.getenv("JWT_SECRET", "sparc-secret-key-change-in-production")
_DEFAULT_JWT_SECRET = "sparc-secret-key-change-in-production"
JWT_SECRET = os.getenv("JWT_SECRET", _DEFAULT_JWT_SECRET)
JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
def check_jwt_secret() -> None:
"""Refuse to start with the default JWT secret in non-development environments.
Raises:
RuntimeError: If JWT_SECRET is the default value and APP_ENV is not 'development'.
"""
if JWT_SECRET == _DEFAULT_JWT_SECRET and config.app_env != "development":
raise RuntimeError(
f"FATAL: JWT_SECRET is set to the default value and APP_ENV={config.app_env!r}. "
"Set a secure JWT_SECRET environment variable before running in non-development environments."
)
security = HTTPBearer()
@@ -132,11 +146,36 @@ def decode_token(token: str) -> Optional[TokenPayload]:
return None
# Shared database client singleton, initialized at startup via init_db_client()
_db_client: DatabaseClient | None = None
def init_db_client() -> None:
"""Initialize the shared database client. Call once at app startup."""
global _db_client
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
def close_db_client() -> None:
"""Close the shared database client. Call at app shutdown."""
global _db_client
if _db_client:
_db_client.close()
_db_client = None
def get_db_client() -> DatabaseClient:
"""Get database client for auth operations."""
client = DatabaseClient(config.database_url)
client.connect()
return client
"""Get the shared pooled database client for auth operations.
Returns the module-level singleton DatabaseClient. If not yet initialized
(e.g., during tests), creates a new instance as a fallback.
"""
global _db_client
if _db_client is None:
_db_client = DatabaseClient(config.database_url)
_db_client.connect()
return _db_client
async def get_current_user(
+29 -1
View File
@@ -2,11 +2,20 @@
Loads environment variables from .env file for API keys and other secrets.
"""
from dotenv import load_dotenv
import logging
import os
from dotenv import load_dotenv
load_dotenv()
# Logging configuration
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=getattr(logging, log_level, logging.INFO),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
# SerpAPI key for patent search
api_key = os.getenv("API_KEY")
@@ -30,6 +39,25 @@ use_database = os.getenv("USE_DATABASE", "false").lower() in ("true", "1", "yes"
patent_search_days = int(os.getenv("PATENT_SEARCH_DAYS", "90"))
patent_thread_workers = int(os.getenv("PATENT_THREAD_WORKERS", "5"))
# LLM model to use via OpenRouter (e.g. "anthropic/claude-3.5-sonnet", "openai/gpt-4o")
model = os.getenv("MODEL", "anthropic/claude-3.5-sonnet")
# SERP cache TTL in hours (how long cached search results are considered fresh)
serp_cache_ttl_hours = int(os.getenv("SERP_CACHE_TTL_HOURS", "24"))
# Root path for running behind a reverse proxy (e.g., "/api" when served at /api/)
# This ensures OpenAPI docs work correctly when accessed via the proxy
root_path = os.getenv("ROOT_PATH", "")
# Application environment: "development", "staging", or "production"
# Used for safety checks (e.g., refusing default JWT secret in production)
app_env = os.getenv("APP_ENV", "development")
# CORS allowed origins (comma-separated)
# Defaults to localhost dev origins when unset
_cors_origins_raw = os.getenv("CORS_ORIGINS", "")
cors_origins: list[str] = (
[o.strip() for o in _cors_origins_raw.split(",") if o.strip()]
if _cors_origins_raw
else ["http://localhost:3000", "http://localhost:5173"]
)
+304 -171
View File
@@ -1,14 +1,15 @@
"""Database client for storing and retrieving LLM messages and user authentication."""
import contextlib
import psycopg2
from psycopg2.pool import ThreadedConnectionPool
from psycopg2.extras import RealDictCursor
from typing import Dict, List, Optional
from datetime import datetime, timedelta
import json
import hashlib
import json
from datetime import datetime, timedelta
from typing import Dict, List, Optional
import bcrypt
import psycopg2
from psycopg2.extras import RealDictCursor
from psycopg2.pool import ThreadedConnectionPool
class DatabaseClient:
@@ -171,6 +172,26 @@ class DatabaseClient:
ON serp_queries(query_hash)
""")
# Create jobs table for persisting async batch job state
cursor.execute("""
CREATE TABLE IF NOT EXISTS jobs (
job_id VARCHAR(128) PRIMARY KEY,
status VARCHAR(20) NOT NULL DEFAULT 'pending',
progress INTEGER NOT NULL DEFAULT 0,
total_companies INTEGER NOT NULL DEFAULT 0,
completed_companies INTEGER NOT NULL DEFAULT 0,
result_json JSONB,
error TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_jobs_status
ON jobs(status)
""")
self.conn.commit()
@staticmethod
@@ -201,8 +222,6 @@ class DatabaseClient:
Returns:
Cached message dict if found, None otherwise
"""
self.connect()
prompt_hash = self.hash_prompt(prompt)
query = """
@@ -225,10 +244,11 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT 1"
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
result = cursor.fetchone()
return dict(result) if result else None
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
result = cursor.fetchone()
return dict(result) if result else None
def store_message(
self,
@@ -256,33 +276,32 @@ class DatabaseClient:
Returns:
The ID of the inserted record
"""
self.connect()
prompt_hash = self.hash_prompt(prompt)
with self.conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO llm_messages
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id
""",
(
prompt,
prompt_hash,
response,
company_name,
analysis_type,
model,
json.dumps(metadata) if metadata else None,
json.dumps(token_usage) if token_usage else None,
is_cached,
),
)
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO llm_messages
(prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id
""",
(
prompt,
prompt_hash,
response,
company_name,
analysis_type,
model,
json.dumps(metadata) if metadata else None,
json.dumps(token_usage) if token_usage else None,
is_cached,
),
)
message_id = cursor.fetchone()[0]
self.conn.commit()
message_id = cursor.fetchone()[0]
conn.commit()
return message_id
@@ -304,8 +323,6 @@ class DatabaseClient:
Returns:
List of message dictionaries
"""
self.connect()
query = "SELECT * FROM llm_messages WHERE 1=1"
params = []
@@ -320,9 +337,10 @@ class DatabaseClient:
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
params.extend([limit, offset])
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def get_analytics(self, days: int = 30) -> Dict:
"""Get analytics on message usage.
@@ -333,53 +351,52 @@ class DatabaseClient:
Returns:
Dictionary with analytics data
"""
self.connect()
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages
cursor.execute(
"""
SELECT COUNT(*) as total_messages
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
""",
(days,),
)
total = cursor.fetchone()["total_messages"]
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Total messages
cursor.execute(
"""
SELECT COUNT(*) as total_messages
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
""",
(days,),
)
total = cursor.fetchone()["total_messages"]
# Messages by company
cursor.execute(
"""
SELECT company_name, COUNT(*) as count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
GROUP BY company_name
ORDER BY count DESC
LIMIT 10
""",
(days,),
)
by_company = cursor.fetchall()
# Messages by company
cursor.execute(
"""
SELECT company_name, COUNT(*) as count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
GROUP BY company_name
ORDER BY count DESC
LIMIT 10
""",
(days,),
)
by_company = cursor.fetchall()
# Messages by type
cursor.execute(
"""
SELECT analysis_type, COUNT(*) as count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
GROUP BY analysis_type
ORDER BY count DESC
""",
(days,),
)
by_type = cursor.fetchall()
# Messages by type
cursor.execute(
"""
SELECT analysis_type, COUNT(*) as count
FROM llm_messages
WHERE timestamp >= NOW() - INTERVAL '%s days'
GROUP BY analysis_type
ORDER BY count DESC
""",
(days,),
)
by_type = cursor.fetchall()
return {
"total_messages": total,
"by_company": [dict(row) for row in by_company],
"by_type": [dict(row) for row in by_type],
"period_days": days,
}
return {
"total_messages": total,
"by_company": [dict(row) for row in by_company],
"by_type": [dict(row) for row in by_type],
"period_days": days,
}
# Patent Cache Methods
@@ -462,6 +479,131 @@ class DatabaseClient:
)
conn.commit()
# Job Persistence Methods
def create_job(
self,
job_id: str,
total_companies: int,
) -> Dict:
"""Create a new job record.
Args:
job_id: Unique job identifier
total_companies: Number of companies in the batch
Returns:
Job dict
"""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies)
VALUES (%s, 'pending', 0, %s, 0)
RETURNING *
""",
(job_id, total_companies),
)
job = cursor.fetchone()
conn.commit()
return dict(job)
def update_job(
self,
job_id: str,
status: Optional[str] = None,
progress: Optional[int] = None,
completed_companies: Optional[int] = None,
result_json: Optional[str] = None,
error: Optional[str] = None,
) -> Optional[Dict]:
"""Update a job's state.
Only non-None fields are updated.
"""
updates = []
params = []
if status is not None:
updates.append("status = %s")
params.append(status)
if progress is not None:
updates.append("progress = %s")
params.append(progress)
if completed_companies is not None:
updates.append("completed_companies = %s")
params.append(completed_companies)
if result_json is not None:
updates.append("result_json = %s")
params.append(result_json)
if error is not None:
updates.append("error = %s")
params.append(error)
if not updates:
return self.get_job(job_id)
updates.append("updated_at = CURRENT_TIMESTAMP")
params.append(job_id)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
f"UPDATE jobs SET {', '.join(updates)} WHERE job_id = %s RETURNING *",
params,
)
job = cursor.fetchone()
conn.commit()
return dict(job) if job else None
def get_job(self, job_id: str) -> Optional[Dict]:
"""Get a job by ID."""
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,))
job = cursor.fetchone()
return dict(job) if job else None
def list_jobs(
self,
status: Optional[str] = None,
limit: int = 10,
) -> List[Dict]:
"""List jobs, optionally filtered by status."""
query = "SELECT * FROM jobs"
params: list = []
if status:
query += " WHERE status = %s"
params.append(status)
query += " ORDER BY created_at DESC LIMIT %s"
params.append(limit)
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def mark_stale_jobs_failed(self) -> int:
"""Mark any jobs in 'running' or 'pending' state as 'failed'.
Called at startup to clean up jobs that were interrupted by a restart.
Returns:
Number of jobs marked as failed.
"""
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
UPDATE jobs SET status = 'failed', error = 'Interrupted by server restart',
updated_at = CURRENT_TIMESTAMP
WHERE status IN ('running', 'pending')
"""
)
count = cursor.rowcount
conn.commit()
return count
# User Authentication Methods
@staticmethod
@@ -505,25 +647,23 @@ class DatabaseClient:
Returns:
Created user dict or None if email exists
"""
self.connect()
password_hash = self.hash_password(password)
try:
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO users (email, password_hash, role)
VALUES (%s, %s, %s)
RETURNING id, email, role, created_at
""",
(email, password_hash, role),
)
user = cursor.fetchone()
self.conn.commit()
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
INSERT INTO users (email, password_hash, role)
VALUES (%s, %s, %s)
RETURNING id, email, role, created_at
""",
(email, password_hash, role),
)
user = cursor.fetchone()
conn.commit()
return dict(user) if user else None
except psycopg2.errors.UniqueViolation:
self.conn.rollback()
return None
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
@@ -536,23 +676,22 @@ class DatabaseClient:
Returns:
User dict if authenticated, None otherwise
"""
self.connect()
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM users WHERE email = %s",
(email,),
)
user = cursor.fetchone()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM users WHERE email = %s",
(email,),
)
user = cursor.fetchone()
if user and self.verify_password(password, user["password_hash"]):
return {
"id": user["id"],
"email": user["email"],
"role": user["role"],
"created_at": user["created_at"],
}
return None
if user and self.verify_password(password, user["password_hash"]):
return {
"id": user["id"],
"email": user["email"],
"role": user["role"],
"created_at": user["created_at"],
}
return None
def get_user_by_id(self, user_id: int) -> Optional[Dict]:
"""Get a user by ID.
@@ -563,15 +702,14 @@ class DatabaseClient:
Returns:
User dict or None
"""
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE id = %s",
(user_id,),
)
user = cursor.fetchone()
return dict(user) if user else None
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE id = %s",
(user_id,),
)
user = cursor.fetchone()
return dict(user) if user else None
def get_user_by_email(self, email: str) -> Optional[Dict]:
"""Get a user by email.
@@ -582,15 +720,14 @@ class DatabaseClient:
Returns:
User dict or None
"""
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE email = %s",
(email,),
)
user = cursor.fetchone()
return dict(user) if user else None
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT id, email, role, created_at FROM users WHERE email = %s",
(email,),
)
user = cursor.fetchone()
return dict(user) if user else None
def get_all_users(self, limit: int = 100, offset: int = 0) -> List[Dict]:
"""Get all users (admin only).
@@ -602,19 +739,18 @@ class DatabaseClient:
Returns:
List of user dicts
"""
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT id, email, role, created_at
FROM users
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""",
(limit, offset),
)
return [dict(row) for row in cursor.fetchall()]
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
SELECT id, email, role, created_at
FROM users
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""",
(limit, offset),
)
return [dict(row) for row in cursor.fetchall()]
def update_user_role(self, user_id: int, role: str) -> Optional[Dict]:
"""Update a user's role (admin only).
@@ -626,20 +762,19 @@ class DatabaseClient:
Returns:
Updated user dict or None
"""
self.connect()
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
UPDATE users
SET role = %s, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
RETURNING id, email, role, created_at
""",
(role, user_id),
)
user = cursor.fetchone()
self.conn.commit()
with self.get_conn() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"""
UPDATE users
SET role = %s, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
RETURNING id, email, role, created_at
""",
(role, user_id),
)
user = cursor.fetchone()
conn.commit()
return dict(user) if user else None
def delete_user(self, user_id: int) -> bool:
@@ -651,12 +786,11 @@ class DatabaseClient:
Returns:
True if deleted
"""
self.connect()
with self.conn.cursor() as cursor:
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
deleted = cursor.rowcount > 0
self.conn.commit()
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
deleted = cursor.rowcount > 0
conn.commit()
return deleted
def get_user_count(self) -> int:
@@ -665,8 +799,7 @@ class DatabaseClient:
Returns:
Number of users
"""
self.connect()
with self.conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0]
with self.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0]
+9 -8
View File
@@ -1,9 +1,14 @@
"""LLM integration for patent analysis using OpenRouter."""
import logging
from typing import Dict
from openai import OpenAI
from SPARC import config
from SPARC.database import DatabaseClient
from typing import Dict
logger = logging.getLogger(__name__)
class LLMAnalyzer:
@@ -20,7 +25,7 @@ class LLMAnalyzer:
"""
self.test_mode = test_mode
self.use_cache = use_cache if use_cache is not None else config.use_cache
self.model = "anthropic/claude-3.5-sonnet"
self.model = config.model
# Always initialize database client for storage and caching
self.db_client = DatabaseClient(config.database_url)
@@ -59,11 +64,7 @@ Patent Content:
Provide a concise analysis (2-3 paragraphs) focusing on what this patent reveals about the company's technical direction and competitive advantage."""
if self.test_mode:
print("=" * 80)
print("TEST MODE - Prompt that would be sent to LLM:")
print("=" * 80)
print(prompt)
print("=" * 80)
logger.debug("TEST MODE - Prompt that would be sent to LLM:\n%s", prompt)
return "[TEST MODE - No API call made]"
# Check cache first
@@ -165,7 +166,7 @@ Patent Portfolio:
Provide a comprehensive analysis (4-5 paragraphs) with a final verdict on the company's innovation strength and performance outlook."""
if self.test_mode:
print(prompt)
logger.debug("TEST MODE - Portfolio prompt:\n%s", prompt)
return "[TEST MODE]"
metadata = {
+8 -5
View File
@@ -1,12 +1,15 @@
import os
import serpapi
from SPARC import config
import re
import pdfplumber # pip install pdfplumber
import requests
from datetime import datetime, timedelta
from typing import Dict
from SPARC.types import Patents, Patent
import pdfplumber # pip install pdfplumber
import requests
import serpapi
from SPARC import config
from SPARC.types import Patent, Patents
class SERP:
def query(company: str, days_back: int = None) -> Patents:
+1 -1
View File
@@ -4,7 +4,7 @@ from datetime import datetime
@dataclass
class Patent:
patent_id: int
patent_id: str
pdf_link: str
pdf_path: str | None = None
summary: dict | None = None
+139
View File
@@ -0,0 +1,139 @@
"""Webhook notifications for job completion and alert events.
Sends JSON payloads to configured webhook URLs with retry logic.
Supports generic HTTP POST and Slack-compatible text payloads.
"""
import logging
import os
import time
from datetime import datetime
from typing import Any
import requests
logger = logging.getLogger(__name__)
# Comma-separated list of webhook URLs (env var based config)
_WEBHOOK_URLS_RAW = os.getenv("WEBHOOK_URLS", "")
WEBHOOK_URLS: list[str] = [
url.strip() for url in _WEBHOOK_URLS_RAW.split(",") if url.strip()
]
MAX_RETRIES = 3
BACKOFF_BASE = 2 # seconds
def _is_slack_url(url: str) -> bool:
"""Check if a URL looks like a Slack incoming webhook."""
return "hooks.slack.com" in url or "discord.com/api/webhooks" in url
def _build_payload(event_type: str, data: dict[str, Any], slack: bool = False) -> dict:
"""Build the webhook payload.
Args:
event_type: Type of event (e.g., "job_completed", "alert")
data: Event-specific data
slack: If True, wrap in Slack-compatible ``text`` format
Returns:
JSON-serializable payload dict
"""
payload = {
"event": event_type,
"timestamp": datetime.utcnow().isoformat() + "Z",
**data,
}
if slack:
# Build a human-readable summary for Slack/Discord
lines = [f"*[SPARC] {event_type}*"]
for key, value in data.items():
lines.append(f" {key}: {value}")
return {"text": "\n".join(lines)}
return payload
def _send_with_retry(url: str, payload: dict) -> bool:
"""Send a POST request with exponential backoff retry.
Args:
url: Webhook URL
payload: JSON payload to send
Returns:
True if delivered successfully, False after all retries exhausted
"""
for attempt in range(1, MAX_RETRIES + 1):
try:
response = requests.post(url, json=payload, timeout=10)
if response.status_code < 300:
logger.debug("Webhook delivered to %s (attempt %d)", url, attempt)
return True
logger.warning(
"Webhook %s returned %d (attempt %d/%d)",
url, response.status_code, attempt, MAX_RETRIES,
)
except requests.RequestException as e:
logger.warning(
"Webhook delivery failed for %s (attempt %d/%d): %s",
url, attempt, MAX_RETRIES, e,
)
if attempt < MAX_RETRIES:
wait = BACKOFF_BASE ** attempt
time.sleep(wait)
logger.error("Webhook permanently failed for %s after %d attempts", url, MAX_RETRIES)
return False
def notify(event_type: str, data: dict[str, Any]) -> None:
"""Fire all configured webhooks for an event.
Safe to call even when no webhooks are configured (returns immediately).
Args:
event_type: Event identifier (e.g., "job_completed", "patent_alert")
data: Event data to include in the payload
"""
if not WEBHOOK_URLS:
return
for url in WEBHOOK_URLS:
slack = _is_slack_url(url)
payload = _build_payload(event_type, data, slack=slack)
_send_with_retry(url, payload)
def notify_job_completed(
job_id: str,
status: str,
total_companies: int,
successful: int,
failed: int,
) -> None:
"""Send notification when a batch job completes."""
notify("job_completed", {
"job_id": job_id,
"status": status,
"total_companies": total_companies,
"successful": successful,
"failed": failed,
"summary": f"Batch job {job_id}: {successful}/{total_companies} succeeded",
})
def notify_alert(
company_name: str,
alert_type: str,
message: str,
) -> None:
"""Send notification for a tracked company alert."""
notify("patent_alert", {
"company_name": company_name,
"alert_type": alert_type,
"message": message,
})
+8 -6
View File
@@ -3,15 +3,15 @@ services:
image: postgres:16-alpine
container_name: sparc-postgres
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: sparc
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_DB: ${POSTGRES_DB}
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER}"]
interval: 5s
timeout: 5s
retries: 5
@@ -22,7 +22,7 @@ services:
container_name: sparc-init-db
command: python scripts/init_database.py
environment:
DATABASE_URL: postgresql://postgres:postgres@postgres:5432/sparc
DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/${POSTGRES_DB}
depends_on:
postgres:
condition: service_healthy
@@ -35,9 +35,11 @@ services:
environment:
API_KEY: ${API_KEY}
OPENROUTER_API_KEY: ${OPENROUTER_API_KEY}
DATABASE_URL: postgresql://postgres:postgres@postgres:5432/sparc
DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/${POSTGRES_DB}
USE_CACHE: "true"
JWT_SECRET: ${JWT_SECRET:-sparc-secret-key-change-in-production}
CORS_ORIGINS: ${CORS_ORIGINS:-}
APP_ENV: ${APP_ENV:-development}
ROOT_PATH: /api
ports:
- "8000:8000"
+4728
View File
File diff suppressed because it is too large Load Diff
+1
View File
@@ -14,3 +14,4 @@ numpy
pandas
bcrypt
PyJWT
slowapi
+8
View File
@@ -0,0 +1,8 @@
[lint]
select = ["E", "F", "I"]
ignore = [
"E501", # line too long (handled by formatter)
]
[lint.per-file-ignores]
"tests/*" = ["E402", "F841"] # allow import not at top of file, unused vars (mocks) in tests
+3
View File
@@ -40,6 +40,9 @@ def main():
print("\nTables created:")
print(" - llm_messages: Stores all LLM prompts and responses")
print(" - users: Stores user accounts")
print(" - jobs: Stores async batch job state")
print(" - patents: Patent PDF cache")
print(" - serp_queries: SERP query result cache")
print("\nIndexes created:")
print(" - idx_messages_timestamp: For time-based queries")
print(" - idx_messages_company: For company-specific queries")
+5 -3
View File
@@ -1,9 +1,11 @@
"""Tests for the high-level company analyzer orchestration."""
from unittest.mock import MagicMock, Mock
import pytest
from unittest.mock import Mock, patch, call, MagicMock
from SPARC.analyzer import CompanyAnalyzer
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult
from SPARC.types import BatchAnalysisResult, Patent, Patents
@pytest.fixture(autouse=True)
@@ -24,7 +26,7 @@ class TestCompanyAnalyzer:
"""Test analyzer initialization with API key."""
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
analyzer = CompanyAnalyzer(openrouter_api_key="test-key")
_analyzer = CompanyAnalyzer(openrouter_api_key="test-key") # noqa: F841
mock_llm.assert_called_once_with(api_key="test-key")
+5 -4
View File
@@ -1,12 +1,13 @@
"""Tests for FastAPI web service endpoints."""
import pytest
from datetime import datetime
from unittest.mock import Mock, patch
from unittest.mock import Mock
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app, _analyzer, _jobs
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult
from SPARC.api import app
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
@pytest.fixture
+302
View File
@@ -0,0 +1,302 @@
"""Tests for JWT authentication flow: register, login, protected routes, refresh, admin access."""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token, create_refresh_token
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db(monkeypatch):
"""Mock the database client used by auth endpoints.
Returns a MagicMock with all DB methods pre-configured.
"""
db = MagicMock()
# Default: no users exist
db.get_user_count.return_value = 0
db.get_user_by_id.return_value = None
db.get_user_by_email.return_value = None
db.authenticate_user.return_value = None
db.create_user.return_value = None
db.get_all_users.return_value = []
db.update_user_role.return_value = None
db.delete_user.return_value = False
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _make_admin_user():
return {
"id": 1,
"email": "admin@test.com",
"role": "admin",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
def _make_regular_user():
return {
"id": 2,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
def _auth_header(user_dict):
"""Create an Authorization header with a valid access token for the given user."""
token = create_access_token(user_dict["id"], user_dict["email"], user_dict["role"])
return {"Authorization": f"Bearer {token}"}
class TestRegister:
"""POST /auth/register"""
def test_register_first_user_becomes_admin(self, client, mock_db):
"""First registered user should get admin role."""
mock_db.get_user_count.return_value = 0
mock_db.create_user.return_value = {
"id": 1,
"email": "admin@test.com",
"role": "admin",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
response = client.post(
"/auth/register",
json={"email": "admin@test.com", "password": "securepass123"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == "admin@test.com"
assert data["role"] == "admin"
mock_db.create_user.assert_called_once_with(
email="admin@test.com", password="securepass123", role="admin"
)
def test_register_subsequent_user_gets_user_role(self, client, mock_db):
"""Non-first user should get regular user role."""
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = _make_regular_user()
response = client.post(
"/auth/register",
json={"email": "user@test.com", "password": "securepass123"},
)
assert response.status_code == 200
data = response.json()
assert data["role"] == "user"
def test_register_duplicate_email_returns_400(self, client, mock_db):
"""Registering with an existing email should return 400."""
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = None # indicates duplicate
response = client.post(
"/auth/register",
json={"email": "existing@test.com", "password": "securepass123"},
)
assert response.status_code == 400
assert "already registered" in response.json()["detail"].lower()
class TestLogin:
"""POST /auth/login"""
def test_login_valid_credentials_returns_tokens(self, client, mock_db):
"""Valid credentials should return access and refresh tokens."""
user = _make_regular_user()
mock_db.authenticate_user.return_value = user
response = client.post(
"/auth/login",
json={"email": "user@test.com", "password": "correctpassword"},
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
def test_login_invalid_credentials_returns_401(self, client, mock_db):
"""Invalid credentials should return 401."""
mock_db.authenticate_user.return_value = None
response = client.post(
"/auth/login",
json={"email": "user@test.com", "password": "wrongpassword"},
)
assert response.status_code == 401
assert "invalid" in response.json()["detail"].lower()
class TestGetMe:
"""GET /auth/me"""
def test_valid_access_token_returns_user(self, client, mock_db):
"""A valid access token should return the user's data."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = user
response = client.get("/auth/me", headers=_auth_header(user))
assert response.status_code == 200
data = response.json()
assert data["email"] == "user@test.com"
assert data["id"] == 2
def test_missing_token_returns_401(self, client):
"""No token should return 401 (403 from HTTPBearer)."""
response = client.get("/auth/me")
assert response.status_code in (401, 403)
def test_expired_token_returns_401(self, client, mock_db):
"""An expired token should return 401."""
# Create a token that has already expired
from datetime import timedelta
import jwt as pyjwt
from SPARC.auth import JWT_ALGORITHM, JWT_SECRET
payload = {
"sub": "1",
"email": "user@test.com",
"role": "user",
"exp": datetime.now(timezone.utc) - timedelta(hours=1),
"type": "access",
}
expired_token = pyjwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
response = client.get(
"/auth/me", headers={"Authorization": f"Bearer {expired_token}"}
)
assert response.status_code == 401
def test_refresh_token_as_access_returns_401(self, client, mock_db):
"""Using a refresh token as an access token should return 401."""
user = _make_regular_user()
refresh_token = create_refresh_token(user["id"], user["email"], user["role"])
response = client.get(
"/auth/me", headers={"Authorization": f"Bearer {refresh_token}"}
)
assert response.status_code == 401
class TestRefreshToken:
"""POST /auth/refresh"""
def test_valid_refresh_token_returns_new_tokens(self, client, mock_db):
"""A valid refresh token should issue new access and refresh tokens."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = user
refresh = create_refresh_token(user["id"], user["email"], user["role"])
response = client.post(
"/auth/refresh", json={"refresh_token": refresh}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
def test_invalid_refresh_token_returns_401(self, client, mock_db):
"""An invalid refresh token should return 401."""
response = client.post(
"/auth/refresh", json={"refresh_token": "invalid-token-string"}
)
assert response.status_code == 401
def test_access_token_as_refresh_returns_401(self, client, mock_db):
"""Using an access token as a refresh token should return 401."""
user = _make_regular_user()
access = create_access_token(user["id"], user["email"], user["role"])
response = client.post(
"/auth/refresh", json={"refresh_token": access}
)
assert response.status_code == 401
class TestAdminUsers:
"""GET /admin/users and PATCH /admin/users/{id}/role"""
def test_admin_can_list_users(self, client, mock_db):
"""Admin token should allow listing users."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
mock_db.get_all_users.return_value = [admin, _make_regular_user()]
response = client.get("/admin/users", headers=_auth_header(admin))
assert response.status_code == 200
data = response.json()
assert len(data) == 2
def test_regular_user_cannot_list_users(self, client, mock_db):
"""Regular user token should be rejected with 403."""
user = _make_regular_user()
mock_db.get_user_by_id.return_value = user
response = client.get("/admin/users", headers=_auth_header(user))
assert response.status_code == 403
def test_no_token_cannot_list_users(self, client):
"""No token should be rejected."""
response = client.get("/admin/users")
assert response.status_code in (401, 403)
def test_admin_can_change_user_role(self, client, mock_db):
"""Admin should be able to change another user's role."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
mock_db.update_user_role.return_value = {
"id": 2,
"email": "user@test.com",
"role": "admin",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
response = client.patch(
"/admin/users/2/role",
json={"role": "admin"},
headers=_auth_header(admin),
)
assert response.status_code == 200
assert response.json()["role"] == "admin"
def test_admin_cannot_change_own_role(self, client, mock_db):
"""Admin should not be able to change their own role."""
admin = _make_admin_user()
mock_db.get_user_by_id.return_value = admin
response = client.patch(
"/admin/users/1/role",
json={"role": "user"},
headers=_auth_header(admin),
)
assert response.status_code == 400
assert "own role" in response.json()["detail"].lower()
+3 -1
View File
@@ -1,7 +1,9 @@
"""Tests for LLM analysis functionality."""
from unittest.mock import Mock
import pytest
from unittest.mock import Mock, MagicMock, patch
from SPARC.llm import LLMAnalyzer
+97
View File
@@ -0,0 +1,97 @@
"""Tests for rate limiting on auth endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from fastapi.testclient import TestClient
from SPARC.api import app
@pytest.fixture
def client():
"""Create test client with rate limiter enabled."""
return TestClient(app)
@pytest.fixture(autouse=True)
def reset_limiter():
"""Reset rate limiter storage between tests."""
from SPARC.api import limiter
limiter.reset()
yield
class TestRateLimiting:
"""Test rate limiting on login and register endpoints."""
@patch("SPARC.api.get_db_client")
def test_login_allows_requests_under_limit(self, mock_db_client, client):
"""Login endpoint allows requests under the rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Should allow at least a few requests
for _ in range(5):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
# 401 is expected (invalid credentials), not 429
assert response.status_code == 401
@patch("SPARC.api.get_db_client")
def test_login_rate_limited_after_threshold(self, mock_db_client, client):
"""Login endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Send more than the limit (10/minute)
statuses = []
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_register_rate_limited_after_threshold(self, mock_db_client, client):
"""Register endpoint returns 429 after exceeding rate limit."""
mock_db = MagicMock()
mock_db.get_user_count.return_value = 1
mock_db.create_user.return_value = None # triggers 400 (email exists)
mock_db_client.return_value = mock_db
# Send more than the limit (5/minute)
statuses = []
for _ in range(10):
response = client.post(
"/auth/register",
json={"email": "test@example.com", "password": "password123"},
)
statuses.append(response.status_code)
# At least one should be 429
assert 429 in statuses, f"Expected 429 in statuses but got: {set(statuses)}"
@patch("SPARC.api.get_db_client")
def test_rate_limit_returns_retry_after_header(self, mock_db_client, client):
"""Rate limited responses include a Retry-After header."""
mock_db = MagicMock()
mock_db.authenticate_user.return_value = None
mock_db_client.return_value = mock_db
# Exhaust the limit
for _ in range(15):
response = client.post(
"/auth/login",
json={"email": "test@example.com", "password": "password123"},
)
if response.status_code == 429:
assert "Retry-After" in response.headers
break
+116
View File
@@ -0,0 +1,116 @@
"""Tests for security hardening: JWT secret startup check, CORS config, credential handling."""
import os
from unittest.mock import patch
import pytest
class TestJWTSecretStartupCheck:
"""Test the startup guard that refuses default JWT secret in non-dev environments."""
def test_default_secret_in_production_raises(self):
"""Starting with default secret and APP_ENV=production must raise RuntimeError."""
with patch.dict(os.environ, {"APP_ENV": "production"}):
# Reload config to pick up the new APP_ENV
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
# Patch JWT_SECRET to the default
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
with pytest.raises(RuntimeError, match="FATAL.*JWT_SECRET"):
check_jwt_secret()
# Restore config
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
def test_default_secret_in_development_succeeds(self):
"""Starting with default secret and APP_ENV=development must not raise."""
with patch.dict(os.environ, {"APP_ENV": "development"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
# Should not raise
check_jwt_secret()
# Restore
importlib.reload(SPARC.config)
def test_custom_secret_in_production_succeeds(self):
"""Starting with a custom secret in production must not raise."""
with patch.dict(os.environ, {"APP_ENV": "production"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", "my-secure-random-secret-abc123"):
# Should not raise
check_jwt_secret()
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
def test_default_secret_unset_env_succeeds(self):
"""When APP_ENV is unset (defaults to development), default secret is allowed."""
with patch.dict(os.environ, {}, clear=False):
# Remove APP_ENV if present
env = os.environ.copy()
env.pop("APP_ENV", None)
with patch.dict(os.environ, env, clear=True):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
from SPARC.auth import _DEFAULT_JWT_SECRET, check_jwt_secret
with patch("SPARC.auth.JWT_SECRET", _DEFAULT_JWT_SECRET):
# Should not raise (defaults to development)
check_jwt_secret()
with patch.dict(os.environ, {"APP_ENV": "development"}):
importlib.reload(SPARC.config)
class TestCORSConfig:
"""Test that CORS origins are configurable via environment variable."""
def test_default_cors_origins(self):
"""When CORS_ORIGINS is unset, defaults to localhost origins."""
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == [
"http://localhost:3000",
"http://localhost:5173",
]
def test_custom_cors_origins(self):
"""Setting CORS_ORIGINS configures allowed origins."""
with patch.dict(os.environ, {"CORS_ORIGINS": "https://sparc.example.com,https://app.example.com"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == [
"https://sparc.example.com",
"https://app.example.com",
]
# Restore
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
importlib.reload(SPARC.config)
def test_single_cors_origin(self):
"""A single origin without comma works correctly."""
with patch.dict(os.environ, {"CORS_ORIGINS": "https://sparc.example.com"}):
import importlib
import SPARC.config
importlib.reload(SPARC.config)
assert SPARC.config.cors_origins == ["https://sparc.example.com"]
with patch.dict(os.environ, {"CORS_ORIGINS": ""}):
importlib.reload(SPARC.config)
+2 -3
View File
@@ -1,9 +1,8 @@
"""Tests for SERP API patent retrieval and parsing functionality."""
import os
import pytest
from unittest.mock import patch, Mock
from datetime import datetime, timedelta
from unittest.mock import Mock
from SPARC.serp_api import SERP
from SPARC.types import Patent