diff --git a/.env.example b/.env.example index 1d776d0..acf4901 100644 --- a/.env.example +++ b/.env.example @@ -6,11 +6,16 @@ API_KEY=your_serpapi_key_here # OpenRouter API key for LLM analysis OPENROUTER_API_KEY=your_openrouter_key_here -# Database configuration (for docker-compose setup) +# Database configuration +# All messages are stored in the database for persistence and caching DATABASE_URL=postgresql://postgres:postgres@localhost:5432/sparc -# Toggle between database mode and API mode -# When USE_DATABASE=true: stores all messages in database instead of sending to OpenRouter -# When USE_DATABASE=false: sends messages to OpenRouter API as normal -# Default: false -USE_DATABASE=false +# Cache configuration +# When USE_CACHE=true: check database for cached responses before making API calls +# When USE_CACHE=false: always make fresh API calls (still stores results in database) +# Default: true +USE_CACHE=true + +# JWT Secret for authentication +# IMPORTANT: Change this to a secure random string in production +JWT_SECRET=your-secure-jwt-secret-change-in-production diff --git a/SPARC/api.py b/SPARC/api.py index 2a75fee..3b4130b 100644 --- a/SPARC/api.py +++ b/SPARC/api.py @@ -5,12 +5,23 @@ Provides REST API endpoints for analyzing company patent portfolios. from contextlib import asynccontextmanager from datetime import datetime -from typing import Annotated +from typing import Annotated, List -from fastapi import BackgroundTasks, FastAPI, HTTPException, Query -from pydantic import BaseModel, Field +from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, EmailStr, Field +from SPARC import config from SPARC.analyzer import CompanyAnalyzer +from SPARC.auth import ( + TokenResponse, + UserResponse, + create_tokens, + decode_token, + get_current_admin, + get_current_user, + get_db_client, +) from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult @@ -67,6 +78,42 @@ class HealthResponse(BaseModel): timestamp: datetime +# Auth request/response models +class RegisterRequest(BaseModel): + """User registration request.""" + + email: EmailStr + password: str = Field(..., min_length=8, description="Password (min 8 characters)") + + +class LoginRequest(BaseModel): + """User login request.""" + + email: EmailStr + password: str + + +class RefreshRequest(BaseModel): + """Token refresh request.""" + + refresh_token: str + + +class UpdateRoleRequest(BaseModel): + """Update user role request.""" + + role: str = Field(..., pattern="^(admin|user)$") + + +class AnalyticsResponse(BaseModel): + """Analytics response model.""" + + total_messages: int + by_company: List[dict] + by_type: List[dict] + period_days: int + + # In-memory job storage (for demo; production would use Redis/DB) _jobs: dict[str, JobStatus] = {} _job_counter = 0 @@ -116,6 +163,196 @@ app = FastAPI( lifespan=lifespan, ) +# Add CORS middleware for React frontend +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:3000", "http://localhost:5173"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# ============== Auth Endpoints ============== + + +@app.post("/auth/register", response_model=UserResponse, tags=["Auth"]) +async def register(request: RegisterRequest): + """Register a new user. + + The first registered user automatically becomes an admin. + """ + db = get_db_client() + + # First user becomes admin + user_count = db.get_user_count() + role = "admin" if user_count == 0 else "user" + + user = db.create_user( + email=request.email, + password=request.password, + role=role, + ) + + if not user: + raise HTTPException( + status_code=400, + detail="Email already registered", + ) + + return UserResponse( + id=user["id"], + email=user["email"], + role=user["role"], + created_at=user["created_at"], + ) + + +@app.post("/auth/login", response_model=TokenResponse, tags=["Auth"]) +async def login(request: LoginRequest): + """Authenticate user and return JWT tokens.""" + db = get_db_client() + + user = db.authenticate_user(request.email, request.password) + + if not user: + raise HTTPException( + status_code=401, + detail="Invalid email or password", + ) + + return create_tokens(user["id"], user["email"], user["role"]) + + +@app.post("/auth/refresh", response_model=TokenResponse, tags=["Auth"]) +async def refresh_token(request: RefreshRequest): + """Refresh access token using refresh token.""" + payload = decode_token(request.refresh_token) + + if not payload or payload.type != "refresh": + raise HTTPException( + status_code=401, + detail="Invalid refresh token", + ) + + db = get_db_client() + user = db.get_user_by_id(payload.sub) + + if not user: + raise HTTPException( + status_code=401, + detail="User not found", + ) + + return create_tokens(user["id"], user["email"], user["role"]) + + +@app.get("/auth/me", response_model=UserResponse, tags=["Auth"]) +async def get_me(current_user: UserResponse = Depends(get_current_user)): + """Get current authenticated user.""" + return current_user + + +# ============== Admin Endpoints ============== + + +@app.get("/admin/users", response_model=List[UserResponse], tags=["Admin"]) +async def list_users( + limit: int = Query(default=100, ge=1, le=1000), + offset: int = Query(default=0, ge=0), + _: UserResponse = Depends(get_current_admin), +): + """List all users (admin only).""" + db = get_db_client() + users = db.get_all_users(limit=limit, offset=offset) + + return [ + UserResponse( + id=u["id"], + email=u["email"], + role=u["role"], + created_at=u["created_at"], + ) + for u in users + ] + + +@app.patch("/admin/users/{user_id}/role", response_model=UserResponse, tags=["Admin"]) +async def update_user_role( + user_id: int, + request: UpdateRoleRequest, + current_admin: UserResponse = Depends(get_current_admin), +): + """Update a user's role (admin only).""" + if user_id == current_admin.id: + raise HTTPException( + status_code=400, + detail="Cannot change your own role", + ) + + db = get_db_client() + user = db.update_user_role(user_id, request.role) + + if not user: + raise HTTPException( + status_code=404, + detail="User not found", + ) + + return UserResponse( + id=user["id"], + email=user["email"], + role=user["role"], + created_at=user["created_at"], + ) + + +@app.delete("/admin/users/{user_id}", tags=["Admin"]) +async def delete_user( + user_id: int, + current_admin: UserResponse = Depends(get_current_admin), +): + """Delete a user (admin only).""" + if user_id == current_admin.id: + raise HTTPException( + status_code=400, + detail="Cannot delete yourself", + ) + + db = get_db_client() + deleted = db.delete_user(user_id) + + if not deleted: + raise HTTPException( + status_code=404, + detail="User not found", + ) + + return {"message": "User deleted"} + + +# ============== Analytics Endpoint ============== + + +@app.get("/analytics", response_model=AnalyticsResponse, tags=["Analytics"]) +async def get_analytics( + days: int = Query(default=30, ge=1, le=365), + _: UserResponse = Depends(get_current_user), +): + """Get analytics data (authenticated users only).""" + db = get_db_client() + analytics = db.get_analytics(days=days) + + return AnalyticsResponse( + total_messages=analytics["total_messages"], + by_company=analytics["by_company"], + by_type=analytics["by_type"], + period_days=analytics["period_days"], + ) + + +# ============== System Endpoints ============== + @app.get("/health", response_model=HealthResponse, tags=["System"]) async def health_check(): @@ -132,7 +369,10 @@ async def health_check(): response_model=CompanyAnalysisResponse, tags=["Analysis"], ) -async def analyze_company(company_name: str): +async def analyze_company( + company_name: str, + _: UserResponse = Depends(get_current_user), +): """Analyze a single company's patent portfolio. This endpoint retrieves recent patents for the specified company, @@ -156,7 +396,10 @@ async def analyze_company(company_name: str): response_model=BatchAnalysisResponse, tags=["Analysis"], ) -async def analyze_companies_batch(request: BatchAnalysisRequest): +async def analyze_companies_batch( + request: BatchAnalysisRequest, + _: UserResponse = Depends(get_current_user), +): """Analyze multiple companies' patent portfolios. Processes companies concurrently for improved performance. @@ -209,7 +452,9 @@ def _run_batch_job(job_id: str, companies: list[str], max_workers: int): @app.post("/analyze/batch/async", response_model=JobStatus, tags=["Analysis"]) async def analyze_companies_async( - request: BatchAnalysisRequest, background_tasks: BackgroundTasks + request: BatchAnalysisRequest, + background_tasks: BackgroundTasks, + _: UserResponse = Depends(get_current_user), ): """Start an asynchronous batch analysis job. @@ -243,7 +488,10 @@ async def analyze_companies_async( @app.get("/jobs/{job_id}", response_model=JobStatus, tags=["Jobs"]) -async def get_job_status(job_id: str): +async def get_job_status( + job_id: str, + _: UserResponse = Depends(get_current_user), +): """Get the status of a background analysis job. Args: @@ -265,6 +513,7 @@ async def list_jobs( Query(description="Filter by status: pending, running, completed, failed"), ] = None, limit: Annotated[int, Query(ge=1, le=100)] = 10, + _: UserResponse = Depends(get_current_user), ): """List all analysis jobs.