Compare commits

...

2 Commits

Author SHA1 Message Date
agent-company a95129904e Add stricter input validation for company names on analysis endpoints
Add a CompanyName validated type enforcing 2-100 character length and
allowing only alphanumeric characters, spaces, hyphens, ampersands, and
periods. Applied to all endpoints accepting company names: /analyze,
/analyze/patent, /analyze/batch, /admin/tracked, and /export.

Includes unit tests covering too-short, too-long, special character,
leading-character, and valid edge cases for both single and batch
endpoints.

Closes leeworks-agents/SPARC#1670

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:38:44 +00:00
AI-Manager 4c411e1e0b Merge pull request 'Add tests for tracked company admin endpoints and scheduler' (#1667) from feature/1656-tracked-company-admin-tests into main
Merge: Add tests for tracked company admin endpoints and scheduler integration

Closes #1656
2026-04-20 23:05:57 +00:00
2 changed files with 176 additions and 9 deletions
+19 -9
View File
@@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, Annotated, List
if TYPE_CHECKING: if TYPE_CHECKING:
from SPARC.database import DatabaseClient from SPARC.database import DatabaseClient
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Path, Query, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field, StringConstraints
from slowapi import Limiter from slowapi import Limiter
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
@@ -36,6 +36,16 @@ from SPARC.auth import (
) )
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
# Validated company name type: 2-100 chars, alphanumeric + spaces/hyphens/ampersands/periods only.
CompanyName = Annotated[
str,
StringConstraints(
min_length=2,
max_length=100,
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
),
]
# Pydantic models for API # Pydantic models for API
class CompanyAnalysisResponse(BaseModel): class CompanyAnalysisResponse(BaseModel):
@@ -72,7 +82,7 @@ class CompanyAnalysisRequest(BaseModel):
class BatchAnalysisRequest(BaseModel): class BatchAnalysisRequest(BaseModel):
"""Request model for batch company analysis.""" """Request model for batch company analysis."""
companies: list[str] = Field( companies: list[CompanyName] = Field(
..., min_length=1, max_length=20, description="List of company names to analyze" ..., min_length=1, max_length=20, description="List of company names to analyze"
) )
max_workers: int = Field( max_workers: int = Field(
@@ -405,7 +415,7 @@ async def delete_user(
class TrackCompanyRequest(BaseModel): class TrackCompanyRequest(BaseModel):
"""Request to add a company to tracking.""" """Request to add a company to tracking."""
company_name: str = Field(..., min_length=1, max_length=255) company_name: CompanyName = Field(...)
@app.get("/admin/tracked", tags=["Admin"]) @app.get("/admin/tracked", tags=["Admin"])
@@ -432,7 +442,7 @@ async def add_tracked_company(
@app.delete("/admin/tracked/{company_name}", tags=["Admin"]) @app.delete("/admin/tracked/{company_name}", tags=["Admin"])
async def remove_tracked_company( async def remove_tracked_company(
company_name: str, company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_admin), _: UserResponse = Depends(get_current_admin),
): ):
"""Remove a company from the tracked list (admin only).""" """Remove a company from the tracked list (admin only)."""
@@ -590,7 +600,7 @@ async def get_analytics_trends(
@app.get("/export/{company_name}", tags=["Export"]) @app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv( async def export_company_csv(
company_name: str, company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
"""Export analysis results for a company as a CSV file. """Export analysis results for a company as a CSV file.
@@ -642,7 +652,7 @@ async def export_company_csv(
@app.get("/export/{company_name}/pdf", tags=["Export"]) @app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf( async def export_company_pdf(
company_name: str, company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
"""Export analysis results for a company as a formatted PDF report. """Export analysis results for a company as a formatted PDF report.
@@ -816,7 +826,7 @@ async def health_check():
tags=["Analysis"], tags=["Analysis"],
) )
async def analyze_company( async def analyze_company(
company_name: str, company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."), model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
@@ -846,7 +856,7 @@ async def analyze_company(
) )
async def analyze_single_patent( async def analyze_single_patent(
patent_id: str, patent_id: str,
company_name: str = Query(description="Company name for analysis context"), company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
"""Analyze a single patent by its publication ID. """Analyze a single patent by its publication ID.
+157
View File
@@ -0,0 +1,157 @@
"""Tests for company name input validation on analysis endpoints."""
from datetime import datetime
from unittest.mock import Mock
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.types import CompanyAnalysisResult
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture
def mock_analyzer(mocker):
"""Mock the global analyzer so valid requests succeed."""
mock = Mock()
mock._analyze_company_safe.return_value = CompanyAnalysisResult(
company_name="nvidia",
analysis="Test analysis",
patent_count=1,
success=True,
timestamp=datetime.now(),
)
mocker.patch("SPARC.api._analyzer", mock)
return mock
class TestCompanyNameValidation:
"""Test that company names are validated on analysis endpoints."""
# --- Too short ---
def test_single_char_rejected(self, client, mock_analyzer):
"""A one-character company name should be rejected."""
response = client.get("/analyze/X")
assert response.status_code == 422
# --- Too long ---
def test_over_100_chars_rejected(self, client, mock_analyzer):
"""A company name longer than 100 characters should be rejected."""
long_name = "A" * 101
response = client.get(f"/analyze/{long_name}")
assert response.status_code == 422
# --- Special characters ---
@pytest.mark.parametrize(
"bad_name",
[
"nvidia!",
"intel@corp",
"test#company",
"foo$bar",
"a%b",
"x^y",
"semi;colon",
"drop'table",
'say"hello',
"path/traversal",
"back\\slash",
"pipe|char",
"star*glob",
"question?mark",
"<script>",
"curly{brace}",
"equal=sign",
"plus+plus",
"comma,separated",
],
)
def test_special_chars_rejected(self, client, mock_analyzer, bad_name):
"""Company names with disallowed special characters should be rejected."""
response = client.get(f"/analyze/{bad_name}")
assert response.status_code == 422
# --- Valid names ---
@pytest.mark.parametrize(
"valid_name",
[
"nvidia",
"Intel",
"TSMC",
"Texas Instruments",
"Johnson-Johnson",
"AT&T",
"St. Jude Medical",
"3M",
"21st Century Fox",
"ab", # minimum length
"A" * 100, # maximum length
],
)
def test_valid_names_accepted(self, client, mock_analyzer, valid_name):
"""Valid company names should be accepted (200, not 422)."""
response = client.get(f"/analyze/{valid_name}")
# Should not be a validation error; 200 or other non-422 status is fine
assert response.status_code != 422
# --- Batch endpoint validation ---
def test_batch_too_short_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names that are too short."""
response = client.post(
"/analyze/batch",
json={"companies": ["X"]},
)
assert response.status_code == 422
def test_batch_too_long_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names that are too long."""
response = client.post(
"/analyze/batch",
json={"companies": ["A" * 101]},
)
assert response.status_code == 422
def test_batch_special_chars_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names with special chars."""
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia!", "intel"]},
)
assert response.status_code == 422
def test_batch_valid_names_accepted(self, client, mock_analyzer):
"""Batch endpoint should accept valid company names."""
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia", "Intel", "AT&T"]},
)
assert response.status_code != 422
# --- Name must start with alphanumeric ---
def test_leading_space_rejected(self, client, mock_analyzer):
"""Company name starting with a space should be rejected."""
response = client.post(
"/analyze/batch",
json={"companies": [" nvidia"]},
)
assert response.status_code == 422
def test_leading_hyphen_rejected(self, client, mock_analyzer):
"""Company name starting with a hyphen should be rejected."""
response = client.post(
"/analyze/batch",
json={"companies": ["-nvidia"]},
)
assert response.status_code == 422