fix: enforce max_length=128 and validate GET /analyze/batch filter

Closes leeworks-agents/SPARC#1685

- Increase CompanyName max_length from 100 to 128 everywhere (Pydantic
  type, Path constraints, and the inline Query on analyze/patent).
- Add _COMPANY_NAME_FILTER_QUERY reusable Query annotation and apply it
  to the optional company_name filter on GET /analyze/batch so it is
  validated with the same rules as all other endpoints.
- Update tests: rename test_over_100_chars_rejected → 128, add
  test_exactly_128_chars_accepted at the new boundary, fix batch
  too-long test to use 129 chars, update valid-name parametrize to use
  "A"*128, and add five new tests covering GET /analyze/batch filter
  validation (special chars, too-short, too-long, valid, omitted).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
agent-company
2026-05-19 15:18:09 +00:00
parent 313800215c
commit 5c25a0f589
2 changed files with 58 additions and 13 deletions
+20 -8
View File
@@ -36,16 +36,28 @@ from SPARC.auth import (
) )
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
# Validated company name type: 2-100 chars, alphanumeric + spaces/hyphens/ampersands/periods only. # Validated company name type: 2-128 chars, alphanumeric + spaces/hyphens/ampersands/periods only.
CompanyName = Annotated[ CompanyName = Annotated[
str, str,
StringConstraints( StringConstraints(
min_length=2, min_length=2,
max_length=100, max_length=128,
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
), ),
] ]
# Reusable Query constraint for optional company_name filter parameters.
_COMPANY_NAME_FILTER_QUERY = Query(
default=None,
min_length=2,
max_length=128,
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
description=(
"Company name filter (2-128 chars; alphanumeric, spaces, hyphens, "
"periods, and ampersands only)"
),
)
# Pydantic models for API # Pydantic models for API
class CompanyAnalysisResponse(BaseModel): class CompanyAnalysisResponse(BaseModel):
@@ -489,7 +501,7 @@ async def add_tracked_company(
@app.delete("/admin/tracked/{company_name}", tags=["Admin"]) @app.delete("/admin/tracked/{company_name}", tags=["Admin"])
async def remove_tracked_company( async def remove_tracked_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_admin), _: UserResponse = Depends(get_current_admin),
): ):
"""Remove a company from the tracked list (admin only).""" """Remove a company from the tracked list (admin only)."""
@@ -677,7 +689,7 @@ async def get_analytics_trends(
@app.get("/export/{company_name}", tags=["Export"]) @app.get("/export/{company_name}", tags=["Export"])
async def export_company_csv( async def export_company_csv(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
"""Export analysis results for a company as a CSV file. """Export analysis results for a company as a CSV file.
@@ -729,7 +741,7 @@ async def export_company_csv(
@app.get("/export/{company_name}/pdf", tags=["Export"]) @app.get("/export/{company_name}/pdf", tags=["Export"])
async def export_company_pdf( async def export_company_pdf(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
"""Export analysis results for a company as a formatted PDF report. """Export analysis results for a company as a formatted PDF report.
@@ -903,7 +915,7 @@ async def health_check():
tags=["Analysis"], tags=["Analysis"],
) )
async def analyze_company( async def analyze_company(
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")], company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."), model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
@@ -933,7 +945,7 @@ async def analyze_company(
) )
async def analyze_single_patent( async def analyze_single_patent(
patent_id: str, patent_id: str,
company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")], company_name: Annotated[str, Query(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
_: UserResponse = Depends(get_current_user), _: UserResponse = Depends(get_current_user),
): ):
"""Analyze a single patent by its publication ID. """Analyze a single patent by its publication ID.
@@ -967,7 +979,7 @@ async def analyze_single_patent(
async def list_analysis_results( async def list_analysis_results(
company_name: Annotated[ company_name: Annotated[
str | None, str | None,
Query(description="Filter results by company name"), _COMPANY_NAME_FILTER_QUERY,
] = None, ] = None,
limit: Annotated[int, Query(ge=1, le=200)] = 50, limit: Annotated[int, Query(ge=1, le=200)] = 50,
cursor: Annotated[ cursor: Annotated[
+38 -5
View File
@@ -43,12 +43,18 @@ class TestCompanyNameValidation:
# --- Too long --- # --- Too long ---
def test_over_100_chars_rejected(self, client, mock_analyzer): def test_over_128_chars_rejected(self, client, mock_analyzer):
"""A company name longer than 100 characters should be rejected.""" """A company name longer than 128 characters should be rejected."""
long_name = "A" * 101 long_name = "A" * 129
response = client.get(f"/analyze/{long_name}") response = client.get(f"/analyze/{long_name}")
assert response.status_code == 422 assert response.status_code == 422
def test_exactly_128_chars_accepted(self, client, mock_analyzer):
"""A company name of exactly 128 characters should be accepted."""
max_name = "A" * 128
response = client.get(f"/analyze/{max_name}")
assert response.status_code != 422
# --- Special characters --- # --- Special characters ---
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -95,7 +101,7 @@ class TestCompanyNameValidation:
"3M", "3M",
"21st Century Fox", "21st Century Fox",
"ab", # minimum length "ab", # minimum length
"A" * 100, # maximum length "A" * 128, # maximum length
], ],
) )
def test_valid_names_accepted(self, client, mock_analyzer, valid_name): def test_valid_names_accepted(self, client, mock_analyzer, valid_name):
@@ -118,7 +124,7 @@ class TestCompanyNameValidation:
"""Batch endpoint should reject company names that are too long.""" """Batch endpoint should reject company names that are too long."""
response = client.post( response = client.post(
"/analyze/batch", "/analyze/batch",
json={"companies": ["A" * 101]}, json={"companies": ["A" * 129]},
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -155,3 +161,30 @@ class TestCompanyNameValidation:
json={"companies": ["-nvidia"]}, json={"companies": ["-nvidia"]},
) )
assert response.status_code == 422 assert response.status_code == 422
# --- GET /analyze/batch company_name filter validation ---
def test_batch_filter_special_chars_rejected(self, client, mock_analyzer):
"""GET /analyze/batch company_name filter rejects disallowed chars."""
response = client.get("/analyze/batch", params={"company_name": "nvidia!"})
assert response.status_code == 422
def test_batch_filter_too_short_rejected(self, client, mock_analyzer):
"""GET /analyze/batch company_name filter rejects names under 2 chars."""
response = client.get("/analyze/batch", params={"company_name": "X"})
assert response.status_code == 422
def test_batch_filter_too_long_rejected(self, client, mock_analyzer):
"""GET /analyze/batch company_name filter rejects names over 128 chars."""
response = client.get("/analyze/batch", params={"company_name": "A" * 129})
assert response.status_code == 422
def test_batch_filter_valid_name_accepted(self, client, mock_analyzer):
"""GET /analyze/batch company_name filter accepts a valid name."""
response = client.get("/analyze/batch", params={"company_name": "nvidia"})
assert response.status_code != 422
def test_batch_filter_omitted_accepted(self, client, mock_analyzer):
"""GET /analyze/batch without company_name filter should work fine."""
response = client.get("/analyze/batch")
assert response.status_code != 422