forked from 0xWheatyz/SPARC
5c25a0f589
Closes leeworks-agents/SPARC#1685 - Increase CompanyName max_length from 100 to 128 everywhere (Pydantic type, Path constraints, and the inline Query on analyze/patent). - Add _COMPANY_NAME_FILTER_QUERY reusable Query annotation and apply it to the optional company_name filter on GET /analyze/batch so it is validated with the same rules as all other endpoints. - Update tests: rename test_over_100_chars_rejected → 128, add test_exactly_128_chars_accepted at the new boundary, fix batch too-long test to use 129 chars, update valid-name parametrize to use "A"*128, and add five new tests covering GET /analyze/batch filter validation (special chars, too-short, too-long, valid, omitted). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
191 lines
6.5 KiB
Python
191 lines
6.5 KiB
Python
"""Tests for company name input validation on analysis endpoints."""
|
|
|
|
from datetime import datetime
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from SPARC.api import app
|
|
from SPARC.types import CompanyAnalysisResult
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Create test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_analyzer(mocker):
|
|
"""Mock the global analyzer so valid requests succeed."""
|
|
mock = Mock()
|
|
mock._analyze_company_safe.return_value = CompanyAnalysisResult(
|
|
company_name="nvidia",
|
|
analysis="Test analysis",
|
|
patent_count=1,
|
|
success=True,
|
|
timestamp=datetime.now(),
|
|
)
|
|
mocker.patch("SPARC.api._analyzer", mock)
|
|
return mock
|
|
|
|
|
|
class TestCompanyNameValidation:
|
|
"""Test that company names are validated on analysis endpoints."""
|
|
|
|
# --- Too short ---
|
|
|
|
def test_single_char_rejected(self, client, mock_analyzer):
|
|
"""A one-character company name should be rejected."""
|
|
response = client.get("/analyze/X")
|
|
assert response.status_code == 422
|
|
|
|
# --- Too long ---
|
|
|
|
def test_over_128_chars_rejected(self, client, mock_analyzer):
|
|
"""A company name longer than 128 characters should be rejected."""
|
|
long_name = "A" * 129
|
|
response = client.get(f"/analyze/{long_name}")
|
|
assert response.status_code == 422
|
|
|
|
def test_exactly_128_chars_accepted(self, client, mock_analyzer):
|
|
"""A company name of exactly 128 characters should be accepted."""
|
|
max_name = "A" * 128
|
|
response = client.get(f"/analyze/{max_name}")
|
|
assert response.status_code != 422
|
|
|
|
# --- Special characters ---
|
|
|
|
@pytest.mark.parametrize(
|
|
"bad_name",
|
|
[
|
|
"nvidia!",
|
|
"intel@corp",
|
|
"test#company",
|
|
"foo$bar",
|
|
"a%b",
|
|
"x^y",
|
|
"semi;colon",
|
|
"drop'table",
|
|
'say"hello',
|
|
"path/traversal",
|
|
"back\\slash",
|
|
"pipe|char",
|
|
"star*glob",
|
|
"question?mark",
|
|
"<script>",
|
|
"curly{brace}",
|
|
"equal=sign",
|
|
"plus+plus",
|
|
"comma,separated",
|
|
],
|
|
)
|
|
def test_special_chars_rejected(self, client, mock_analyzer, bad_name):
|
|
"""Company names with disallowed special characters should be rejected."""
|
|
response = client.get(f"/analyze/{bad_name}")
|
|
assert response.status_code == 422
|
|
|
|
# --- Valid names ---
|
|
|
|
@pytest.mark.parametrize(
|
|
"valid_name",
|
|
[
|
|
"nvidia",
|
|
"Intel",
|
|
"TSMC",
|
|
"Texas Instruments",
|
|
"Johnson-Johnson",
|
|
"AT&T",
|
|
"St. Jude Medical",
|
|
"3M",
|
|
"21st Century Fox",
|
|
"ab", # minimum length
|
|
"A" * 128, # maximum length
|
|
],
|
|
)
|
|
def test_valid_names_accepted(self, client, mock_analyzer, valid_name):
|
|
"""Valid company names should be accepted (200, not 422)."""
|
|
response = client.get(f"/analyze/{valid_name}")
|
|
# Should not be a validation error; 200 or other non-422 status is fine
|
|
assert response.status_code != 422
|
|
|
|
# --- Batch endpoint validation ---
|
|
|
|
def test_batch_too_short_rejected(self, client, mock_analyzer):
|
|
"""Batch endpoint should reject company names that are too short."""
|
|
response = client.post(
|
|
"/analyze/batch",
|
|
json={"companies": ["X"]},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
def test_batch_too_long_rejected(self, client, mock_analyzer):
|
|
"""Batch endpoint should reject company names that are too long."""
|
|
response = client.post(
|
|
"/analyze/batch",
|
|
json={"companies": ["A" * 129]},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
def test_batch_special_chars_rejected(self, client, mock_analyzer):
|
|
"""Batch endpoint should reject company names with special chars."""
|
|
response = client.post(
|
|
"/analyze/batch",
|
|
json={"companies": ["nvidia!", "intel"]},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
def test_batch_valid_names_accepted(self, client, mock_analyzer):
|
|
"""Batch endpoint should accept valid company names."""
|
|
response = client.post(
|
|
"/analyze/batch",
|
|
json={"companies": ["nvidia", "Intel", "AT&T"]},
|
|
)
|
|
assert response.status_code != 422
|
|
|
|
# --- Name must start with alphanumeric ---
|
|
|
|
def test_leading_space_rejected(self, client, mock_analyzer):
|
|
"""Company name starting with a space should be rejected."""
|
|
response = client.post(
|
|
"/analyze/batch",
|
|
json={"companies": [" nvidia"]},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
def test_leading_hyphen_rejected(self, client, mock_analyzer):
|
|
"""Company name starting with a hyphen should be rejected."""
|
|
response = client.post(
|
|
"/analyze/batch",
|
|
json={"companies": ["-nvidia"]},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
# --- GET /analyze/batch company_name filter validation ---
|
|
|
|
def test_batch_filter_special_chars_rejected(self, client, mock_analyzer):
|
|
"""GET /analyze/batch company_name filter rejects disallowed chars."""
|
|
response = client.get("/analyze/batch", params={"company_name": "nvidia!"})
|
|
assert response.status_code == 422
|
|
|
|
def test_batch_filter_too_short_rejected(self, client, mock_analyzer):
|
|
"""GET /analyze/batch company_name filter rejects names under 2 chars."""
|
|
response = client.get("/analyze/batch", params={"company_name": "X"})
|
|
assert response.status_code == 422
|
|
|
|
def test_batch_filter_too_long_rejected(self, client, mock_analyzer):
|
|
"""GET /analyze/batch company_name filter rejects names over 128 chars."""
|
|
response = client.get("/analyze/batch", params={"company_name": "A" * 129})
|
|
assert response.status_code == 422
|
|
|
|
def test_batch_filter_valid_name_accepted(self, client, mock_analyzer):
|
|
"""GET /analyze/batch company_name filter accepts a valid name."""
|
|
response = client.get("/analyze/batch", params={"company_name": "nvidia"})
|
|
assert response.status_code != 422
|
|
|
|
def test_batch_filter_omitted_accepted(self, client, mock_analyzer):
|
|
"""GET /analyze/batch without company_name filter should work fine."""
|
|
response = client.get("/analyze/batch")
|
|
assert response.status_code != 422
|