diff --git a/SPARC/api.py b/SPARC/api.py index cf053e8..cd5e6e6 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -106,6 +106,24 @@ class JobStatus(BaseModel): error: str | None = None +class AnalysisRecord(BaseModel): + """A single stored analysis result.""" + + id: int + company_name: str | None = None + analysis_type: str | None = None + model: str | None = None + response: str | None = None + timestamp: datetime | None = None + + +class PaginatedAnalysisResponse(BaseModel): + """Paginated response for analysis result listings.""" + + items: list[AnalysisRecord] + next_cursor: str | None = None + + class PaginatedJobsResponse(BaseModel): """Paginated response for job listings.""" @@ -882,6 +900,58 @@ async def analyze_single_patent( raise HTTPException(status_code=404, detail=str(e)) +@app.get( + "/analyze/batch", + response_model=PaginatedAnalysisResponse, + tags=["Analysis"], +) +async def list_analysis_results( + company_name: Annotated[ + str | None, + Query(description="Filter results by company name"), + ] = None, + limit: Annotated[int, Query(ge=1, le=200)] = 50, + cursor: Annotated[ + str | None, + Query(description="Opaque cursor from a previous response's next_cursor field"), + ] = None, + _: UserResponse = Depends(get_current_user), +): + """List stored analysis results with cursor-based pagination. + + Returns past analysis results ordered by timestamp descending. Use + ``limit`` to control page size (default 50, max 200). The response + includes a ``next_cursor`` field; pass it back as the ``cursor`` query + parameter to fetch the next page. When ``next_cursor`` is ``null``, + there are no more results. + + Args: + company_name: Optional filter by company name + limit: Maximum number of results to return (default 50, max 200) + cursor: Opaque pagination cursor from a previous response + + Returns: + Paginated list of analysis results + """ + db = _get_job_db() + rows = db.list_analyses(company_name=company_name, limit=limit + 1, cursor=cursor) + + has_next = len(rows) > limit + if has_next: + rows = rows[:limit] + + items = [AnalysisRecord(**row) for row in rows] + + next_cursor = None + if has_next and rows: + last = rows[-1] + ts = last["timestamp"] + ts_str = ts.isoformat() if hasattr(ts, "isoformat") else str(ts) + next_cursor = f"{ts_str}|{last['id']}" + + return PaginatedAnalysisResponse(items=items, next_cursor=next_cursor) + + @app.post( "/analyze/batch", response_model=BatchAnalysisResponse, @@ -1057,7 +1127,7 @@ async def list_jobs( str | None, Query(description="Filter by status: pending, running, completed, failed"), ] = None, - limit: Annotated[int, Query(ge=1, le=100)] = 10, + limit: Annotated[int, Query(ge=1, le=200)] = 50, cursor: Annotated[ str | None, Query(description="Opaque cursor from a previous response's next_cursor field"), diff --git a/SPARC/database.py b/SPARC/database.py index 24c7081..0759a66 100644 --- a/SPARC/database.py +++ b/SPARC/database.py @@ -371,6 +371,48 @@ class DatabaseClient: cursor.execute(query, params) return [dict(row) for row in cursor.fetchall()] + def list_analyses( + self, + company_name: Optional[str] = None, + limit: int = 50, + cursor: Optional[str] = None, + ) -> List[Dict]: + """List analysis results with cursor-based pagination. + + Args: + company_name: Optional filter by company name. + limit: Maximum number of records to return. + cursor: Opaque cursor (``timestamp|id``) from a previous response. + + Returns: + List of analysis dicts ordered by timestamp descending. + """ + conditions: list[str] = ["is_cached = FALSE"] + params: list = [] + + if company_name: + conditions.append("LOWER(company_name) = LOWER(%s)") + params.append(company_name) + + if cursor: + try: + ts_str, cursor_id = cursor.rsplit("|", 1) + conditions.append("(timestamp, id) < (%s, %s)") + params.extend([ts_str, int(cursor_id)]) + except (ValueError, TypeError): + pass # Ignore malformed cursors; return from start + + query = "SELECT id, company_name, analysis_type, model, response, timestamp FROM llm_messages" + if conditions: + query += " WHERE " + " AND ".join(conditions) + query += " ORDER BY timestamp DESC, id DESC LIMIT %s" + params.append(limit) + + with self.get_conn() as conn: + with conn.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute(query, params) + return [dict(row) for row in cur.fetchall()] + def get_analytics(self, days: int = 30) -> Dict: """Get analytics on message usage. diff --git a/tests/test_pagination.py b/tests/test_pagination.py new file mode 100644 index 0000000..01bc5b3 --- /dev/null +++ b/tests/test_pagination.py @@ -0,0 +1,169 @@ +"""Tests for cursor-based pagination on /analyze/batch GET and /jobs endpoints.""" + +from datetime import datetime, timedelta +from unittest.mock import Mock, patch + +import pytest +from fastapi.testclient import TestClient + +from SPARC.api import app + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +def _make_analysis_row(id_: int, minutes_ago: int = 0, company: str = "nvidia"): + """Create a fake analysis row dict.""" + ts = datetime.now() - timedelta(minutes=minutes_ago) + return { + "id": id_, + "company_name": company, + "analysis_type": "patent_portfolio", + "model": "openai/gpt-4o", + "response": f"Analysis for {company}", + "timestamp": ts, + } + + +def _make_job_row(job_id: str, minutes_ago: int = 0, status: str = "completed"): + """Create a fake job row dict.""" + ts = datetime.now() - timedelta(minutes=minutes_ago) + return { + "job_id": job_id, + "status": status, + "progress": 100 if status == "completed" else 0, + "total_companies": 1, + "completed_companies": 1 if status == "completed" else 0, + "result": None, + "error": None, + "created_at": ts, + } + + +class TestAnalyzeBatchGetPagination: + """Test cursor-based pagination on GET /analyze/batch.""" + + @patch("SPARC.api._get_job_db") + def test_returns_items_and_no_cursor_when_less_than_limit(self, mock_get_db, client): + """When fewer results than limit, next_cursor should be null.""" + db = Mock() + db.list_analyses.return_value = [ + _make_analysis_row(1, minutes_ago=10), + _make_analysis_row(2, minutes_ago=20), + ] + mock_get_db.return_value = db + + response = client.get("/analyze/batch?limit=10") + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 2 + assert data["next_cursor"] is None + + @patch("SPARC.api._get_job_db") + def test_returns_cursor_when_more_results_exist(self, mock_get_db, client): + """When more results exist than limit, next_cursor should be set.""" + db = Mock() + # Return limit+1 rows to simulate more data + rows = [_make_analysis_row(i, minutes_ago=i) for i in range(4)] + db.list_analyses.return_value = rows + mock_get_db.return_value = db + + response = client.get("/analyze/batch?limit=3") + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 3 + assert data["next_cursor"] is not None + + @patch("SPARC.api._get_job_db") + def test_cursor_passed_to_db(self, mock_get_db, client): + """The cursor query param should be forwarded to the database layer.""" + db = Mock() + db.list_analyses.return_value = [] + mock_get_db.return_value = db + + client.get("/analyze/batch?cursor=2025-01-01T00:00:00|42") + db.list_analyses.assert_called_once() + call_kwargs = db.list_analyses.call_args + assert call_kwargs.kwargs.get("cursor") == "2025-01-01T00:00:00|42" or \ + (call_kwargs[1].get("cursor") == "2025-01-01T00:00:00|42" if len(call_kwargs) > 1 else False) + + @patch("SPARC.api._get_job_db") + def test_default_limit_is_50(self, mock_get_db, client): + """Default limit should be 50.""" + db = Mock() + db.list_analyses.return_value = [] + mock_get_db.return_value = db + + client.get("/analyze/batch") + call_kwargs = db.list_analyses.call_args + # The endpoint requests limit+1 from DB, so 51 + assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51 + + def test_limit_over_200_rejected(self, client): + """Limit > 200 should be rejected with 422.""" + response = client.get("/analyze/batch?limit=201") + assert response.status_code == 422 + + def test_limit_zero_rejected(self, client): + """Limit < 1 should be rejected with 422.""" + response = client.get("/analyze/batch?limit=0") + assert response.status_code == 422 + + @patch("SPARC.api._get_job_db") + def test_company_name_filter(self, mock_get_db, client): + """The company_name filter should be forwarded to the database.""" + db = Mock() + db.list_analyses.return_value = [] + mock_get_db.return_value = db + + client.get("/analyze/batch?company_name=intel") + call_kwargs = db.list_analyses.call_args + assert call_kwargs.kwargs.get("company_name") == "intel" or \ + "intel" in (call_kwargs.args if call_kwargs.args else []) + + @patch("SPARC.api._get_job_db") + def test_empty_result_set(self, mock_get_db, client): + """Empty result set returns empty items and null cursor.""" + db = Mock() + db.list_analyses.return_value = [] + mock_get_db.return_value = db + + response = client.get("/analyze/batch") + assert response.status_code == 200 + data = response.json() + assert data["items"] == [] + assert data["next_cursor"] is None + + +class TestJobsPaginationDefaults: + """Test that /jobs endpoint uses updated defaults.""" + + @patch("SPARC.api._get_job_db") + def test_default_limit_is_50(self, mock_get_db, client): + """Default limit should now be 50.""" + db = Mock() + db.list_jobs.return_value = [] + mock_get_db.return_value = db + + client.get("/jobs") + call_kwargs = db.list_jobs.call_args + # Endpoint requests limit+1 from DB, so 51 + assert 51 in call_kwargs.args or call_kwargs.kwargs.get("limit") == 51 + + def test_limit_over_200_rejected(self, client): + """Limit > 200 should be rejected with 422.""" + response = client.get("/jobs?limit=201") + assert response.status_code == 422 + + @patch("SPARC.api._get_job_db") + def test_limit_200_accepted(self, mock_get_db, client): + """Limit of exactly 200 should be accepted.""" + db = Mock() + db.list_jobs.return_value = [] + mock_get_db.return_value = db + + response = client.get("/jobs?limit=200") + assert response.status_code == 200