diff --git a/SPARC/api.py b/SPARC/api.py index 1b29d38..d380e1d 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -5,8 +5,9 @@ Provides REST API endpoints for analyzing company patent portfolios. from __future__ import annotations +from collections import deque from contextlib import asynccontextmanager -from datetime import datetime +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Annotated, List if TYPE_CHECKING: @@ -248,6 +249,9 @@ app.state.limiter = limiter # In-memory rate limit statistics _rate_limit_stats: dict[str, dict] = {} +# Time-series log of rejected requests (capped to last 24 h worth of entries). +_rejected_log: deque[dict] = deque(maxlen=100_000) + def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> None: """Record a request against a rate-limited endpoint.""" @@ -262,6 +266,11 @@ def _track_rate_limit_request(endpoint: str, ip: str, rejected: bool = False) -> _rate_limit_stats[key]["total_requests"] += 1 if rejected: _rate_limit_stats[key]["rejected_requests"] += 1 + _rejected_log.append({ + "endpoint": endpoint, + "ip": ip, + "timestamp": datetime.now(timezone.utc).isoformat(), + }) ip_stats = _rate_limit_stats[key].setdefault("by_ip", {}) if ip not in ip_stats: ip_stats[ip] = {"total": 0, "rejected": 0} @@ -465,11 +474,46 @@ class TrackCompanyRequest(BaseModel): company_name: CompanyName = Field(...) +@app.get("/tracked", tags=["Tracked Companies"]) +async def list_my_tracked_companies( + current_user: UserResponse = Depends(get_current_user), +): + """List tracked companies for the current user.""" + db = get_db_client() + return db.list_tracked_companies(owner_id=current_user.id) + + +@app.post("/tracked", tags=["Tracked Companies"]) +async def add_my_tracked_company( + request: TrackCompanyRequest, + current_user: UserResponse = Depends(get_current_user), +): + """Add a company to the current user's tracked list.""" + db = get_db_client() + result = db.add_tracked_company(request.company_name, owner_id=current_user.id) + if not result: + raise HTTPException(status_code=409, detail="Company already tracked") + return result + + +@app.delete("/tracked/{company_name}", tags=["Tracked Companies"]) +async def remove_my_tracked_company( + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], + current_user: UserResponse = Depends(get_current_user), +): + """Remove a company from the current user's tracked list.""" + db = get_db_client() + removed = db.remove_tracked_company(company_name, owner_id=current_user.id) + if not removed: + raise HTTPException(status_code=404, detail="Company not found in tracking list") + return {"message": f"Stopped tracking {company_name}"} + + @app.get("/admin/tracked", tags=["Admin"]) async def list_tracked_companies( _: UserResponse = Depends(get_current_admin), ): - """List all tracked companies (admin only).""" + """List all tracked companies across all users (admin only).""" db = get_db_client() return db.list_tracked_companies() @@ -477,11 +521,11 @@ async def list_tracked_companies( @app.post("/admin/tracked", tags=["Admin"]) async def add_tracked_company( request: TrackCompanyRequest, - _: UserResponse = Depends(get_current_admin), + current_admin: UserResponse = Depends(get_current_admin), ): - """Add a company to the tracked list (admin only).""" + """Add a company to the tracked list (admin only, owned by admin).""" db = get_db_client() - result = db.add_tracked_company(request.company_name) + result = db.add_tracked_company(request.company_name, owner_id=current_admin.id) if not result: raise HTTPException(status_code=409, detail="Company already tracked") return result @@ -492,7 +536,7 @@ async def remove_tracked_company( company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], _: UserResponse = Depends(get_current_admin), ): - """Remove a company from the tracked list (admin only).""" + """Remove a company from the tracked list (admin only, any owner).""" db = get_db_client() removed = db.remove_tracked_company(company_name) if not removed: @@ -507,10 +551,12 @@ async def get_rate_limit_stats( """Get rate limit status and usage statistics (admin only). Returns current rate limit configuration and request statistics - for all rate-limited endpoints. + for all rate-limited endpoints, including per-IP breakdown and + a time-series of throttled (rejected) requests in the last 24 hours. Returns: - List of rate limit stats per endpoint with total/rejected counts + Rate limit stats per endpoint, per-IP breakdown, and throttled + request history bucketed by hour. """ rate_limits_config = { "/auth/register": {"limit": "5/minute"}, @@ -520,14 +566,45 @@ async def get_rate_limit_stats( results = [] for endpoint, conf in rate_limits_config.items(): stats = _rate_limit_stats.get(endpoint, {}) + by_ip_raw = stats.get("by_ip", {}) + by_ip = [ + {"ip": ip, "total": counts["total"], "rejected": counts["rejected"]} + for ip, counts in by_ip_raw.items() + ] results.append({ "endpoint": endpoint, "limit": conf["limit"], "total_requests": stats.get("total_requests", 0), "rejected_requests": stats.get("rejected_requests", 0), + "by_ip": by_ip, }) - return {"rate_limits": results} + # Build hourly buckets of throttled requests for the last 24 hours + now = datetime.now(timezone.utc) + cutoff = now - timedelta(hours=24) + hourly_buckets: dict[str, int] = {} + throttled_24h = 0 + for entry in _rejected_log: + ts_str = entry["timestamp"] + try: + ts = datetime.fromisoformat(ts_str) + except (ValueError, TypeError): + continue + if ts >= cutoff: + throttled_24h += 1 + bucket = ts.strftime("%Y-%m-%dT%H:00:00Z") + hourly_buckets[bucket] = hourly_buckets.get(bucket, 0) + 1 + + throttled_over_time = [ + {"timestamp": k, "count": v} + for k, v in sorted(hourly_buckets.items()) + ] + + return { + "rate_limits": results, + "throttled_24h": throttled_24h, + "throttled_over_time": throttled_over_time, + } @app.get("/admin/alerts", tags=["Admin"]) @@ -540,17 +617,86 @@ async def list_alerts( return db.list_alerts(limit=limit) +# ============== Admin-Scoped Data Endpoints ============== + + +@app.get("/admin/analyses", response_model=PaginatedAnalysisResponse, tags=["Admin"]) +async def admin_list_analyses( + company_name: Annotated[ + str | None, + Query(description="Filter results by company name"), + ] = None, + limit: Annotated[int, Query(ge=1, le=200)] = 50, + cursor: Annotated[ + str | None, + Query(description="Opaque cursor from a previous response's next_cursor field"), + ] = None, + _: UserResponse = Depends(get_current_admin), +): + """List all analysis results across all users (admin only).""" + db = _get_job_db() + rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor) + + has_next = len(rows) > limit + if has_next: + rows = rows[:limit] + + items = [AnalysisRecord(**row) for row in rows] + + next_cursor = None + if has_next and rows: + last = rows[-1] + ts = last["timestamp"] + ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts) + next_cursor = f"{ts_str}|{last['id']}" + + return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor) + + +@app.get("/admin/jobs", response_model=PaginatedJobsResponse, tags=["Admin"]) +async def admin_list_jobs( + status: Annotated[ + str | None, + Query(description="Filter by status: pending, running, completed, failed"), + ] = None, + limit: Annotated[int, Query(ge=1, le=200)] = 50, + cursor: Annotated[ + str | None, + Query(description="Opaque cursor from a previous response's next_cursor field"), + ] = None, + _: UserResponse = Depends(get_current_admin), +): + """List all jobs across all users (admin only).""" + db = _get_job_db() + job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor) + + has_next = len(job_rows) > limit + if has_next: + job_rows = job_rows[:limit] + + items = [_job_row_to_status(row) for row in job_rows] + + next_cursor = None + if has_next and job_rows: + last = job_rows[-1] + created = last["created_at"] + ts = created.isoformat() if hasattr(created, "isoformat") else str(created) + next_cursor = f"{ts}|{last['job_id']}" + + return PaginatedJobsResponse(items=items, next_cursor=next_cursor) + + # ============== Analytics Endpoint ============== @app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"]) async def get_analytics( days: int = Query(default=30, ge=1, le=365), - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """Get analytics data (authenticated users only).""" + """Get analytics data scoped to the current user.""" db = get_db_client() - analytics = db.get_analytics(days=days) + analytics = db.get_analytics(days=days, owner_id=current_user.id) return AnalyticsResponse( total_messages=analytics["total_messages"], @@ -603,9 +749,9 @@ async def list_models(): @app.get("/analytics/trends", tags=["Analytics"]) async def get_analytics_trends( days: int = Query(default=90, ge=7, le=365), - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """Get trend data for patent analysis over time. + """Get trend data for patent analysis over time (scoped to current user). Returns two datasets: - ``by_month``: analysis count per company per month @@ -619,11 +765,14 @@ async def get_analytics_trends( """ db = get_db_client() + owner_filter = " AND owner_id = %s" if current_user else "" + owner_params = (current_user.id,) if current_user else () + with db.get_conn() as conn: with conn.cursor() as cur: # Analyses per company per month cur.execute( - """ + f""" SELECT TO_CHAR(timestamp, 'YYYY-MM') AS month, company_name, @@ -632,16 +781,17 @@ async def get_analytics_trends( WHERE timestamp >= NOW() - INTERVAL '%s days' AND is_cached = FALSE AND company_name IS NOT NULL + {owner_filter} GROUP BY month, company_name ORDER BY month """, - (days,), + (days, *owner_params), ) by_month_rows = cur.fetchall() # Analysis type distribution per month cur.execute( - """ + f""" SELECT TO_CHAR(timestamp, 'YYYY-MM') AS month, analysis_type, @@ -649,10 +799,11 @@ async def get_analytics_trends( FROM llm_messages WHERE timestamp >= NOW() - INTERVAL '%s days' AND is_cached = FALSE + {owner_filter} GROUP BY month, analysis_type ORDER BY month """, - (days,), + (days, *owner_params), ) by_type_rows = cur.fetchall() @@ -678,9 +829,9 @@ async def get_analytics_trends( @app.get("/export/{company_name}", tags=["Export"]) async def export_company_csv( company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """Export analysis results for a company as a CSV file. + """Export analysis results for a company as a CSV file (scoped to current user). Returns all stored analysis records for the given company, including analysis type, model used, response text, and timestamp. @@ -695,7 +846,7 @@ async def export_company_csv( import io db = get_db_client() - # Query all non-cached analysis results for this company + # Query all non-cached analysis results for this company owned by current user with db.get_conn() as conn: with conn.cursor() as cur: cur.execute( @@ -703,9 +854,10 @@ async def export_company_csv( SELECT company_name, analysis_type, model, response, timestamp FROM llm_messages WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE + AND owner_id = %s ORDER BY timestamp DESC """, - (company_name,), + (company_name, current_user.id), ) rows = cur.fetchall() @@ -730,9 +882,9 @@ async def export_company_csv( @app.get("/export/{company_name}/pdf", tags=["Export"]) async def export_company_pdf( company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """Export analysis results for a company as a formatted PDF report. + """Export analysis results for a company as a formatted PDF report (scoped to current user). Returns all stored analysis records for the given company, including analysis type, model used, response text, and timestamp, formatted @@ -766,9 +918,10 @@ async def export_company_pdf( SELECT company_name, analysis_type, model, response, timestamp FROM llm_messages WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE + AND owner_id = %s ORDER BY timestamp DESC """, - (company_name,), + (company_name, current_user.id), ) rows = cur.fetchall() @@ -897,68 +1050,6 @@ async def health_check(): ) -@app.get( - "/analyze/{company_name}", - response_model=CompanyAnalysisResponse, - tags=["Analysis"], -) -async def analyze_company( - company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], - model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."), - _: UserResponse = Depends(get_current_user), -): - """Analyze a single company's patent portfolio. - - This endpoint retrieves recent patents for the specified company, - parses them, and uses AI to generate a comprehensive analysis. - - Args: - company_name: Name of the company to analyze (e.g., "nvidia", "intel") - model: Optional LLM model override - - Returns: - Analysis results including patent count, AI insights, and success status - """ - _validate_model(model) - if not _analyzer: - raise HTTPException(status_code=503, detail="Analyzer not initialized") - - result = _analyzer._analyze_company_safe(company_name, model=model) - return _convert_result(result) - - -@app.get( - "/analyze/patent/{patent_id}", - tags=["Analysis"], -) -async def analyze_single_patent( - patent_id: str, - company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")], - _: UserResponse = Depends(get_current_user), -): - """Analyze a single patent by its publication ID. - - If the patent PDF is not already cached locally, the system will attempt - to download it automatically from a previously cached link. If no link - is available, a 404 error is returned. - - Args: - patent_id: Patent publication ID (e.g. "US-11234567-B2") - company_name: Company name for analysis context - - Returns: - Analysis text for the patent - """ - if not _analyzer: - raise HTTPException(status_code=503, detail="Analyzer not initialized") - - try: - analysis = _analyzer.analyze_single_patent(patent_id, company_name) - return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) - - @app.get( "/analyze/batch", response_model=PaginatedAnalysisResponse, @@ -974,9 +1065,9 @@ async def list_analysis_results( str | None, Query(description="Opaque cursor from a previous response's next_cursor field"), ] = None, - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """List stored analysis results with cursor-based pagination. + """List stored analysis results with cursor-based pagination (scoped to current user). Returns past analysis results ordered by timestamp descending. Use ``limit`` to control page size (default 50, max 200). The response @@ -993,7 +1084,7 @@ async def list_analysis_results( Paginated list of analysis results """ db = _get_job_db() - rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor) + rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor, owner_id=current_user.id) has_next = len(rows) > limit if has_next: @@ -1043,6 +1134,68 @@ async def analyze_companies_batch( return _convert_batch_result(result) +@app.get( + "/analyze/patent/{patent_id}", + tags=["Analysis"], +) +async def analyze_single_patent( + patent_id: str, + company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")], + _: UserResponse = Depends(get_current_user), +): + """Analyze a single patent by its publication ID. + + If the patent PDF is not already cached locally, the system will attempt + to download it automatically from a previously cached link. If no link + is available, a 404 error is returned. + + Args: + patent_id: Patent publication ID (e.g. "US-11234567-B2") + company_name: Company name for analysis context + + Returns: + Analysis text for the patent + """ + if not _analyzer: + raise HTTPException(status_code=503, detail="Analyzer not initialized") + + try: + analysis = _analyzer.analyze_single_patent(patent_id, company_name) + return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis} + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@app.get( + "/analyze/{company_name}", + response_model=CompanyAnalysisResponse, + tags=["Analysis"], +) +async def analyze_company( + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], + model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."), + _: UserResponse = Depends(get_current_user), +): + """Analyze a single company's patent portfolio. + + This endpoint retrieves recent patents for the specified company, + parses them, and uses AI to generate a comprehensive analysis. + + Args: + company_name: Name of the company to analyze (e.g., "nvidia", "intel") + model: Optional LLM model override + + Returns: + Analysis results including patent count, AI insights, and success status + """ + _validate_model(model) + if not _analyzer: + raise HTTPException(status_code=503, detail="Analyzer not initialized") + + result = _analyzer._analyze_company_safe(company_name, model=model) + return _convert_result(result) + + def _get_job_db() -> "DatabaseClient": """Get a DatabaseClient for job persistence.""" from SPARC.database import DatabaseClient @@ -1129,7 +1282,7 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: s async def analyze_companies_async( request: BatchAnalysisRequest, background_tasks: BackgroundTasks, - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): """Start an asynchronous batch analysis job. @@ -1149,7 +1302,7 @@ async def analyze_companies_async( job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}" db = _get_job_db() - job_row = db.create_job(job_id=job_id, total_companies=len(request.companies)) + job_row = db.create_job(job_id=job_id, total_companies=len(request.companies), owner_id=current_user.id) background_tasks.add_task( _run_batch_job, job_id, request.companies, request.max_workers, request.model @@ -1161,9 +1314,9 @@ async def analyze_companies_async( @app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"]) async def get_job_status( job_id: str, - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """Get the status of a background analysis job. + """Get the status of a background analysis job (scoped to current user). Args: job_id: The job ID returned from the async batch endpoint @@ -1172,7 +1325,7 @@ async def get_job_status( Current job status including progress and results when complete """ db = _get_job_db() - job_row = db.get_job(job_id) + job_row = db.get_job(job_id, owner_id=current_user.id) if not job_row: raise HTTPException(status_code=404, detail=f"Job {job_id} not found") @@ -1191,9 +1344,9 @@ async def list_jobs( str | None, Query(description="Opaque cursor from a previous response's next_cursor field"), ] = None, - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """List analysis jobs with cursor-based pagination. + """List analysis jobs with cursor-based pagination (scoped to current user). Pass ``limit`` to control page size. The response includes a ``next_cursor`` field; pass it back as the ``cursor`` query parameter to fetch the next page. @@ -1212,7 +1365,7 @@ async def list_jobs( """ db = _get_job_db() # Fetch one extra to determine if there is a next page - job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor) + job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor, owner_id=current_user.id) has_next = len(job_rows) > limit if has_next: diff --git a/SPARC/database.py b/SPARC/database.py index 0759a66..a393162 100644 --- a/SPARC/database.py +++ b/SPARC/database.py @@ -196,7 +196,7 @@ class DatabaseClient: cursor.execute(""" CREATE TABLE IF NOT EXISTS tracked_companies ( id SERIAL PRIMARY KEY, - company_name VARCHAR(255) UNIQUE NOT NULL, + company_name VARCHAR(255) NOT NULL, last_patent_count INTEGER DEFAULT 0, last_analysis_at TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP @@ -221,6 +221,68 @@ class DatabaseClient: ON alerts(company_name) """) + # ---- Multi-tenant: add owner_id columns if missing ---- + cursor.execute(""" + DO $$ + BEGIN + -- llm_messages.owner_id + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'llm_messages' AND column_name = 'owner_id' + ) THEN + ALTER TABLE llm_messages ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + -- jobs.owner_id + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'jobs' AND column_name = 'owner_id' + ) THEN + ALTER TABLE jobs ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + -- tracked_companies.owner_id + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'tracked_companies' AND column_name = 'owner_id' + ) THEN + ALTER TABLE tracked_companies ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + END $$; + """) + + # Indexes for owner_id filtering + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_messages_owner + ON llm_messages(owner_id) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_jobs_owner + ON jobs(owner_id) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_tracked_companies_owner + ON tracked_companies(owner_id) + """) + + # Drop the old unique constraint on company_name alone (if it exists) + # and replace with a per-owner unique constraint so different users + # can track the same company independently. + cursor.execute(""" + DO $$ + BEGIN + IF EXISTS ( + SELECT 1 FROM pg_constraint + WHERE conname = 'tracked_companies_company_name_key' + ) THEN + ALTER TABLE tracked_companies + DROP CONSTRAINT tracked_companies_company_name_key; + END IF; + END $$; + """) + cursor.execute(""" + CREATE UNIQUE INDEX IF NOT EXISTS uq_tracked_company_owner + ON tracked_companies(LOWER(company_name), owner_id) + """) + self.conn.commit() @staticmethod @@ -289,6 +351,7 @@ class DatabaseClient: metadata: Optional[Dict] = None, token_usage: Optional[Dict] = None, is_cached: bool = False, + owner_id: Optional[int] = None, ) -> int: """Store an LLM message exchange in the database. @@ -301,6 +364,7 @@ class DatabaseClient: metadata: Additional metadata as dict token_usage: Token usage information is_cached: Whether this response was served from cache + owner_id: ID of the user who owns this record Returns: The ID of the inserted record @@ -312,8 +376,8 @@ class DatabaseClient: cursor.execute( """ INSERT INTO llm_messages - (prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) + (prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached, owner_id) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING id """, ( @@ -326,6 +390,7 @@ class DatabaseClient: json.dumps(metadata) if metadata else None, json.dumps(token_usage) if token_usage else None, is_cached, + owner_id, ), ) @@ -340,6 +405,7 @@ class DatabaseClient: analysis_type: Optional[str] = None, limit: int = 100, offset: int = 0, + owner_id: Optional[int] = None, ) -> List[Dict]: """Retrieve messages from the database. @@ -348,6 +414,7 @@ class DatabaseClient: analysis_type: Filter by analysis type limit: Maximum number of records to return offset: Number of records to skip + owner_id: Filter by owner (None returns all, for admin use) Returns: List of message dictionaries @@ -355,6 +422,10 @@ class DatabaseClient: query = "SELECT * FROM llm_messages WHERE 1=1" params = [] + if owner_id is not None: + query += " AND owner_id = %s" + params.append(owner_id) + if company_name: query += " AND company_name = %s" params.append(company_name) @@ -376,6 +447,7 @@ class DatabaseClient: company_name: Optional[str] = None, limit: int = 50, cursor: Optional[str] = None, + owner_id: Optional[int] = None, ) -> List[Dict]: """List analysis results with cursor-based pagination. @@ -383,6 +455,7 @@ class DatabaseClient: company_name: Optional filter by company name. limit: Maximum number of records to return. cursor: Opaque cursor (``timestamp|id``) from a previous response. + owner_id: Filter by owner (None returns all, for admin use). Returns: List of analysis dicts ordered by timestamp descending. @@ -390,6 +463,10 @@ class DatabaseClient: conditions: list[str] = ["is_cached = FALSE"] params: list = [] + if owner_id is not None: + conditions.append("owner_id = %s") + params.append(owner_id) + if company_name: conditions.append("LOWER(company_name) = LOWER(%s)") params.append(company_name) @@ -413,52 +490,62 @@ class DatabaseClient: cur.execute(query, params) return [dict(row) for row in cur.fetchall()] - def get_analytics(self, days: int = 30) -> Dict: + def get_analytics(self, days: int = 30, owner_id: Optional[int] = None) -> Dict: """Get analytics on message usage. Args: days: Number of days to look back + owner_id: Filter by owner (None returns all, for admin use) Returns: Dictionary with analytics data """ + owner_filter = "" + owner_params: list = [] + if owner_id is not None: + owner_filter = " AND owner_id = %s" + owner_params = [owner_id] + with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: # Total messages cursor.execute( - """ + f""" SELECT COUNT(*) as total_messages FROM llm_messages WHERE timestamp >= NOW() - INTERVAL '%s days' + {owner_filter} """, - (days,), + (days, *owner_params), ) total = cursor.fetchone()["total_messages"] # Messages by company cursor.execute( - """ + f""" SELECT company_name, COUNT(*) as count FROM llm_messages WHERE timestamp >= NOW() - INTERVAL '%s days' + {owner_filter} GROUP BY company_name ORDER BY count DESC LIMIT 10 """, - (days,), + (days, *owner_params), ) by_company = cursor.fetchall() # Messages by type cursor.execute( - """ + f""" SELECT analysis_type, COUNT(*) as count FROM llm_messages WHERE timestamp >= NOW() - INTERVAL '%s days' + {owner_filter} GROUP BY analysis_type ORDER BY count DESC """, - (days,), + (days, *owner_params), ) by_type = cursor.fetchall() @@ -556,12 +643,14 @@ class DatabaseClient: self, job_id: str, total_companies: int, + owner_id: Optional[int] = None, ) -> Dict: """Create a new job record. Args: job_id: Unique job identifier total_companies: Number of companies in the batch + owner_id: ID of the user who owns this job Returns: Job dict @@ -570,11 +659,11 @@ class DatabaseClient: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( """ - INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies) - VALUES (%s, 'pending', 0, %s, 0) + INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies, owner_id) + VALUES (%s, 'pending', 0, %s, 0, %s) RETURNING * """, - (job_id, total_companies), + (job_id, total_companies, owner_id), ) job = cursor.fetchone() conn.commit() @@ -627,11 +716,22 @@ class DatabaseClient: conn.commit() return dict(job) if job else None - def get_job(self, job_id: str) -> Optional[Dict]: - """Get a job by ID.""" + def get_job(self, job_id: str, owner_id: Optional[int] = None) -> Optional[Dict]: + """Get a job by ID. + + Args: + job_id: Job identifier. + owner_id: When provided, only return the job if it belongs to this owner. + """ + query = "SELECT * FROM jobs WHERE job_id = %s" + params: list = [job_id] + if owner_id is not None: + query += " AND owner_id = %s" + params.append(owner_id) + with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,)) + cursor.execute(query, params) job = cursor.fetchone() return dict(job) if job else None @@ -640,6 +740,7 @@ class DatabaseClient: status: Optional[str] = None, limit: int = 10, cursor: Optional[str] = None, + owner_id: Optional[int] = None, ) -> List[Dict]: """List jobs with optional status filter and cursor-based pagination. @@ -649,6 +750,7 @@ class DatabaseClient: cursor: Opaque cursor (``created_at|job_id``) from a previous response. When provided, only jobs older than the cursor are returned. + owner_id: Filter by owner (None returns all, for admin use). Returns: List of job dicts ordered by created_at descending. @@ -656,6 +758,10 @@ class DatabaseClient: conditions: list[str] = [] params: list = [] + if owner_id is not None: + conditions.append("owner_id = %s") + params.append(owner_id) + if status: conditions.append("status = %s") params.append(status) @@ -902,14 +1008,21 @@ class DatabaseClient: # Tracked Companies Methods - def add_tracked_company(self, company_name: str) -> Optional[Dict]: - """Add a company to the tracking list.""" + def add_tracked_company( + self, company_name: str, owner_id: Optional[int] = None + ) -> Optional[Dict]: + """Add a company to the tracking list. + + Args: + company_name: Company name to track. + owner_id: ID of the user who owns this tracked company. + """ with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: try: cursor.execute( - "INSERT INTO tracked_companies (company_name) VALUES (%s) RETURNING *", - (company_name,), + "INSERT INTO tracked_companies (company_name, owner_id) VALUES (%s, %s) RETURNING *", + (company_name, owner_id), ) row = cursor.fetchone() conn.commit() @@ -918,22 +1031,45 @@ class DatabaseClient: conn.rollback() return None - def remove_tracked_company(self, company_name: str) -> bool: - """Remove a company from the tracking list.""" + def remove_tracked_company( + self, company_name: str, owner_id: Optional[int] = None + ) -> bool: + """Remove a company from the tracking list. + + Args: + company_name: Company name to remove. + owner_id: When provided, only remove if owned by this user. + """ + query = "DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)" + params: list = [company_name] + if owner_id is not None: + query += " AND owner_id = %s" + params.append(owner_id) + with self.get_conn() as conn: with conn.cursor() as cursor: - cursor.execute( - "DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)", - (company_name,), - ) + cursor.execute(query, params) conn.commit() return cursor.rowcount > 0 - def list_tracked_companies(self) -> List[Dict]: - """List all tracked companies.""" + def list_tracked_companies( + self, owner_id: Optional[int] = None + ) -> List[Dict]: + """List tracked companies. + + Args: + owner_id: Filter by owner (None returns all, for admin/scheduler use). + """ + query = "SELECT * FROM tracked_companies" + params: list = [] + if owner_id is not None: + query += " WHERE owner_id = %s" + params.append(owner_id) + query += " ORDER BY company_name" + with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute("SELECT * FROM tracked_companies ORDER BY company_name") + cursor.execute(query, params) return [dict(row) for row in cursor.fetchall()] def update_tracked_company( diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index d7ec5ba..41883b0 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -11,6 +11,7 @@ import { Batch } from './pages/Batch'; import { AnalyticsPage } from './pages/Analytics'; import { About } from './pages/About'; import { AdminUsers } from './pages/AdminUsers'; +import { AdminRateLimits } from './pages/AdminRateLimits'; import { Compare } from './pages/Compare'; const queryClient = new QueryClient({ @@ -56,6 +57,14 @@ function App() { } /> + + + + } + /> {/* Default redirect */} diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 09a4ae6..bbfe6cb 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -201,6 +201,32 @@ export const analyticsApi = { }, }; +// Rate limit types +export interface RateLimitIpEntry { + ip: string; + total: number; + rejected: number; +} + +export interface RateLimitEndpointStats { + endpoint: string; + limit: string; + total_requests: number; + rejected_requests: number; + by_ip: RateLimitIpEntry[]; +} + +export interface ThrottledBucket { + timestamp: string; + count: number; +} + +export interface RateLimitStatsResponse { + rate_limits: RateLimitEndpointStats[]; + throttled_24h: number; + throttled_over_time: ThrottledBucket[]; +} + // Admin API export const adminApi = { listUsers: async (limit = 100, offset = 0): Promise => { @@ -216,6 +242,11 @@ export const adminApi = { deleteUser: async (userId: number): Promise => { await api.delete(`/admin/users/${userId}`); }, + + getRateLimits: async (): Promise => { + const response = await api.get('/admin/rate-limits'); + return response.data; + }, }; export default api; diff --git a/frontend/src/components/Layout.tsx b/frontend/src/components/Layout.tsx index d0df715..d1bfe41 100644 --- a/frontend/src/components/Layout.tsx +++ b/frontend/src/components/Layout.tsx @@ -1,7 +1,7 @@ import { Outlet, NavLink, useNavigate } from 'react-router-dom'; import { useAuth } from '../context/AuthContext'; import { useTheme } from '../context/ThemeContext'; -import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon } from 'lucide-react'; +import { Search, Layers, BarChart3, Info, Users, LogOut, GitCompareArrows, Sun, Moon, ShieldAlert } from 'lucide-react'; export function Layout() { const { user, isAdmin, logout } = useAuth(); @@ -23,6 +23,7 @@ export function Layout() { if (isAdmin) { navItems.push({ to: '/admin/users', icon: Users, label: 'Users' }); + navItems.push({ to: '/admin/rate-limits', icon: ShieldAlert, label: 'Rate Limits' }); } return ( diff --git a/frontend/src/pages/AdminRateLimits.tsx b/frontend/src/pages/AdminRateLimits.tsx new file mode 100644 index 0000000..97b41c4 --- /dev/null +++ b/frontend/src/pages/AdminRateLimits.tsx @@ -0,0 +1,240 @@ +import { useState } from 'react'; +import { useQuery } from '@tanstack/react-query'; +import { adminApi } from '../api/client'; +import type { RateLimitStatsResponse } from '../api/client'; +import { ShieldAlert, Activity, AlertCircle, RefreshCw, Clock } from 'lucide-react'; + +const REFRESH_OPTIONS = [ + { label: '15s', value: 15_000 }, + { label: '30s', value: 30_000 }, + { label: '1m', value: 60_000 }, + { label: 'Off', value: 0 }, +]; + +export function AdminRateLimits() { + const [refreshInterval, setRefreshInterval] = useState(30_000); + + const { data, isLoading, isError, dataUpdatedAt } = useQuery({ + queryKey: ['admin-rate-limits'], + queryFn: () => adminApi.getRateLimits(), + refetchInterval: refreshInterval || false, + }); + + if (isLoading) { + return ( +
+
+
+ ); + } + + if (isError) { + return ( +
+ + Failed to load rate limit statistics. +
+ ); + } + + const maxThrottledCount = data?.throttled_over_time?.length + ? Math.max(...data.throttled_over_time.map((b) => b.count)) + : 0; + + return ( +
+ {/* Header */} +
+
+

+ Rate Limiting Dashboard +

+

Monitor API rate limits and throttled requests.

+
+
+ {/* Last updated */} + {dataUpdatedAt > 0 && ( + + + Updated {new Date(dataUpdatedAt).toLocaleTimeString()} + + )} + {/* Refresh interval selector */} +
+ + {REFRESH_OPTIONS.map((opt) => ( + + ))} +
+
+
+ + {/* Summary cards */} +
+
+
+ + + Total Requests + +
+
+ {data?.rate_limits.reduce((sum, rl) => sum + rl.total_requests, 0) ?? 0} +
+
+
+
+ + + Throttled (24h) + +
+
+ {data?.throttled_24h ?? 0} +
+
+
+
+ + + Rate-Limited Endpoints + +
+
+ {data?.rate_limits.length ?? 0} +
+
+
+ + {/* Throttled over time chart (simple bar chart) */} + {data?.throttled_over_time && data.throttled_over_time.length > 0 && ( +
+

+ Throttled Requests Over Time (Last 24h) +

+
+ {data.throttled_over_time.map((bucket) => { + const height = maxThrottledCount > 0 ? (bucket.count / maxThrottledCount) * 100 : 0; + const hour = new Date(bucket.timestamp).getHours(); + return ( +
+ {bucket.count} +
+ {hour}:00 +
+ ); + })} +
+
+ )} + + {/* Per-endpoint table */} +
+
+ + + + + + + + + + + {data?.rate_limits.map((rl) => ( + + + + + + + ))} + +
+ Endpoint + + Limit + + Total Requests + + Rejected +
{rl.endpoint} + + {rl.limit} + + + {rl.total_requests} + + 0 ? 'text-error font-semibold' : 'text-text-secondary'}> + {rl.rejected_requests} + +
+
+
+ + {/* Per-IP breakdown */} + {data?.rate_limits.some((rl) => rl.by_ip.length > 0) && ( +
+
+

+ Per-IP Breakdown +

+
+
+ + + + + + + + + + + {data.rate_limits.flatMap((rl) => + rl.by_ip.map((ipEntry) => ( + + + + + + + )) + )} + +
+ Endpoint + + IP Address + + Total + + Rejected +
{rl.endpoint}{ipEntry.ip}{ipEntry.total} + 0 ? 'text-error font-semibold' : 'text-text-secondary'}> + {ipEntry.rejected} + +
+
+
+ )} +
+ ); +} diff --git a/scripts/migrate_add_owner_id.py b/scripts/migrate_add_owner_id.py new file mode 100644 index 0000000..3e0ea53 --- /dev/null +++ b/scripts/migrate_add_owner_id.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +"""Migration: add owner_id columns and backfill existing rows. + +This script adds an ``owner_id`` column (FK to ``users``) to the +``llm_messages``, ``jobs``, and ``tracked_companies`` tables, then +backfills all existing rows with ``owner_id = 1`` (the default admin user). + +It also replaces the old global UNIQUE constraint on +``tracked_companies.company_name`` with a per-owner unique index so that +different users can independently track the same company. + +Usage: + python scripts/migrate_add_owner_id.py + +The script is idempotent — running it multiple times is safe. +""" + +import os +import sys + +import psycopg2 + +DATABASE_URL = os.getenv( + "DATABASE_URL", + "postgresql://postgres:postgres@localhost:5432/sparc", +) + +DEFAULT_OWNER_ID = 1 + + +def run_migration(): + """Execute the migration.""" + conn = psycopg2.connect(DATABASE_URL) + conn.autocommit = False + + try: + with conn.cursor() as cur: + # ---------- 1. Add owner_id columns if missing ---------- + cur.execute(""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'llm_messages' AND column_name = 'owner_id' + ) THEN + ALTER TABLE llm_messages ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'jobs' AND column_name = 'owner_id' + ) THEN + ALTER TABLE jobs ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'tracked_companies' AND column_name = 'owner_id' + ) THEN + ALTER TABLE tracked_companies ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + END $$; + """) + + # ---------- 2. Backfill owner_id = DEFAULT_OWNER_ID ---------- + cur.execute( + "UPDATE llm_messages SET owner_id = %s WHERE owner_id IS NULL", + (DEFAULT_OWNER_ID,), + ) + messages_updated = cur.rowcount + print(f" llm_messages: backfilled {messages_updated} rows") + + cur.execute( + "UPDATE jobs SET owner_id = %s WHERE owner_id IS NULL", + (DEFAULT_OWNER_ID,), + ) + jobs_updated = cur.rowcount + print(f" jobs: backfilled {jobs_updated} rows") + + cur.execute( + "UPDATE tracked_companies SET owner_id = %s WHERE owner_id IS NULL", + (DEFAULT_OWNER_ID,), + ) + tracked_updated = cur.rowcount + print(f" tracked_companies: backfilled {tracked_updated} rows") + + # ---------- 3. Create indexes ---------- + cur.execute(""" + CREATE INDEX IF NOT EXISTS idx_messages_owner + ON llm_messages(owner_id) + """) + cur.execute(""" + CREATE INDEX IF NOT EXISTS idx_jobs_owner + ON jobs(owner_id) + """) + cur.execute(""" + CREATE INDEX IF NOT EXISTS idx_tracked_companies_owner + ON tracked_companies(owner_id) + """) + + # ---------- 4. Replace unique constraint on tracked_companies ---------- + cur.execute(""" + DO $$ + BEGIN + IF EXISTS ( + SELECT 1 FROM pg_constraint + WHERE conname = 'tracked_companies_company_name_key' + ) THEN + ALTER TABLE tracked_companies + DROP CONSTRAINT tracked_companies_company_name_key; + END IF; + END $$; + """) + cur.execute(""" + CREATE UNIQUE INDEX IF NOT EXISTS uq_tracked_company_owner + ON tracked_companies(LOWER(company_name), owner_id) + """) + + conn.commit() + print("Migration completed successfully.") + + except Exception: + conn.rollback() + print("Migration FAILED — rolled back.", file=sys.stderr) + raise + finally: + conn.close() + + +if __name__ == "__main__": + print(f"Running owner_id migration against {DATABASE_URL.split('@')[-1]} ...") + run_migration() diff --git a/tests/test_api.py b/tests/test_api.py index fd16921..e1def71 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,12 +1,13 @@ """Tests for FastAPI web service endpoints.""" -from datetime import datetime -from unittest.mock import Mock +from datetime import datetime, timezone +from unittest.mock import Mock, MagicMock, patch import pytest from fastapi.testclient import TestClient from SPARC.api import app +from SPARC.auth import create_access_token from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult @@ -16,6 +17,22 @@ def client(): return TestClient(app) +@pytest.fixture(autouse=True) +def mock_db(): + """Mock the database client used by auth endpoints.""" + db = MagicMock() + db.get_user_by_id.return_value = { + "id": 1, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + with patch("SPARC.api.get_db_client", return_value=db), \ + patch("SPARC.auth.get_db_client", return_value=db): + yield db + + @pytest.fixture def mock_analyzer(mocker): """Mock the global analyzer.""" @@ -24,6 +41,12 @@ def mock_analyzer(mocker): return mock +def _auth_header(user_id=1, email="user@test.com", role="user"): + """Create an Authorization header with a valid access token.""" + token = create_access_token(user_id, email, role) + return {"Authorization": f"Bearer {token}"} + + class TestHealthEndpoint: """Test health check endpoint.""" @@ -51,7 +74,7 @@ class TestAnalyzeCompanyEndpoint: ) mock_analyzer._analyze_company_safe.return_value = mock_result - response = client.get("/analyze/nvidia") + response = client.get("/analyze/nvidia", headers=_auth_header()) assert response.status_code == 200 data = response.json() @@ -72,7 +95,7 @@ class TestAnalyzeCompanyEndpoint: ) mock_analyzer._analyze_company_safe.return_value = mock_result - response = client.get("/analyze/unknown") + response = client.get("/analyze/unknown", headers=_auth_header()) assert response.status_code == 200 data = response.json() @@ -113,6 +136,7 @@ class TestBatchAnalysisEndpoint: response = client.post( "/analyze/batch", json={"companies": ["nvidia", "amd"], "max_workers": 2}, + headers=_auth_header(), ) assert response.status_code == 200 @@ -125,13 +149,14 @@ class TestBatchAnalysisEndpoint: def test_batch_analysis_validation(self, client): """Test batch analysis request validation.""" # Empty companies list - response = client.post("/analyze/batch", json={"companies": []}) + response = client.post("/analyze/batch", json={"companies": []}, headers=_auth_header()) assert response.status_code == 422 # Too many companies response = client.post( "/analyze/batch", json={"companies": [f"company{i}" for i in range(25)]}, + headers=_auth_header(), ) assert response.status_code == 422 @@ -139,6 +164,7 @@ class TestBatchAnalysisEndpoint: response = client.post( "/analyze/batch", json={"companies": ["nvidia"], "max_workers": 10}, + headers=_auth_header(), ) assert response.status_code == 422 @@ -146,11 +172,26 @@ class TestBatchAnalysisEndpoint: class TestAsyncBatchEndpoint: """Test async batch analysis endpoint.""" - def test_async_batch_creates_job(self, client, mock_analyzer): - """Test async endpoint creates a job.""" + @patch("SPARC.api._get_job_db") + def test_async_batch_creates_job(self, mock_get_db, client, mock_analyzer): + """Test async endpoint creates a job with owner_id.""" + job_db = MagicMock() + job_db.create_job.return_value = { + "job_id": "j1", + "status": "pending", + "progress": 0, + "total_companies": 2, + "completed_companies": 0, + "result_json": None, + "error": None, + "owner_id": 1, + } + mock_get_db.return_value = job_db + response = client.post( "/analyze/batch/async", json={"companies": ["nvidia", "amd"]}, + headers=_auth_header(), ) assert response.status_code == 200 @@ -159,28 +200,42 @@ class TestAsyncBatchEndpoint: assert data["status"] == "pending" assert data["total_companies"] == 2 assert data["progress"] == 0 + # Verify owner_id was passed + job_db.create_job.assert_called_once() + assert job_db.create_job.call_args.kwargs.get("owner_id") == 1 class TestJobEndpoints: """Test job management endpoints.""" - def test_get_job_not_found(self, client): + @patch("SPARC.api._get_job_db") + def test_get_job_not_found(self, mock_get_db, client): """Test getting nonexistent job.""" - response = client.get("/jobs/nonexistent") + job_db = MagicMock() + job_db.get_job.return_value = None + mock_get_db.return_value = job_db + + response = client.get("/jobs/nonexistent", headers=_auth_header()) assert response.status_code == 404 - def test_list_jobs(self, client, mocker): + @patch("SPARC.api._get_job_db") + def test_list_jobs(self, mock_get_db, client): """Test listing jobs.""" - # Clear existing jobs - mocker.patch.dict("SPARC.api._jobs", {}, clear=True) + job_db = MagicMock() + job_db.list_jobs.return_value = [] + mock_get_db.return_value = job_db - response = client.get("/jobs") + response = client.get("/jobs", headers=_auth_header()) assert response.status_code == 200 - assert isinstance(response.json(), list) - def test_list_jobs_with_filter(self, client, mocker): + @patch("SPARC.api._get_job_db") + def test_list_jobs_with_filter(self, mock_get_db, client): """Test listing jobs with status filter.""" - response = client.get("/jobs?status=completed") + job_db = MagicMock() + job_db.list_jobs.return_value = [] + mock_get_db.return_value = job_db + + response = client.get("/jobs?status=completed", headers=_auth_header()) assert response.status_code == 200 @@ -189,7 +244,7 @@ class TestModelValidation: def test_analyze_rejects_unsupported_model(self, client, mock_analyzer): """GET /analyze/{company} with unsupported model returns 400.""" - response = client.get("/analyze/nvidia?model=fake/nonexistent-model") + response = client.get("/analyze/nvidia?model=fake/nonexistent-model", headers=_auth_header()) assert response.status_code == 400 assert "Unsupported model" in response.json()["detail"] @@ -205,7 +260,7 @@ class TestModelValidation: ) mock_analyzer._analyze_company_safe.return_value = mock_result - response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet") + response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet", headers=_auth_header()) assert response.status_code == 200 def test_batch_rejects_unsupported_model(self, client, mock_analyzer): @@ -213,6 +268,7 @@ class TestModelValidation: response = client.post( "/analyze/batch", json={"companies": ["nvidia"], "model": "fake/nonexistent-model"}, + headers=_auth_header(), ) assert response.status_code == 400 assert "Unsupported model" in response.json()["detail"] diff --git a/tests/test_export.py b/tests/test_export.py index d0c856f..321e443 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -5,6 +5,7 @@ Covers issue #1655: - GET /export/{company_name}/pdf (PDF export) All tests mock the database layer and use JWT auth fixtures from test_auth patterns. +Export queries are now scoped to the current user's owner_id. """ from datetime import datetime, timezone diff --git a/tests/test_multi_tenant.py b/tests/test_multi_tenant.py new file mode 100644 index 0000000..7ce758d --- /dev/null +++ b/tests/test_multi_tenant.py @@ -0,0 +1,281 @@ +"""Cross-tenant isolation tests for multi-tenant support. + +Verifies that: +- User A cannot read, update, or delete User B's analyses, tracked companies, or jobs +- Admin users can access all data via admin endpoints +- owner_id is correctly set on new resources +""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastapi.testclient import TestClient + +from SPARC.api import app +from SPARC.auth import create_access_token + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +def _make_user(user_id, email, role="user"): + return { + "id": user_id, + "email": email, + "role": role, + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + +USER_A = _make_user(10, "alice@test.com") +USER_B = _make_user(20, "bob@test.com") +ADMIN = _make_user(1, "admin@test.com", role="admin") + + +def _header_for(user): + token = create_access_token(user["id"], user["email"], user["role"]) + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture(autouse=True) +def mock_db(): + """Mock DB returning the correct user based on user_id.""" + db = MagicMock() + + def _get_user_by_id(uid): + for u in [USER_A, USER_B, ADMIN]: + if u["id"] == uid: + return u + return None + + db.get_user_by_id.side_effect = _get_user_by_id + + with patch("SPARC.api.get_db_client", return_value=db), \ + patch("SPARC.auth.get_db_client", return_value=db): + yield db + + +# ==================== Tracked Companies Isolation ==================== + + +class TestTrackedCompanyIsolation: + """User A's tracked companies are invisible to User B.""" + + def test_user_a_list_scoped_to_own(self, client, mock_db): + """GET /tracked returns only User A's companies.""" + mock_db.list_tracked_companies.return_value = [ + {"company_name": "AliceCo", "owner_id": USER_A["id"]}, + ] + + response = client.get("/tracked", headers=_header_for(USER_A)) + assert response.status_code == 200 + mock_db.list_tracked_companies.assert_called_with(owner_id=USER_A["id"]) + + def test_user_b_list_scoped_to_own(self, client, mock_db): + """GET /tracked returns only User B's companies.""" + mock_db.list_tracked_companies.return_value = [] + + response = client.get("/tracked", headers=_header_for(USER_B)) + assert response.status_code == 200 + mock_db.list_tracked_companies.assert_called_with(owner_id=USER_B["id"]) + + def test_user_a_add_sets_owner(self, client, mock_db): + """POST /tracked sets owner_id to User A.""" + mock_db.add_tracked_company.return_value = {"company_name": "NewCo", "owner_id": 10} + + response = client.post("/tracked", json={"company_name": "NewCo"}, headers=_header_for(USER_A)) + assert response.status_code == 200 + mock_db.add_tracked_company.assert_called_with("NewCo", owner_id=USER_A["id"]) + + def test_user_b_cannot_remove_user_a_company(self, client, mock_db): + """DELETE /tracked/{name} filters by owner, so B can't remove A's company.""" + mock_db.remove_tracked_company.return_value = False # not found for B + + response = client.delete("/tracked/AliceCo", headers=_header_for(USER_B)) + assert response.status_code == 404 + mock_db.remove_tracked_company.assert_called_with("AliceCo", owner_id=USER_B["id"]) + + +# ==================== Job Isolation ==================== + + +class TestJobIsolation: + """User A's jobs are invisible to User B.""" + + def test_user_a_get_own_job(self, client, mock_db): + """GET /jobs/{id} scoped to User A returns the job.""" + mock_db.get_job.return_value = None # mock via _get_job_db + + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.get_job.return_value = { + "job_id": "j1", + "status": "completed", + "progress": 100, + "total_companies": 1, + "completed_companies": 1, + "result_json": None, + "error": None, + "owner_id": USER_A["id"], + } + mock_get_db.return_value = job_db + + response = client.get("/jobs/j1", headers=_header_for(USER_A)) + assert response.status_code == 200 + job_db.get_job.assert_called_with("j1", owner_id=USER_A["id"]) + + def test_user_b_cannot_see_user_a_job(self, client, mock_db): + """GET /jobs/{id} returns 404 when User B tries to access User A's job.""" + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.get_job.return_value = None # not found for B's owner_id + mock_get_db.return_value = job_db + + response = client.get("/jobs/j1", headers=_header_for(USER_B)) + assert response.status_code == 404 + job_db.get_job.assert_called_with("j1", owner_id=USER_B["id"]) + + def test_list_jobs_scoped_to_user(self, client, mock_db): + """GET /jobs filters by owner_id.""" + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.list_jobs.return_value = [] + mock_get_db.return_value = job_db + + response = client.get("/jobs", headers=_header_for(USER_A)) + assert response.status_code == 200 + call_kwargs = job_db.list_jobs.call_args + assert call_kwargs.kwargs.get("owner_id") == USER_A["id"] + + def test_async_job_created_with_owner(self, client, mock_db): + """POST /analyze/batch/async creates job with current user's owner_id.""" + mock_analyzer = MagicMock() + with patch("SPARC.api._analyzer", mock_analyzer), \ + patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.create_job.return_value = { + "job_id": "j2", + "status": "pending", + "progress": 0, + "total_companies": 1, + "completed_companies": 0, + "result_json": None, + "error": None, + "owner_id": USER_A["id"], + } + mock_get_db.return_value = job_db + + response = client.post( + "/analyze/batch/async", + json={"companies": ["nvidia"]}, + headers=_header_for(USER_A), + ) + assert response.status_code == 200 + create_kwargs = job_db.create_job.call_args + assert create_kwargs.kwargs.get("owner_id") == USER_A["id"] + + +# ==================== Analysis Listing Isolation ==================== + + +class TestAnalysisListIsolation: + """GET /analyze/batch scoped to current user.""" + + def test_list_analyses_scoped_to_user(self, client, mock_db): + """GET /analyze/batch passes owner_id to db.list_analyses.""" + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.list_analyses.return_value = [] + mock_get_db.return_value = job_db + + response = client.get("/analyze/batch", headers=_header_for(USER_A)) + assert response.status_code == 200 + call_kwargs = job_db.list_analyses.call_args + assert call_kwargs.kwargs.get("owner_id") == USER_A["id"] + + +# ==================== Admin Cross-Tenant Access ==================== + + +class TestAdminCrossTenantAccess: + """Admin endpoints return data from all tenants (no owner_id filter).""" + + def test_admin_list_tracked_all_tenants(self, client, mock_db): + """GET /admin/tracked returns all companies (no owner_id filter).""" + mock_db.list_tracked_companies.return_value = [ + {"company_name": "AliceCo", "owner_id": 10}, + {"company_name": "BobCo", "owner_id": 20}, + ] + + response = client.get("/admin/tracked", headers=_header_for(ADMIN)) + assert response.status_code == 200 + # Should be called without owner_id filter + mock_db.list_tracked_companies.assert_called_with() + + def test_admin_list_analyses_all_tenants(self, client, mock_db): + """GET /admin/analyses returns all analyses (no owner_id filter).""" + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.list_analyses.return_value = [] + mock_get_db.return_value = job_db + + response = client.get("/admin/analyses", headers=_header_for(ADMIN)) + assert response.status_code == 200 + call_kwargs = job_db.list_analyses.call_args + # No owner_id should be passed + assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None + + def test_admin_list_jobs_all_tenants(self, client, mock_db): + """GET /admin/jobs returns all jobs (no owner_id filter).""" + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.list_jobs.return_value = [] + mock_get_db.return_value = job_db + + response = client.get("/admin/jobs", headers=_header_for(ADMIN)) + assert response.status_code == 200 + call_kwargs = job_db.list_jobs.call_args + assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None + + def test_admin_remove_tracked_any_owner(self, client, mock_db): + """DELETE /admin/tracked/{name} removes without owner filter.""" + mock_db.remove_tracked_company.return_value = True + + response = client.delete("/admin/tracked/SomeCo", headers=_header_for(ADMIN)) + assert response.status_code == 200 + # Called without owner_id + mock_db.remove_tracked_company.assert_called_with("SomeCo") + + def test_regular_user_cannot_access_admin_analyses(self, client, mock_db): + """Regular user gets 403 on /admin/analyses.""" + response = client.get("/admin/analyses", headers=_header_for(USER_A)) + assert response.status_code == 403 + + def test_regular_user_cannot_access_admin_jobs(self, client, mock_db): + """Regular user gets 403 on /admin/jobs.""" + response = client.get("/admin/jobs", headers=_header_for(USER_A)) + assert response.status_code == 403 + + +# ==================== Analytics Isolation ==================== + + +class TestAnalyticsIsolation: + """GET /analytics scoped to current user.""" + + def test_analytics_scoped_to_user(self, client, mock_db): + """GET /analytics passes owner_id to db.get_analytics.""" + mock_db.get_analytics.return_value = { + "total_messages": 5, + "by_company": [], + "by_type": [], + "period_days": 30, + } + + response = client.get("/analytics", headers=_header_for(USER_A)) + assert response.status_code == 200 + mock_db.get_analytics.assert_called_with(days=30, owner_id=USER_A["id"]) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 01bc5b3..0f3cc2d 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,12 +1,13 @@ """Tests for cursor-based pagination on /analyze/batch GET and /jobs endpoints.""" -from datetime import datetime, timedelta -from unittest.mock import Mock, patch +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock, MagicMock, patch import pytest from fastapi.testclient import TestClient from SPARC.api import app +from SPARC.auth import create_access_token @pytest.fixture @@ -15,6 +16,27 @@ def client(): return TestClient(app) +@pytest.fixture(autouse=True) +def mock_db(): + """Mock the database client used by auth endpoints.""" + db = MagicMock() + db.get_user_by_id.return_value = { + "id": 1, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + with patch("SPARC.api.get_db_client", return_value=db), \ + patch("SPARC.auth.get_db_client", return_value=db): + yield db + + +def _auth_header(): + token = create_access_token(1, "user@test.com", "user") + return {"Authorization": f"Bearer {token}"} + + def _make_analysis_row(id_: int, minutes_ago: int = 0, company: str = "nvidia"): """Create a fake analysis row dict.""" ts = datetime.now() - timedelta(minutes=minutes_ago) @@ -56,7 +78,7 @@ class TestAnalyzeBatchGetPagination: ] mock_get_db.return_value = db - response = client.get("/analyze/batch?limit=10") + response = client.get("/analyze/batch?limit=10", headers=_auth_header()) assert response.status_code == 200 data = response.json() assert len(data["items"]) == 2 @@ -71,7 +93,7 @@ class TestAnalyzeBatchGetPagination: db.list_analyses.return_value = rows mock_get_db.return_value = db - response = client.get("/analyze/batch?limit=3") + response = client.get("/analyze/batch?limit=3", headers=_auth_header()) assert response.status_code == 200 data = response.json() assert len(data["items"]) == 3 @@ -84,7 +106,7 @@ class TestAnalyzeBatchGetPagination: db.list_analyses.return_value = [] mock_get_db.return_value = db - client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42") + client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42", headers=_auth_header()) db.list_analyses.assert_called_once() call_kwargs = db.list_analyses.call_args assert call_kwargs.kwargs.get("cursor") == "2025-01-01T00:00:00|42" or \ @@ -97,19 +119,19 @@ class TestAnalyzeBatchGetPagination: db.list_analyses.return_value = [] mock_get_db.return_value = db - client.get("/analyze/batch") + client.get("/analyze/batch", headers=_auth_header()) call_kwargs = db.list_analyses.call_args # The endpoint requests limit+1 from DB, so 51 assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51 def test_limit_over_200_rejected(self, client): """Limit > 200 should be rejected with 422.""" - response = client.get("/analyze/batch?limit=201") + response = client.get("/analyze/batch?limit=201", headers=_auth_header()) assert response.status_code == 422 def test_limit_zero_rejected(self, client): """Limit < 1 should be rejected with 422.""" - response = client.get("/analyze/batch?limit=0") + response = client.get("/analyze/batch?limit=0", headers=_auth_header()) assert response.status_code == 422 @patch("SPARC.api._get_job_db") @@ -119,7 +141,7 @@ class TestAnalyzeBatchGetPagination: db.list_analyses.return_value = [] mock_get_db.return_value = db - client.get("/analyze/batch?company_name=intel") + client.get("/analyze/batch?company_name=intel", headers=_auth_header()) call_kwargs = db.list_analyses.call_args assert call_kwargs.kwargs.get("company_name") == "intel" or \ "intel" in (call_kwargs.args if call_kwargs.args else []) @@ -131,7 +153,7 @@ class TestAnalyzeBatchGetPagination: db.list_analyses.return_value = [] mock_get_db.return_value = db - response = client.get("/analyze/batch") + response = client.get("/analyze/batch", headers=_auth_header()) assert response.status_code == 200 data = response.json() assert data["items"] == [] @@ -148,14 +170,14 @@ class TestJobsPaginationDefaults: db.list_jobs.return_value = [] mock_get_db.return_value = db - client.get("/jobs") + client.get("/jobs", headers=_auth_header()) call_kwargs = db.list_jobs.call_args # Endpoint requests limit+1 from DB, so 51 assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51 def test_limit_over_200_rejected(self, client): """Limit > 200 should be rejected with 422.""" - response = client.get("/jobs?limit=201") + response = client.get("/jobs?limit=201", headers=_auth_header()) assert response.status_code == 422 @patch("SPARC.api._get_job_db") @@ -165,5 +187,5 @@ class TestJobsPaginationDefaults: db.list_jobs.return_value = [] mock_get_db.return_value = db - response = client.get("/jobs?limit=200") + response = client.get("/jobs?limit=200", headers=_auth_header()) assert response.status_code == 200 diff --git a/tests/test_rate_limit_admin.py b/tests/test_rate_limit_admin.py index bc63a5a..f10e9da 100644 --- a/tests/test_rate_limit_admin.py +++ b/tests/test_rate_limit_admin.py @@ -20,8 +20,10 @@ def client(): def reset_stats(): """Reset rate limit stats between tests.""" api._rate_limit_stats.clear() + api._rejected_log.clear() yield api._rate_limit_stats.clear() + api._rejected_log.clear() def _mock_admin(): @@ -50,8 +52,7 @@ class TestRateLimitAdminEndpoint: app.dependency_overrides.clear() def test_non_admin_rejected(self, client): - """Non-admin users should get 403.""" - # Without overriding the dependency, it should fail auth + """Non-admin users should get 401/403.""" response = client.get("/admin/rate-limits") assert response.status_code in (401, 403) @@ -77,6 +78,9 @@ class TestRateLimitAdminEndpoint: for rl in data["rate_limits"]: assert rl["total_requests"] == 0 assert rl["rejected_requests"] == 0 + assert rl["by_ip"] == [] + assert data["throttled_24h"] == 0 + assert data["throttled_over_time"] == [] finally: app.dependency_overrides.clear() @@ -107,3 +111,68 @@ class TestRateLimitAdminEndpoint: assert isinstance(rl["limit"], str) finally: app.dependency_overrides.clear() + + def test_per_ip_breakdown(self, client): + """Stats should include per-IP breakdown with total and rejected counts.""" + api._track_rate_limit_request("/auth/login", "10.0.0.1") + api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True) + api._track_rate_limit_request("/auth/login", "10.0.0.2") + + app.dependency_overrides[api.get_current_admin] = _mock_admin + try: + response = client.get("/admin/rate-limits") + data = response.json() + login_stats = next(rl for rl in data["rate_limits"] if rl["endpoint"] == "/auth/login") + by_ip = login_stats["by_ip"] + assert len(by_ip) == 2 + ip1 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.1") + assert ip1["total"] == 2 + assert ip1["rejected"] == 1 + ip2 = next(entry for entry in by_ip if entry["ip"] == "10.0.0.2") + assert ip2["total"] == 1 + assert ip2["rejected"] == 0 + finally: + app.dependency_overrides.clear() + + def test_throttled_24h_count(self, client): + """Should report total throttled requests in the last 24 hours.""" + api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True) + api._track_rate_limit_request("/auth/register", "10.0.0.2", rejected=True) + + app.dependency_overrides[api.get_current_admin] = _mock_admin + try: + response = client.get("/admin/rate-limits") + data = response.json() + assert data["throttled_24h"] == 2 + finally: + app.dependency_overrides.clear() + + def test_throttled_over_time_structure(self, client): + """Throttled-over-time should be a list of {timestamp, count} buckets.""" + api._track_rate_limit_request("/auth/login", "10.0.0.1", rejected=True) + + app.dependency_overrides[api.get_current_admin] = _mock_admin + try: + response = client.get("/admin/rate-limits") + data = response.json() + assert len(data["throttled_over_time"]) >= 1 + entry = data["throttled_over_time"][0] + assert "timestamp" in entry + assert "count" in entry + assert entry["count"] >= 1 + finally: + app.dependency_overrides.clear() + + def test_response_shape_matches_contract(self, client): + """The full response should match the expected shape for the frontend.""" + app.dependency_overrides[api.get_current_admin] = _mock_admin + try: + response = client.get("/admin/rate-limits") + data = response.json() + # Top-level keys + assert set(data.keys()) == {"rate_limits", "throttled_24h", "throttled_over_time"} + # Each rate_limit entry + for rl in data["rate_limits"]: + assert set(rl.keys()) == {"endpoint", "limit", "total_requests", "rejected_requests", "by_ip"} + finally: + app.dependency_overrides.clear() diff --git a/tests/test_tracked_companies.py b/tests/test_tracked_companies.py index df25134..4aec720 100644 --- a/tests/test_tracked_companies.py +++ b/tests/test_tracked_companies.py @@ -1,17 +1,18 @@ -"""Tests for tracked company admin endpoints and scheduler integration. +"""Tests for tracked company endpoints and scheduler integration. -Covers issue #1656: -- GET /admin/tracked (list tracked companies) -- POST /admin/tracked (add a tracked company) -- DELETE /admin/tracked/{company_name} (remove a tracked company) +Covers: +- GET /tracked (user-scoped list) +- POST /tracked (user-scoped add) +- DELETE /tracked/{company_name} (user-scoped remove) +- GET /admin/tracked (admin: all companies) +- POST /admin/tracked (admin: add) +- DELETE /admin/tracked/{company_name} (admin: remove any) - GET /admin/alerts (list alerts) - scheduler.run_scheduled_analysis() integration - -All tests mock the database layer and use JWT auth fixtures. """ from datetime import datetime, timezone -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient @@ -125,7 +126,7 @@ class TestAddTrackedCompany: assert response.status_code == 200 data = response.json() assert data["company_name"] == "Intel" - mock_db.add_tracked_company.assert_called_once_with("Intel") + mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=1) def test_add_duplicate_returns_409(self, client, mock_db): """Adding an already-tracked company returns 409.""" @@ -141,7 +142,7 @@ class TestAddTrackedCompany: assert "already tracked" in response.json()["detail"].lower() def test_add_tracked_requires_admin(self, client, mock_db): - """Regular user cannot add tracked companies.""" + """Regular user cannot add tracked companies via admin endpoint.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", @@ -215,6 +216,66 @@ class TestRemoveTrackedCompany: assert response.status_code == 403 +# ---------- User-scoped tracked companies ---------- + +class TestUserScopedTrackedCompanies: + """Tests for /tracked user-scoped endpoints.""" + + def test_user_list_tracked(self, client, mock_db): + """Regular user can list their own tracked companies.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + mock_db.list_tracked_companies.return_value = [ + {"company_name": "AMD", "owner_id": 2}, + ] + + response = client.get("/tracked", headers=_user_header()) + + assert response.status_code == 200 + mock_db.list_tracked_companies.assert_called_with(owner_id=2) + + def test_user_add_tracked(self, client, mock_db): + """Regular user can add a company to their own tracked list.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + mock_db.add_tracked_company.return_value = { + "company_name": "Intel", + "owner_id": 2, + } + + response = client.post( + "/tracked", + json={"company_name": "Intel"}, + headers=_user_header(), + ) + + assert response.status_code == 200 + mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=2) + + def test_user_remove_tracked(self, client, mock_db): + """Regular user can remove a company from their own tracked list.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + mock_db.remove_tracked_company.return_value = True + + response = client.delete("/tracked/Intel", headers=_user_header()) + + assert response.status_code == 200 + mock_db.remove_tracked_company.assert_called_once_with("Intel", owner_id=2) + + # ---------- GET /admin/alerts ---------- class TestListAlerts: