diff --git a/SPARC/api.py b/SPARC/api.py index b6095bd..d380e1d 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -474,11 +474,46 @@ class TrackCompanyRequest(BaseModel): company_name: CompanyName = Field(...) +@app.get("/tracked", tags=["Tracked Companies"]) +async def list_my_tracked_companies( + current_user: UserResponse = Depends(get_current_user), +): + """List tracked companies for the current user.""" + db = get_db_client() + return db.list_tracked_companies(owner_id=current_user.id) + + +@app.post("/tracked", tags=["Tracked Companies"]) +async def add_my_tracked_company( + request: TrackCompanyRequest, + current_user: UserResponse = Depends(get_current_user), +): + """Add a company to the current user's tracked list.""" + db = get_db_client() + result = db.add_tracked_company(request.company_name, owner_id=current_user.id) + if not result: + raise HTTPException(status_code=409, detail="Company already tracked") + return result + + +@app.delete("/tracked/{company_name}", tags=["Tracked Companies"]) +async def remove_my_tracked_company( + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], + current_user: UserResponse = Depends(get_current_user), +): + """Remove a company from the current user's tracked list.""" + db = get_db_client() + removed = db.remove_tracked_company(company_name, owner_id=current_user.id) + if not removed: + raise HTTPException(status_code=404, detail="Company not found in tracking list") + return {"message": f"Stopped tracking {company_name}"} + + @app.get("/admin/tracked", tags=["Admin"]) async def list_tracked_companies( _: UserResponse = Depends(get_current_admin), ): - """List all tracked companies (admin only).""" + """List all tracked companies across all users (admin only).""" db = get_db_client() return db.list_tracked_companies() @@ -486,11 +521,11 @@ async def list_tracked_companies( @app.post("/admin/tracked", tags=["Admin"]) async def add_tracked_company( request: TrackCompanyRequest, - _: UserResponse = Depends(get_current_admin), + current_admin: UserResponse = Depends(get_current_admin), ): - """Add a company to the tracked list (admin only).""" + """Add a company to the tracked list (admin only, owned by admin).""" db = get_db_client() - result = db.add_tracked_company(request.company_name) + result = db.add_tracked_company(request.company_name, owner_id=current_admin.id) if not result: raise HTTPException(status_code=409, detail="Company already tracked") return result @@ -501,7 +536,7 @@ async def remove_tracked_company( company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], _: UserResponse = Depends(get_current_admin), ): - """Remove a company from the tracked list (admin only).""" + """Remove a company from the tracked list (admin only, any owner).""" db = get_db_client() removed = db.remove_tracked_company(company_name) if not removed: @@ -582,17 +617,86 @@ async def list_alerts( return db.list_alerts(limit=limit) +# ============== Admin-Scoped Data Endpoints ============== + + +@app.get("/admin/analyses", response_model=PaginatedAnalysisResponse, tags=["Admin"]) +async def admin_list_analyses( + company_name: Annotated[ + str | None, + Query(description="Filter results by company name"), + ] = None, + limit: Annotated[int, Query(ge=1, le=200)] = 50, + cursor: Annotated[ + str | None, + Query(description="Opaque cursor from a previous response's next_cursor field"), + ] = None, + _: UserResponse = Depends(get_current_admin), +): + """List all analysis results across all users (admin only).""" + db = _get_job_db() + rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor) + + has_next = len(rows) > limit + if has_next: + rows = rows[:limit] + + items = [AnalysisRecord(**row) for row in rows] + + next_cursor = None + if has_next and rows: + last = rows[-1] + ts = last["timestamp"] + ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts) + next_cursor = f"{ts_str}|{last['id']}" + + return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor) + + +@app.get("/admin/jobs", response_model=PaginatedJobsResponse, tags=["Admin"]) +async def admin_list_jobs( + status: Annotated[ + str | None, + Query(description="Filter by status: pending, running, completed, failed"), + ] = None, + limit: Annotated[int, Query(ge=1, le=200)] = 50, + cursor: Annotated[ + str | None, + Query(description="Opaque cursor from a previous response's next_cursor field"), + ] = None, + _: UserResponse = Depends(get_current_admin), +): + """List all jobs across all users (admin only).""" + db = _get_job_db() + job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor) + + has_next = len(job_rows) > limit + if has_next: + job_rows = job_rows[:limit] + + items = [_job_row_to_status(row) for row in job_rows] + + next_cursor = None + if has_next and job_rows: + last = job_rows[-1] + created = last["created_at"] + ts = created.isoformat() if hasattr(created, "isoformat") else str(created) + next_cursor = f"{ts}|{last['job_id']}" + + return PaginatedJobsResponse(items=items, next_cursor=next_cursor) + + # ============== Analytics Endpoint ============== @app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"]) async def get_analytics( days: int = Query(default=30, ge=1, le=365), - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """Get analytics data (authenticated users only).""" + """Get analytics data scoped to the current user.""" db = get_db_client() - analytics = db.get_analytics(days=days) + analytics = db.get_analytics(days=days, owner_id=current_user.id) return AnalyticsResponse( total_messages=analytics["total_messages"], @@ -645,9 +749,9 @@ async def list_models(): @app.get("/analytics/trends", tags=["Analytics"]) async def get_analytics_trends( days: int = Query(default=90, ge=7, le=365), - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """Get trend data for patent analysis over time. + """Get trend data for patent analysis over time (scoped to current user). Returns two datasets: - ``by_month``: analysis count per company per month @@ -661,11 +765,14 @@ async def get_analytics_trends( """ db = get_db_client() + owner_filter = " AND owner_id = %s" if current_user else "" + owner_params = (current_user.id,) if current_user else () + with db.get_conn() as conn: with conn.cursor() as cur: # Analyses per company per month cur.execute( - """ + f""" SELECT TO_CHAR(timestamp, 'YYYY-MM') AS month, company_name, @@ -674,16 +781,17 @@ async def get_analytics_trends( WHERE timestamp >= NOW() - INTERVAL '%s days' AND is_cached = FALSE AND company_name IS NOT NULL + {owner_filter} GROUP BY month, company_name ORDER BY month """, - (days,), + (days, *owner_params), ) by_month_rows = cur.fetchall() # Analysis type distribution per month cur.execute( - """ + f""" SELECT TO_CHAR(timestamp, 'YYYY-MM') AS month, analysis_type, @@ -691,10 +799,11 @@ async def get_analytics_trends( FROM llm_messages WHERE timestamp >= NOW() - INTERVAL '%s days' AND is_cached = FALSE + {owner_filter} GROUP BY month, analysis_type ORDER BY month """, - (days,), + (days, *owner_params), ) by_type_rows = cur.fetchall() @@ -720,9 +829,9 @@ async def get_analytics_trends( @app.get("/export/{company_name}", tags=["Export"]) async def export_company_csv( company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """Export analysis results for a company as a CSV file. + """Export analysis results for a company as a CSV file (scoped to current user). Returns all stored analysis records for the given company, including analysis type, model used, response text, and timestamp. @@ -737,7 +846,7 @@ async def export_company_csv( import io db = get_db_client() - # Query all non-cached analysis results for this company + # Query all non-cached analysis results for this company owned by current user with db.get_conn() as conn: with conn.cursor() as cur: cur.execute( @@ -745,9 +854,10 @@ async def export_company_csv( SELECT company_name, analysis_type, model, response, timestamp FROM llm_messages WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE + AND owner_id = %s ORDER BY timestamp DESC """, - (company_name,), + (company_name, current_user.id), ) rows = cur.fetchall() @@ -772,9 +882,9 @@ async def export_company_csv( @app.get("/export/{company_name}/pdf", tags=["Export"]) async def export_company_pdf( company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """Export analysis results for a company as a formatted PDF report. + """Export analysis results for a company as a formatted PDF report (scoped to current user). Returns all stored analysis records for the given company, including analysis type, model used, response text, and timestamp, formatted @@ -808,9 +918,10 @@ async def export_company_pdf( SELECT company_name, analysis_type, model, response, timestamp FROM llm_messages WHERE LOWER(company_name) = LOWER(%s) AND is_cached = FALSE + AND owner_id = %s ORDER BY timestamp DESC """, - (company_name,), + (company_name, current_user.id), ) rows = cur.fetchall() @@ -939,68 +1050,6 @@ async def health_check(): ) -@app.get( - "/analyze/{company_name}", - response_model=CompanyAnalysisResponse, - tags=["Analysis"], -) -async def analyze_company( - company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], - model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."), - _: UserResponse = Depends(get_current_user), -): - """Analyze a single company's patent portfolio. - - This endpoint retrieves recent patents for the specified company, - parses them, and uses AI to generate a comprehensive analysis. - - Args: - company_name: Name of the company to analyze (e.g., "nvidia", "intel") - model: Optional LLM model override - - Returns: - Analysis results including patent count, AI insights, and success status - """ - _validate_model(model) - if not _analyzer: - raise HTTPException(status_code=503, detail="Analyzer not initialized") - - result = _analyzer._analyze_company_safe(company_name, model=model) - return _convert_result(result) - - -@app.get( - "/analyze/patent/{patent_id}", - tags=["Analysis"], -) -async def analyze_single_patent( - patent_id: str, - company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")], - _: UserResponse = Depends(get_current_user), -): - """Analyze a single patent by its publication ID. - - If the patent PDF is not already cached locally, the system will attempt - to download it automatically from a previously cached link. If no link - is available, a 404 error is returned. - - Args: - patent_id: Patent publication ID (e.g. "US-11234567-B2") - company_name: Company name for analysis context - - Returns: - Analysis text for the patent - """ - if not _analyzer: - raise HTTPException(status_code=503, detail="Analyzer not initialized") - - try: - analysis = _analyzer.analyze_single_patent(patent_id, company_name) - return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) - - @app.get( "/analyze/batch", response_model=PaginatedAnalysisResponse, @@ -1016,9 +1065,9 @@ async def list_analysis_results( str | None, Query(description="Opaque cursor from a previous response's next_cursor field"), ] = None, - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """List stored analysis results with cursor-based pagination. + """List stored analysis results with cursor-based pagination (scoped to current user). Returns past analysis results ordered by timestamp descending. Use ``limit`` to control page size (default 50, max 200). The response @@ -1035,7 +1084,7 @@ async def list_analysis_results( Paginated list of analysis results """ db = _get_job_db() - rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor) + rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor, owner_id=current_user.id) has_next = len(rows) > limit if has_next: @@ -1085,6 +1134,68 @@ async def analyze_companies_batch( return _convert_batch_result(result) +@app.get( + "/analyze/patent/{patent_id}", + tags=["Analysis"], +) +async def analyze_single_patent( + patent_id: str, + company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")], + _: UserResponse = Depends(get_current_user), +): + """Analyze a single patent by its publication ID. + + If the patent PDF is not already cached locally, the system will attempt + to download it automatically from a previously cached link. If no link + is available, a 404 error is returned. + + Args: + patent_id: Patent publication ID (e.g. "US-11234567-B2") + company_name: Company name for analysis context + + Returns: + Analysis text for the patent + """ + if not _analyzer: + raise HTTPException(status_code=503, detail="Analyzer not initialized") + + try: + analysis = _analyzer.analyze_single_patent(patent_id, company_name) + return {"patent_id": patent_id, "company_name": company_name, "analysis": analysis} + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@app.get( + "/analyze/{company_name}", + response_model=CompanyAnalysisResponse, + tags=["Analysis"], +) +async def analyze_company( + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], + model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."), + _: UserResponse = Depends(get_current_user), +): + """Analyze a single company's patent portfolio. + + This endpoint retrieves recent patents for the specified company, + parses them, and uses AI to generate a comprehensive analysis. + + Args: + company_name: Name of the company to analyze (e.g., "nvidia", "intel") + model: Optional LLM model override + + Returns: + Analysis results including patent count, AI insights, and success status + """ + _validate_model(model) + if not _analyzer: + raise HTTPException(status_code=503, detail="Analyzer not initialized") + + result = _analyzer._analyze_company_safe(company_name, model=model) + return _convert_result(result) + + def _get_job_db() -> "DatabaseClient": """Get a DatabaseClient for job persistence.""" from SPARC.database import DatabaseClient @@ -1171,7 +1282,7 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int, model: s async def analyze_companies_async( request: BatchAnalysisRequest, background_tasks: BackgroundTasks, - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): """Start an asynchronous batch analysis job. @@ -1191,7 +1302,7 @@ async def analyze_companies_async( job_id = f"job_{_job_counter}_{datetime.now().strftime('%Y%m%d%H%M%S')}" db = _get_job_db() - job_row = db.create_job(job_id=job_id, total_companies=len(request.companies)) + job_row = db.create_job(job_id=job_id, total_companies=len(request.companies), owner_id=current_user.id) background_tasks.add_task( _run_batch_job, job_id, request.companies, request.max_workers, request.model @@ -1203,9 +1314,9 @@ async def analyze_companies_async( @app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"]) async def get_job_status( job_id: str, - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """Get the status of a background analysis job. + """Get the status of a background analysis job (scoped to current user). Args: job_id: The job ID returned from the async batch endpoint @@ -1214,7 +1325,7 @@ async def get_job_status( Current job status including progress and results when complete """ db = _get_job_db() - job_row = db.get_job(job_id) + job_row = db.get_job(job_id, owner_id=current_user.id) if not job_row: raise HTTPException(status_code=404, detail=f"Job {job_id} not found") @@ -1233,9 +1344,9 @@ async def list_jobs( str | None, Query(description="Opaque cursor from a previous response's next_cursor field"), ] = None, - _: UserResponse = Depends(get_current_user), + current_user: UserResponse = Depends(get_current_user), ): - """List analysis jobs with cursor-based pagination. + """List analysis jobs with cursor-based pagination (scoped to current user). Pass ``limit`` to control page size. The response includes a ``next_cursor`` field; pass it back as the ``cursor`` query parameter to fetch the next page. @@ -1254,7 +1365,7 @@ async def list_jobs( """ db = _get_job_db() # Fetch one extra to determine if there is a next page - job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor) + job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor, owner_id=current_user.id) has_next = len(job_rows) > limit if has_next: diff --git a/SPARC/database.py b/SPARC/database.py index 0759a66..a393162 100644 --- a/SPARC/database.py +++ b/SPARC/database.py @@ -196,7 +196,7 @@ class DatabaseClient: cursor.execute(""" CREATE TABLE IF NOT EXISTS tracked_companies ( id SERIAL PRIMARY KEY, - company_name VARCHAR(255) UNIQUE NOT NULL, + company_name VARCHAR(255) NOT NULL, last_patent_count INTEGER DEFAULT 0, last_analysis_at TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP @@ -221,6 +221,68 @@ class DatabaseClient: ON alerts(company_name) """) + # ---- Multi-tenant: add owner_id columns if missing ---- + cursor.execute(""" + DO $$ + BEGIN + -- llm_messages.owner_id + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'llm_messages' AND column_name = 'owner_id' + ) THEN + ALTER TABLE llm_messages ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + -- jobs.owner_id + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'jobs' AND column_name = 'owner_id' + ) THEN + ALTER TABLE jobs ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + -- tracked_companies.owner_id + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'tracked_companies' AND column_name = 'owner_id' + ) THEN + ALTER TABLE tracked_companies ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + END $$; + """) + + # Indexes for owner_id filtering + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_messages_owner + ON llm_messages(owner_id) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_jobs_owner + ON jobs(owner_id) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_tracked_companies_owner + ON tracked_companies(owner_id) + """) + + # Drop the old unique constraint on company_name alone (if it exists) + # and replace with a per-owner unique constraint so different users + # can track the same company independently. + cursor.execute(""" + DO $$ + BEGIN + IF EXISTS ( + SELECT 1 FROM pg_constraint + WHERE conname = 'tracked_companies_company_name_key' + ) THEN + ALTER TABLE tracked_companies + DROP CONSTRAINT tracked_companies_company_name_key; + END IF; + END $$; + """) + cursor.execute(""" + CREATE UNIQUE INDEX IF NOT EXISTS uq_tracked_company_owner + ON tracked_companies(LOWER(company_name), owner_id) + """) + self.conn.commit() @staticmethod @@ -289,6 +351,7 @@ class DatabaseClient: metadata: Optional[Dict] = None, token_usage: Optional[Dict] = None, is_cached: bool = False, + owner_id: Optional[int] = None, ) -> int: """Store an LLM message exchange in the database. @@ -301,6 +364,7 @@ class DatabaseClient: metadata: Additional metadata as dict token_usage: Token usage information is_cached: Whether this response was served from cache + owner_id: ID of the user who owns this record Returns: The ID of the inserted record @@ -312,8 +376,8 @@ class DatabaseClient: cursor.execute( """ INSERT INTO llm_messages - (prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) + (prompt, prompt_hash, response, company_name, analysis_type, model, metadata, token_usage, is_cached, owner_id) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING id """, ( @@ -326,6 +390,7 @@ class DatabaseClient: json.dumps(metadata) if metadata else None, json.dumps(token_usage) if token_usage else None, is_cached, + owner_id, ), ) @@ -340,6 +405,7 @@ class DatabaseClient: analysis_type: Optional[str] = None, limit: int = 100, offset: int = 0, + owner_id: Optional[int] = None, ) -> List[Dict]: """Retrieve messages from the database. @@ -348,6 +414,7 @@ class DatabaseClient: analysis_type: Filter by analysis type limit: Maximum number of records to return offset: Number of records to skip + owner_id: Filter by owner (None returns all, for admin use) Returns: List of message dictionaries @@ -355,6 +422,10 @@ class DatabaseClient: query = "SELECT * FROM llm_messages WHERE 1=1" params = [] + if owner_id is not None: + query += " AND owner_id = %s" + params.append(owner_id) + if company_name: query += " AND company_name = %s" params.append(company_name) @@ -376,6 +447,7 @@ class DatabaseClient: company_name: Optional[str] = None, limit: int = 50, cursor: Optional[str] = None, + owner_id: Optional[int] = None, ) -> List[Dict]: """List analysis results with cursor-based pagination. @@ -383,6 +455,7 @@ class DatabaseClient: company_name: Optional filter by company name. limit: Maximum number of records to return. cursor: Opaque cursor (``timestamp|id``) from a previous response. + owner_id: Filter by owner (None returns all, for admin use). Returns: List of analysis dicts ordered by timestamp descending. @@ -390,6 +463,10 @@ class DatabaseClient: conditions: list[str] = ["is_cached = FALSE"] params: list = [] + if owner_id is not None: + conditions.append("owner_id = %s") + params.append(owner_id) + if company_name: conditions.append("LOWER(company_name) = LOWER(%s)") params.append(company_name) @@ -413,52 +490,62 @@ class DatabaseClient: cur.execute(query, params) return [dict(row) for row in cur.fetchall()] - def get_analytics(self, days: int = 30) -> Dict: + def get_analytics(self, days: int = 30, owner_id: Optional[int] = None) -> Dict: """Get analytics on message usage. Args: days: Number of days to look back + owner_id: Filter by owner (None returns all, for admin use) Returns: Dictionary with analytics data """ + owner_filter = "" + owner_params: list = [] + if owner_id is not None: + owner_filter = " AND owner_id = %s" + owner_params = [owner_id] + with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: # Total messages cursor.execute( - """ + f""" SELECT COUNT(*) as total_messages FROM llm_messages WHERE timestamp >= NOW() - INTERVAL '%s days' + {owner_filter} """, - (days,), + (days, *owner_params), ) total = cursor.fetchone()["total_messages"] # Messages by company cursor.execute( - """ + f""" SELECT company_name, COUNT(*) as count FROM llm_messages WHERE timestamp >= NOW() - INTERVAL '%s days' + {owner_filter} GROUP BY company_name ORDER BY count DESC LIMIT 10 """, - (days,), + (days, *owner_params), ) by_company = cursor.fetchall() # Messages by type cursor.execute( - """ + f""" SELECT analysis_type, COUNT(*) as count FROM llm_messages WHERE timestamp >= NOW() - INTERVAL '%s days' + {owner_filter} GROUP BY analysis_type ORDER BY count DESC """, - (days,), + (days, *owner_params), ) by_type = cursor.fetchall() @@ -556,12 +643,14 @@ class DatabaseClient: self, job_id: str, total_companies: int, + owner_id: Optional[int] = None, ) -> Dict: """Create a new job record. Args: job_id: Unique job identifier total_companies: Number of companies in the batch + owner_id: ID of the user who owns this job Returns: Job dict @@ -570,11 +659,11 @@ class DatabaseClient: with conn.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( """ - INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies) - VALUES (%s, 'pending', 0, %s, 0) + INSERT INTO jobs (job_id, status, progress, total_companies, completed_companies, owner_id) + VALUES (%s, 'pending', 0, %s, 0, %s) RETURNING * """, - (job_id, total_companies), + (job_id, total_companies, owner_id), ) job = cursor.fetchone() conn.commit() @@ -627,11 +716,22 @@ class DatabaseClient: conn.commit() return dict(job) if job else None - def get_job(self, job_id: str) -> Optional[Dict]: - """Get a job by ID.""" + def get_job(self, job_id: str, owner_id: Optional[int] = None) -> Optional[Dict]: + """Get a job by ID. + + Args: + job_id: Job identifier. + owner_id: When provided, only return the job if it belongs to this owner. + """ + query = "SELECT * FROM jobs WHERE job_id = %s" + params: list = [job_id] + if owner_id is not None: + query += " AND owner_id = %s" + params.append(owner_id) + with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute("SELECT * FROM jobs WHERE job_id = %s", (job_id,)) + cursor.execute(query, params) job = cursor.fetchone() return dict(job) if job else None @@ -640,6 +740,7 @@ class DatabaseClient: status: Optional[str] = None, limit: int = 10, cursor: Optional[str] = None, + owner_id: Optional[int] = None, ) -> List[Dict]: """List jobs with optional status filter and cursor-based pagination. @@ -649,6 +750,7 @@ class DatabaseClient: cursor: Opaque cursor (``created_at|job_id``) from a previous response. When provided, only jobs older than the cursor are returned. + owner_id: Filter by owner (None returns all, for admin use). Returns: List of job dicts ordered by created_at descending. @@ -656,6 +758,10 @@ class DatabaseClient: conditions: list[str] = [] params: list = [] + if owner_id is not None: + conditions.append("owner_id = %s") + params.append(owner_id) + if status: conditions.append("status = %s") params.append(status) @@ -902,14 +1008,21 @@ class DatabaseClient: # Tracked Companies Methods - def add_tracked_company(self, company_name: str) -> Optional[Dict]: - """Add a company to the tracking list.""" + def add_tracked_company( + self, company_name: str, owner_id: Optional[int] = None + ) -> Optional[Dict]: + """Add a company to the tracking list. + + Args: + company_name: Company name to track. + owner_id: ID of the user who owns this tracked company. + """ with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: try: cursor.execute( - "INSERT INTO tracked_companies (company_name) VALUES (%s) RETURNING *", - (company_name,), + "INSERT INTO tracked_companies (company_name, owner_id) VALUES (%s, %s) RETURNING *", + (company_name, owner_id), ) row = cursor.fetchone() conn.commit() @@ -918,22 +1031,45 @@ class DatabaseClient: conn.rollback() return None - def remove_tracked_company(self, company_name: str) -> bool: - """Remove a company from the tracking list.""" + def remove_tracked_company( + self, company_name: str, owner_id: Optional[int] = None + ) -> bool: + """Remove a company from the tracking list. + + Args: + company_name: Company name to remove. + owner_id: When provided, only remove if owned by this user. + """ + query = "DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)" + params: list = [company_name] + if owner_id is not None: + query += " AND owner_id = %s" + params.append(owner_id) + with self.get_conn() as conn: with conn.cursor() as cursor: - cursor.execute( - "DELETE FROM tracked_companies WHERE LOWER(company_name) = LOWER(%s)", - (company_name,), - ) + cursor.execute(query, params) conn.commit() return cursor.rowcount > 0 - def list_tracked_companies(self) -> List[Dict]: - """List all tracked companies.""" + def list_tracked_companies( + self, owner_id: Optional[int] = None + ) -> List[Dict]: + """List tracked companies. + + Args: + owner_id: Filter by owner (None returns all, for admin/scheduler use). + """ + query = "SELECT * FROM tracked_companies" + params: list = [] + if owner_id is not None: + query += " WHERE owner_id = %s" + params.append(owner_id) + query += " ORDER BY company_name" + with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as cursor: - cursor.execute("SELECT * FROM tracked_companies ORDER BY company_name") + cursor.execute(query, params) return [dict(row) for row in cursor.fetchall()] def update_tracked_company( diff --git a/scripts/migrate_add_owner_id.py b/scripts/migrate_add_owner_id.py new file mode 100644 index 0000000..3e0ea53 --- /dev/null +++ b/scripts/migrate_add_owner_id.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +"""Migration: add owner_id columns and backfill existing rows. + +This script adds an ``owner_id`` column (FK to ``users``) to the +``llm_messages``, ``jobs``, and ``tracked_companies`` tables, then +backfills all existing rows with ``owner_id = 1`` (the default admin user). + +It also replaces the old global UNIQUE constraint on +``tracked_companies.company_name`` with a per-owner unique index so that +different users can independently track the same company. + +Usage: + python scripts/migrate_add_owner_id.py + +The script is idempotent — running it multiple times is safe. +""" + +import os +import sys + +import psycopg2 + +DATABASE_URL = os.getenv( + "DATABASE_URL", + "postgresql://postgres:postgres@localhost:5432/sparc", +) + +DEFAULT_OWNER_ID = 1 + + +def run_migration(): + """Execute the migration.""" + conn = psycopg2.connect(DATABASE_URL) + conn.autocommit = False + + try: + with conn.cursor() as cur: + # ---------- 1. Add owner_id columns if missing ---------- + cur.execute(""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'llm_messages' AND column_name = 'owner_id' + ) THEN + ALTER TABLE llm_messages ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'jobs' AND column_name = 'owner_id' + ) THEN + ALTER TABLE jobs ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'tracked_companies' AND column_name = 'owner_id' + ) THEN + ALTER TABLE tracked_companies ADD COLUMN owner_id INTEGER REFERENCES users(id); + END IF; + END $$; + """) + + # ---------- 2. Backfill owner_id = DEFAULT_OWNER_ID ---------- + cur.execute( + "UPDATE llm_messages SET owner_id = %s WHERE owner_id IS NULL", + (DEFAULT_OWNER_ID,), + ) + messages_updated = cur.rowcount + print(f" llm_messages: backfilled {messages_updated} rows") + + cur.execute( + "UPDATE jobs SET owner_id = %s WHERE owner_id IS NULL", + (DEFAULT_OWNER_ID,), + ) + jobs_updated = cur.rowcount + print(f" jobs: backfilled {jobs_updated} rows") + + cur.execute( + "UPDATE tracked_companies SET owner_id = %s WHERE owner_id IS NULL", + (DEFAULT_OWNER_ID,), + ) + tracked_updated = cur.rowcount + print(f" tracked_companies: backfilled {tracked_updated} rows") + + # ---------- 3. Create indexes ---------- + cur.execute(""" + CREATE INDEX IF NOT EXISTS idx_messages_owner + ON llm_messages(owner_id) + """) + cur.execute(""" + CREATE INDEX IF NOT EXISTS idx_jobs_owner + ON jobs(owner_id) + """) + cur.execute(""" + CREATE INDEX IF NOT EXISTS idx_tracked_companies_owner + ON tracked_companies(owner_id) + """) + + # ---------- 4. Replace unique constraint on tracked_companies ---------- + cur.execute(""" + DO $$ + BEGIN + IF EXISTS ( + SELECT 1 FROM pg_constraint + WHERE conname = 'tracked_companies_company_name_key' + ) THEN + ALTER TABLE tracked_companies + DROP CONSTRAINT tracked_companies_company_name_key; + END IF; + END $$; + """) + cur.execute(""" + CREATE UNIQUE INDEX IF NOT EXISTS uq_tracked_company_owner + ON tracked_companies(LOWER(company_name), owner_id) + """) + + conn.commit() + print("Migration completed successfully.") + + except Exception: + conn.rollback() + print("Migration FAILED — rolled back.", file=sys.stderr) + raise + finally: + conn.close() + + +if __name__ == "__main__": + print(f"Running owner_id migration against {DATABASE_URL.split('@')[-1]} ...") + run_migration() diff --git a/tests/test_api.py b/tests/test_api.py index fd16921..e1def71 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,12 +1,13 @@ """Tests for FastAPI web service endpoints.""" -from datetime import datetime -from unittest.mock import Mock +from datetime import datetime, timezone +from unittest.mock import Mock, MagicMock, patch import pytest from fastapi.testclient import TestClient from SPARC.api import app +from SPARC.auth import create_access_token from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult @@ -16,6 +17,22 @@ def client(): return TestClient(app) +@pytest.fixture(autouse=True) +def mock_db(): + """Mock the database client used by auth endpoints.""" + db = MagicMock() + db.get_user_by_id.return_value = { + "id": 1, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + with patch("SPARC.api.get_db_client", return_value=db), \ + patch("SPARC.auth.get_db_client", return_value=db): + yield db + + @pytest.fixture def mock_analyzer(mocker): """Mock the global analyzer.""" @@ -24,6 +41,12 @@ def mock_analyzer(mocker): return mock +def _auth_header(user_id=1, email="user@test.com", role="user"): + """Create an Authorization header with a valid access token.""" + token = create_access_token(user_id, email, role) + return {"Authorization": f"Bearer {token}"} + + class TestHealthEndpoint: """Test health check endpoint.""" @@ -51,7 +74,7 @@ class TestAnalyzeCompanyEndpoint: ) mock_analyzer._analyze_company_safe.return_value = mock_result - response = client.get("/analyze/nvidia") + response = client.get("/analyze/nvidia", headers=_auth_header()) assert response.status_code == 200 data = response.json() @@ -72,7 +95,7 @@ class TestAnalyzeCompanyEndpoint: ) mock_analyzer._analyze_company_safe.return_value = mock_result - response = client.get("/analyze/unknown") + response = client.get("/analyze/unknown", headers=_auth_header()) assert response.status_code == 200 data = response.json() @@ -113,6 +136,7 @@ class TestBatchAnalysisEndpoint: response = client.post( "/analyze/batch", json={"companies": ["nvidia", "amd"], "max_workers": 2}, + headers=_auth_header(), ) assert response.status_code == 200 @@ -125,13 +149,14 @@ class TestBatchAnalysisEndpoint: def test_batch_analysis_validation(self, client): """Test batch analysis request validation.""" # Empty companies list - response = client.post("/analyze/batch", json={"companies": []}) + response = client.post("/analyze/batch", json={"companies": []}, headers=_auth_header()) assert response.status_code == 422 # Too many companies response = client.post( "/analyze/batch", json={"companies": [f"company{i}" for i in range(25)]}, + headers=_auth_header(), ) assert response.status_code == 422 @@ -139,6 +164,7 @@ class TestBatchAnalysisEndpoint: response = client.post( "/analyze/batch", json={"companies": ["nvidia"], "max_workers": 10}, + headers=_auth_header(), ) assert response.status_code == 422 @@ -146,11 +172,26 @@ class TestBatchAnalysisEndpoint: class TestAsyncBatchEndpoint: """Test async batch analysis endpoint.""" - def test_async_batch_creates_job(self, client, mock_analyzer): - """Test async endpoint creates a job.""" + @patch("SPARC.api._get_job_db") + def test_async_batch_creates_job(self, mock_get_db, client, mock_analyzer): + """Test async endpoint creates a job with owner_id.""" + job_db = MagicMock() + job_db.create_job.return_value = { + "job_id": "j1", + "status": "pending", + "progress": 0, + "total_companies": 2, + "completed_companies": 0, + "result_json": None, + "error": None, + "owner_id": 1, + } + mock_get_db.return_value = job_db + response = client.post( "/analyze/batch/async", json={"companies": ["nvidia", "amd"]}, + headers=_auth_header(), ) assert response.status_code == 200 @@ -159,28 +200,42 @@ class TestAsyncBatchEndpoint: assert data["status"] == "pending" assert data["total_companies"] == 2 assert data["progress"] == 0 + # Verify owner_id was passed + job_db.create_job.assert_called_once() + assert job_db.create_job.call_args.kwargs.get("owner_id") == 1 class TestJobEndpoints: """Test job management endpoints.""" - def test_get_job_not_found(self, client): + @patch("SPARC.api._get_job_db") + def test_get_job_not_found(self, mock_get_db, client): """Test getting nonexistent job.""" - response = client.get("/jobs/nonexistent") + job_db = MagicMock() + job_db.get_job.return_value = None + mock_get_db.return_value = job_db + + response = client.get("/jobs/nonexistent", headers=_auth_header()) assert response.status_code == 404 - def test_list_jobs(self, client, mocker): + @patch("SPARC.api._get_job_db") + def test_list_jobs(self, mock_get_db, client): """Test listing jobs.""" - # Clear existing jobs - mocker.patch.dict("SPARC.api._jobs", {}, clear=True) + job_db = MagicMock() + job_db.list_jobs.return_value = [] + mock_get_db.return_value = job_db - response = client.get("/jobs") + response = client.get("/jobs", headers=_auth_header()) assert response.status_code == 200 - assert isinstance(response.json(), list) - def test_list_jobs_with_filter(self, client, mocker): + @patch("SPARC.api._get_job_db") + def test_list_jobs_with_filter(self, mock_get_db, client): """Test listing jobs with status filter.""" - response = client.get("/jobs?status=completed") + job_db = MagicMock() + job_db.list_jobs.return_value = [] + mock_get_db.return_value = job_db + + response = client.get("/jobs?status=completed", headers=_auth_header()) assert response.status_code == 200 @@ -189,7 +244,7 @@ class TestModelValidation: def test_analyze_rejects_unsupported_model(self, client, mock_analyzer): """GET /analyze/{company} with unsupported model returns 400.""" - response = client.get("/analyze/nvidia?model=fake/nonexistent-model") + response = client.get("/analyze/nvidia?model=fake/nonexistent-model", headers=_auth_header()) assert response.status_code == 400 assert "Unsupported model" in response.json()["detail"] @@ -205,7 +260,7 @@ class TestModelValidation: ) mock_analyzer._analyze_company_safe.return_value = mock_result - response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet") + response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet", headers=_auth_header()) assert response.status_code == 200 def test_batch_rejects_unsupported_model(self, client, mock_analyzer): @@ -213,6 +268,7 @@ class TestModelValidation: response = client.post( "/analyze/batch", json={"companies": ["nvidia"], "model": "fake/nonexistent-model"}, + headers=_auth_header(), ) assert response.status_code == 400 assert "Unsupported model" in response.json()["detail"] diff --git a/tests/test_export.py b/tests/test_export.py index d0c856f..321e443 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -5,6 +5,7 @@ Covers issue #1655: - GET /export/{company_name}/pdf (PDF export) All tests mock the database layer and use JWT auth fixtures from test_auth patterns. +Export queries are now scoped to the current user's owner_id. """ from datetime import datetime, timezone diff --git a/tests/test_multi_tenant.py b/tests/test_multi_tenant.py new file mode 100644 index 0000000..7ce758d --- /dev/null +++ b/tests/test_multi_tenant.py @@ -0,0 +1,281 @@ +"""Cross-tenant isolation tests for multi-tenant support. + +Verifies that: +- User A cannot read, update, or delete User B's analyses, tracked companies, or jobs +- Admin users can access all data via admin endpoints +- owner_id is correctly set on new resources +""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastapi.testclient import TestClient + +from SPARC.api import app +from SPARC.auth import create_access_token + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +def _make_user(user_id, email, role="user"): + return { + "id": user_id, + "email": email, + "role": role, + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + +USER_A = _make_user(10, "alice@test.com") +USER_B = _make_user(20, "bob@test.com") +ADMIN = _make_user(1, "admin@test.com", role="admin") + + +def _header_for(user): + token = create_access_token(user["id"], user["email"], user["role"]) + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture(autouse=True) +def mock_db(): + """Mock DB returning the correct user based on user_id.""" + db = MagicMock() + + def _get_user_by_id(uid): + for u in [USER_A, USER_B, ADMIN]: + if u["id"] == uid: + return u + return None + + db.get_user_by_id.side_effect = _get_user_by_id + + with patch("SPARC.api.get_db_client", return_value=db), \ + patch("SPARC.auth.get_db_client", return_value=db): + yield db + + +# ==================== Tracked Companies Isolation ==================== + + +class TestTrackedCompanyIsolation: + """User A's tracked companies are invisible to User B.""" + + def test_user_a_list_scoped_to_own(self, client, mock_db): + """GET /tracked returns only User A's companies.""" + mock_db.list_tracked_companies.return_value = [ + {"company_name": "AliceCo", "owner_id": USER_A["id"]}, + ] + + response = client.get("/tracked", headers=_header_for(USER_A)) + assert response.status_code == 200 + mock_db.list_tracked_companies.assert_called_with(owner_id=USER_A["id"]) + + def test_user_b_list_scoped_to_own(self, client, mock_db): + """GET /tracked returns only User B's companies.""" + mock_db.list_tracked_companies.return_value = [] + + response = client.get("/tracked", headers=_header_for(USER_B)) + assert response.status_code == 200 + mock_db.list_tracked_companies.assert_called_with(owner_id=USER_B["id"]) + + def test_user_a_add_sets_owner(self, client, mock_db): + """POST /tracked sets owner_id to User A.""" + mock_db.add_tracked_company.return_value = {"company_name": "NewCo", "owner_id": 10} + + response = client.post("/tracked", json={"company_name": "NewCo"}, headers=_header_for(USER_A)) + assert response.status_code == 200 + mock_db.add_tracked_company.assert_called_with("NewCo", owner_id=USER_A["id"]) + + def test_user_b_cannot_remove_user_a_company(self, client, mock_db): + """DELETE /tracked/{name} filters by owner, so B can't remove A's company.""" + mock_db.remove_tracked_company.return_value = False # not found for B + + response = client.delete("/tracked/AliceCo", headers=_header_for(USER_B)) + assert response.status_code == 404 + mock_db.remove_tracked_company.assert_called_with("AliceCo", owner_id=USER_B["id"]) + + +# ==================== Job Isolation ==================== + + +class TestJobIsolation: + """User A's jobs are invisible to User B.""" + + def test_user_a_get_own_job(self, client, mock_db): + """GET /jobs/{id} scoped to User A returns the job.""" + mock_db.get_job.return_value = None # mock via _get_job_db + + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.get_job.return_value = { + "job_id": "j1", + "status": "completed", + "progress": 100, + "total_companies": 1, + "completed_companies": 1, + "result_json": None, + "error": None, + "owner_id": USER_A["id"], + } + mock_get_db.return_value = job_db + + response = client.get("/jobs/j1", headers=_header_for(USER_A)) + assert response.status_code == 200 + job_db.get_job.assert_called_with("j1", owner_id=USER_A["id"]) + + def test_user_b_cannot_see_user_a_job(self, client, mock_db): + """GET /jobs/{id} returns 404 when User B tries to access User A's job.""" + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.get_job.return_value = None # not found for B's owner_id + mock_get_db.return_value = job_db + + response = client.get("/jobs/j1", headers=_header_for(USER_B)) + assert response.status_code == 404 + job_db.get_job.assert_called_with("j1", owner_id=USER_B["id"]) + + def test_list_jobs_scoped_to_user(self, client, mock_db): + """GET /jobs filters by owner_id.""" + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.list_jobs.return_value = [] + mock_get_db.return_value = job_db + + response = client.get("/jobs", headers=_header_for(USER_A)) + assert response.status_code == 200 + call_kwargs = job_db.list_jobs.call_args + assert call_kwargs.kwargs.get("owner_id") == USER_A["id"] + + def test_async_job_created_with_owner(self, client, mock_db): + """POST /analyze/batch/async creates job with current user's owner_id.""" + mock_analyzer = MagicMock() + with patch("SPARC.api._analyzer", mock_analyzer), \ + patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.create_job.return_value = { + "job_id": "j2", + "status": "pending", + "progress": 0, + "total_companies": 1, + "completed_companies": 0, + "result_json": None, + "error": None, + "owner_id": USER_A["id"], + } + mock_get_db.return_value = job_db + + response = client.post( + "/analyze/batch/async", + json={"companies": ["nvidia"]}, + headers=_header_for(USER_A), + ) + assert response.status_code == 200 + create_kwargs = job_db.create_job.call_args + assert create_kwargs.kwargs.get("owner_id") == USER_A["id"] + + +# ==================== Analysis Listing Isolation ==================== + + +class TestAnalysisListIsolation: + """GET /analyze/batch scoped to current user.""" + + def test_list_analyses_scoped_to_user(self, client, mock_db): + """GET /analyze/batch passes owner_id to db.list_analyses.""" + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.list_analyses.return_value = [] + mock_get_db.return_value = job_db + + response = client.get("/analyze/batch", headers=_header_for(USER_A)) + assert response.status_code == 200 + call_kwargs = job_db.list_analyses.call_args + assert call_kwargs.kwargs.get("owner_id") == USER_A["id"] + + +# ==================== Admin Cross-Tenant Access ==================== + + +class TestAdminCrossTenantAccess: + """Admin endpoints return data from all tenants (no owner_id filter).""" + + def test_admin_list_tracked_all_tenants(self, client, mock_db): + """GET /admin/tracked returns all companies (no owner_id filter).""" + mock_db.list_tracked_companies.return_value = [ + {"company_name": "AliceCo", "owner_id": 10}, + {"company_name": "BobCo", "owner_id": 20}, + ] + + response = client.get("/admin/tracked", headers=_header_for(ADMIN)) + assert response.status_code == 200 + # Should be called without owner_id filter + mock_db.list_tracked_companies.assert_called_with() + + def test_admin_list_analyses_all_tenants(self, client, mock_db): + """GET /admin/analyses returns all analyses (no owner_id filter).""" + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.list_analyses.return_value = [] + mock_get_db.return_value = job_db + + response = client.get("/admin/analyses", headers=_header_for(ADMIN)) + assert response.status_code == 200 + call_kwargs = job_db.list_analyses.call_args + # No owner_id should be passed + assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None + + def test_admin_list_jobs_all_tenants(self, client, mock_db): + """GET /admin/jobs returns all jobs (no owner_id filter).""" + with patch("SPARC.api._get_job_db") as mock_get_db: + job_db = MagicMock() + job_db.list_jobs.return_value = [] + mock_get_db.return_value = job_db + + response = client.get("/admin/jobs", headers=_header_for(ADMIN)) + assert response.status_code == 200 + call_kwargs = job_db.list_jobs.call_args + assert "owner_id" not in call_kwargs.kwargs or call_kwargs.kwargs["owner_id"] is None + + def test_admin_remove_tracked_any_owner(self, client, mock_db): + """DELETE /admin/tracked/{name} removes without owner filter.""" + mock_db.remove_tracked_company.return_value = True + + response = client.delete("/admin/tracked/SomeCo", headers=_header_for(ADMIN)) + assert response.status_code == 200 + # Called without owner_id + mock_db.remove_tracked_company.assert_called_with("SomeCo") + + def test_regular_user_cannot_access_admin_analyses(self, client, mock_db): + """Regular user gets 403 on /admin/analyses.""" + response = client.get("/admin/analyses", headers=_header_for(USER_A)) + assert response.status_code == 403 + + def test_regular_user_cannot_access_admin_jobs(self, client, mock_db): + """Regular user gets 403 on /admin/jobs.""" + response = client.get("/admin/jobs", headers=_header_for(USER_A)) + assert response.status_code == 403 + + +# ==================== Analytics Isolation ==================== + + +class TestAnalyticsIsolation: + """GET /analytics scoped to current user.""" + + def test_analytics_scoped_to_user(self, client, mock_db): + """GET /analytics passes owner_id to db.get_analytics.""" + mock_db.get_analytics.return_value = { + "total_messages": 5, + "by_company": [], + "by_type": [], + "period_days": 30, + } + + response = client.get("/analytics", headers=_header_for(USER_A)) + assert response.status_code == 200 + mock_db.get_analytics.assert_called_with(days=30, owner_id=USER_A["id"]) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 01bc5b3..0f3cc2d 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,12 +1,13 @@ """Tests for cursor-based pagination on /analyze/batch GET and /jobs endpoints.""" -from datetime import datetime, timedelta -from unittest.mock import Mock, patch +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock, MagicMock, patch import pytest from fastapi.testclient import TestClient from SPARC.api import app +from SPARC.auth import create_access_token @pytest.fixture @@ -15,6 +16,27 @@ def client(): return TestClient(app) +@pytest.fixture(autouse=True) +def mock_db(): + """Mock the database client used by auth endpoints.""" + db = MagicMock() + db.get_user_by_id.return_value = { + "id": 1, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + + with patch("SPARC.api.get_db_client", return_value=db), \ + patch("SPARC.auth.get_db_client", return_value=db): + yield db + + +def _auth_header(): + token = create_access_token(1, "user@test.com", "user") + return {"Authorization": f"Bearer {token}"} + + def _make_analysis_row(id_: int, minutes_ago: int = 0, company: str = "nvidia"): """Create a fake analysis row dict.""" ts = datetime.now() - timedelta(minutes=minutes_ago) @@ -56,7 +78,7 @@ class TestAnalyzeBatchGetPagination: ] mock_get_db.return_value = db - response = client.get("/analyze/batch?limit=10") + response = client.get("/analyze/batch?limit=10", headers=_auth_header()) assert response.status_code == 200 data = response.json() assert len(data["items"]) == 2 @@ -71,7 +93,7 @@ class TestAnalyzeBatchGetPagination: db.list_analyses.return_value = rows mock_get_db.return_value = db - response = client.get("/analyze/batch?limit=3") + response = client.get("/analyze/batch?limit=3", headers=_auth_header()) assert response.status_code == 200 data = response.json() assert len(data["items"]) == 3 @@ -84,7 +106,7 @@ class TestAnalyzeBatchGetPagination: db.list_analyses.return_value = [] mock_get_db.return_value = db - client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42") + client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42", headers=_auth_header()) db.list_analyses.assert_called_once() call_kwargs = db.list_analyses.call_args assert call_kwargs.kwargs.get("cursor") == "2025-01-01T00:00:00|42" or \ @@ -97,19 +119,19 @@ class TestAnalyzeBatchGetPagination: db.list_analyses.return_value = [] mock_get_db.return_value = db - client.get("/analyze/batch") + client.get("/analyze/batch", headers=_auth_header()) call_kwargs = db.list_analyses.call_args # The endpoint requests limit+1 from DB, so 51 assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51 def test_limit_over_200_rejected(self, client): """Limit > 200 should be rejected with 422.""" - response = client.get("/analyze/batch?limit=201") + response = client.get("/analyze/batch?limit=201", headers=_auth_header()) assert response.status_code == 422 def test_limit_zero_rejected(self, client): """Limit < 1 should be rejected with 422.""" - response = client.get("/analyze/batch?limit=0") + response = client.get("/analyze/batch?limit=0", headers=_auth_header()) assert response.status_code == 422 @patch("SPARC.api._get_job_db") @@ -119,7 +141,7 @@ class TestAnalyzeBatchGetPagination: db.list_analyses.return_value = [] mock_get_db.return_value = db - client.get("/analyze/batch?company_name=intel") + client.get("/analyze/batch?company_name=intel", headers=_auth_header()) call_kwargs = db.list_analyses.call_args assert call_kwargs.kwargs.get("company_name") == "intel" or \ "intel" in (call_kwargs.args if call_kwargs.args else []) @@ -131,7 +153,7 @@ class TestAnalyzeBatchGetPagination: db.list_analyses.return_value = [] mock_get_db.return_value = db - response = client.get("/analyze/batch") + response = client.get("/analyze/batch", headers=_auth_header()) assert response.status_code == 200 data = response.json() assert data["items"] == [] @@ -148,14 +170,14 @@ class TestJobsPaginationDefaults: db.list_jobs.return_value = [] mock_get_db.return_value = db - client.get("/jobs") + client.get("/jobs", headers=_auth_header()) call_kwargs = db.list_jobs.call_args # Endpoint requests limit+1 from DB, so 51 assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51 def test_limit_over_200_rejected(self, client): """Limit > 200 should be rejected with 422.""" - response = client.get("/jobs?limit=201") + response = client.get("/jobs?limit=201", headers=_auth_header()) assert response.status_code == 422 @patch("SPARC.api._get_job_db") @@ -165,5 +187,5 @@ class TestJobsPaginationDefaults: db.list_jobs.return_value = [] mock_get_db.return_value = db - response = client.get("/jobs?limit=200") + response = client.get("/jobs?limit=200", headers=_auth_header()) assert response.status_code == 200 diff --git a/tests/test_tracked_companies.py b/tests/test_tracked_companies.py index df25134..4aec720 100644 --- a/tests/test_tracked_companies.py +++ b/tests/test_tracked_companies.py @@ -1,17 +1,18 @@ -"""Tests for tracked company admin endpoints and scheduler integration. +"""Tests for tracked company endpoints and scheduler integration. -Covers issue #1656: -- GET /admin/tracked (list tracked companies) -- POST /admin/tracked (add a tracked company) -- DELETE /admin/tracked/{company_name} (remove a tracked company) +Covers: +- GET /tracked (user-scoped list) +- POST /tracked (user-scoped add) +- DELETE /tracked/{company_name} (user-scoped remove) +- GET /admin/tracked (admin: all companies) +- POST /admin/tracked (admin: add) +- DELETE /admin/tracked/{company_name} (admin: remove any) - GET /admin/alerts (list alerts) - scheduler.run_scheduled_analysis() integration - -All tests mock the database layer and use JWT auth fixtures. """ from datetime import datetime, timezone -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient @@ -125,7 +126,7 @@ class TestAddTrackedCompany: assert response.status_code == 200 data = response.json() assert data["company_name"] == "Intel" - mock_db.add_tracked_company.assert_called_once_with("Intel") + mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=1) def test_add_duplicate_returns_409(self, client, mock_db): """Adding an already-tracked company returns 409.""" @@ -141,7 +142,7 @@ class TestAddTrackedCompany: assert "already tracked" in response.json()["detail"].lower() def test_add_tracked_requires_admin(self, client, mock_db): - """Regular user cannot add tracked companies.""" + """Regular user cannot add tracked companies via admin endpoint.""" mock_db.get_user_by_id.return_value = { "id": 2, "email": "user@test.com", @@ -215,6 +216,66 @@ class TestRemoveTrackedCompany: assert response.status_code == 403 +# ---------- User-scoped tracked companies ---------- + +class TestUserScopedTrackedCompanies: + """Tests for /tracked user-scoped endpoints.""" + + def test_user_list_tracked(self, client, mock_db): + """Regular user can list their own tracked companies.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + mock_db.list_tracked_companies.return_value = [ + {"company_name": "AMD", "owner_id": 2}, + ] + + response = client.get("/tracked", headers=_user_header()) + + assert response.status_code == 200 + mock_db.list_tracked_companies.assert_called_with(owner_id=2) + + def test_user_add_tracked(self, client, mock_db): + """Regular user can add a company to their own tracked list.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + mock_db.add_tracked_company.return_value = { + "company_name": "Intel", + "owner_id": 2, + } + + response = client.post( + "/tracked", + json={"company_name": "Intel"}, + headers=_user_header(), + ) + + assert response.status_code == 200 + mock_db.add_tracked_company.assert_called_once_with("Intel", owner_id=2) + + def test_user_remove_tracked(self, client, mock_db): + """Regular user can remove a company from their own tracked list.""" + mock_db.get_user_by_id.return_value = { + "id": 2, + "email": "user@test.com", + "role": "user", + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + mock_db.remove_tracked_company.return_value = True + + response = client.delete("/tracked/Intel", headers=_user_header()) + + assert response.status_code == 200 + mock_db.remove_tracked_company.assert_called_once_with("Intel", owner_id=2) + + # ---------- GET /admin/alerts ---------- class TestListAlerts: