diff --git a/SPARC/api.py b/SPARC/api.py index a42ddd7..cf053e8 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, Annotated, List if TYPE_CHECKING: from SPARC.database import DatabaseClient -from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request +from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Path, Query, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse -from pydantic import BaseModel, EmailStr, Field +from pydantic import BaseModel, EmailStr, Field, StringConstraints from slowapi import Limiter from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address @@ -36,6 +36,16 @@ from SPARC.auth import ( ) from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult +# Validated company name type: 2-100 chars, alphanumeric + spaces/hyphens/ampersands/periods only. +CompanyName = Annotated[ + str, + StringConstraints( + min_length=2, + max_length=100, + pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", + ), +] + # Pydantic models for API class CompanyAnalysisResponse(BaseModel): @@ -72,7 +82,7 @@ class CompanyAnalysisRequest(BaseModel): class BatchAnalysisRequest(BaseModel): """Request model for batch company analysis.""" - companies: list[str] = Field( + companies: list[CompanyName] = Field( ..., min_length=1, max_length=20, description="List of company names to analyze" ) max_workers: int = Field( @@ -405,7 +415,7 @@ async def delete_user( class TrackCompanyRequest(BaseModel): """Request to add a company to tracking.""" - company_name: str = Field(..., min_length=1, max_length=255) + company_name: CompanyName = Field(...) @app.get("/admin/tracked", tags=["Admin"]) @@ -432,7 +442,7 @@ async def add_tracked_company( @app.delete("/admin/tracked/{company_name}", tags=["Admin"]) async def remove_tracked_company( - company_name: str, + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], _: UserResponse = Depends(get_current_admin), ): """Remove a company from the tracked list (admin only).""" @@ -590,7 +600,7 @@ async def get_analytics_trends( @app.get("/export/{company_name}", tags=["Export"]) async def export_company_csv( - company_name: str, + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], _: UserResponse = Depends(get_current_user), ): """Export analysis results for a company as a CSV file. @@ -642,7 +652,7 @@ async def export_company_csv( @app.get("/export/{company_name}/pdf", tags=["Export"]) async def export_company_pdf( - company_name: str, + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], _: UserResponse = Depends(get_current_user), ): """Export analysis results for a company as a formatted PDF report. @@ -816,7 +826,7 @@ async def health_check(): tags=["Analysis"], ) async def analyze_company( - company_name: str, + company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."), _: UserResponse = Depends(get_current_user), ): @@ -846,7 +856,7 @@ async def analyze_company( ) async def analyze_single_patent( patent_id: str, - company_name: str = Query(description="Company name for analysis context"), + company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")], _: UserResponse = Depends(get_current_user), ): """Analyze a single patent by its publication ID. diff --git a/tests/test_company_name_validation.py b/tests/test_company_name_validation.py new file mode 100644 index 0000000..3e6855f --- /dev/null +++ b/tests/test_company_name_validation.py @@ -0,0 +1,157 @@ +"""Tests for company name input validation on analysis endpoints.""" + +from datetime import datetime +from unittest.mock import Mock + +import pytest +from fastapi.testclient import TestClient + +from SPARC.api import app +from SPARC.types import CompanyAnalysisResult + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +@pytest.fixture +def mock_analyzer(mocker): + """Mock the global analyzer so valid requests succeed.""" + mock = Mock() + mock._analyze_company_safe.return_value = CompanyAnalysisResult( + company_name="nvidia", + analysis="Test analysis", + patent_count=1, + success=True, + timestamp=datetime.now(), + ) + mocker.patch("SPARC.api._analyzer", mock) + return mock + + +class TestCompanyNameValidation: + """Test that company names are validated on analysis endpoints.""" + + # --- Too short --- + + def test_single_char_rejected(self, client, mock_analyzer): + """A one-character company name should be rejected.""" + response = client.get("/analyze/X") + assert response.status_code == 422 + + # --- Too long --- + + def test_over_100_chars_rejected(self, client, mock_analyzer): + """A company name longer than 100 characters should be rejected.""" + long_name = "A" * 101 + response = client.get(f"/analyze/{long_name}") + assert response.status_code == 422 + + # --- Special characters --- + + @pytest.mark.parametrize( + "bad_name", + [ + "nvidia!", + "intel@corp", + "test#company", + "foo$bar", + "a%b", + "x^y", + "semi;colon", + "drop'table", + 'say"hello', + "path/traversal", + "back\\slash", + "pipe|char", + "star*glob", + "question?mark", + "