forked from 0xWheatyz/SPARC
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3b6411869d | |||
| 55c131cb32 | |||
| fbb72fe2a5 | |||
| e484baaf5f | |||
| 069f1c343c | |||
| d366443b38 |
@@ -9,7 +9,43 @@ on:
|
|||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Install system dependencies
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
|
||||||
|
git checkout ${{ gitea.sha }}
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
pip3 install --break-system-packages -r requirements.txt ruff
|
||||||
|
|
||||||
|
- name: Run ruff linter
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
ruff check SPARC/ tests/
|
||||||
|
|
||||||
|
- name: Run pytest
|
||||||
|
shell: sh
|
||||||
|
env:
|
||||||
|
DATABASE_URL: "sqlite://"
|
||||||
|
API_KEY: "test-key"
|
||||||
|
OPENROUTER_API_KEY: "test-key"
|
||||||
|
JWT_SECRET: "test-secret-for-ci"
|
||||||
|
APP_ENV: "development"
|
||||||
|
run: |
|
||||||
|
python3 -m pytest tests/ -v --tb=short -x
|
||||||
|
|
||||||
build-api:
|
build-api:
|
||||||
|
needs: test
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
@@ -81,6 +117,7 @@ jobs:
|
|||||||
echo "API image available at ${{ steps.tags.outputs.IMAGE_TAG }}"
|
echo "API image available at ${{ steps.tags.outputs.IMAGE_TAG }}"
|
||||||
|
|
||||||
build-frontend:
|
build-frontend:
|
||||||
|
needs: test
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
name: Test and Lint
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Install system dependencies
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
apk add --no-cache git python3 py3-pip gcc musl-dev libpq-dev python3-dev
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
git clone http://gitea.gitea.svc.cluster.local/${{ gitea.repository }}.git .
|
||||||
|
git checkout ${{ gitea.sha }}
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
pip3 install --break-system-packages -r requirements.txt ruff
|
||||||
|
|
||||||
|
- name: Run ruff linter
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
ruff check SPARC/ tests/
|
||||||
|
|
||||||
|
- name: Run pytest
|
||||||
|
shell: sh
|
||||||
|
env:
|
||||||
|
DATABASE_URL: "sqlite://"
|
||||||
|
API_KEY: "test-key"
|
||||||
|
OPENROUTER_API_KEY: "test-key"
|
||||||
|
JWT_SECRET: "test-secret-for-ci"
|
||||||
|
APP_ENV: "development"
|
||||||
|
run: |
|
||||||
|
python3 -m pytest tests/ -v --tb=short -x
|
||||||
+3
-2
@@ -1,3 +1,4 @@
|
|||||||
from .types import Patents, Patent
|
from .types import Patent as Patent
|
||||||
|
from .types import Patents as Patents
|
||||||
|
|
||||||
all = ["Patents", "Patent"]
|
__all__ = ["Patents", "Patent"]
|
||||||
|
|||||||
+2
-2
@@ -13,9 +13,9 @@ from SPARC import config
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
from SPARC.database import DatabaseClient
|
from SPARC.database import DatabaseClient
|
||||||
from SPARC.serp_api import SERP
|
|
||||||
from SPARC.llm import LLMAnalyzer
|
from SPARC.llm import LLMAnalyzer
|
||||||
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult
|
from SPARC.serp_api import SERP
|
||||||
|
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult, Patent, Patents
|
||||||
|
|
||||||
|
|
||||||
class CompanyAnalyzer:
|
class CompanyAnalyzer:
|
||||||
|
|||||||
+44
-6
@@ -21,11 +21,13 @@ from SPARC.auth import (
|
|||||||
TokenResponse,
|
TokenResponse,
|
||||||
UserResponse,
|
UserResponse,
|
||||||
check_jwt_secret,
|
check_jwt_secret,
|
||||||
|
close_db_client,
|
||||||
create_tokens,
|
create_tokens,
|
||||||
decode_token,
|
decode_token,
|
||||||
get_current_admin,
|
get_current_admin,
|
||||||
get_current_user,
|
get_current_user,
|
||||||
get_db_client,
|
get_db_client,
|
||||||
|
init_db_client,
|
||||||
)
|
)
|
||||||
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||||
|
|
||||||
@@ -75,6 +77,13 @@ class JobStatus(BaseModel):
|
|||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedJobsResponse(BaseModel):
|
||||||
|
"""Paginated response for job listings."""
|
||||||
|
|
||||||
|
items: list["JobStatus"]
|
||||||
|
next_cursor: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class HealthResponse(BaseModel):
|
class HealthResponse(BaseModel):
|
||||||
"""Health check response."""
|
"""Health check response."""
|
||||||
|
|
||||||
@@ -155,6 +164,7 @@ async def lifespan(app: FastAPI):
|
|||||||
"""Initialize resources on startup, clean up on shutdown."""
|
"""Initialize resources on startup, clean up on shutdown."""
|
||||||
global _analyzer
|
global _analyzer
|
||||||
check_jwt_secret()
|
check_jwt_secret()
|
||||||
|
init_db_client()
|
||||||
_analyzer = CompanyAnalyzer()
|
_analyzer = CompanyAnalyzer()
|
||||||
# Mark any jobs that were running/pending before the restart as failed
|
# Mark any jobs that were running/pending before the restart as failed
|
||||||
from SPARC.database import DatabaseClient
|
from SPARC.database import DatabaseClient
|
||||||
@@ -167,8 +177,9 @@ async def lifespan(app: FastAPI):
|
|||||||
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
|
logging.getLogger(__name__).warning("Marked %d stale jobs as failed on startup", stale)
|
||||||
_db.close()
|
_db.close()
|
||||||
yield
|
yield
|
||||||
# Cleanup if needed
|
# Cleanup
|
||||||
_analyzer = None
|
_analyzer = None
|
||||||
|
close_db_client()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
@@ -573,24 +584,51 @@ async def get_job_status(
|
|||||||
return _job_row_to_status(job_row)
|
return _job_row_to_status(job_row)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/jobs", response_model=list[JobStatus], tags=["Jobs"])
|
@app.get("/jobs", response_model=PaginatedJobsResponse, tags=["Jobs"])
|
||||||
async def list_jobs(
|
async def list_jobs(
|
||||||
status: Annotated[
|
status: Annotated[
|
||||||
str | None,
|
str | None,
|
||||||
Query(description="Filter by status: pending, running, completed, failed"),
|
Query(description="Filter by status: pending, running, completed, failed"),
|
||||||
] = None,
|
] = None,
|
||||||
limit: Annotated[int, Query(ge=1, le=100)] = 10,
|
limit: Annotated[int, Query(ge=1, le=100)] = 10,
|
||||||
|
cursor: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Opaque cursor from a previous response's next_cursor field"),
|
||||||
|
] = None,
|
||||||
_: UserResponse = Depends(get_current_user),
|
_: UserResponse = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List all analysis jobs.
|
"""List analysis jobs with cursor-based pagination.
|
||||||
|
|
||||||
|
Pass ``limit`` to control page size. The response includes a ``next_cursor``
|
||||||
|
field; pass it back as the ``cursor`` query parameter to fetch the next page.
|
||||||
|
When ``next_cursor`` is ``null``, there are no more results.
|
||||||
|
|
||||||
|
Existing clients that use only ``limit`` (without ``cursor``) continue to
|
||||||
|
work without modification.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
status: Optional filter by job status
|
status: Optional filter by job status
|
||||||
limit: Maximum number of jobs to return (default 10, max 100)
|
limit: Maximum number of jobs to return (default 10, max 100)
|
||||||
|
cursor: Opaque pagination cursor from a previous response
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of job statuses
|
Paginated list of job statuses
|
||||||
"""
|
"""
|
||||||
db = _get_job_db()
|
db = _get_job_db()
|
||||||
job_rows = db.list_jobs(status=status, limit=limit)
|
# Fetch one extra to determine if there is a next page
|
||||||
return [_job_row_to_status(row) for row in job_rows]
|
job_rows = db.list_jobs(status=status, limit=limit + 1, cursor=cursor)
|
||||||
|
|
||||||
|
has_next = len(job_rows) > limit
|
||||||
|
if has_next:
|
||||||
|
job_rows = job_rows[:limit]
|
||||||
|
|
||||||
|
items = [_job_row_to_status(row) for row in job_rows]
|
||||||
|
|
||||||
|
next_cursor = None
|
||||||
|
if has_next and job_rows:
|
||||||
|
last = job_rows[-1]
|
||||||
|
created = last["created_at"]
|
||||||
|
ts = created.isoformat() if hasattr(created, "isoformat") else str(created)
|
||||||
|
next_cursor = f"{ts}|{last['job_id']}"
|
||||||
|
|
||||||
|
return PaginatedJobsResponse(items=items, next_cursor=next_cursor)
|
||||||
|
|||||||
+29
-4
@@ -146,11 +146,36 @@ def decode_token(token: str) -> Optional[TokenPayload]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Shared database client singleton, initialized at startup via init_db_client()
|
||||||
|
_db_client: DatabaseClient | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_db_client() -> None:
|
||||||
|
"""Initialize the shared database client. Call once at app startup."""
|
||||||
|
global _db_client
|
||||||
|
_db_client = DatabaseClient(config.database_url)
|
||||||
|
_db_client.connect()
|
||||||
|
|
||||||
|
|
||||||
|
def close_db_client() -> None:
|
||||||
|
"""Close the shared database client. Call at app shutdown."""
|
||||||
|
global _db_client
|
||||||
|
if _db_client:
|
||||||
|
_db_client.close()
|
||||||
|
_db_client = None
|
||||||
|
|
||||||
|
|
||||||
def get_db_client() -> DatabaseClient:
|
def get_db_client() -> DatabaseClient:
|
||||||
"""Get database client for auth operations."""
|
"""Get the shared pooled database client for auth operations.
|
||||||
client = DatabaseClient(config.database_url)
|
|
||||||
client.connect()
|
Returns the module-level singleton DatabaseClient. If not yet initialized
|
||||||
return client
|
(e.g., during tests), creates a new instance as a fallback.
|
||||||
|
"""
|
||||||
|
global _db_client
|
||||||
|
if _db_client is None:
|
||||||
|
_db_client = DatabaseClient(config.database_url)
|
||||||
|
_db_client.connect()
|
||||||
|
return _db_client
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
|
|||||||
+67
-54
@@ -1,14 +1,15 @@
|
|||||||
"""Database client for storing and retrieving LLM messages and user authentication."""
|
"""Database client for storing and retrieving LLM messages and user authentication."""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import psycopg2
|
|
||||||
from psycopg2.pool import ThreadedConnectionPool
|
|
||||||
from psycopg2.extras import RealDictCursor
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2.extras import RealDictCursor
|
||||||
|
from psycopg2.pool import ThreadedConnectionPool
|
||||||
|
|
||||||
|
|
||||||
class DatabaseClient:
|
class DatabaseClient:
|
||||||
@@ -221,8 +222,6 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Cached message dict if found, None otherwise
|
Cached message dict if found, None otherwise
|
||||||
"""
|
"""
|
||||||
self.connect()
|
|
||||||
|
|
||||||
prompt_hash = self.hash_prompt(prompt)
|
prompt_hash = self.hash_prompt(prompt)
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
@@ -245,7 +244,8 @@ class DatabaseClient:
|
|||||||
|
|
||||||
query += " ORDER BY timestamp DESC LIMIT 1"
|
query += " ORDER BY timestamp DESC LIMIT 1"
|
||||||
|
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return dict(result) if result else None
|
return dict(result) if result else None
|
||||||
@@ -276,11 +276,10 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
The ID of the inserted record
|
The ID of the inserted record
|
||||||
"""
|
"""
|
||||||
self.connect()
|
|
||||||
|
|
||||||
prompt_hash = self.hash_prompt(prompt)
|
prompt_hash = self.hash_prompt(prompt)
|
||||||
|
|
||||||
with self.conn.cursor() as cursor:
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO llm_messages
|
INSERT INTO llm_messages
|
||||||
@@ -302,7 +301,7 @@ class DatabaseClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
message_id = cursor.fetchone()[0]
|
message_id = cursor.fetchone()[0]
|
||||||
self.conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
return message_id
|
return message_id
|
||||||
|
|
||||||
@@ -324,8 +323,6 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
List of message dictionaries
|
List of message dictionaries
|
||||||
"""
|
"""
|
||||||
self.connect()
|
|
||||||
|
|
||||||
query = "SELECT * FROM llm_messages WHERE 1=1"
|
query = "SELECT * FROM llm_messages WHERE 1=1"
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
@@ -340,7 +337,8 @@ class DatabaseClient:
|
|||||||
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
|
query += " ORDER BY timestamp DESC LIMIT %s OFFSET %s"
|
||||||
params.extend([limit, offset])
|
params.extend([limit, offset])
|
||||||
|
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
@@ -353,9 +351,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary with analytics data
|
Dictionary with analytics data
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
# Total messages
|
# Total messages
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
@@ -571,20 +568,45 @@ class DatabaseClient:
|
|||||||
self,
|
self,
|
||||||
status: Optional[str] = None,
|
status: Optional[str] = None,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
|
cursor: Optional[str] = None,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""List jobs, optionally filtered by status."""
|
"""List jobs with optional status filter and cursor-based pagination.
|
||||||
query = "SELECT * FROM jobs"
|
|
||||||
|
Args:
|
||||||
|
status: Optional status filter (pending, running, completed, failed).
|
||||||
|
limit: Maximum number of jobs to return.
|
||||||
|
cursor: Opaque cursor (``created_at|job_id``) from a previous
|
||||||
|
response. When provided, only jobs older than the cursor are
|
||||||
|
returned.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of job dicts ordered by created_at descending.
|
||||||
|
"""
|
||||||
|
conditions: list[str] = []
|
||||||
params: list = []
|
params: list = []
|
||||||
|
|
||||||
if status:
|
if status:
|
||||||
query += " WHERE status = %s"
|
conditions.append("status = %s")
|
||||||
params.append(status)
|
params.append(status)
|
||||||
query += " ORDER BY created_at DESC LIMIT %s"
|
|
||||||
|
if cursor:
|
||||||
|
try:
|
||||||
|
ts_str, cursor_job_id = cursor.rsplit("|", 1)
|
||||||
|
conditions.append("(created_at, job_id) < (%s, %s)")
|
||||||
|
params.extend([ts_str, cursor_job_id])
|
||||||
|
except ValueError:
|
||||||
|
pass # Ignore malformed cursors; return from start
|
||||||
|
|
||||||
|
query = "SELECT * FROM jobs"
|
||||||
|
if conditions:
|
||||||
|
query += " WHERE " + " AND ".join(conditions)
|
||||||
|
query += " ORDER BY created_at DESC, job_id DESC LIMIT %s"
|
||||||
params.append(limit)
|
params.append(limit)
|
||||||
|
|
||||||
with self.get_conn() as conn:
|
with self.get_conn() as conn:
|
||||||
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||||
cursor.execute(query, params)
|
cur.execute(query, params)
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
return [dict(row) for row in cur.fetchall()]
|
||||||
|
|
||||||
def mark_stale_jobs_failed(self) -> int:
|
def mark_stale_jobs_failed(self) -> int:
|
||||||
"""Mark any jobs in 'running' or 'pending' state as 'failed'.
|
"""Mark any jobs in 'running' or 'pending' state as 'failed'.
|
||||||
@@ -650,12 +672,11 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Created user dict or None if email exists
|
Created user dict or None if email exists
|
||||||
"""
|
"""
|
||||||
self.connect()
|
|
||||||
|
|
||||||
password_hash = self.hash_password(password)
|
password_hash = self.hash_password(password)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO users (email, password_hash, role)
|
INSERT INTO users (email, password_hash, role)
|
||||||
@@ -665,10 +686,9 @@ class DatabaseClient:
|
|||||||
(email, password_hash, role),
|
(email, password_hash, role),
|
||||||
)
|
)
|
||||||
user = cursor.fetchone()
|
user = cursor.fetchone()
|
||||||
self.conn.commit()
|
conn.commit()
|
||||||
return dict(user) if user else None
|
return dict(user) if user else None
|
||||||
except psycopg2.errors.UniqueViolation:
|
except psycopg2.errors.UniqueViolation:
|
||||||
self.conn.rollback()
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
|
def authenticate_user(self, email: str, password: str) -> Optional[Dict]:
|
||||||
@@ -681,9 +701,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
User dict if authenticated, None otherwise
|
User dict if authenticated, None otherwise
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT * FROM users WHERE email = %s",
|
"SELECT * FROM users WHERE email = %s",
|
||||||
(email,),
|
(email,),
|
||||||
@@ -708,9 +727,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
User dict or None
|
User dict or None
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
"SELECT id, email, role, created_at FROM users WHERE id = %s",
|
||||||
(user_id,),
|
(user_id,),
|
||||||
@@ -727,9 +745,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
User dict or None
|
User dict or None
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
"SELECT id, email, role, created_at FROM users WHERE email = %s",
|
||||||
(email,),
|
(email,),
|
||||||
@@ -747,9 +764,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
List of user dicts
|
List of user dicts
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
SELECT id, email, role, created_at
|
SELECT id, email, role, created_at
|
||||||
@@ -771,9 +787,8 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Updated user dict or None
|
Updated user dict or None
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cursor:
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
@@ -784,7 +799,7 @@ class DatabaseClient:
|
|||||||
(role, user_id),
|
(role, user_id),
|
||||||
)
|
)
|
||||||
user = cursor.fetchone()
|
user = cursor.fetchone()
|
||||||
self.conn.commit()
|
conn.commit()
|
||||||
return dict(user) if user else None
|
return dict(user) if user else None
|
||||||
|
|
||||||
def delete_user(self, user_id: int) -> bool:
|
def delete_user(self, user_id: int) -> bool:
|
||||||
@@ -796,12 +811,11 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
True if deleted
|
True if deleted
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
with self.conn.cursor() as cursor:
|
|
||||||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||||||
deleted = cursor.rowcount > 0
|
deleted = cursor.rowcount > 0
|
||||||
self.conn.commit()
|
conn.commit()
|
||||||
return deleted
|
return deleted
|
||||||
|
|
||||||
def get_user_count(self) -> int:
|
def get_user_count(self) -> int:
|
||||||
@@ -810,8 +824,7 @@ class DatabaseClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Number of users
|
Number of users
|
||||||
"""
|
"""
|
||||||
self.connect()
|
with self.get_conn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
with self.conn.cursor() as cursor:
|
|
||||||
cursor.execute("SELECT COUNT(*) FROM users")
|
cursor.execute("SELECT COUNT(*) FROM users")
|
||||||
return cursor.fetchone()[0]
|
return cursor.fetchone()[0]
|
||||||
|
|||||||
+8
-5
@@ -1,12 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
import serpapi
|
|
||||||
from SPARC import config
|
|
||||||
import re
|
import re
|
||||||
import pdfplumber # pip install pdfplumber
|
|
||||||
import requests
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from SPARC.types import Patents, Patent
|
|
||||||
|
import pdfplumber # pip install pdfplumber
|
||||||
|
import requests
|
||||||
|
import serpapi
|
||||||
|
|
||||||
|
from SPARC import config
|
||||||
|
from SPARC.types import Patent, Patents
|
||||||
|
|
||||||
|
|
||||||
class SERP:
|
class SERP:
|
||||||
def query(company: str, days_back: int = None) -> Patents:
|
def query(company: str, days_back: int = None) -> Patents:
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
[lint]
|
||||||
|
select = ["E", "F", "I"]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by formatter)
|
||||||
|
]
|
||||||
|
|
||||||
|
[lint.per-file-ignores]
|
||||||
|
"tests/*" = ["E402", "F841"] # allow import not at top of file, unused vars (mocks) in tests
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
"""Tests for the high-level company analyzer orchestration."""
|
"""Tests for the high-level company analyzer orchestration."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, patch, call, MagicMock
|
|
||||||
from SPARC.analyzer import CompanyAnalyzer
|
from SPARC.analyzer import CompanyAnalyzer
|
||||||
from SPARC.types import Patent, Patents, CompanyAnalysisResult, BatchAnalysisResult
|
from SPARC.types import BatchAnalysisResult, Patent, Patents
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -24,7 +26,7 @@ class TestCompanyAnalyzer:
|
|||||||
"""Test analyzer initialization with API key."""
|
"""Test analyzer initialization with API key."""
|
||||||
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
mock_llm = mocker.patch("SPARC.analyzer.LLMAnalyzer")
|
||||||
|
|
||||||
analyzer = CompanyAnalyzer(openrouter_api_key="test-key")
|
_analyzer = CompanyAnalyzer(openrouter_api_key="test-key") # noqa: F841
|
||||||
|
|
||||||
mock_llm.assert_called_once_with(api_key="test-key")
|
mock_llm.assert_called_once_with(api_key="test-key")
|
||||||
|
|
||||||
|
|||||||
+4
-3
@@ -1,12 +1,13 @@
|
|||||||
"""Tests for FastAPI web service endpoints."""
|
"""Tests for FastAPI web service endpoints."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from SPARC.api import app
|
from SPARC.api import app
|
||||||
from SPARC.types import CompanyAnalysisResult, BatchAnalysisResult
|
from SPARC.types import BatchAnalysisResult, CompanyAnalysisResult
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
+3
-1
@@ -1,7 +1,9 @@
|
|||||||
"""Tests for LLM analysis functionality."""
|
"""Tests for LLM analysis functionality."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, MagicMock, patch
|
|
||||||
from SPARC.llm import LLMAnalyzer
|
from SPARC.llm import LLMAnalyzer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
"""Tests for SERP API patent retrieval and parsing functionality."""
|
"""Tests for SERP API patent retrieval and parsing functionality."""
|
||||||
|
|
||||||
import os
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import patch, Mock
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from SPARC.serp_api import SERP
|
from SPARC.serp_api import SERP
|
||||||
from SPARC.types import Patent
|
from SPARC.types import Patent
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user