forked from 0xWheatyz/SPARC
Add stricter input validation for company names on analysis endpoints
Add a CompanyName validated type enforcing 2-100 character length and allowing only alphanumeric characters, spaces, hyphens, ampersands, and periods. Applied to all endpoints accepting company names: /analyze, /analyze/patent, /analyze/batch, /admin/tracked, and /export. Includes unit tests covering too-short, too-long, special character, leading-character, and valid edge cases for both single and batch endpoints. Closes leeworks-agents/SPARC#1670 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,157 @@
|
||||
"""Tests for company name input validation on analysis endpoints."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from SPARC.api import app
|
||||
from SPARC.types import CompanyAnalysisResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_analyzer(mocker):
|
||||
"""Mock the global analyzer so valid requests succeed."""
|
||||
mock = Mock()
|
||||
mock._analyze_company_safe.return_value = CompanyAnalysisResult(
|
||||
company_name="nvidia",
|
||||
analysis="Test analysis",
|
||||
patent_count=1,
|
||||
success=True,
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
mocker.patch("SPARC.api._analyzer", mock)
|
||||
return mock
|
||||
|
||||
|
||||
class TestCompanyNameValidation:
|
||||
"""Test that company names are validated on analysis endpoints."""
|
||||
|
||||
# --- Too short ---
|
||||
|
||||
def test_single_char_rejected(self, client, mock_analyzer):
|
||||
"""A one-character company name should be rejected."""
|
||||
response = client.get("/analyze/X")
|
||||
assert response.status_code == 422
|
||||
|
||||
# --- Too long ---
|
||||
|
||||
def test_over_100_chars_rejected(self, client, mock_analyzer):
|
||||
"""A company name longer than 100 characters should be rejected."""
|
||||
long_name = "A" * 101
|
||||
response = client.get(f"/analyze/{long_name}")
|
||||
assert response.status_code == 422
|
||||
|
||||
# --- Special characters ---
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad_name",
|
||||
[
|
||||
"nvidia!",
|
||||
"intel@corp",
|
||||
"test#company",
|
||||
"foo$bar",
|
||||
"a%b",
|
||||
"x^y",
|
||||
"semi;colon",
|
||||
"drop'table",
|
||||
'say"hello',
|
||||
"path/traversal",
|
||||
"back\\slash",
|
||||
"pipe|char",
|
||||
"star*glob",
|
||||
"question?mark",
|
||||
"<script>",
|
||||
"curly{brace}",
|
||||
"equal=sign",
|
||||
"plus+plus",
|
||||
"comma,separated",
|
||||
],
|
||||
)
|
||||
def test_special_chars_rejected(self, client, mock_analyzer, bad_name):
|
||||
"""Company names with disallowed special characters should be rejected."""
|
||||
response = client.get(f"/analyze/{bad_name}")
|
||||
assert response.status_code == 422
|
||||
|
||||
# --- Valid names ---
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"valid_name",
|
||||
[
|
||||
"nvidia",
|
||||
"Intel",
|
||||
"TSMC",
|
||||
"Texas Instruments",
|
||||
"Johnson-Johnson",
|
||||
"AT&T",
|
||||
"St. Jude Medical",
|
||||
"3M",
|
||||
"21st Century Fox",
|
||||
"ab", # minimum length
|
||||
"A" * 100, # maximum length
|
||||
],
|
||||
)
|
||||
def test_valid_names_accepted(self, client, mock_analyzer, valid_name):
|
||||
"""Valid company names should be accepted (200, not 422)."""
|
||||
response = client.get(f"/analyze/{valid_name}")
|
||||
# Should not be a validation error; 200 or other non-422 status is fine
|
||||
assert response.status_code != 422
|
||||
|
||||
# --- Batch endpoint validation ---
|
||||
|
||||
def test_batch_too_short_rejected(self, client, mock_analyzer):
|
||||
"""Batch endpoint should reject company names that are too short."""
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["X"]},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_batch_too_long_rejected(self, client, mock_analyzer):
|
||||
"""Batch endpoint should reject company names that are too long."""
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["A" * 101]},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_batch_special_chars_rejected(self, client, mock_analyzer):
|
||||
"""Batch endpoint should reject company names with special chars."""
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["nvidia!", "intel"]},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_batch_valid_names_accepted(self, client, mock_analyzer):
|
||||
"""Batch endpoint should accept valid company names."""
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["nvidia", "Intel", "AT&T"]},
|
||||
)
|
||||
assert response.status_code != 422
|
||||
|
||||
# --- Name must start with alphanumeric ---
|
||||
|
||||
def test_leading_space_rejected(self, client, mock_analyzer):
|
||||
"""Company name starting with a space should be rejected."""
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": [" nvidia"]},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_leading_hyphen_rejected(self, client, mock_analyzer):
|
||||
"""Company name starting with a hyphen should be rejected."""
|
||||
response = client.post(
|
||||
"/analyze/batch",
|
||||
json={"companies": ["-nvidia"]},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
Reference in New Issue
Block a user