deploy: security hardening, multi-model support, S3 storage, analytics, CI improvements (70 commits) #4
@@ -479,6 +479,20 @@ SUPPORTED_MODELS = [
|
|||||||
{"id": "meta-llama/llama-3.1-70b-instruct", "name": "Llama 3.1 70B", "provider": "Meta"},
|
{"id": "meta-llama/llama-3.1-70b-instruct", "name": "Llama 3.1 70B", "provider": "Meta"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_SUPPORTED_MODEL_IDS = {m["id"] for m in SUPPORTED_MODELS}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_model(model: str | None) -> None:
|
||||||
|
"""Raise HTTP 400 if *model* is not in the supported allow-list."""
|
||||||
|
if model is not None and model not in _SUPPORTED_MODEL_IDS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=(
|
||||||
|
f"Unsupported model '{model}'. "
|
||||||
|
f"Supported models: {', '.join(sorted(_SUPPORTED_MODEL_IDS))}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/models", tags=["System"])
|
@app.get("/models", tags=["System"])
|
||||||
async def list_models():
|
async def list_models():
|
||||||
@@ -814,6 +828,7 @@ async def analyze_company(
|
|||||||
Returns:
|
Returns:
|
||||||
Analysis results including patent count, AI insights, and success status
|
Analysis results including patent count, AI insights, and success status
|
||||||
"""
|
"""
|
||||||
|
_validate_model(model)
|
||||||
if not _analyzer:
|
if not _analyzer:
|
||||||
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
||||||
|
|
||||||
@@ -873,6 +888,7 @@ async def analyze_companies_batch(
|
|||||||
Returns:
|
Returns:
|
||||||
Batch results with individual company analyses and summary statistics
|
Batch results with individual company analyses and summary statistics
|
||||||
"""
|
"""
|
||||||
|
_validate_model(request.model)
|
||||||
if not _analyzer:
|
if not _analyzer:
|
||||||
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
raise HTTPException(status_code=503, detail="Analyzer not initialized")
|
||||||
|
|
||||||
@@ -983,6 +999,7 @@ async def analyze_companies_async(
|
|||||||
Returns:
|
Returns:
|
||||||
Job status with job_id for polling
|
Job status with job_id for polling
|
||||||
"""
|
"""
|
||||||
|
_validate_model(request.model)
|
||||||
global _job_counter
|
global _job_counter
|
||||||
|
|
||||||
_job_counter += 1
|
_job_counter += 1
|
||||||
|
|||||||
@@ -182,3 +182,47 @@ class TestJobEndpoints:
|
|||||||
"""Test listing jobs with status filter."""
|
"""Test listing jobs with status filter."""
|
||||||
response = client.get("/jobs?status=completed")
|
response = client.get("/jobs?status=completed")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelValidation:
|
||||||
|
"""Test that unsupported model identifiers are rejected."""
|
||||||
|
|
||||||
|
def test_analyze_rejects_unsupported_model(self, client, mock_analyzer):
|
||||||
|
"""GET /analyze/{company} with unsupported model returns 400."""
|
||||||
|
response = client.get("/analyze/nvidia?model=fake/nonexistent-model")
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Unsupported model" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_analyze_accepts_supported_model(self, client, mock_analyzer):
|
||||||
|
"""GET /analyze/{company} with a supported model succeeds."""
|
||||||
|
mock_result = CompanyAnalysisResult(
|
||||||
|
company_name="nvidia",
|
||||||
|
analysis="test",
|
||||||
|
patent_count=1,
|
||||||
|
success=True,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
model="anthropic/claude-3.5-sonnet",
|
||||||
|
)
|
||||||
|
mock_analyzer._analyze_company_safe.return_value = mock_result
|
||||||
|
|
||||||
|
response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_batch_rejects_unsupported_model(self, client, mock_analyzer):
|
||||||
|
"""POST /analyze/batch with unsupported model returns 400."""
|
||||||
|
response = client.post(
|
||||||
|
"/analyze/batch",
|
||||||
|
json={"companies": ["nvidia"], "model": "fake/nonexistent-model"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Unsupported model" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_list_models_returns_supported(self, client):
|
||||||
|
"""GET /models returns the allow-list."""
|
||||||
|
response = client.get("/models")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "models" in data
|
||||||
|
assert "default" in data
|
||||||
|
assert len(data["models"]) > 0
|
||||||
|
assert all("id" in m and "name" in m and "provider" in m for m in data["models"])
|
||||||
|
|||||||
Reference in New Issue
Block a user