forked from 0xWheatyz/SPARC
fix: enforce max_length=128 and validate GET /analyze/batch company_name filter #1688
+20
-8
@@ -36,16 +36,28 @@ from SPARC.auth import (
|
||||
)
|
||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||
|
||||
# Validated company name type: 2-100 chars, alphanumeric + spaces/hyphens/ampersands/periods only.
|
||||
# Validated company name type: 2-128 chars, alphanumeric + spaces/hyphens/ampersands/periods only.
|
||||
CompanyName = Annotated[
|
||||
str,
|
||||
StringConstraints(
|
||||
min_length=2,
|
||||
max_length=100,
|
||||
max_length=128,
|
||||
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
|
||||
),
|
||||
]
|
||||
|
||||
# Reusable Query constraint for optional company_name filter parameters.
|
||||
_COMPANY_NAME_FILTER_QUERY = Query(
|
||||
default=None,
|
||||
min_length=2,
|
||||
max_length=128,
|
||||
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
|
||||
description=(
|
||||
"Company name filter (2-128 chars; alphanumeric, spaces, hyphens, "
|
||||
"periods, and ampersands only)"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Pydantic models for API
|
||||
class CompanyAnalysisResponse(BaseModel):
|
||||
@@ -489,7 +501,7 @@ async def add_tracked_company(
|
||||
|
||||
@app.delete("/admin/tracked/{company_name}", tags=["Admin"])
|
||||
async def remove_tracked_company(
|
||||
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||
company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||
_: UserResponse = Depends(get_current_admin),
|
||||
):
|
||||
"""Remove a company from the tracked list (admin only)."""
|
||||
@@ -677,7 +689,7 @@ async def get_analytics_trends(
|
||||
|
||||
@app.get("/export/{company_name}", tags=["Export"])
|
||||
async def export_company_csv(
|
||||
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||
company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||
_: UserResponse = Depends(get_current_user),
|
||||
):
|
||||
"""Export analysis results for a company as a CSV file.
|
||||
@@ -729,7 +741,7 @@ async def export_company_csv(
|
||||
|
||||
@app.get("/export/{company_name}/pdf", tags=["Export"])
|
||||
async def export_company_pdf(
|
||||
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||
company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||
_: UserResponse = Depends(get_current_user),
|
||||
):
|
||||
"""Export analysis results for a company as a formatted PDF report.
|
||||
@@ -903,7 +915,7 @@ async def health_check():
|
||||
tags=["Analysis"],
|
||||
)
|
||||
async def analyze_company(
|
||||
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||
company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
|
||||
_: UserResponse = Depends(get_current_user),
|
||||
):
|
||||
@@ -933,7 +945,7 @@ async def analyze_company(
|
||||
)
|
||||
async def analyze_single_patent(
|
||||
patent_id: str,
|
||||
company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
|
||||
company_name: Annotated[str, Query(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
|
||||
_: UserResponse = Depends(get_current_user),
|
||||
):
|
||||
"""Analyze a single patent by its publication ID.
|
||||
@@ -967,7 +979,7 @@ async def analyze_single_patent(
|
||||
async def list_analysis_results(
|
||||
company_name: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter results by company name"),
|
||||
_COMPANY_NAME_FILTER_QUERY,
|
||||
] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
||||
cursor: Annotated[
|
||||
|
||||
@@ -43,12 +43,18 @@ class TestCompanyNameValidation:
|
||||
|
||||
# --- Too long ---
|
||||
|
||||
def test_over_100_chars_rejected(self, client, mock_analyzer):
|
||||
"""A company name longer than 100 characters should be rejected."""
|
||||
long_name = "A" * 101
|
||||
def test_over_128_chars_rejected(self, client, mock_analyzer):
|
||||
"""A company name longer than 128 characters should be rejected."""
|
||||
long_name = "A" * 129
|
||||
response = client.get(f"/analyze/{long_name}")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_exactly_128_chars_accepted(self, client, mock_analyzer):
|
||||
"""A company name of exactly 128 characters should be accepted."""
|
||||
max_name = "A" * 128
|
||||
response = client.get(f"/analyze/{max_name}")
|
||||
assert response.status_code != 422
|
||||
|
||||
# --- Special characters ---
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -95,7 +101,7 @@ class TestCompanyNameValidation:
|
||||
"3M",
|
||||
"21st Century Fox",
|
||||
"ab", # minimum length
|
||||
"A" * 100, # maximum length
|
||||
"A" * 128, # maximum length
|
||||
],
|
||||
)
|
||||
def test_valid_names_accepted(self, client, mock_analyzer, valid_name):
|
||||
@@ -118,7 +124,7 @@ class TestCompanyNameValidation:
|
||||
"""Batch endpoint should reject company names that are too long."""
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["A" * 101]},
|
||||
json={"companies": ["A" * 129]},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -155,3 +161,30 @@ class TestCompanyNameValidation:
|
||||
json={"companies": ["-nvidia"]},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
# --- GET /analyze/batch company_name filter validation ---
|
||||
|
||||
def test_batch_filter_special_chars_rejected(self, client, mock_analyzer):
|
||||
"""GET /analyze/batch company_name filter rejects disallowed chars."""
|
||||
response = client.get("/analyze/batch", params={"company_name": "nvidia!"})
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_batch_filter_too_short_rejected(self, client, mock_analyzer):
|
||||
"""GET /analyze/batch company_name filter rejects names under 2 chars."""
|
||||
response = client.get("/analyze/batch", params={"company_name": "X"})
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_batch_filter_too_long_rejected(self, client, mock_analyzer):
|
||||
"""GET /analyze/batch company_name filter rejects names over 128 chars."""
|
||||
response = client.get("/analyze/batch", params={"company_name": "A" * 129})
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_batch_filter_valid_name_accepted(self, client, mock_analyzer):
|
||||
"""GET /analyze/batch company_name filter accepts a valid name."""
|
||||
response = client.get("/analyze/batch", params={"company_name": "nvidia"})
|
||||
assert response.status_code != 422
|
||||
|
||||
def test_batch_filter_omitted_accepted(self, client, mock_analyzer):
|
||||
"""GET /analyze/batch without company_name filter should work fine."""
|
||||
response = client.get("/analyze/batch")
|
||||
assert response.status_code != 422
|
||||
|
||||
Reference in New Issue
Block a user