Compare commits

..

1 Commits

Author SHA1 Message Date
agent-company ab3964b18d Move webhook delivery to background task queue
Introduce a lightweight in-process task queue (thread + queue.Queue) so
that webhook HTTP delivery no longer blocks the scheduler or batch-job
background tasks.  The worker thread preserves the existing exponential-
backoff retry logic from _send_with_retry.

- Add SPARC/task_queue.py: WebhookTask, start/stop worker, enqueue, drain
- Add enqueue_notify / enqueue_job_completed / enqueue_alert to webhooks.py
- Update api.py lifespan to start/stop the webhook worker
- Update _run_batch_job to use enqueue_job_completed (non-blocking)
- Update scheduler to fire enqueue_alert on patent count changes
- Add 13 tests covering worker lifecycle, async delivery, retry in worker
  context, and integration via enqueue helpers
- All 22 existing webhook tests continue to pass unchanged

Closes leeworks-agents/SPARC#1676

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 15:22:21 +00:00
6 changed files with 519 additions and 529 deletions
+79 -153
View File
@@ -224,11 +224,16 @@ async def lifespan(app: FastAPI):
import logging
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
_db.close()
# Start webhook background worker
from SPARC.task_queue import start_worker as start_webhook_worker
from SPARC.task_queue import stop_worker as stop_webhook_worker
start_webhook_worker()
# Start scheduled analysis if tracked companies are configured
from SPARC.scheduler import start_scheduler
start_scheduler()
yield
# Cleanup
stop_webhook_worker()
_analyzer = None
close_db_client()
@@ -675,25 +680,27 @@ async def get_analytics_trends(
# ============== Export Endpoints ==============
class BatchExportRequest(BaseModel):
"""Request model for batch ZIP export of analysis results."""
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
companies: list[CompanyName] = Field(
..., min_length=1, max_length=50, description="List of company names to export"
)
format: str = Field(
default="csv",
pattern="^(csv|pdf)$",
description="Export format: 'csv' or 'pdf'",
)
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp.
Args:
company_name: Company name to export results for
def _fetch_company_rows(db, company_name: str) -> list:
"""Fetch all non-cached analysis rows for *company_name* from the DB.
Returns a list of tuples: (company_name, analysis_type, model, response, timestamp).
Returns an empty list when no results exist.
Returns:
CSV file download
"""
import csv
import io
db = get_db_client()
# Query all non-cached analysis results for this company
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
@@ -705,24 +712,43 @@ def _fetch_company_rows(db, company_name: str) -> list:
""",
(company_name,),
)
return cur.fetchall()
rows = cur.fetchall()
def _build_company_csv(rows) -> bytes:
"""Render *rows* as CSV bytes."""
import csv
import io
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(["company_name", "analysis_type", "model", "analysis", "timestamp"])
for row in rows:
writer.writerow(row)
return output.getvalue().encode("utf-8")
output.seek(0)
safe_name = company_name.replace(" ", "_").lower()
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="sparc_{safe_name}_export.csv"'},
)
def _build_company_pdf(rows, company_name: str) -> bytes:
"""Render *rows* as PDF bytes using reportlab."""
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted
as a downloadable PDF document.
Args:
company_name: Company name to export results for
Returns:
PDF file download
"""
import io
from reportlab.lib import colors
@@ -737,6 +763,23 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
TableStyle,
)
db = get_db_client()
with db.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT company_name, analysis_type, model, response, timestamp
FROM llm_messages
WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE
ORDER BY timestamp DESC
""",
(company_name,),
)
rows = cur.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
@@ -779,11 +822,13 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
elements = []
display_name = rows[0][0]
# Title and date
display_name = rows[0][0] # Use the casing from the database
analysis_date = datetime.now().strftime("%Y-%m-%d")
elements.append(Paragraph(f"SPARC Analysis Report: {display_name}", title_style))
elements.append(Paragraph(f"Generated on {analysis_date}", subtitle_style))
# Summary table
summary_data = [
["Total Analyses", str(len(rows))],
["Analysis Types", ", ".join(sorted(set(r[1] for r in rows)))],
@@ -805,6 +850,7 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
elements.append(summary_table)
elements.append(Spacer(1, 16))
# Individual analysis sections
for i, row in enumerate(rows, 1):
_, analysis_type, model, response, timestamp = row
ts_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") if hasattr(timestamp, "strftime") else str(timestamp)
@@ -816,11 +862,13 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
Paragraph(f"<i>Performed: {ts_str}</i>", body_style)
)
# Wrap long response text into paragraphs, escaping XML special chars
safe_response = (
response.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
# Split into manageable paragraphs to avoid overflow
for line in safe_response.split("\n"):
if line.strip():
elements.append(Paragraph(line, body_style))
@@ -831,133 +879,11 @@ def _build_company_pdf(rows, company_name: str) -> bytes:
doc.build(elements)
buffer.seek(0)
return buffer.getvalue()
@app.post("/export/batch", tags=["Export"])
async def export_batch_zip(
request: BatchExportRequest,
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for multiple companies as a ZIP archive.
For each company in the request, fetches all stored analysis records and
adds a per-company file (CSV or PDF) to the archive. Companies with no
stored results are skipped; a ``manifest.json`` inside the ZIP lists both
the exported and skipped companies.
Args:
request: List of company names and desired export format ('csv' or 'pdf')
Returns:
ZIP archive download containing one file per found company plus a manifest
"""
import io
import json
import zipfile
db = get_db_client()
export_date = datetime.now().strftime("%Y-%m-%d")
fmt = request.format
exported: list[str] = []
skipped: list[str] = []
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
for company_name in request.companies:
rows = _fetch_company_rows(db, company_name)
if not rows:
skipped.append(company_name)
continue
safe_name = company_name.replace(" ", "_").lower()
if fmt == "pdf":
file_bytes = _build_company_pdf(rows, company_name)
filename = f"{safe_name}-analysis-{export_date}.pdf"
else:
file_bytes = _build_company_csv(rows)
filename = f"sparc_{safe_name}_export.csv"
zf.writestr(filename, file_bytes)
exported.append(company_name)
# Always include a manifest
manifest = {
"export_date": export_date,
"format": fmt,
"exported": exported,
"skipped": skipped,
}
zf.writestr("manifest.json", json.dumps(manifest, indent=2))
zip_buffer.seek(0)
zip_filename = f"sparc-export-{export_date}.zip"
return StreamingResponse(
iter([zip_buffer.getvalue()]),
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{zip_filename}"'},
)
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp.
Args:
company_name: Company name to export results for
Returns:
CSV file download
"""
db = get_db_client()
rows = _fetch_company_rows(db, company_name)
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
safe_name = company_name.replace(" ", "_").lower()
return StreamingResponse(
iter([_build_company_csv(rows)]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="sparc_{safe_name}_export.csv"'},
)
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
Returns all stored analysis records for the given company, including
analysis type, model used, response text, and timestamp, formatted
as a downloadable PDF document.
Args:
company_name: Company name to export results for
Returns:
PDF file download
"""
db = get_db_client()
rows = _fetch_company_rows(db, company_name)
if not rows:
raise HTTPException(status_code=404, detail=f"No analysis results found for '{company_name}'")
safe_name = company_name.replace(" ", "_").lower()
analysis_date = datetime.now().strftime("%Y-%m-%d")
filename = f"{safe_name}-analysis-{analysis_date}.pdf"
return StreamingResponse(
iter([_build_company_pdf(rows, company_name)]),
iter([buffer.getvalue()]),
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@@ -1183,9 +1109,9 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: s
progress=100,
result_json=_json.dumps(batch_response.model_dump(), default=str),
)
# Fire webhook notification
from SPARC.webhooks import notify_job_completed
notify_job_completed(
# Fire webhook notification (non-blocking via task queue)
from SPARC.webhooks import enqueue_job_completed
enqueue_job_completed(
job_id=job_id,
status="completed",
total_companies=result.total_companies,
@@ -1194,8 +1120,8 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: s
)
except Exception as e:
db.update_job(job_id, status="failed", error=str(e))
from SPARC.webhooks import notify_job_completed
notify_job_completed(
from SPARC.webhooks import enqueue_job_completed
enqueue_job_completed(
job_id=job_id,
status="failed",
total_companies=len(companies),
+7
View File
@@ -71,6 +71,13 @@ def run_scheduled_analysis() -> None:
old_value=old_count,
new_value=new_count,
)
# Fire non-blocking webhook notification
from SPARC.webhooks import enqueue_alert
enqueue_alert(
company_name=name,
alert_type="patent_count_change",
message=message,
)
elif new_count > 0:
# First analysis -- record baseline
logger.info("Baseline for %s: %d patents", name, new_count)
+113
View File
@@ -0,0 +1,113 @@
"""Lightweight in-process task queue for non-blocking webhook delivery.
Uses a daemon thread and a :class:`queue.Queue` so that the scheduler and
background jobs can enqueue webhook deliveries without blocking on HTTP
round-trips and retry backoff.
No external dependencies (Redis, etc.) are required.
"""
import logging
import queue
import threading
from dataclasses import dataclass
from typing import Any
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class WebhookTask:
"""A single webhook delivery request."""
url: str
payload: dict[str, Any]
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_queue: queue.Queue[WebhookTask | None] = queue.Queue()
_worker_thread: threading.Thread | None = None
_started = threading.Event()
def _worker_loop() -> None:
"""Process webhook tasks until a ``None`` sentinel is received."""
import SPARC.webhooks as _webhooks # deferred to avoid circular import
logger.info("Webhook worker thread started")
_started.set()
while True:
task = _queue.get()
if task is None:
# Sentinel — shut down
logger.info("Webhook worker thread stopping")
_queue.task_done()
break
try:
# Look up dynamically so that tests can patch the function
_webhooks._send_with_retry(task.url, task.payload)
except Exception:
logger.exception("Unexpected error delivering webhook to %s", task.url)
finally:
_queue.task_done()
def start_worker() -> None:
"""Start the background worker thread (idempotent)."""
global _worker_thread
if _worker_thread is not None and _worker_thread.is_alive():
return
_started.clear()
_worker_thread = threading.Thread(target=_worker_loop, daemon=True, name="webhook-worker")
_worker_thread.start()
_started.wait() # block until the worker is actually running
logger.info("Webhook task queue ready")
def stop_worker(timeout: float = 5.0) -> None:
"""Send the stop sentinel and wait for the worker to finish.
Args:
timeout: Maximum seconds to wait for the worker thread to join.
"""
global _worker_thread
if _worker_thread is None or not _worker_thread.is_alive():
_worker_thread = None
return
_queue.put(None) # sentinel
_worker_thread.join(timeout=timeout)
_worker_thread = None
logger.info("Webhook task queue stopped")
def enqueue(task: WebhookTask) -> None:
"""Add a webhook delivery task to the queue.
If the worker has not been started the task is still accepted into the
queue and will be processed once :func:`start_worker` is called.
"""
_queue.put(task)
def queue_size() -> int:
"""Return the approximate number of pending tasks."""
return _queue.qsize()
def drain(timeout: float = 10.0) -> None:
"""Block until all currently-enqueued tasks have been processed.
Useful in tests and graceful shutdown to ensure pending deliveries
complete before the process exits.
Args:
timeout: Maximum seconds to wait.
"""
_queue.join()
+58 -3
View File
@@ -91,9 +91,10 @@ def _send_with_retry(url: str, payload: dict) -> bool:
def notify(event_type: str, data: dict[str, Any]) -> None:
"""Fire all configured webhooks for an event.
"""Fire all configured webhooks for an event (**blocking**).
Safe to call even when no webhooks are configured (returns immediately).
For non-blocking delivery, use :func:`enqueue_notify` instead.
Args:
event_type: Event identifier (e.g., "job_completed", "patent_alert")
@@ -108,6 +109,29 @@ def notify(event_type: str, data: dict[str, Any]) -> None:
_send_with_retry(url, payload)
def enqueue_notify(event_type: str, data: dict[str, Any]) -> None:
"""Enqueue webhook delivery for all configured URLs (non-blocking).
Returns immediately after placing tasks on the background queue.
The worker thread handles retry logic asynchronously.
Safe to call even when no webhooks are configured.
Args:
event_type: Event identifier (e.g., "job_completed", "patent_alert")
data: Event data to include in the payload
"""
if not WEBHOOK_URLS:
return
from SPARC.task_queue import WebhookTask, enqueue
for url in WEBHOOK_URLS:
slack = _is_slack_url(url)
payload = _build_payload(event_type, data, slack=slack)
enqueue(WebhookTask(url=url, payload=payload))
def notify_job_completed(
job_id: str,
status: str,
@@ -115,7 +139,7 @@ def notify_job_completed(
successful: int,
failed: int,
) -> None:
"""Send notification when a batch job completes."""
"""Send notification when a batch job completes (blocking)."""
notify("job_completed", {
"job_id": job_id,
"status": status,
@@ -126,14 +150,45 @@ def notify_job_completed(
})
def enqueue_job_completed(
job_id: str,
status: str,
total_companies: int,
successful: int,
failed: int,
) -> None:
"""Enqueue notification when a batch job completes (non-blocking)."""
enqueue_notify("job_completed", {
"job_id": job_id,
"status": status,
"total_companies": total_companies,
"successful": successful,
"failed": failed,
"summary": f"Batch job {job_id}: {successful}/{total_companies} succeeded",
})
def notify_alert(
company_name: str,
alert_type: str,
message: str,
) -> None:
"""Send notification for a tracked company alert."""
"""Send notification for a tracked company alert (blocking)."""
notify("patent_alert", {
"company_name": company_name,
"alert_type": alert_type,
"message": message,
})
def enqueue_alert(
company_name: str,
alert_type: str,
message: str,
) -> None:
"""Enqueue notification for a tracked company alert (non-blocking)."""
enqueue_notify("patent_alert", {
"company_name": company_name,
"alert_type": alert_type,
"message": message,
})
-373
View File
@@ -1,373 +0,0 @@
"""Tests for POST /export/batch endpoint (issue #1674).
Covers:
- Single company export (CSV + PDF)
- Multiple company export
- All-missing companies (every requested company is skipped)
- Unauthenticated / invalid-token requests
- Manifest content validation
- Invalid format rejection
"""
import io
import json
import zipfile
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.auth import create_access_token
@pytest.fixture
def client():
"""Create a FastAPI test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def mock_db():
"""Mock database client for all tests in this module."""
db = MagicMock()
# Auth: user always exists
db.get_user_by_id.return_value = {
"id": 1,
"email": "user@test.com",
"role": "user",
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
# Default cursor mock (overridden per-test via side_effect or return_value)
mock_cursor = MagicMock()
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
db.get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn)
db.get_conn.return_value.__exit__ = MagicMock(return_value=False)
db._mock_cursor = mock_cursor
with patch("SPARC.api.get_db_client", return_value=db), \
patch("SPARC.auth.get_db_client", return_value=db):
yield db
def _auth_header():
token = create_access_token(1, "user@test.com", "user")
return {"Authorization": f"Bearer {token}"}
def _rows_for(company_name: str):
"""Return a single sample row for the given company."""
return [
(
company_name,
"company_analysis",
"anthropic/claude-3.5-sonnet",
f"Strong patent portfolio for {company_name}.",
datetime(2025, 6, 15, 10, 30, 0),
)
]
def _open_zip(content: bytes) -> zipfile.ZipFile:
"""Helper: wrap response bytes as a ZipFile."""
return zipfile.ZipFile(io.BytesIO(content))
# ---------------------------------------------------------------------------
# Authentication
# ---------------------------------------------------------------------------
class TestBatchExportAuth:
"""Unauthenticated and invalid-token requests must be rejected."""
def test_unauthenticated_returns_401(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
)
assert response.status_code == 401
def test_invalid_token_returns_401(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers={"Authorization": "Bearer totally.invalid.token"},
)
assert response.status_code == 401
# ---------------------------------------------------------------------------
# Single company
# ---------------------------------------------------------------------------
class TestBatchExportSingleCompany:
"""POST /export/batch with a single company name."""
def test_single_company_csv_returns_zip(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/zip"
assert "attachment" in response.headers["content-disposition"]
assert "sparc-export-" in response.headers["content-disposition"]
assert response.headers["content-disposition"].endswith('.zip"')
def test_single_company_csv_zip_contains_csv_file(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
names = zf.namelist()
csv_files = [n for n in names if n.endswith(".csv")]
assert len(csv_files) == 1
assert "nvidia" in csv_files[0]
def test_single_company_csv_content_is_valid_csv(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
csv_name = [n for n in zf.namelist() if n.endswith(".csv")][0]
csv_text = zf.read(csv_name).decode("utf-8")
lines = csv_text.strip().split("\n")
assert lines[0].strip() == "company_name,analysis_type,model,analysis,timestamp"
assert "NVIDIA" in lines[1]
def test_single_company_pdf_zip_contains_pdf_file(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "pdf"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
pdf_files = [n for n in zf.namelist() if n.endswith(".pdf")]
assert len(pdf_files) == 1
# Verify it is actually a PDF (starts with %PDF)
pdf_bytes = zf.read(pdf_files[0])
assert pdf_bytes[:4] == b"%PDF"
# ---------------------------------------------------------------------------
# Multiple companies
# ---------------------------------------------------------------------------
class TestBatchExportMultipleCompanies:
"""POST /export/batch with several companies."""
def test_multiple_companies_each_gets_a_file(self, client, mock_db):
companies = ["NVIDIA", "Intel", "AMD"]
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
_rows_for("Intel"),
_rows_for("AMD"),
]
response = client.post(
"/export/batch",
json={"companies": companies, "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
assert len(csv_files) == 3
def test_multiple_companies_manifest_lists_all_exported(self, client, mock_db):
companies = ["NVIDIA", "Intel"]
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
_rows_for("Intel"),
]
response = client.post(
"/export/batch",
json={"companies": companies, "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert set(manifest["exported"]) == {"NVIDIA", "Intel"}
assert manifest["skipped"] == []
assert manifest["format"] == "csv"
def test_partial_missing_companies_skipped(self, client, mock_db):
"""Companies with no data are skipped; others are exported."""
mock_db._mock_cursor.fetchall.side_effect = [
_rows_for("NVIDIA"),
[], # no data for "UnknownCo"
]
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA", "UnknownCo"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["exported"] == ["NVIDIA"]
assert manifest["skipped"] == ["UnknownCo"]
csv_files = [n for n in zf.namelist() if n.endswith(".csv")]
assert len(csv_files) == 1
# ---------------------------------------------------------------------------
# All-missing companies
# ---------------------------------------------------------------------------
class TestBatchExportAllMissing:
"""When every requested company has no data, the ZIP still returns 200
with only a manifest (no per-company files, all listed in skipped)."""
def test_all_missing_returns_200_with_manifest_only(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = []
response = client.post(
"/export/batch",
json={"companies": ["GhostCo", "PhantomInc"], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
assert "manifest.json" in zf.namelist()
manifest = json.loads(zf.read("manifest.json"))
assert manifest["exported"] == []
assert set(manifest["skipped"]) == {"GhostCo", "PhantomInc"}
def test_all_missing_zip_has_no_data_files(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = []
response = client.post(
"/export/batch",
json={"companies": ["GhostCo"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
data_files = [n for n in zf.namelist() if n != "manifest.json"]
assert data_files == []
# ---------------------------------------------------------------------------
# Manifest validation
# ---------------------------------------------------------------------------
class TestBatchExportManifest:
"""The manifest.json inside every ZIP must be well-formed."""
def test_manifest_always_present(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
assert "manifest.json" in zf.namelist()
def test_manifest_contains_required_keys(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "csv"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert "export_date" in manifest
assert "format" in manifest
assert "exported" in manifest
assert "skipped" in manifest
def test_manifest_format_field_matches_request(self, client, mock_db):
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "pdf"},
headers=_auth_header(),
)
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["format"] == "pdf"
# ---------------------------------------------------------------------------
# Input validation
# ---------------------------------------------------------------------------
class TestBatchExportInputValidation:
"""Invalid request bodies must return 422."""
def test_invalid_format_returns_422(self, client):
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"], "format": "xlsx"},
headers=_auth_header(),
)
assert response.status_code == 422
def test_empty_companies_list_returns_422(self, client):
response = client.post(
"/export/batch",
json={"companies": [], "format": "csv"},
headers=_auth_header(),
)
assert response.status_code == 422
def test_default_format_is_csv(self, client, mock_db):
"""Omitting `format` should default to CSV."""
mock_db._mock_cursor.fetchall.return_value = _rows_for("NVIDIA")
response = client.post(
"/export/batch",
json={"companies": ["NVIDIA"]},
headers=_auth_header(),
)
assert response.status_code == 200
zf = _open_zip(response.content)
manifest = json.loads(zf.read("manifest.json"))
assert manifest["format"] == "csv"
+262
View File
@@ -0,0 +1,262 @@
"""Tests for the webhook background task queue.
Covers:
- Worker lifecycle (start / stop / idempotent start)
- Tasks are processed asynchronously by the worker
- Retry logic is preserved (executed inside the worker thread)
- enqueue_notify / enqueue_job_completed / enqueue_alert non-blocking helpers
- Integration: queued webhook task is eventually delivered (mocked HTTP)
"""
import threading
import time
from unittest.mock import MagicMock, call, patch
import pytest
from SPARC.task_queue import (
WebhookTask,
drain,
enqueue,
queue_size,
start_worker,
stop_worker,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _worker_lifecycle():
"""Start the worker before each test and stop it after."""
start_worker()
yield
stop_worker(timeout=3)
# ---------------------------------------------------------------------------
# Worker lifecycle
# ---------------------------------------------------------------------------
class TestWorkerLifecycle:
def test_start_is_idempotent(self):
"""Calling start_worker() twice does not create a second thread."""
import SPARC.task_queue as tq
first = tq._worker_thread
start_worker()
assert tq._worker_thread is first
def test_stop_worker_gracefully(self):
"""stop_worker joins the thread cleanly."""
import SPARC.task_queue as tq
assert tq._worker_thread is not None
stop_worker(timeout=3)
assert tq._worker_thread is None
# ---------------------------------------------------------------------------
# Task processing
# ---------------------------------------------------------------------------
class TestTaskProcessing:
@patch("SPARC.webhooks._send_with_retry")
def test_enqueued_task_is_delivered(self, mock_send):
"""A task put on the queue is eventually processed by the worker."""
mock_send.return_value = True
task = WebhookTask(url="https://example.com/hook", payload={"event": "test"})
enqueue(task)
drain(timeout=5)
mock_send.assert_called_once_with("https://example.com/hook", {"event": "test"})
@patch("SPARC.webhooks._send_with_retry")
def test_multiple_tasks_processed_in_order(self, mock_send):
"""Tasks are processed FIFO."""
mock_send.return_value = True
for i in range(3):
enqueue(WebhookTask(url=f"https://example.com/{i}", payload={"n": i}))
drain(timeout=5)
assert mock_send.call_count == 3
urls = [c[0][0] for c in mock_send.call_args_list]
assert urls == [
"https://example.com/0",
"https://example.com/1",
"https://example.com/2",
]
@patch("SPARC.webhooks._send_with_retry")
def test_enqueue_returns_immediately(self, mock_send):
"""enqueue() does not block even if the worker is slow."""
event = threading.Event()
def slow_send(url, payload):
event.wait(timeout=5)
return True
mock_send.side_effect = slow_send
start = time.monotonic()
enqueue(WebhookTask(url="https://slow.example.com", payload={}))
elapsed = time.monotonic() - start
# enqueue should return in well under 1 second
assert elapsed < 0.5
# Let the worker finish
event.set()
drain(timeout=5)
@patch("SPARC.webhooks._send_with_retry", side_effect=RuntimeError("boom"))
def test_worker_survives_unexpected_error(self, mock_send):
"""An unexpected exception in delivery does not kill the worker."""
enqueue(WebhookTask(url="https://example.com/bad", payload={}))
drain(timeout=5)
# Worker is still alive; enqueue another task
mock_send.side_effect = None
mock_send.return_value = True
enqueue(WebhookTask(url="https://example.com/good", payload={"ok": True}))
drain(timeout=5)
assert mock_send.call_count == 2
# ---------------------------------------------------------------------------
# Retry logic preserved in worker context
# ---------------------------------------------------------------------------
class TestRetryInWorker:
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_retry_logic_runs_inside_worker(self, mock_post, mock_sleep):
"""The worker thread uses _send_with_retry, which retries on failure."""
mock_post.side_effect = [
MagicMock(status_code=500),
MagicMock(status_code=200),
]
enqueue(WebhookTask(
url="https://example.com/retry",
payload={"event": "test"},
))
drain(timeout=10)
assert mock_post.call_count == 2
mock_sleep.assert_called_once()
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_all_retries_exhausted_in_worker(self, mock_post, mock_sleep):
"""Worker handles permanent failure gracefully."""
mock_post.return_value = MagicMock(status_code=500)
enqueue(WebhookTask(
url="https://example.com/fail",
payload={"event": "test"},
))
drain(timeout=10)
from SPARC.webhooks import MAX_RETRIES
assert mock_post.call_count == MAX_RETRIES
# ---------------------------------------------------------------------------
# Integration: enqueue_notify and convenience helpers
# ---------------------------------------------------------------------------
class TestEnqueueHelpers:
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", ["https://example.com/hook"])
def test_enqueue_notify_delivers_via_worker(self, mock_send):
"""enqueue_notify puts a task on the queue and the worker delivers it."""
mock_send.return_value = True
from SPARC.webhooks import enqueue_notify
enqueue_notify("test_event", {"key": "val"})
drain(timeout=5)
mock_send.assert_called_once()
url, payload = mock_send.call_args[0]
assert url == "https://example.com/hook"
assert payload["event"] == "test_event"
assert payload["key"] == "val"
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", ["https://example.com/hook"])
def test_enqueue_job_completed(self, mock_send):
"""enqueue_job_completed sends job completion data via the queue."""
mock_send.return_value = True
from SPARC.webhooks import enqueue_job_completed
enqueue_job_completed(
job_id="job-1",
status="completed",
total_companies=5,
successful=4,
failed=1,
)
drain(timeout=5)
mock_send.assert_called_once()
payload = mock_send.call_args[0][1]
assert payload["event"] == "job_completed"
assert payload["job_id"] == "job-1"
assert payload["successful"] == 4
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", ["https://example.com/hook"])
def test_enqueue_alert(self, mock_send):
"""enqueue_alert sends alert data via the queue."""
mock_send.return_value = True
from SPARC.webhooks import enqueue_alert
enqueue_alert(
company_name="NVIDIA",
alert_type="patent_count_change",
message="Patent count increased by 30%",
)
drain(timeout=5)
mock_send.assert_called_once()
payload = mock_send.call_args[0][1]
assert payload["event"] == "patent_alert"
assert payload["company_name"] == "NVIDIA"
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", [])
def test_enqueue_notify_noop_when_no_urls(self, mock_send):
"""enqueue_notify is a no-op when WEBHOOK_URLS is empty."""
from SPARC.webhooks import enqueue_notify
enqueue_notify("test_event", {"key": "val"})
drain(timeout=2)
mock_send.assert_not_called()
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", [
"https://hooks.slack.com/services/T00/B00/xxx",
"https://example.com/generic",
])
def test_enqueue_notify_slack_formatting(self, mock_send):
"""Slack URLs get Slack-formatted payloads even via the queue."""
mock_send.return_value = True
from SPARC.webhooks import enqueue_notify
enqueue_notify("test_event", {"key": "val"})
drain(timeout=5)
assert mock_send.call_count == 2
slack_payload = mock_send.call_args_list[0][0][1]
assert "text" in slack_payload
generic_payload = mock_send.call_args_list[1][0][1]
assert "event" in generic_payload