Compare commits

..

1 Commits

Author SHA1 Message Date
agent-company a95129904e Add stricter input validation for company names on analysis endpoints
Add a CompanyName validated type enforcing 2-100 character length and
allowing only alphanumeric characters, spaces, hyphens, ampersands, and
periods. Applied to all endpoints accepting company names: /analyze,
/analyze/patent, /analyze/batch, /admin/tracked, and /export.

Includes unit tests covering too-short, too-long, special character,
leading-character, and valid edge cases for both single and batch
endpoints.

Closes leeworks-agents/SPARC#1670

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-18 21:38:44 +00:00
3 changed files with 213 additions and 39 deletions
+37 -30
View File
@@ -81,50 +81,57 @@ Items that have been implemented and merged into main.
- ~~OpenAPI client generation.~~ TypeScript API client auto-generated from
FastAPI spec with CI freshness check.
### Resilience
- ~~`_jobs` dict is in-memory only.~~ Database-backed job persistence
implemented using `db.list_jobs()` and `mark_stale_jobs_failed()`. The
in-memory `_jobs` dict has been removed.
### Test coverage (P1/P2)
- ~~Export endpoint tests.~~ Tests added for CSV and PDF export endpoints.
- ~~Tracked company admin endpoint tests.~~ Tests added for `/admin/tracked`
CRUD endpoints and scheduler integration.
- ~~Webhook integration tests.~~ Tests added for retry logic, Slack/Discord
payload format, and multi-URL dispatch.
- ~~S3/MinIO storage backend tests.~~ Unit tests added for the S3 backend
(read, write, exists, delete, error handling).
- ~~`analyze_single_patent` auto-download path tests.~~ Tests added for the
auto-download fallback (cache lookup, PDF download, FileNotFoundError).
### Code quality
- ~~Scheduler creates its own DatabaseClient.~~ Refactored to use the
application-level pooled `get_db_client()`.
---
## P1 -- High Priority
No outstanding P1 items. All previously listed items have been completed and
moved to the Completed section above.
These items address correctness, reliability, and coverage gaps that should be
resolved before broader production use.
### Resilience
- **`_jobs` dict is in-memory only.** Job state is lost on API restart.
Persist job status in PostgreSQL or Redis so async batch results survive
restarts.
### Test coverage gaps
- **Export endpoint tests.** The CSV and PDF export endpoints (`/export/`)
lack test coverage. Add tests covering auth, success, 404, and edge cases.
*(Issue #1655)*
- **Tracked company admin endpoint tests.** The `/admin/tracked` CRUD
endpoints and scheduler integration lack test coverage. *(Issue #1656)*
---
## P2 -- Medium Priority
Improvements to the API surface.
Improvements to reliability, test coverage, and code quality.
### Test coverage
- **Webhook integration tests.** The retry logic, Slack/Discord payload
format, and multi-URL dispatch in `webhooks.py` need test coverage.
*(Issue #1657)*
- **S3/MinIO storage backend tests.** `storage.py` has local filesystem tests
but no unit tests for the S3 backend (read, write, exists, delete,
error handling). *(Issue #1660)*
- **`analyze_single_patent` auto-download path tests.** The auto-download
fallback (cache lookup, PDF download, FileNotFoundError) in
`analyzer.py` lacks test coverage. *(Issue #1661)*
### Code quality
- **Scheduler creates its own DatabaseClient.** `scheduler.py` bypasses the
application-level pooled client, creating a new connection on every tick.
Refactor to use `get_db_client()`. *(Issue #1658)*
### API improvements
- **API pagination.** The `/analyze/batch` endpoint needs cursor-based
pagination for large result sets. The `/jobs` endpoint already has cursor
pagination. *(Issue #1669)*
- **API pagination.** The `/analyze/batch` and `/jobs` endpoints could benefit
from cursor-based pagination for large result sets.
- **Request validation improvements.** Add stricter input validation for
company names (disallow special characters, enforce length limits).
*(Issue #1670)*
---
+19 -9
View File
@@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, Annotated, List
if TYPE_CHECKING:
from SPARC.database import DatabaseClient
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, Request
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Path, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, EmailStr, Field
from pydantic import BaseModel, EmailStr, Field, StringConstraints
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
@@ -36,6 +36,16 @@ from SPARC.auth import (
)
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
# Validated company name type: 2-100 chars, alphanumeric + spaces/hyphens/ampersands/periods only.
CompanyName = Annotated[
str,
StringConstraints(
min_length=2,
max_length=100,
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
),
]
# Pydantic models for API
class CompanyAnalysisResponse(BaseModel):
@@ -72,7 +82,7 @@ class CompanyAnalysisRequest(BaseModel):
class BatchAnalysisRequest(BaseModel):
"""Request model for batch company analysis."""
companies: list[str] = Field(
companies: list[CompanyName] = Field(
..., min_length=1, max_length=20, description="List of company names to analyze"
)
max_workers: int = Field(
@@ -405,7 +415,7 @@ async def delete_user(
class TrackCompanyRequest(BaseModel):
"""Request to add a company to tracking."""
company_name: str = Field(..., min_length=1, max_length=255)
company_name: CompanyName = Field(...)
@app.get("/admin/tracked", tags=["Admin"])
@@ -432,7 +442,7 @@ async def add_tracked_company(
@app.delete("/admin/tracked/{company_name}", tags=["Admin"])
async def remove_tracked_company(
company_name: str,
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_admin),
):
"""Remove a company from the tracked list (admin only)."""
@@ -590,7 +600,7 @@ async def get_analytics_trends(
@app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv(
company_name: str,
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a CSV file.
@@ -642,7 +652,7 @@ async def export_company_csv(
@app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf(
company_name: str,
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user),
):
"""Export analysis results for a company as a formatted PDF report.
@@ -816,7 +826,7 @@ async def health_check():
tags=["Analysis"],
)
async def analyze_company(
company_name: str,
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user),
):
@@ -846,7 +856,7 @@ async def analyze_company(
)
async def analyze_single_patent(
patent_id: str,
company_name: str = Query(description="Company name for analysis context"),
company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
_: UserResponse = Depends(get_current_user),
):
"""Analyze a single patent by its publication ID.
+157
View File
@@ -0,0 +1,157 @@
"""Tests for company name input validation on analysis endpoints."""
from datetime import datetime
from unittest.mock import Mock
import pytest
from fastapi.testclient import TestClient
from SPARC.api import app
from SPARC.types import CompanyAnalysisResult
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture
def mock_analyzer(mocker):
"""Mock the global analyzer so valid requests succeed."""
mock = Mock()
mock._analyze_company_safe.return_value = CompanyAnalysisResult(
company_name="nvidia",
analysis="Test analysis",
patent_count=1,
success=True,
timestamp=datetime.now(),
)
mocker.patch("SPARC.api._analyzer", mock)
return mock
class TestCompanyNameValidation:
"""Test that company names are validated on analysis endpoints."""
# --- Too short ---
def test_single_char_rejected(self, client, mock_analyzer):
"""A one-character company name should be rejected."""
response = client.get("/analyze/X")
assert response.status_code == 422
# --- Too long ---
def test_over_100_chars_rejected(self, client, mock_analyzer):
"""A company name longer than 100 characters should be rejected."""
long_name = "A" * 101
response = client.get(f"/analyze/{long_name}")
assert response.status_code == 422
# --- Special characters ---
@pytest.mark.parametrize(
"bad_name",
[
"nvidia!",
"intel@corp",
"test#company",
"foo$bar",
"a%b",
"x^y",
"semi;colon",
"drop'table",
'say"hello',
"path/traversal",
"back\\slash",
"pipe|char",
"star*glob",
"question?mark",
"<script>",
"curly{brace}",
"equal=sign",
"plus+plus",
"comma,separated",
],
)
def test_special_chars_rejected(self, client, mock_analyzer, bad_name):
"""Company names with disallowed special characters should be rejected."""
response = client.get(f"/analyze/{bad_name}")
assert response.status_code == 422
# --- Valid names ---
@pytest.mark.parametrize(
"valid_name",
[
"nvidia",
"Intel",
"TSMC",
"Texas Instruments",
"Johnson-Johnson",
"AT&T",
"St. Jude Medical",
"3M",
"21st Century Fox",
"ab", # minimum length
"A" * 100, # maximum length
],
)
def test_valid_names_accepted(self, client, mock_analyzer, valid_name):
"""Valid company names should be accepted (200, not 422)."""
response = client.get(f"/analyze/{valid_name}")
# Should not be a validation error; 200 or other non-422 status is fine
assert response.status_code != 422
# --- Batch endpoint validation ---
def test_batch_too_short_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names that are too short."""
response = client.post(
"/analyze/batch",
json={"companies": ["X"]},
)
assert response.status_code == 422
def test_batch_too_long_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names that are too long."""
response = client.post(
"/analyze/batch",
json={"companies": ["A" * 101]},
)
assert response.status_code == 422
def test_batch_special_chars_rejected(self, client, mock_analyzer):
"""Batch endpoint should reject company names with special chars."""
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia!", "intel"]},
)
assert response.status_code == 422
def test_batch_valid_names_accepted(self, client, mock_analyzer):
"""Batch endpoint should accept valid company names."""
response = client.post(
"/analyze/batch",
json={"companies": ["nvidia", "Intel", "AT&T"]},
)
assert response.status_code != 422
# --- Name must start with alphanumeric ---
def test_leading_space_rejected(self, client, mock_analyzer):
"""Company name starting with a space should be rejected."""
response = client.post(
"/analyze/batch",
json={"companies": [" nvidia"]},
)
assert response.status_code == 422
def test_leading_hyphen_rejected(self, client, mock_analyzer):
"""Company name starting with a hyphen should be rejected."""
response = client.post(
"/analyze/batch",
json={"companies": ["-nvidia"]},
)
assert response.status_code == 422