diff --git a/SPARC/api.py b/SPARC/api.py index 6fbbcdb..3a28033 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -479,6 +479,20 @@ SUPPORTED_MODELS = [ {"id": "meta-llama/llama-3.1-70b-instruct", "name": "Llama 3.1 70B", "provider": "Meta"}, ] +_SUPPORTED_MODEL_IDS = {m["id"] for m in SUPPORTED_MODELS} + + +def _validate_model(model: str | None) -> None: + """Raise HTTP 400 if *model* is not in the supported allow-list.""" + if model is not None and model not in _SUPPORTED_MODEL_IDS: + raise HTTPException( + status_code=400, + detail=( + f"Unsupported model '{model}'. " + f"Supported models: {', '.join(sorted(_SUPPORTED_MODEL_IDS))}" + ), + ) + @app.get("/models", tags=["System"]) async def list_models(): @@ -814,6 +828,7 @@ async def analyze_company( Returns: Analysis results including patent count, AI insights, and success status """ + _validate_model(model) if not _analyzer: raise HTTPException(status_code=503, detail="Analyzer not initialized") @@ -873,6 +888,7 @@ async def analyze_companies_batch( Returns: Batch results with individual company analyses and summary statistics """ + _validate_model(request.model) if not _analyzer: raise HTTPException(status_code=503, detail="Analyzer not initialized") @@ -983,6 +999,7 @@ async def analyze_companies_async( Returns: Job status with job_id for polling """ + _validate_model(request.model) global _job_counter _job_counter += 1 diff --git a/tests/test_api.py b/tests/test_api.py index 169be27..fd16921 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -182,3 +182,47 @@ class TestJobEndpoints: """Test listing jobs with status filter.""" response = client.get("/jobs?status=completed") assert response.status_code == 200 + + +class TestModelValidation: + """Test that unsupported model identifiers are rejected.""" + + def test_analyze_rejects_unsupported_model(self, client, mock_analyzer): + """GET /analyze/{company} with unsupported model returns 400.""" + response = client.get("/analyze/nvidia?model=fake/nonexistent-model") + assert response.status_code == 400 + assert "Unsupported model" in response.json()["detail"] + + def test_analyze_accepts_supported_model(self, client, mock_analyzer): + """GET /analyze/{company} with a supported model succeeds.""" + mock_result = CompanyAnalysisResult( + company_name="nvidia", + analysis="test", + patent_count=1, + success=True, + timestamp=datetime.now(), + model="anthropic/claude-3.5-sonnet", + ) + mock_analyzer._analyze_company_safe.return_value = mock_result + + response = client.get("/analyze/nvidia?model=anthropic/claude-3.5-sonnet") + assert response.status_code == 200 + + def test_batch_rejects_unsupported_model(self, client, mock_analyzer): + """POST /analyze/batch with unsupported model returns 400.""" + response = client.post( + "/analyze/batch", + json={"companies": ["nvidia"], "model": "fake/nonexistent-model"}, + ) + assert response.status_code == 400 + assert "Unsupported model" in response.json()["detail"] + + def test_list_models_returns_supported(self, client): + """GET /models returns the allow-list.""" + response = client.get("/models") + assert response.status_code == 200 + data = response.json() + assert "models" in data + assert "default" in data + assert len(data["models"]) > 0 + assert all("id" in m and "name" in m and "provider" in m for m in data["models"])