Compare commits

..

1 Commits

Author SHA1 Message Date
agent-company ab3964b18d Move webhook delivery to background task queue
Introduce a lightweight in-process task queue (thread + queue.Queue) so
that webhook HTTP delivery no longer blocks the scheduler or batch-job
background tasks.  The worker thread preserves the existing exponential-
backoff retry logic from _send_with_retry.

- Add SPARC/task_queue.py: WebhookTask, start/stop worker, enqueue, drain
- Add enqueue_notify / enqueue_job_completed / enqueue_alert to webhooks.py
- Update api.py lifespan to start/stop the webhook worker
- Update _run_batch_job to use enqueue_job_completed (non-blocking)
- Update scheduler to fire enqueue_alert on patent count changes
- Add 13 tests covering worker lifecycle, async delivery, retry in worker
  context, and integration via enqueue helpers
- All 22 existing webhook tests continue to pass unchanged

Closes leeworks-agents/SPARC#1676

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 15:22:21 +00:00
10 changed files with 457 additions and 407 deletions
+14 -51
View File
@@ -5,9 +5,8 @@ Provides REST API endpoints for analyzing company patent portfolios.
from __future__ import annotations from __future__ import annotations
from collections import deque
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone from datetime import datetime
from typing import TYPE_CHECKING, Annotated, List from typing import TYPE_CHECKING, Annotated, List
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -225,11 +224,16 @@ async def lifespan(app: FastAPI):
import logging import logging
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale) logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
_db.close() _db.close()
# Start webhook background worker
from SPARC.task_queue import start_worker as start_webhook_worker
from SPARC.task_queue import stop_worker as stop_webhook_worker
start_webhook_worker()
# Start scheduled analysis if tracked companies are configured # Start scheduled analysis if tracked companies are configured
from SPARC.scheduler import start_scheduler from SPARC.scheduler import start_scheduler
start_scheduler() start_scheduler()
yield yield
# Cleanup # Cleanup
stop_webhook_worker()
_analyzer = None _analyzer = None
close_db_client() close_db_client()
@@ -249,9 +253,6 @@ app.state.limiter = limiter
# In-memory rate limit statistics # In-memory rate limit statistics
_rate_limit_stats: dict[str, dict] = {} _rate_limit_stats: dict[str, dict] = {}
# Time-series log of rejected requests (capped to last 24 h worth of entries).
_rejected_log: deque[dict] = deque(maxlen=100_000)
def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None: def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None:
"""Record a request against a rate-limited endpoint.""" """Record a request against a rate-limited endpoint."""
@@ -266,11 +267,6 @@ def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) ->
_rate_limit_stats[key]["total_requests"] += 1 _rate_limit_stats[key]["total_requests"] += 1
if rejected: if rejected:
_rate_limit_stats[key]["rejected_requests"] += 1 _rate_limit_stats[key]["rejected_requests"] += 1
_rejected_log.append({
"endpoint": endpoint,
"ip": ip,
"timestamp": datetime.now(timezone.utc).isoformat(),
})
ip_stats = _rate_limit_stats[key].setdefault("by_ip", {}) ip_stats = _rate_limit_stats[key].setdefault("by_ip", {})
if ip not in ip_stats: if ip not in ip_stats:
ip_stats[ip] = {"total": 0, "rejected": 0} ip_stats[ip] = {"total": 0, "rejected": 0}
@@ -516,12 +512,10 @@ async def get_rate_limit_stats(
"""Get rate limit status and usage statistics (admin only). """Get rate limit status and usage statistics (admin only).
Returns current rate limit configuration and request statistics Returns current rate limit configuration and request statistics
for all rate-limited endpoints, including per-IP breakdown and for all rate-limited endpoints.
a time-series of throttled (rejected) requests in the last 24 hours.
Returns: Returns:
Rate limit stats per endpoint, per-IP breakdown, and throttled List of rate limit stats per endpoint with total/rejected counts
request history bucketed by hour.
""" """
rate_limits_config = { rate_limits_config = {
"/auth/register": {"limit": "5/minute"}, "/auth/register": {"limit": "5/minute"},
@@ -531,45 +525,14 @@ async def get_rate_limit_stats(
results = [] results = []
for endpoint, conf in rate_limits_config.items(): for endpoint, conf in rate_limits_config.items():
stats = _rate_limit_stats.get(endpoint, {}) stats = _rate_limit_stats.get(endpoint, {})
by_ip_raw = stats.get("by_ip", {})
by_ip = [
{"ip": ip, "total": counts["total"], "rejected": counts["rejected"]}
for ip, counts in by_ip_raw.items()
]
results.append({ results.append({
"endpoint": endpoint, "endpoint": endpoint,
"limit": conf["limit"], "limit": conf["limit"],
"total_requests": stats.get("total_requests", 0), "total_requests": stats.get("total_requests", 0),
"rejected_requests": stats.get("rejected_requests", 0), "rejected_requests": stats.get("rejected_requests", 0),
"by_ip": by_ip,
}) })
# Build hourly buckets of throttled requests for the last 24 hours return {"rate_limits": results}
now = datetime.now(timezone.utc)
cutoff = now - timedelta(hours=24)
hourly_buckets: dict[str, int] = {}
throttled_24h = 0
for entry in _rejected_log:
ts_str = entry["timestamp"]
try:
ts = datetime.fromisoformat(ts_str)
except (ValueError, TypeError):
continue
if ts >= cutoff:
throttled_24h += 1
bucket = ts.strftime("%Y-%m-%dT%H:00:00Z")
hourly_buckets[bucket] = hourly_buckets.get(bucket, 0) + 1
throttled_over_time = [
{"timestamp": k, "count": v}
for k, v in sorted(hourly_buckets.items())
]
return {
"rate_limits": results,
"throttled_24h": throttled_24h,
"throttled_over_time": throttled_over_time,
}
@app.get("/admin/alerts", tags=["Admin"]) @app.get("/admin/alerts", tags=["Admin"])
@@ -1146,9 +1109,9 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: s
progress=100, progress=100,
result_json=_json.dumps(batch_response.model_dump(), default=str), result_json=_json.dumps(batch_response.model_dump(), default=str),
) )
# Fire webhook notification # Fire webhook notification (non-blocking via task queue)
from SPARC.webhooks import notify_job_completed from SPARC.webhooks import enqueue_job_completed
notify_job_completed( enqueue_job_completed(
job_id=job_id, job_id=job_id,
status="completed", status="completed",
total_companies=result.total_companies, total_companies=result.total_companies,
@@ -1157,8 +1120,8 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: s
) )
except Exception as e: except Exception as e:
db.update_job(job_id, status="failed", error=str(e)) db.update_job(job_id, status="failed", error=str(e))
from SPARC.webhooks import notify_job_completed from SPARC.webhooks import enqueue_job_completed
notify_job_completed( enqueue_job_completed(
job_id=job_id, job_id=job_id,
status="failed", status="failed",
total_companies=len(companies), total_companies=len(companies),
+7
View File
@@ -71,6 +71,13 @@ def run_scheduled_analysis() -> None:
old_value=old_count, old_value=old_count,
new_value=new_count, new_value=new_count,
) )
# Fire non-blocking webhook notification
from SPARC.webhooks import enqueue_alert
enqueue_alert(
company_name=name,
alert_type="patent_count_change",
message=message,
)
elif new_count > 0: elif new_count > 0:
# First analysis -- record baseline # First analysis -- record baseline
logger.info("Baseline for %s: %d patents", name, new_count) logger.info("Baseline for %s: %d patents", name, new_count)
+113
View File
@@ -0,0 +1,113 @@
"""Lightweight in-process task queue for non-blocking webhook delivery.
Uses a daemon thread and a :class:`queue.Queue` so that the scheduler and
background jobs can enqueue webhook deliveries without blocking on HTTP
round-trips and retry backoff.
No external dependencies (Redis, etc.) are required.
"""
import logging
import queue
import threading
from dataclasses import dataclass
from typing import Any
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class WebhookTask:
"""A single webhook delivery request."""
url: str
payload: dict[str, Any]
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_queue: queue.Queue[WebhookTask | None] = queue.Queue()
_worker_thread: threading.Thread | None = None
_started = threading.Event()
def _worker_loop() -> None:
"""Process webhook tasks until a ``None`` sentinel is received."""
import SPARC.webhooks as _webhooks # deferred to avoid circular import
logger.info("Webhook worker thread started")
_started.set()
while True:
task = _queue.get()
if task is None:
# Sentinel — shut down
logger.info("Webhook worker thread stopping")
_queue.task_done()
break
try:
# Look up dynamically so that tests can patch the function
_webhooks._send_with_retry(task.url, task.payload)
except Exception:
logger.exception("Unexpected error delivering webhook to %s", task.url)
finally:
_queue.task_done()
def start_worker() -> None:
"""Start the background worker thread (idempotent)."""
global _worker_thread
if _worker_thread is not None and _worker_thread.is_alive():
return
_started.clear()
_worker_thread = threading.Thread(target=_worker_loop, daemon=True, name="webhook-worker")
_worker_thread.start()
_started.wait() # block until the worker is actually running
logger.info("Webhook task queue ready")
def stop_worker(timeout: float = 5.0) -> None:
"""Send the stop sentinel and wait for the worker to finish.
Args:
timeout: Maximum seconds to wait for the worker thread to join.
"""
global _worker_thread
if _worker_thread is None or not _worker_thread.is_alive():
_worker_thread = None
return
_queue.put(None) # sentinel
_worker_thread.join(timeout=timeout)
_worker_thread = None
logger.info("Webhook task queue stopped")
def enqueue(task: WebhookTask) -> None:
"""Add a webhook delivery task to the queue.
If the worker has not been started the task is still accepted into the
queue and will be processed once :func:`start_worker` is called.
"""
_queue.put(task)
def queue_size() -> int:
"""Return the approximate number of pending tasks."""
return _queue.qsize()
def drain(timeout: float = 10.0) -> None:
"""Block until all currently-enqueued tasks have been processed.
Useful in tests and graceful shutdown to ensure pending deliveries
complete before the process exits.
Args:
timeout: Maximum seconds to wait.
"""
_queue.join()
+58 -3
View File
@@ -91,9 +91,10 @@ def _send_with_retry(url: str, payload: dict) -> bool:
def notify(event_type: str, data: dict[str, Any]) -> None: def notify(event_type: str, data: dict[str, Any]) -> None:
"""Fire all configured webhooks for an event. """Fire all configured webhooks for an event (**blocking**).
Safe to call even when no webhooks are configured (returns immediately). Safe to call even when no webhooks are configured (returns immediately).
For non-blocking delivery, use :func:`enqueue_notify` instead.
Args: Args:
event_type: Event identifier (e.g., "job_completed", "patent_alert") event_type: Event identifier (e.g., "job_completed", "patent_alert")
@@ -108,6 +109,29 @@ def notify(event_type: str, data: dict[str, Any]) -> None:
_send_with_retry(url, payload) _send_with_retry(url, payload)
def enqueue_notify(event_type: str, data: dict[str, Any]) -> None:
"""Enqueue webhook delivery for all configured URLs (non-blocking).
Returns immediately after placing tasks on the background queue.
The worker thread handles retry logic asynchronously.
Safe to call even when no webhooks are configured.
Args:
event_type: Event identifier (e.g., "job_completed", "patent_alert")
data: Event data to include in the payload
"""
if not WEBHOOK_URLS:
return
from SPARC.task_queue import WebhookTask, enqueue
for url in WEBHOOK_URLS:
slack = _is_slack_url(url)
payload = _build_payload(event_type, data, slack=slack)
enqueue(WebhookTask(url=url, payload=payload))
def notify_job_completed( def notify_job_completed(
job_id: str, job_id: str,
status: str, status: str,
@@ -115,7 +139,7 @@ def notify_job_completed(
successful: int, successful: int,
failed: int, failed: int,
) -> None: ) -> None:
"""Send notification when a batch job completes.""" """Send notification when a batch job completes (blocking)."""
notify("job_completed", { notify("job_completed", {
"job_id": job_id, "job_id": job_id,
"status": status, "status": status,
@@ -126,14 +150,45 @@ def notify_job_completed(
}) })
def enqueue_job_completed(
job_id: str,
status: str,
total_companies: int,
successful: int,
failed: int,
) -> None:
"""Enqueue notification when a batch job completes (non-blocking)."""
enqueue_notify("job_completed", {
"job_id": job_id,
"status": status,
"total_companies": total_companies,
"successful": successful,
"failed": failed,
"summary": f"Batch job {job_id}: {successful}/{total_companies} succeeded",
})
def notify_alert( def notify_alert(
company_name: str, company_name: str,
alert_type: str, alert_type: str,
message: str, message: str,
) -> None: ) -> None:
"""Send notification for a tracked company alert.""" """Send notification for a tracked company alert (blocking)."""
notify("patent_alert", { notify("patent_alert", {
"company_name": company_name, "company_name": company_name,
"alert_type": alert_type, "alert_type": alert_type,
"message": message, "message": message,
}) })
def enqueue_alert(
company_name: str,
alert_type: str,
message: str,
) -> None:
"""Enqueue notification for a tracked company alert (non-blocking)."""
enqueue_notify("patent_alert", {
"company_name": company_name,
"alert_type": alert_type,
"message": message,
})
-9
View File
@@ -11,7 +11,6 @@ import { Batch } from './pages/Batch';
import { AnalyticsPage } from './pages/Analytics'; import { AnalyticsPage } from './pages/Analytics';
import { About } from './pages/About'; import { About } from './pages/About';
import { AdminUsers } from './pages/AdminUsers'; import { AdminUsers } from './pages/AdminUsers';
import { AdminRateLimits } from './pages/AdminRateLimits';
import { Compare } from './pages/Compare'; import { Compare } from './pages/Compare';
const queryClient = new QueryClient({ const queryClient = new QueryClient({
@@ -57,14 +56,6 @@ function App() {
</ProtectedRoute> </ProtectedRoute>
} }
/> />
<Route
path="/admin/rate-limits"
element={
<ProtectedRoute requireAdmin>
<AdminRateLimits />
</ProtectedRoute>
}
/>
</Route> </Route>
{/* Default redirect */} {/* Default redirect */}
-31
View File
@@ -201,32 +201,6 @@ export const analyticsApi = {
}, },
}; };
// Rate limit types
export interface RateLimitIpEntry {
ip: string;
total: number;
rejected: number;
}
export interface RateLimitEndpointStats {
endpoint: string;
limit: string;
total_requests: number;
rejected_requests: number;
by_ip: RateLimitIpEntry[];
}
export interface ThrottledBucket {
timestamp: string;
count: number;
}
export interface RateLimitStatsResponse {
rate_limits: RateLimitEndpointStats[];
throttled_24h: number;
throttled_over_time: ThrottledBucket[];
}
// Admin API // Admin API
export const adminApi = { export const adminApi = {
listUsers: async (limit = 100, offset = 0): Promise<User[]> => { listUsers: async (limit = 100, offset = 0): Promise<User[]> => {
@@ -242,11 +216,6 @@ export const adminApi = {
deleteUser: async (userId: number): Promise<void> => { deleteUser: async (userId: number): Promise<void> => {
await api.delete(`/admin/users/${userId}`); await api.delete(`/admin/users/${userId}`);
}, },
getRateLimits: async (): Promise<RateLimitStatsResponse> => {
const response = await api.get<RateLimitStatsResponse>('/admin/rate-limits');
return response.data;
},
}; };
export default api; export default api;
+1 -2
View File
@@ -1,7 +1,7 @@
import { Outlet, NavLink, useNavigate } from 'react-router-dom'; import { Outlet, NavLink, useNavigate } from 'react-router-dom';
import { useAuth } from '../context/AuthContext'; import { useAuth } from '../context/AuthContext';
import { useTheme } from '../context/ThemeContext'; import { useTheme } from '../context/ThemeContext';
import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon, ShieldAlert } from 'lucide-react'; import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon } from 'lucide-react';
export function Layout() { export function Layout() {
const { user, isAdmin, logout } = useAuth(); const { user, isAdmin, logout } = useAuth();
@@ -23,7 +23,6 @@ export function Layout() {
if (isAdmin) { if (isAdmin) {
navItems.push({ to: '/admin/users', icon: Users, label: 'Users' }); navItems.push({ to: '/admin/users', icon: Users, label: 'Users' });
navItems.push({ to: '/admin/rate-limits', icon: ShieldAlert, label: 'Rate Limits' });
} }
return ( return (
-240
View File
@@ -1,240 +0,0 @@
import { useState } from 'react';
import { useQuery } from '@tanstack/react-query';
import { adminApi } from '../api/client';
import type { RateLimitStatsResponse } from '../api/client';
import { ShieldAlert, Activity, AlertCircle, RefreshCw, Clock } from 'lucide-react';
const REFRESH_OPTIONS = [
{ label: '15s', value: 15_000 },
{ label: '30s', value: 30_000 },
{ label: '1m', value: 60_000 },
{ label: 'Off', value: 0 },
];
export function AdminRateLimits() {
const [refreshInterval, setRefreshInterval] = useState(30_000);
const { data, isLoading, isError, dataUpdatedAt } = useQuery<RateLimitStatsResponse>({
queryKey: ['admin-rate-limits'],
queryFn: () => adminApi.getRateLimits(),
refetchInterval: refreshInterval || false,
});
if (isLoading) {
return (
<div className="flex items-center justify-center min-h-[400px]">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-primary"></div>
</div>
);
}
if (isError) {
return (
<div className="flex items-center gap-2 bg-error/10 border border-error/20 text-error rounded-xl px-4 py-3">
<AlertCircle size={18} />
<span>Failed to load rate limit statistics.</span>
</div>
);
}
const maxThrottledCount = data?.throttled_over_time?.length
? Math.max(...data.throttled_over_time.map((b) => b.count))
: 0;
return (
<div className="space-y-6">
{/* Header */}
<div className="flex items-center justify-between flex-wrap gap-4">
<div>
<h2 className="text-xl font-semibold text-text-primary border-b-2 border-primary/30 pb-2 mb-2">
Rate Limiting Dashboard
</h2>
<p className="text-text-secondary">Monitor API rate limits and throttled requests.</p>
</div>
<div className="flex items-center gap-3">
{/* Last updated */}
{dataUpdatedAt > 0 && (
<span className="text-xs text-text-secondary flex items-center gap-1">
<Clock size={12} />
Updated {new Date(dataUpdatedAt).toLocaleTimeString()}
</span>
)}
{/* Refresh interval selector */}
<div className="flex items-center gap-1 bg-bg-card/60 border border-primary/15 rounded-xl p-1">
<RefreshCw size={14} className="text-text-secondary ml-2" />
{REFRESH_OPTIONS.map((opt) => (
<button
key={opt.value}
onClick={() => setRefreshInterval(opt.value)}
className={`px-3 py-1 rounded-lg text-xs font-medium transition-all ${
refreshInterval === opt.value
? 'bg-primary text-white'
: 'text-text-secondary hover:text-text-primary hover:bg-bg-card-hover'
}`}
>
{opt.label}
</button>
))}
</div>
</div>
</div>
{/* Summary cards */}
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<Activity size={18} className="text-primary" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total Requests
</span>
</div>
<div className="text-3xl font-bold text-text-primary">
{data?.rate_limits.reduce((sum, rl) => sum + rl.total_requests, 0) ?? 0}
</div>
</div>
<div className="bg-bg-card/60 border border-error/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<ShieldAlert size={18} className="text-error" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Throttled (24h)
</span>
</div>
<div className="text-3xl font-bold text-error">
{data?.throttled_24h ?? 0}
</div>
</div>
<div className="bg-bg-card/60 border border-secondary/15 rounded-2xl p-5">
<div className="flex items-center gap-2 mb-2">
<ShieldAlert size={18} className="text-secondary" />
<span className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rate-Limited Endpoints
</span>
</div>
<div className="text-3xl font-bold text-text-primary">
{data?.rate_limits.length ?? 0}
</div>
</div>
</div>
{/* Throttled over time chart (simple bar chart) */}
{data?.throttled_over_time && data.throttled_over_time.length > 0 && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl p-5">
<h3 className="text-sm font-semibold text-text-secondary uppercase tracking-wider mb-4">
Throttled Requests Over Time (Last 24h)
</h3>
<div className="flex items-end gap-1 h-32">
{data.throttled_over_time.map((bucket) => {
const height = maxThrottledCount > 0 ? (bucket.count / maxThrottledCount) * 100 : 0;
const hour = new Date(bucket.timestamp).getHours();
return (
<div key={bucket.timestamp} className="flex-1 flex flex-col items-center gap-1">
<span className="text-xs text-text-secondary">{bucket.count}</span>
<div
className="w-full bg-error/70 rounded-t-sm min-h-[2px] transition-all"
style={{ height: `${Math.max(height, 2)}%` }}
title={`${bucket.timestamp}: ${bucket.count} throttled`}
/>
<span className="text-[10px] text-text-secondary">{hour}:00</span>
</div>
);
})}
</div>
</div>
)}
{/* Per-endpoint table */}
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl overflow-hidden">
<div className="overflow-x-auto">
<table className="w-full">
<thead>
<tr className="border-b border-primary/10">
<th className="text-left px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Endpoint
</th>
<th className="text-left px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Limit
</th>
<th className="text-right px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total Requests
</th>
<th className="text-right px-6 py-4 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rejected
</th>
</tr>
</thead>
<tbody className="divide-y divide-primary/10">
{data?.rate_limits.map((rl) => (
<tr key={rl.endpoint} className="hover:bg-bg-card-hover/50 transition-colors">
<td className="px-6 py-4 font-mono text-sm text-text-primary">{rl.endpoint}</td>
<td className="px-6 py-4">
<span className="inline-flex px-2 py-0.5 rounded-full text-xs font-medium bg-primary/10 text-primary border border-primary/20">
{rl.limit}
</span>
</td>
<td className="px-6 py-4 text-right text-text-primary font-semibold">
{rl.total_requests}
</td>
<td className="px-6 py-4 text-right">
<span className={rl.rejected_requests > 0 ? 'text-error font-semibold' : 'text-text-secondary'}>
{rl.rejected_requests}
</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
{/* Per-IP breakdown */}
{data?.rate_limits.some((rl) => rl.by_ip.length > 0) && (
<div className="bg-bg-card/60 border border-primary/15 rounded-2xl overflow-hidden">
<div className="px-6 py-4 border-b border-primary/10">
<h3 className="text-sm font-semibold text-text-secondary uppercase tracking-wider">
Per-IP Breakdown
</h3>
</div>
<div className="overflow-x-auto">
<table className="w-full">
<thead>
<tr className="border-b border-primary/10">
<th className="text-left px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Endpoint
</th>
<th className="text-left px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
IP Address
</th>
<th className="text-right px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Total
</th>
<th className="text-right px-6 py-3 text-sm font-semibold text-text-secondary uppercase tracking-wider">
Rejected
</th>
</tr>
</thead>
<tbody className="divide-y divide-primary/10">
{data.rate_limits.flatMap((rl) =>
rl.by_ip.map((ipEntry) => (
<tr
key={`${rl.endpoint}-${ipEntry.ip}`}
className="hover:bg-bg-card-hover/50 transition-colors"
>
<td className="px-6 py-3 font-mono text-sm text-text-primary">{rl.endpoint}</td>
<td className="px-6 py-3 font-mono text-sm text-text-secondary">{ipEntry.ip}</td>
<td className="px-6 py-3 text-right text-text-primary">{ipEntry.total}</td>
<td className="px-6 py-3 text-right">
<span className={ipEntry.rejected > 0 ? 'text-error font-semibold' : 'text-text-secondary'}>
{ipEntry.rejected}
</span>
</td>
</tr>
))
)}
</tbody>
</table>
</div>
</div>
)}
</div>
);
}
+2 -71
View File
@@ -20,10 +20,8 @@ def client():
def reset_stats(): def reset_stats():
"""Reset rate limit stats between tests.""" """Reset rate limit stats between tests."""
api._rate_limit_stats.clear() api._rate_limit_stats.clear()
api._rejected_log.clear()
yield yield
api._rate_limit_stats.clear() api._rate_limit_stats.clear()
api._rejected_log.clear()
def _mock_admin(): def _mock_admin():
@@ -52,7 +50,8 @@ class TestRateLimitAdminEndpoint:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_non_admin_rejected(self, client): def test_non_admin_rejected(self, client):
"""Non-admin users should get 401/403.""" """Non-admin users should get 403."""
# Without overriding the dependency, it should fail auth
response = client.get("/admin/rate-limits") response = client.get("/admin/rate-limits")
assert response.status_code in (401, 403) assert response.status_code in (401, 403)
@@ -78,9 +77,6 @@ class TestRateLimitAdminEndpoint:
for rl in data["rate_limits"]: for rl in data["rate_limits"]:
assert rl["total_requests"] == 0 assert rl["total_requests"] == 0
assert rl["rejected_requests"] == 0 assert rl["rejected_requests"] == 0
assert rl["by_ip"] == []
assert data["throttled_24h"] == 0
assert data["throttled_over_time"] == []
finally: finally:
app.dependency_overrides.clear() app.dependency_overrides.clear()
@@ -111,68 +107,3 @@ class TestRateLimitAdminEndpoint:
assert isinstance(rl["limit"], str) assert isinstance(rl["limit"], str)
finally: finally:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_per_ip_breakdown(self, client):
"""Stats should include per-IP breakdown with total and rejected counts."""
api._track_rate_limit_request("/auth/login", "10.0.0.1")
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
api._track_rate_limit_request("/auth/login", "10.0.0.2")
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login")
by_ip = login_stats["by_ip"]
assert len(by_ip) == 2
ip1 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.1")
assert ip1["total"] == 2
assert ip1["rejected"] == 1
ip2 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.2")
assert ip2["total"] == 1
assert ip2["rejected"] == 0
finally:
app.dependency_overrides.clear()
def test_throttled_24h_count(self, client):
"""Should report total throttled requests in the last 24 hours."""
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
api._track_rate_limit_request("/auth/register", "10.0.0.2", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
assert data["throttled_24h"] == 2
finally:
app.dependency_overrides.clear()
def test_throttled_over_time_structure(self, client):
"""Throttled-over-time should be a list of {timestamp, count} buckets."""
api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True)
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
assert len(data["throttled_over_time"]) >= 1
entry = data["throttled_over_time"][0]
assert "timestamp" in entry
assert "count" in entry
assert entry["count"] >= 1
finally:
app.dependency_overrides.clear()
def test_response_shape_matches_contract(self, client):
"""The full response should match the expected shape for the frontend."""
app.dependency_overrides[api.get_current_admin] = _mock_admin
try:
response = client.get("/admin/rate-limits")
data = response.json()
# Top-level keys
assert set(data.keys()) == {"rate_limits", "throttled_24h", "throttled_over_time"}
# Each rate_limit entry
for rl in data["rate_limits"]:
assert set(rl.keys()) == {"endpoint", "limit", "total_requests", "rejected_requests", "by_ip"}
finally:
app.dependency_overrides.clear()
+262
View File
@@ -0,0 +1,262 @@
"""Tests for the webhook background task queue.
Covers:
- Worker lifecycle (start / stop / idempotent start)
- Tasks are processed asynchronously by the worker
- Retry logic is preserved (executed inside the worker thread)
- enqueue_notify / enqueue_job_completed / enqueue_alert non-blocking helpers
- Integration: queued webhook task is eventually delivered (mocked HTTP)
"""
import threading
import time
from unittest.mock import MagicMock, call, patch
import pytest
from SPARC.task_queue import (
WebhookTask,
drain,
enqueue,
queue_size,
start_worker,
stop_worker,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _worker_lifecycle():
"""Start the worker before each test and stop it after."""
start_worker()
yield
stop_worker(timeout=3)
# ---------------------------------------------------------------------------
# Worker lifecycle
# ---------------------------------------------------------------------------
class TestWorkerLifecycle:
def test_start_is_idempotent(self):
"""Calling start_worker() twice does not create a second thread."""
import SPARC.task_queue as tq
first = tq._worker_thread
start_worker()
assert tq._worker_thread is first
def test_stop_worker_gracefully(self):
"""stop_worker joins the thread cleanly."""
import SPARC.task_queue as tq
assert tq._worker_thread is not None
stop_worker(timeout=3)
assert tq._worker_thread is None
# ---------------------------------------------------------------------------
# Task processing
# ---------------------------------------------------------------------------
class TestTaskProcessing:
@patch("SPARC.webhooks._send_with_retry")
def test_enqueued_task_is_delivered(self, mock_send):
"""A task put on the queue is eventually processed by the worker."""
mock_send.return_value = True
task = WebhookTask(url="https://example.com/hook", payload={"event": "test"})
enqueue(task)
drain(timeout=5)
mock_send.assert_called_once_with("https://example.com/hook", {"event": "test"})
@patch("SPARC.webhooks._send_with_retry")
def test_multiple_tasks_processed_in_order(self, mock_send):
"""Tasks are processed FIFO."""
mock_send.return_value = True
for i in range(3):
enqueue(WebhookTask(url=f"https://example.com/{i}", payload={"n": i}))
drain(timeout=5)
assert mock_send.call_count == 3
urls = [c[0][0] for c in mock_send.call_args_list]
assert urls == [
"https://example.com/0",
"https://example.com/1",
"https://example.com/2",
]
@patch("SPARC.webhooks._send_with_retry")
def test_enqueue_returns_immediately(self, mock_send):
"""enqueue() does not block even if the worker is slow."""
event = threading.Event()
def slow_send(url, payload):
event.wait(timeout=5)
return True
mock_send.side_effect = slow_send
start = time.monotonic()
enqueue(WebhookTask(url="https://slow.example.com", payload={}))
elapsed = time.monotonic() - start
# enqueue should return in well under 1 second
assert elapsed < 0.5
# Let the worker finish
event.set()
drain(timeout=5)
@patch("SPARC.webhooks._send_with_retry", side_effect=RuntimeError("boom"))
def test_worker_survives_unexpected_error(self, mock_send):
"""An unexpected exception in delivery does not kill the worker."""
enqueue(WebhookTask(url="https://example.com/bad", payload={}))
drain(timeout=5)
# Worker is still alive; enqueue another task
mock_send.side_effect = None
mock_send.return_value = True
enqueue(WebhookTask(url="https://example.com/good", payload={"ok": True}))
drain(timeout=5)
assert mock_send.call_count == 2
# ---------------------------------------------------------------------------
# Retry logic preserved in worker context
# ---------------------------------------------------------------------------
class TestRetryInWorker:
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_retry_logic_runs_inside_worker(self, mock_post, mock_sleep):
"""The worker thread uses _send_with_retry, which retries on failure."""
mock_post.side_effect = [
MagicMock(status_code=500),
MagicMock(status_code=200),
]
enqueue(WebhookTask(
url="https://example.com/retry",
payload={"event": "test"},
))
drain(timeout=10)
assert mock_post.call_count == 2
mock_sleep.assert_called_once()
@patch("SPARC.webhooks.time.sleep")
@patch("SPARC.webhooks.requests.post")
def test_all_retries_exhausted_in_worker(self, mock_post, mock_sleep):
"""Worker handles permanent failure gracefully."""
mock_post.return_value = MagicMock(status_code=500)
enqueue(WebhookTask(
url="https://example.com/fail",
payload={"event": "test"},
))
drain(timeout=10)
from SPARC.webhooks import MAX_RETRIES
assert mock_post.call_count == MAX_RETRIES
# ---------------------------------------------------------------------------
# Integration: enqueue_notify and convenience helpers
# ---------------------------------------------------------------------------
class TestEnqueueHelpers:
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", ["https://example.com/hook"])
def test_enqueue_notify_delivers_via_worker(self, mock_send):
"""enqueue_notify puts a task on the queue and the worker delivers it."""
mock_send.return_value = True
from SPARC.webhooks import enqueue_notify
enqueue_notify("test_event", {"key": "val"})
drain(timeout=5)
mock_send.assert_called_once()
url, payload = mock_send.call_args[0]
assert url == "https://example.com/hook"
assert payload["event"] == "test_event"
assert payload["key"] == "val"
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", ["https://example.com/hook"])
def test_enqueue_job_completed(self, mock_send):
"""enqueue_job_completed sends job completion data via the queue."""
mock_send.return_value = True
from SPARC.webhooks import enqueue_job_completed
enqueue_job_completed(
job_id="job-1",
status="completed",
total_companies=5,
successful=4,
failed=1,
)
drain(timeout=5)
mock_send.assert_called_once()
payload = mock_send.call_args[0][1]
assert payload["event"] == "job_completed"
assert payload["job_id"] == "job-1"
assert payload["successful"] == 4
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", ["https://example.com/hook"])
def test_enqueue_alert(self, mock_send):
"""enqueue_alert sends alert data via the queue."""
mock_send.return_value = True
from SPARC.webhooks import enqueue_alert
enqueue_alert(
company_name="NVIDIA",
alert_type="patent_count_change",
message="Patent count increased by 30%",
)
drain(timeout=5)
mock_send.assert_called_once()
payload = mock_send.call_args[0][1]
assert payload["event"] == "patent_alert"
assert payload["company_name"] == "NVIDIA"
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", [])
def test_enqueue_notify_noop_when_no_urls(self, mock_send):
"""enqueue_notify is a no-op when WEBHOOK_URLS is empty."""
from SPARC.webhooks import enqueue_notify
enqueue_notify("test_event", {"key": "val"})
drain(timeout=2)
mock_send.assert_not_called()
@patch("SPARC.webhooks._send_with_retry")
@patch("SPARC.webhooks.WEBHOOK_URLS", [
"https://hooks.slack.com/services/T00/B00/xxx",
"https://example.com/generic",
])
def test_enqueue_notify_slack_formatting(self, mock_send):
"""Slack URLs get Slack-formatted payloads even via the queue."""
mock_send.return_value = True
from SPARC.webhooks import enqueue_notify
enqueue_notify("test_event", {"key": "val"})
drain(timeout=5)
assert mock_send.call_count == 2
slack_payload = mock_send.call_args_list[0][0][1]
assert "text" in slack_payload
generic_payload = mock_send.call_args_list[1][0][1]
assert "event" in generic_payload