forked from 0xWheatyz/SPARC
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5c25a0f589 |
+20
-8
@@ -36,16 +36,28 @@ from SPARC.auth import (
|
|||||||
)
|
)
|
||||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||||
|
|
||||||
# Validated company name type: 2-100 chars, alphanumeric + spaces/hyphens/ampersands/periods only.
|
# Validated company name type: 2-128 chars, alphanumeric + spaces/hyphens/ampersands/periods only.
|
||||||
CompanyName = Annotated[
|
CompanyName = Annotated[
|
||||||
str,
|
str,
|
||||||
StringConstraints(
|
StringConstraints(
|
||||||
min_length=2,
|
min_length=2,
|
||||||
max_length=100,
|
max_length=128,
|
||||||
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
|
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Reusable Query constraint for optional company_name filter parameters.
|
||||||
|
_COMPANY_NAME_FILTER_QUERY = Query(
|
||||||
|
default=None,
|
||||||
|
min_length=2,
|
||||||
|
max_length=128,
|
||||||
|
pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$",
|
||||||
|
description=(
|
||||||
|
"Company name filter (2-128 chars; alphanumeric, spaces, hyphens, "
|
||||||
|
"periods, and ampersands only)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Pydantic models for API
|
# Pydantic models for API
|
||||||
class CompanyAnalysisResponse(BaseModel):
|
class CompanyAnalysisResponse(BaseModel):
|
||||||
@@ -489,7 +501,7 @@ async def add_tracked_company(
|
|||||||
|
|
||||||
@app.delete("/admin/tracked/{company_name}", tags=["Admin"])
|
@app.delete("/admin/tracked/{company_name}", tags=["Admin"])
|
||||||
async def remove_tracked_company(
|
async def remove_tracked_company(
|
||||||
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||||
_: UserResponse = Depends(get_current_admin),
|
_: UserResponse = Depends(get_current_admin),
|
||||||
):
|
):
|
||||||
"""Remove a company from the tracked list (admin only)."""
|
"""Remove a company from the tracked list (admin only)."""
|
||||||
@@ -677,7 +689,7 @@ async def get_analytics_trends(
|
|||||||
|
|
||||||
@app.get("/export/{company_name}", tags=["Export"])
|
@app.get("/export/{company_name}", tags=["Export"])
|
||||||
async def export_company_csv(
|
async def export_company_csv(
|
||||||
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||||
_: UserResponse = Depends(get_current_user),
|
_: UserResponse = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Export analysis results for a company as a CSV file.
|
"""Export analysis results for a company as a CSV file.
|
||||||
@@ -729,7 +741,7 @@ async def export_company_csv(
|
|||||||
|
|
||||||
@app.get("/export/{company_name}/pdf", tags=["Export"])
|
@app.get("/export/{company_name}/pdf", tags=["Export"])
|
||||||
async def export_company_pdf(
|
async def export_company_pdf(
|
||||||
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||||
_: UserResponse = Depends(get_current_user),
|
_: UserResponse = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Export analysis results for a company as a formatted PDF report.
|
"""Export analysis results for a company as a formatted PDF report.
|
||||||
@@ -903,7 +915,7 @@ async def health_check():
|
|||||||
tags=["Analysis"],
|
tags=["Analysis"],
|
||||||
)
|
)
|
||||||
async def analyze_company(
|
async def analyze_company(
|
||||||
company_name: Annotated[str, Path(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
company_name: Annotated[str, Path(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$")],
|
||||||
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
|
model: str | None = Query(default=None, description="LLM model to use (e.g. 'openai/gpt-4o'). Defaults to server config."),
|
||||||
_: UserResponse = Depends(get_current_user),
|
_: UserResponse = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
@@ -933,7 +945,7 @@ async def analyze_company(
|
|||||||
)
|
)
|
||||||
async def analyze_single_patent(
|
async def analyze_single_patent(
|
||||||
patent_id: str,
|
patent_id: str,
|
||||||
company_name: Annotated[str, Query(min_length=2, max_length=100, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
|
company_name: Annotated[str, Query(min_length=2, max_length=128, pattern=r"^[a-zA-Z0-9][a-zA-Z0-9 \-&.]*$", description="Company name for analysis context")],
|
||||||
_: UserResponse = Depends(get_current_user),
|
_: UserResponse = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Analyze a single patent by its publication ID.
|
"""Analyze a single patent by its publication ID.
|
||||||
@@ -967,7 +979,7 @@ async def analyze_single_patent(
|
|||||||
async def list_analysis_results(
|
async def list_analysis_results(
|
||||||
company_name: Annotated[
|
company_name: Annotated[
|
||||||
str | None,
|
str | None,
|
||||||
Query(description="Filter results by company name"),
|
_COMPANY_NAME_FILTER_QUERY,
|
||||||
] = None,
|
] = None,
|
||||||
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
||||||
cursor: Annotated[
|
cursor: Annotated[
|
||||||
|
|||||||
@@ -43,12 +43,18 @@ class TestCompanyNameValidation:
|
|||||||
|
|
||||||
# --- Too long ---
|
# --- Too long ---
|
||||||
|
|
||||||
def test_over_100_chars_rejected(self, client, mock_analyzer):
|
def test_over_128_chars_rejected(self, client, mock_analyzer):
|
||||||
"""A company name longer than 100 characters should be rejected."""
|
"""A company name longer than 128 characters should be rejected."""
|
||||||
long_name = "A" * 101
|
long_name = "A" * 129
|
||||||
response = client.get(f"/analyze/{long_name}")
|
response = client.get(f"/analyze/{long_name}")
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_exactly_128_chars_accepted(self, client, mock_analyzer):
|
||||||
|
"""A company name of exactly 128 characters should be accepted."""
|
||||||
|
max_name = "A" * 128
|
||||||
|
response = client.get(f"/analyze/{max_name}")
|
||||||
|
assert response.status_code != 422
|
||||||
|
|
||||||
# --- Special characters ---
|
# --- Special characters ---
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -95,7 +101,7 @@ class TestCompanyNameValidation:
|
|||||||
"3M",
|
"3M",
|
||||||
"21st Century Fox",
|
"21st Century Fox",
|
||||||
"ab", # minimum length
|
"ab", # minimum length
|
||||||
"A" * 100, # maximum length
|
"A" * 128, # maximum length
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_valid_names_accepted(self, client, mock_analyzer, valid_name):
|
def test_valid_names_accepted(self, client, mock_analyzer, valid_name):
|
||||||
@@ -118,7 +124,7 @@ class TestCompanyNameValidation:
|
|||||||
"""Batch endpoint should reject company names that are too long."""
|
"""Batch endpoint should reject company names that are too long."""
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/analyze/batch",
|
"/analyze/batch",
|
||||||
json={"companies": ["A" * 101]},
|
json={"companies": ["A" * 129]},
|
||||||
)
|
)
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
@@ -155,3 +161,30 @@ class TestCompanyNameValidation:
|
|||||||
json={"companies": ["-nvidia"]},
|
json={"companies": ["-nvidia"]},
|
||||||
)
|
)
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
# --- GET /analyze/batch company_name filter validation ---
|
||||||
|
|
||||||
|
def test_batch_filter_special_chars_rejected(self, client, mock_analyzer):
|
||||||
|
"""GET /analyze/batch company_name filter rejects disallowed chars."""
|
||||||
|
response = client.get("/analyze/batch", params={"company_name": "nvidia!"})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_batch_filter_too_short_rejected(self, client, mock_analyzer):
|
||||||
|
"""GET /analyze/batch company_name filter rejects names under 2 chars."""
|
||||||
|
response = client.get("/analyze/batch", params={"company_name": "X"})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_batch_filter_too_long_rejected(self, client, mock_analyzer):
|
||||||
|
"""GET /analyze/batch company_name filter rejects names over 128 chars."""
|
||||||
|
response = client.get("/analyze/batch", params={"company_name": "A" * 129})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_batch_filter_valid_name_accepted(self, client, mock_analyzer):
|
||||||
|
"""GET /analyze/batch company_name filter accepts a valid name."""
|
||||||
|
response = client.get("/analyze/batch", params={"company_name": "nvidia"})
|
||||||
|
assert response.status_code != 422
|
||||||
|
|
||||||
|
def test_batch_filter_omitted_accepted(self, client, mock_analyzer):
|
||||||
|
"""GET /analyze/batch without company_name filter should work fine."""
|
||||||
|
response = client.get("/analyze/batch")
|
||||||
|
assert response.status_code != 422
|
||||||
|
|||||||
Reference in New Issue
Block a user