Add search backend: FastAPI + FastMCP + pgvector for docs Q&A and live SQL
Search stack replicates the Hamilton site pattern with pg_orrery-specific additions: - FastAPI REST API (chat SSE streaming, semantic search, health check) - FastMCP server at /mcp with doc search and live SQL query tools - pgvector + pgai vectorizer for 1024-dim document embeddings - Hybrid search (semantic cosine + text ILIKE with pg_trgm GIN) - Dual LLM backend: self-hosted qwen3 via GPU gateway or Anthropic Claude - Live read-only pg_orrery SQL execution with safety guardrails (SELECT-only validation, read-only transaction, 5s timeout, 100-row cap) - Convenience MCP tools: planet_position, sky_survey, satellite_pass - MDX content ingestion from docs/src/content/docs/ (50 pages) - Docker Compose: pg_orrery+pgvector DB, pgai, vectorizer-worker, API - Alembic async migrations, Makefile, .env.example
This commit is contained in:
parent
c850277efe
commit
317f74b33b
24
search/.env.example
Normal file
24
search/.env.example
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# pg_orrery Search — Environment Configuration
|
||||||
|
COMPOSE_PROJECT_NAME=pg-orrery-search
|
||||||
|
|
||||||
|
# Domain for caddy-docker-proxy labels
|
||||||
|
DOMAIN=pg-orrery.warehack.ing
|
||||||
|
|
||||||
|
# PostgreSQL
|
||||||
|
POSTGRES_USER=orrery
|
||||||
|
POSTGRES_PASSWORD=changeme
|
||||||
|
|
||||||
|
# Database URLs (must match POSTGRES_USER/PASSWORD above)
|
||||||
|
DATABASE_URL=postgresql+asyncpg://orrery:changeme@db:5432/orrery_search
|
||||||
|
ORRERY_DB_URL=postgresql://orrery:changeme@db:5432/orrery_search
|
||||||
|
|
||||||
|
# GPU Gateway (embedding + chat)
|
||||||
|
GPU_API_KEY=sk-gpu-lb-master-key-2026
|
||||||
|
GPU_BASE_URL=https://orrery-search.gpu.supported.systems/v1
|
||||||
|
|
||||||
|
# LLM Provider: "gpu" (self-hosted qwen3) or "anthropic" (Claude)
|
||||||
|
LLM_PROVIDER=gpu
|
||||||
|
|
||||||
|
# Anthropic (only needed if LLM_PROVIDER=anthropic)
|
||||||
|
ANTHROPIC_API_KEY=
|
||||||
|
ANTHROPIC_MODEL=claude-sonnet-4-20250514
|
||||||
29
search/Dockerfile
Normal file
29
search/Dockerfile
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# syntax=docker/dockerfile:1
|
||||||
|
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim AS base
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ENV UV_COMPILE_BYTECODE=1
|
||||||
|
ENV UV_LINK_MODE=copy
|
||||||
|
|
||||||
|
# Install dependencies first (cache layer)
|
||||||
|
COPY pyproject.toml ./
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv pip install --system -r pyproject.toml
|
||||||
|
|
||||||
|
# Copy source
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# --- Development (editable install for hot reload) ---
|
||||||
|
FROM base AS dev
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv pip install --system --no-deps -e .
|
||||||
|
CMD ["uvicorn", "orrery_search.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload", "--proxy-headers", "--forwarded-allow-ips", "*"]
|
||||||
|
|
||||||
|
# --- Production (non-editable, non-root) ---
|
||||||
|
FROM base AS prod
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv pip install --system --no-deps .
|
||||||
|
RUN adduser --disabled-password --gecos "" orrery && \
|
||||||
|
mkdir -p /data && chown -R orrery:orrery /data
|
||||||
|
USER orrery
|
||||||
|
CMD ["uvicorn", "orrery_search.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2", "--proxy-headers", "--forwarded-allow-ips", "*"]
|
||||||
8
search/Dockerfile.db
Normal file
8
search/Dockerfile.db
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# pg_orrery + pgvector: celestial mechanics + vector search in one database
|
||||||
|
FROM git.supported.systems/warehack.ing/pg_orrery:pg17
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
postgresql-17-pgvector \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY docker/030_install_extensions.sh /docker-entrypoint-initdb.d/
|
||||||
57
search/Makefile
Normal file
57
search/Makefile
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
.PHONY: dev prod build logs down clean restart migrate ingest shell status
|
||||||
|
|
||||||
|
.DEFAULT_GOAL := dev
|
||||||
|
|
||||||
|
# Development mode with hot reload
|
||||||
|
dev:
|
||||||
|
@echo "Starting pg_orrery Search in development mode..."
|
||||||
|
docker compose --profile dev up -d --build
|
||||||
|
@sleep 3
|
||||||
|
docker compose logs -f
|
||||||
|
|
||||||
|
# Production mode
|
||||||
|
prod:
|
||||||
|
@echo "Starting pg_orrery Search in production mode..."
|
||||||
|
docker compose --profile prod up -d --build
|
||||||
|
@sleep 3
|
||||||
|
docker compose logs -f
|
||||||
|
|
||||||
|
# Build without starting
|
||||||
|
build:
|
||||||
|
docker compose build
|
||||||
|
|
||||||
|
# View logs
|
||||||
|
logs:
|
||||||
|
docker compose logs -f
|
||||||
|
|
||||||
|
# Stop containers
|
||||||
|
down:
|
||||||
|
docker compose down
|
||||||
|
|
||||||
|
# Clean up containers, images, volumes
|
||||||
|
clean:
|
||||||
|
docker compose down -v --rmi local
|
||||||
|
|
||||||
|
# Restart
|
||||||
|
restart: down dev
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
migrate:
|
||||||
|
docker compose exec api-dev alembic upgrade head
|
||||||
|
|
||||||
|
# Run content ingestion
|
||||||
|
ingest:
|
||||||
|
docker compose exec api-dev python -m orrery_search.ingest
|
||||||
|
|
||||||
|
# Shell into running API container
|
||||||
|
shell:
|
||||||
|
docker compose exec api-dev bash
|
||||||
|
|
||||||
|
# Check vectorizer status
|
||||||
|
status:
|
||||||
|
docker compose exec db psql -U orrery -d orrery_search -c "SELECT * FROM ai.vectorizer_status;"
|
||||||
|
|
||||||
|
# Test search
|
||||||
|
test-search:
|
||||||
|
@echo "Testing search API..."
|
||||||
|
curl -s "https://$$(grep DOMAIN .env | cut -d= -f2)/api/search?q=satellite+tracking" | python3 -m json.tool
|
||||||
36
search/alembic.ini
Normal file
36
search/alembic.ini
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
[alembic]
|
||||||
|
script_location = alembic
|
||||||
|
sqlalchemy.url = postgresql+asyncpg://orrery:orrery@localhost:5432/orrery_search
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
72
search/alembic/env.py
Normal file
72
search/alembic/env.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||||
|
|
||||||
|
from orrery_search.models import Base
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
db_url = os.environ.get("DATABASE_URL")
|
||||||
|
if db_url:
|
||||||
|
config.set_main_option("sqlalchemy.url", db_url)
|
||||||
|
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
PGAI_MANAGED_TABLES = {
|
||||||
|
"document_embedding_store",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def include_object(object, name, type_, reflected, compare_to):
|
||||||
|
if type_ == "table" and name in PGAI_MANAGED_TABLES:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline():
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
include_object=include_object,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection):
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
include_object=include_object,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_async_migrations():
|
||||||
|
connectable = async_engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online():
|
||||||
|
asyncio.run(run_async_migrations())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
23
search/alembic/script.py.mako
Normal file
23
search/alembic/script.py.mako
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
revision = ${repr(up_revision)}
|
||||||
|
down_revision = ${repr(down_revision)}
|
||||||
|
branch_labels = ${repr(branch_labels)}
|
||||||
|
depends_on = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
83
search/alembic/versions/001_baseline.py
Normal file
83
search/alembic/versions/001_baseline.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
"""document table with pgai vectorizer
|
||||||
|
|
||||||
|
Revision ID: 001_baseline
|
||||||
|
Revises: None
|
||||||
|
Create Date: 2026-03-01
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "001_baseline"
|
||||||
|
down_revision = None
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
op.create_table(
|
||||||
|
"document",
|
||||||
|
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
|
||||||
|
sa.Column("content_type", sa.String(20), nullable=False, index=True),
|
||||||
|
sa.Column("slug", sa.String(300), nullable=False, unique=True),
|
||||||
|
sa.Column("title", sa.String(300), nullable=False),
|
||||||
|
sa.Column("section", sa.String(200), nullable=False, index=True),
|
||||||
|
sa.Column("description", sa.Text, nullable=True),
|
||||||
|
sa.Column("body", sa.Text, nullable=False),
|
||||||
|
sa.Column("search_text", sa.Text, nullable=True),
|
||||||
|
sa.Column("url", sa.String(300), nullable=False),
|
||||||
|
sa.Column("word_count", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enable pg_trgm for fast ILIKE with GIN indexes
|
||||||
|
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX ix_document_title_trgm ON document USING gin (title gin_trgm_ops)"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX ix_document_body_trgm ON document USING gin (body gin_trgm_ops)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# pgai vectorizer — reads search_text, generates 1024-dim embeddings
|
||||||
|
# Uses mxbai-embed-large via the GPU embedding gateway.
|
||||||
|
op.execute("""
|
||||||
|
SELECT ai.create_vectorizer(
|
||||||
|
'document'::regclass,
|
||||||
|
name => 'document_embedder',
|
||||||
|
loading => ai.loading_column(column_name => 'search_text'),
|
||||||
|
embedding => ai.embedding_openai(
|
||||||
|
model => 'mxbai-embed-large',
|
||||||
|
dimensions => 1024
|
||||||
|
),
|
||||||
|
chunking => ai.chunking_recursive_character_text_splitter(
|
||||||
|
chunk_size => 400,
|
||||||
|
chunk_overlap => 50,
|
||||||
|
separators => array[E'\\n\\n', E'\\n', '. ', ' ']
|
||||||
|
),
|
||||||
|
formatting => ai.formatting_python_template(
|
||||||
|
template => '$chunk'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
PERFORM ai.drop_vectorizer('document_embedder', drop_all => true);
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'document_embedder not found, skipping drop';
|
||||||
|
END $$
|
||||||
|
""")
|
||||||
|
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_document_title_trgm")
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_document_body_trgm")
|
||||||
|
|
||||||
|
op.drop_table("document")
|
||||||
134
search/docker-compose.yml
Normal file
134
search/docker-compose.yml
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
services:
|
||||||
|
db:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile.db
|
||||||
|
environment:
|
||||||
|
POSTGRES_DB: orrery_search
|
||||||
|
POSTGRES_USER: ${POSTGRES_USER:-orrery}
|
||||||
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||||
|
volumes:
|
||||||
|
- pg-data:/var/lib/postgresql/data
|
||||||
|
networks:
|
||||||
|
- internal
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U orrery -d orrery_search"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
start_period: 10s
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
pgai-install:
|
||||||
|
image: timescale/pgai-vectorizer-worker:latest
|
||||||
|
entrypoint: ["python", "-m", "pgai", "install", "-d", "postgresql://${POSTGRES_USER:-orrery}:${POSTGRES_PASSWORD}@db:5432/orrery_search"]
|
||||||
|
networks:
|
||||||
|
- internal
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
restart: "no"
|
||||||
|
|
||||||
|
vectorizer-worker:
|
||||||
|
image: timescale/pgai-vectorizer-worker:latest
|
||||||
|
environment:
|
||||||
|
PGAI_VECTORIZER_WORKER_DB_URL: postgresql://${POSTGRES_USER:-orrery}:${POSTGRES_PASSWORD}@db:5432/orrery_search
|
||||||
|
OPENAI_BASE_URL: ${GPU_BASE_URL:-https://orrery-search.gpu.supported.systems/v1}
|
||||||
|
OPENAI_API_KEY: ${GPU_API_KEY}
|
||||||
|
command: ["--poll-interval", "5s"]
|
||||||
|
networks:
|
||||||
|
- internal
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
pgai-install:
|
||||||
|
condition: service_completed_successfully
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
api-dev:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
target: dev
|
||||||
|
profiles: ["dev"]
|
||||||
|
env_file: .env
|
||||||
|
volumes:
|
||||||
|
- ./src:/app/src
|
||||||
|
- ./alembic:/app/alembic
|
||||||
|
- ../docs/src/content/docs:/data/content:ro
|
||||||
|
networks:
|
||||||
|
- internal
|
||||||
|
- caddy
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
pgai-install:
|
||||||
|
condition: service_completed_successfully
|
||||||
|
labels:
|
||||||
|
caddy: ${DOMAIN:-pg-orrery.warehack.ing}
|
||||||
|
caddy.handle: /api/search*
|
||||||
|
caddy.handle.0_reverse_proxy: "{{upstreams 8000}}"
|
||||||
|
caddy.handle_1: /health
|
||||||
|
caddy.handle_1.0_reverse_proxy: "{{upstreams 8000}}"
|
||||||
|
caddy.handle_2: /mcp*
|
||||||
|
caddy.handle_2.0_reverse_proxy: "{{upstreams 8000}}"
|
||||||
|
caddy.handle_3: /api/chat*
|
||||||
|
caddy.handle_3.0_reverse_proxy: "{{upstreams 8000}}"
|
||||||
|
caddy.handle_3.0_reverse_proxy.flush_interval: "-1"
|
||||||
|
caddy.handle_3.0_reverse_proxy.transport: "http"
|
||||||
|
caddy.handle_3.0_reverse_proxy.transport.read_timeout: "0"
|
||||||
|
caddy.handle_3.0_reverse_proxy.transport.write_timeout: "0"
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health')"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 3
|
||||||
|
start_period: 15s
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
api-prod:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
target: prod
|
||||||
|
profiles: ["prod"]
|
||||||
|
env_file: .env
|
||||||
|
volumes:
|
||||||
|
- ../docs/src/content/docs:/data/content:ro
|
||||||
|
networks:
|
||||||
|
- internal
|
||||||
|
- caddy
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
pgai-install:
|
||||||
|
condition: service_completed_successfully
|
||||||
|
labels:
|
||||||
|
caddy: ${DOMAIN:-pg-orrery.warehack.ing}
|
||||||
|
caddy.handle: /api/search*
|
||||||
|
caddy.handle.0_reverse_proxy: "{{upstreams 8000}}"
|
||||||
|
caddy.handle_1: /health
|
||||||
|
caddy.handle_1.0_reverse_proxy: "{{upstreams 8000}}"
|
||||||
|
caddy.handle_2: /mcp*
|
||||||
|
caddy.handle_2.0_reverse_proxy: "{{upstreams 8000}}"
|
||||||
|
caddy.handle_3: /api/chat*
|
||||||
|
caddy.handle_3.0_reverse_proxy: "{{upstreams 8000}}"
|
||||||
|
caddy.handle_3.0_reverse_proxy.flush_interval: "-1"
|
||||||
|
caddy.handle_3.0_reverse_proxy.transport: "http"
|
||||||
|
caddy.handle_3.0_reverse_proxy.transport.read_timeout: "0"
|
||||||
|
caddy.handle_3.0_reverse_proxy.transport.write_timeout: "0"
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health')"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 3
|
||||||
|
start_period: 15s
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
pg-data:
|
||||||
|
|
||||||
|
networks:
|
||||||
|
internal:
|
||||||
|
caddy:
|
||||||
|
external: true
|
||||||
8
search/docker/030_install_extensions.sh
Executable file
8
search/docker/030_install_extensions.sh
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
|
||||||
|
CREATE EXTENSION IF NOT EXISTS vector;
|
||||||
|
CREATE EXTENSION IF NOT EXISTS pg_trgm;
|
||||||
|
CREATE EXTENSION IF NOT EXISTS pg_orrery;
|
||||||
|
EOSQL
|
||||||
50
search/pyproject.toml
Normal file
50
search/pyproject.toml
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "orrery-search"
|
||||||
|
version = "2026.03.01"
|
||||||
|
description = "Semantic search and chat API for the pg_orrery documentation"
|
||||||
|
license = "MIT"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
authors = [{name = "Ryan Malloy", email = "ryan@supported.systems"}]
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.115.0",
|
||||||
|
"uvicorn[standard]>=0.32.0",
|
||||||
|
"sqlalchemy[asyncio]>=2.0.36",
|
||||||
|
"asyncpg>=0.30.0",
|
||||||
|
"pgvector>=0.4.2",
|
||||||
|
"pgai[sqlalchemy]>=0.12.0",
|
||||||
|
"tiktoken>=0.7.0",
|
||||||
|
"alembic>=1.14.0",
|
||||||
|
"pydantic-settings>=2.6.0",
|
||||||
|
"openai>=1.60.0",
|
||||||
|
"anthropic>=0.40.0",
|
||||||
|
"pyyaml>=6.0",
|
||||||
|
"fastmcp>=3.0.0",
|
||||||
|
"httpx>=0.28.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
orrery-search = "orrery_search.main:run"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py312"
|
||||||
|
src = ["src"]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "I", "UP", "B", "SIM"]
|
||||||
|
ignore = ["B008"]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/orrery_search"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"pytest>=9.0.2",
|
||||||
|
"pytest-asyncio>=1.3.0",
|
||||||
|
]
|
||||||
0
search/src/orrery_search/__init__.py
Normal file
0
search/src/orrery_search/__init__.py
Normal file
36
search/src/orrery_search/config.py
Normal file
36
search/src/orrery_search/config.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
api_host: str = "0.0.0.0"
|
||||||
|
api_port: int = 8000
|
||||||
|
api_log_level: str = "info"
|
||||||
|
|
||||||
|
database_url: str = "postgresql+asyncpg://orrery:orrery@localhost:5432/orrery_search"
|
||||||
|
|
||||||
|
# Raw asyncpg URL for direct pg_orrery SQL execution (no SQLAlchemy)
|
||||||
|
orrery_db_url: str = "postgresql://orrery:orrery@localhost:5432/orrery_search"
|
||||||
|
|
||||||
|
gpu_api_key: str = ""
|
||||||
|
gpu_base_url: str = "https://orrery-search.gpu.supported.systems/v1"
|
||||||
|
gpu_embed_model: str = "mxbai-embed-large"
|
||||||
|
gpu_embed_dimensions: int = 1024
|
||||||
|
|
||||||
|
search_max_results: int = 50
|
||||||
|
|
||||||
|
# LLM provider: "gpu" for self-hosted (qwen3), "anthropic" for Claude
|
||||||
|
llm_provider: str = "gpu"
|
||||||
|
|
||||||
|
gpu_chat_model: str = "qwen3"
|
||||||
|
chat_timeout: float = 30.0
|
||||||
|
chat_max_tokens: int = 8192
|
||||||
|
|
||||||
|
anthropic_api_key: str = ""
|
||||||
|
anthropic_model: str = "claude-sonnet-4-20250514"
|
||||||
|
|
||||||
|
run_query_timeout: float = 5.0
|
||||||
|
|
||||||
|
model_config = {"env_prefix": "", "env_file": ".env", "extra": "ignore"}
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
17
search/src/orrery_search/db.py
Normal file
17
search/src/orrery_search/db.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from orrery_search.config import settings
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.database_url,
|
||||||
|
echo=False,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_size=10,
|
||||||
|
max_overflow=20,
|
||||||
|
)
|
||||||
|
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db() -> AsyncSession:
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
0
search/src/orrery_search/ingest/__init__.py
Normal file
0
search/src/orrery_search/ingest/__init__.py
Normal file
6
search/src/orrery_search/ingest/__main__.py
Normal file
6
search/src/orrery_search/ingest/__main__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from orrery_search.ingest.runner import ingest
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(ingest())
|
||||||
54
search/src/orrery_search/ingest/mdx_parser.py
Normal file
54
search/src/orrery_search/ingest/mdx_parser.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
"""MDX/Markdown frontmatter extraction and content cleaning.
|
||||||
|
|
||||||
|
Strips JSX components, import statements, Starlight admonitions,
|
||||||
|
and other MDX-specific syntax to produce clean plain text for indexing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import html
|
||||||
|
import re
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
_FRONTMATTER_RE = re.compile(r"^---\s*\n(.*?)\n---\s*\n", re.DOTALL)
|
||||||
|
_IMPORT_RE = re.compile(r"^import\s+.*$", re.MULTILINE)
|
||||||
|
_MERMAID_BLOCK_RE = re.compile(r"```mermaid\n.*?```", re.DOTALL)
|
||||||
|
_JSX_SELF_CLOSING_RE = re.compile(r"<\w+[^>]*/\s*>")
|
||||||
|
_JSX_BLOCK_RE = re.compile(
|
||||||
|
r"<(Aside|CardGrid|Steps|Tabs|TabItem|Card|LinkCard)[^>]*>(.*?)</\1>",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
_JSX_OPEN_CLOSE_RE = re.compile(
|
||||||
|
r"</?(?:Aside|CardGrid|Card|LinkCard|Steps|Tabs|TabItem|Icon)[^>]*>"
|
||||||
|
)
|
||||||
|
_ADMONITION_FENCE_RE = re.compile(r"^:::.*$", re.MULTILINE)
|
||||||
|
_MULTI_BLANK_RE = re.compile(r"\n{3,}")
|
||||||
|
|
||||||
|
|
||||||
|
def strip_mdx(raw: str) -> tuple[dict, str]:
|
||||||
|
"""Extract frontmatter and clean MDX content to plain text.
|
||||||
|
|
||||||
|
Returns (frontmatter_dict, cleaned_body).
|
||||||
|
"""
|
||||||
|
frontmatter: dict = {}
|
||||||
|
body = raw
|
||||||
|
|
||||||
|
fm_match = _FRONTMATTER_RE.match(raw)
|
||||||
|
if fm_match:
|
||||||
|
with contextlib.suppress(yaml.YAMLError):
|
||||||
|
frontmatter = yaml.safe_load(fm_match.group(1)) or {}
|
||||||
|
body = raw[fm_match.end():]
|
||||||
|
|
||||||
|
body = _IMPORT_RE.sub("", body)
|
||||||
|
body = _MERMAID_BLOCK_RE.sub("", body)
|
||||||
|
|
||||||
|
while _JSX_BLOCK_RE.search(body):
|
||||||
|
body = _JSX_BLOCK_RE.sub(r"\2", body)
|
||||||
|
|
||||||
|
body = _JSX_SELF_CLOSING_RE.sub("", body)
|
||||||
|
body = _JSX_OPEN_CLOSE_RE.sub("", body)
|
||||||
|
body = _ADMONITION_FENCE_RE.sub("", body)
|
||||||
|
body = html.unescape(body)
|
||||||
|
body = _MULTI_BLANK_RE.sub("\n\n", body).strip()
|
||||||
|
|
||||||
|
return frontmatter, body
|
||||||
183
search/src/orrery_search/ingest/runner.py
Normal file
183
search/src/orrery_search/ingest/runner.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
"""Ingestion orchestrator — walks content directories, classifies, and upserts.
|
||||||
|
|
||||||
|
Run: docker compose exec api-dev python -m orrery_search.ingest
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from orrery_search.db import async_session
|
||||||
|
from orrery_search.ingest.mdx_parser import strip_mdx
|
||||||
|
from orrery_search.models.document import Document
|
||||||
|
from orrery_search.services.search_text import build_search_text
|
||||||
|
|
||||||
|
CONTENT_DIR = Path("/data/content")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_paths() -> Path:
|
||||||
|
"""Return content_dir, preferring Docker mount."""
|
||||||
|
if CONTENT_DIR.exists():
|
||||||
|
return CONTENT_DIR
|
||||||
|
|
||||||
|
here = Path(__file__).resolve()
|
||||||
|
project_root = here
|
||||||
|
for _ in range(10):
|
||||||
|
project_root = project_root.parent
|
||||||
|
if (project_root / "docs" / "src").exists():
|
||||||
|
break
|
||||||
|
return project_root / "docs" / "src" / "content" / "docs"
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_content_type(rel_path: str) -> str:
|
||||||
|
"""Classify content type from the relative path within content/docs/."""
|
||||||
|
parts = rel_path.split("/")
|
||||||
|
|
||||||
|
if parts[0] == "getting-started":
|
||||||
|
return "getting_started"
|
||||||
|
if parts[0] == "guides":
|
||||||
|
return "guide"
|
||||||
|
if parts[0] == "workflow":
|
||||||
|
return "workflow"
|
||||||
|
if parts[0] == "reference":
|
||||||
|
return "reference"
|
||||||
|
if parts[0] == "architecture":
|
||||||
|
return "architecture"
|
||||||
|
if parts[0] == "performance":
|
||||||
|
return "performance"
|
||||||
|
|
||||||
|
return "page"
|
||||||
|
|
||||||
|
|
||||||
|
def _mdx_path_to_url(rel_path: str) -> str:
|
||||||
|
"""Convert relative .mdx path to Starlight page URL."""
|
||||||
|
slug = rel_path.removesuffix(".mdx").removesuffix("/index")
|
||||||
|
return f"/{slug}/"
|
||||||
|
|
||||||
|
|
||||||
|
def _mdx_path_to_section(rel_path: str) -> str:
|
||||||
|
"""Extract section from relative path."""
|
||||||
|
parts = Path(rel_path).parts
|
||||||
|
if len(parts) > 1:
|
||||||
|
return "/".join(parts[:-1])
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _mdx_path_to_slug(rel_path: str) -> str:
|
||||||
|
"""Convert to a unique slug for dedup."""
|
||||||
|
return rel_path.removesuffix(".mdx").removesuffix("/index")
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_mdx_pages(content_dir: Path) -> list[dict]:
|
||||||
|
"""Walk the content directory and parse all .mdx files."""
|
||||||
|
pages = []
|
||||||
|
for mdx_path in sorted(content_dir.rglob("*.mdx")):
|
||||||
|
try:
|
||||||
|
raw = mdx_path.read_text(encoding="utf-8")
|
||||||
|
except (OSError, UnicodeDecodeError) as exc:
|
||||||
|
print(f" SKIP {mdx_path}: {exc}", file=sys.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
frontmatter, body = strip_mdx(raw)
|
||||||
|
|
||||||
|
rel_path = str(mdx_path.relative_to(content_dir))
|
||||||
|
title = frontmatter.get("title", mdx_path.stem.replace("-", " ").title())
|
||||||
|
description = frontmatter.get("description")
|
||||||
|
content_type = _classify_content_type(rel_path)
|
||||||
|
word_count = len(body.split())
|
||||||
|
|
||||||
|
pages.append({
|
||||||
|
"content_type": content_type,
|
||||||
|
"slug": _mdx_path_to_slug(rel_path),
|
||||||
|
"title": title,
|
||||||
|
"section": _mdx_path_to_section(rel_path),
|
||||||
|
"description": description,
|
||||||
|
"body": body,
|
||||||
|
"url": _mdx_path_to_url(rel_path),
|
||||||
|
"word_count": word_count,
|
||||||
|
})
|
||||||
|
|
||||||
|
return pages
|
||||||
|
|
||||||
|
|
||||||
|
async def ingest():
|
||||||
|
"""Main ingestion: read docs content, upsert into document table."""
|
||||||
|
content_dir = _resolve_paths()
|
||||||
|
print(f"Content dir: {content_dir}", file=sys.stderr)
|
||||||
|
|
||||||
|
if not content_dir.exists():
|
||||||
|
print(f"Content directory not found: {content_dir}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
pages = _collect_mdx_pages(content_dir)
|
||||||
|
print(f"Found {len(pages)} published pages", file=sys.stderr)
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
inserted = 0
|
||||||
|
updated = 0
|
||||||
|
errors = 0
|
||||||
|
|
||||||
|
for i, page_data in enumerate(pages):
|
||||||
|
try:
|
||||||
|
search_text = build_search_text(
|
||||||
|
title=page_data["title"],
|
||||||
|
section=page_data["section"],
|
||||||
|
content_type=page_data["content_type"],
|
||||||
|
description=page_data["description"],
|
||||||
|
body=page_data["body"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async with db.begin_nested():
|
||||||
|
stmt = select(Document).where(
|
||||||
|
Document.slug == page_data["slug"]
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
existing.title = page_data["title"]
|
||||||
|
existing.section = page_data["section"]
|
||||||
|
existing.description = page_data["description"]
|
||||||
|
existing.body = page_data["body"]
|
||||||
|
existing.search_text = search_text
|
||||||
|
existing.url = page_data["url"]
|
||||||
|
existing.content_type = page_data["content_type"]
|
||||||
|
existing.word_count = page_data["word_count"]
|
||||||
|
updated += 1
|
||||||
|
else:
|
||||||
|
db.add(Document(
|
||||||
|
content_type=page_data["content_type"],
|
||||||
|
slug=page_data["slug"],
|
||||||
|
title=page_data["title"],
|
||||||
|
section=page_data["section"],
|
||||||
|
description=page_data["description"],
|
||||||
|
body=page_data["body"],
|
||||||
|
search_text=search_text,
|
||||||
|
url=page_data["url"],
|
||||||
|
word_count=page_data["word_count"],
|
||||||
|
))
|
||||||
|
inserted += 1
|
||||||
|
|
||||||
|
if (i + 1) % 50 == 0:
|
||||||
|
await db.commit()
|
||||||
|
print(
|
||||||
|
f" progress: {i + 1}/{len(pages)}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
print(
|
||||||
|
f" ERROR on {page_data['slug']}: {exc}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
errors += 1
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Ingestion complete: {inserted} inserted, {updated} updated, "
|
||||||
|
f"{errors} errors ({inserted + updated} total)",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
49
search/src/orrery_search/main.py
Normal file
49
search/src/orrery_search/main.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastmcp.utilities.lifespan import combine_lifespans
|
||||||
|
|
||||||
|
from orrery_search.config import settings
|
||||||
|
from orrery_search.db import engine
|
||||||
|
from orrery_search.mcp import mcp
|
||||||
|
from orrery_search.mcp.chat import close_chat_client
|
||||||
|
from orrery_search.mcp.query import close_query_pool
|
||||||
|
from orrery_search.routers import chat, health, search
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.WARNING, format="%(name)s: %(message)s")
|
||||||
|
logging.getLogger("orrery_search").setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
yield
|
||||||
|
await close_chat_client()
|
||||||
|
await close_query_pool()
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
mcp_app = mcp.http_app(path="/", stateless_http=True)
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="pg_orrery Search",
|
||||||
|
description="Semantic search and chat for the pg_orrery documentation",
|
||||||
|
version="2026.03.01",
|
||||||
|
lifespan=combine_lifespans(lifespan, mcp_app.lifespan),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(search.router, prefix="/api/search", tags=["search"])
|
||||||
|
app.include_router(chat.router, prefix="/api/chat", tags=["chat"])
|
||||||
|
app.include_router(health.router, tags=["health"])
|
||||||
|
app.mount("/mcp", mcp_app)
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
uvicorn.run(
|
||||||
|
"orrery_search.main:app",
|
||||||
|
host=settings.api_host,
|
||||||
|
port=settings.api_port,
|
||||||
|
log_level=settings.api_log_level,
|
||||||
|
reload=True,
|
||||||
|
)
|
||||||
17
search/src/orrery_search/mcp/__init__.py
Normal file
17
search/src/orrery_search/mcp/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
|
mcp = FastMCP(
|
||||||
|
"pg_orrery",
|
||||||
|
instructions=(
|
||||||
|
"MCP server for the pg_orrery documentation and live celestial mechanics. "
|
||||||
|
"Search 50+ pages of docs covering 225 SQL functions for satellite tracking, "
|
||||||
|
"planetary ephemerides, rise/set prediction, Lagrange points, and more. "
|
||||||
|
"Use ask_orrery for natural-language Q&A with source citations. "
|
||||||
|
"Use run_query to execute live pg_orrery SQL against a PostgreSQL database."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register tools, resources, and prompts
|
||||||
|
import orrery_search.mcp.chat # noqa: E402, F401
|
||||||
|
import orrery_search.mcp.query # noqa: E402, F401
|
||||||
|
import orrery_search.mcp.tools # noqa: E402, F401
|
||||||
307
search/src/orrery_search/mcp/chat.py
Normal file
307
search/src/orrery_search/mcp/chat.py
Normal file
@ -0,0 +1,307 @@
|
|||||||
|
import asyncio
|
||||||
|
import collections.abc
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from orrery_search.config import settings
|
||||||
|
from orrery_search.db import async_session
|
||||||
|
from orrery_search.mcp import mcp
|
||||||
|
from orrery_search.models.document import Document
|
||||||
|
from orrery_search.services.search import search_documents
|
||||||
|
|
||||||
|
logger = logging.getLogger("orrery_search")
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = (
|
||||||
|
"/no_think\n"
|
||||||
|
"You are a knowledgeable assistant for pg_orrery, a PostgreSQL extension providing "
|
||||||
|
"celestial mechanics types and functions. Answer questions using the provided context "
|
||||||
|
"from the documentation.\n\n"
|
||||||
|
"Key domain knowledge:\n"
|
||||||
|
"- Body IDs: 0=Sun, 1=Mercury, 2=Venus, 3=Earth, 4=Mars, 5=Jupiter, 6=Saturn, "
|
||||||
|
"7=Uranus, 8=Neptune, 10=Moon\n"
|
||||||
|
"- Observer format: '(lat_deg, lon_deg, alt_m)' e.g. '(40.7128, -74.0060, 10)'\n"
|
||||||
|
"- Key functions: planet_observe(body_id, observer, timestamp), "
|
||||||
|
"planet_equatorial(body_id, timestamp), sun_observe(), moon_observe(), "
|
||||||
|
"predict_passes(tle, observer, timestamp, count), satellite_is_eclipsed()\n"
|
||||||
|
"- Coordinate systems: topocentric (az/el/range), equatorial (RA/Dec), "
|
||||||
|
"heliocentric (x,y,z AU), geodetic (lat/lon/alt)\n"
|
||||||
|
"- All functions are PARALLEL SAFE. VSOP87 functions are IMMUTABLE.\n\n"
|
||||||
|
"If the context doesn't contain enough information, say so clearly. "
|
||||||
|
"Cite specific documents by title when referencing information. "
|
||||||
|
"Be precise and factual — never fabricate claims."
|
||||||
|
)
|
||||||
|
|
||||||
|
MAX_CONTEXT_CHARS = 2_000
|
||||||
|
|
||||||
|
_chat_client: httpx.AsyncClient | None = None
|
||||||
|
_client_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_chat_client() -> httpx.AsyncClient:
|
||||||
|
global _chat_client
|
||||||
|
if _chat_client is not None and not _chat_client.is_closed:
|
||||||
|
return _chat_client
|
||||||
|
async with _client_lock:
|
||||||
|
if _chat_client is not None and not _chat_client.is_closed:
|
||||||
|
return _chat_client
|
||||||
|
_chat_client = httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(
|
||||||
|
connect=10.0,
|
||||||
|
read=120.0,
|
||||||
|
write=10.0,
|
||||||
|
pool=10.0,
|
||||||
|
),
|
||||||
|
limits=httpx.Limits(
|
||||||
|
max_connections=20,
|
||||||
|
max_keepalive_connections=10,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return _chat_client
|
||||||
|
|
||||||
|
|
||||||
|
async def close_chat_client() -> None:
|
||||||
|
global _chat_client
|
||||||
|
if _chat_client is not None and not _chat_client.is_closed:
|
||||||
|
await _chat_client.aclose()
|
||||||
|
_chat_client = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_context(query: str) -> tuple[str, list[dict]]:
|
||||||
|
"""Search docs and batch-fetch full document bodies for RAG context."""
|
||||||
|
async with async_session() as db:
|
||||||
|
output = await search_documents(q=query, db=db, mode="hybrid", limit=5)
|
||||||
|
|
||||||
|
if not output.results:
|
||||||
|
return "", []
|
||||||
|
|
||||||
|
slugs = [r.slug for r in output.results]
|
||||||
|
score_by_slug = {r.slug: r.score for r in output.results}
|
||||||
|
|
||||||
|
docs_result = await db.execute(
|
||||||
|
select(Document).where(Document.slug.in_(slugs))
|
||||||
|
)
|
||||||
|
docs_by_slug = {doc.slug: doc for doc in docs_result.scalars()}
|
||||||
|
|
||||||
|
sources = []
|
||||||
|
context_parts = []
|
||||||
|
chars_used = 0
|
||||||
|
|
||||||
|
for slug in slugs:
|
||||||
|
doc = docs_by_slug.get(slug)
|
||||||
|
if not doc:
|
||||||
|
continue
|
||||||
|
|
||||||
|
remaining = MAX_CONTEXT_CHARS - chars_used
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
body = doc.body
|
||||||
|
if len(body) > remaining:
|
||||||
|
body = body[:remaining] + "..."
|
||||||
|
|
||||||
|
context_parts.append(
|
||||||
|
f"--- {doc.title} (/{doc.slug}) ---\n{body}"
|
||||||
|
)
|
||||||
|
sources.append({
|
||||||
|
"title": doc.title,
|
||||||
|
"slug": doc.slug,
|
||||||
|
"url": doc.url,
|
||||||
|
"score": score_by_slug.get(slug, 0.0),
|
||||||
|
})
|
||||||
|
chars_used += len(body)
|
||||||
|
|
||||||
|
context_text = "\n\n".join(context_parts)
|
||||||
|
return context_text, sources
|
||||||
|
|
||||||
|
|
||||||
|
async def _chat_completion_stream_gpu(
|
||||||
|
context: str, question: str
|
||||||
|
) -> collections.abc.AsyncIterator[tuple[str, str]]:
|
||||||
|
"""Stream chat completion tokens from the GPU gateway (OpenAI-compatible)."""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
client = await _get_chat_client()
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{settings.gpu_base_url}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {settings.gpu_api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": settings.gpu_chat_model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": 0.3,
|
||||||
|
"max_tokens": settings.chat_max_tokens,
|
||||||
|
"num_ctx": settings.chat_max_tokens,
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
body = await resp.aread()
|
||||||
|
error_text = body[:500].decode("utf-8", errors="replace")
|
||||||
|
logger.error("GPU gateway returned %d: %s", resp.status_code, error_text)
|
||||||
|
raise httpx.HTTPStatusError(
|
||||||
|
f"GPU gateway error {resp.status_code}",
|
||||||
|
request=resp.request, response=resp,
|
||||||
|
)
|
||||||
|
reasoning_count = 0
|
||||||
|
content_count = 0
|
||||||
|
try:
|
||||||
|
async for line in resp.aiter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
payload = line[6:]
|
||||||
|
if payload.strip() == "[DONE]":
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
chunk = json.loads(payload)
|
||||||
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
content = delta.get("content")
|
||||||
|
if content:
|
||||||
|
content_count += 1
|
||||||
|
yield ("content", content)
|
||||||
|
elif delta.get("reasoning_content"):
|
||||||
|
reasoning_count += 1
|
||||||
|
yield ("reasoning", delta["reasoning_content"])
|
||||||
|
except (json.JSONDecodeError, KeyError, IndexError):
|
||||||
|
continue
|
||||||
|
except GeneratorExit:
|
||||||
|
logger.debug("Chat stream cancelled after %d content chunks", content_count)
|
||||||
|
await resp.aclose()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def _chat_completion_stream_anthropic(
|
||||||
|
context: str, question: str
|
||||||
|
) -> collections.abc.AsyncIterator[tuple[str, str]]:
|
||||||
|
"""Stream chat completion tokens via the Anthropic API."""
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
client = anthropic.AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||||
|
try:
|
||||||
|
async with client.messages.stream(
|
||||||
|
model=settings.anthropic_model,
|
||||||
|
max_tokens=settings.chat_max_tokens,
|
||||||
|
system=SYSTEM_PROMPT,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
) as stream:
|
||||||
|
async for text in stream.text_stream:
|
||||||
|
yield ("content", text)
|
||||||
|
except anthropic.APIError as exc:
|
||||||
|
logger.error("Anthropic API error: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _chat_completion_stream(
|
||||||
|
context: str, question: str
|
||||||
|
) -> collections.abc.AsyncIterator[tuple[str, str]]:
|
||||||
|
"""Dispatch to the configured LLM provider."""
|
||||||
|
if settings.llm_provider == "anthropic" and settings.anthropic_api_key:
|
||||||
|
async for item in _chat_completion_stream_anthropic(context, question):
|
||||||
|
yield item
|
||||||
|
else:
|
||||||
|
async for item in _chat_completion_stream_gpu(context, question):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
async def _chat_completion(context: str, question: str) -> str:
|
||||||
|
"""Non-streaming chat completion for MCP tool use."""
|
||||||
|
if settings.llm_provider == "anthropic" and settings.anthropic_api_key:
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
client = anthropic.AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||||
|
message = await client.messages.create(
|
||||||
|
model=settings.anthropic_model,
|
||||||
|
max_tokens=settings.chat_max_tokens,
|
||||||
|
system=SYSTEM_PROMPT,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return message.content[0].text
|
||||||
|
|
||||||
|
client = await _get_chat_client()
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
resp = await client.post(
|
||||||
|
f"{settings.gpu_base_url}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {settings.gpu_api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": settings.gpu_chat_model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": 0.3,
|
||||||
|
"max_tokens": settings.chat_max_tokens,
|
||||||
|
"num_ctx": settings.chat_max_tokens,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
return data["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def ask_orrery(
|
||||||
|
question: str,
|
||||||
|
include_sources: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""Ask a question about pg_orrery and get an answer grounded in documentation.
|
||||||
|
|
||||||
|
Uses retrieval-augmented generation: searches the docs for relevant pages,
|
||||||
|
then synthesizes an answer citing specific sources. Best for conceptual questions
|
||||||
|
about celestial mechanics functions, types, coordinate systems, and workflows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: Natural language question (e.g. "How do I predict satellite passes?")
|
||||||
|
include_sources: Whether to include source document references in the response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
context, sources = await _build_context(question)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to build RAG context")
|
||||||
|
raise RuntimeError("Search failed while building context")
|
||||||
|
|
||||||
|
if not context:
|
||||||
|
return json.dumps({
|
||||||
|
"answer": "No relevant documents found in the pg_orrery docs for this question.",
|
||||||
|
"sources": [],
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
answer = await _chat_completion(context, question)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
logger.error("Chat completion failed: HTTP %s", exc.response.status_code)
|
||||||
|
raise RuntimeError(f"Chat model returned HTTP {exc.response.status_code}")
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.warning("Chat completion timed out (limit: %.0fs)", settings.chat_timeout)
|
||||||
|
raise RuntimeError("Chat model timed out — try a simpler question")
|
||||||
|
|
||||||
|
result: dict = {"answer": answer}
|
||||||
|
if include_sources:
|
||||||
|
result["sources"] = sources
|
||||||
|
return json.dumps(result)
|
||||||
247
search/src/orrery_search/mcp/query.py
Normal file
247
search/src/orrery_search/mcp/query.py
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
"""Live pg_orrery SQL query execution via MCP tools.
|
||||||
|
|
||||||
|
Provides read-only access to pg_orrery functions running inside the same
|
||||||
|
PostgreSQL instance that stores document embeddings. Safety guardrails
|
||||||
|
ensure only SELECT statements execute, with statement timeouts and row limits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
from orrery_search.config import settings
|
||||||
|
from orrery_search.mcp import mcp
|
||||||
|
|
||||||
|
logger = logging.getLogger("orrery_search")
|
||||||
|
|
||||||
|
_pool: asyncpg.Pool | None = None
|
||||||
|
_pool_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
MAX_ROWS = 100
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_pool() -> asyncpg.Pool:
|
||||||
|
global _pool
|
||||||
|
if _pool is not None:
|
||||||
|
return _pool
|
||||||
|
async with _pool_lock:
|
||||||
|
if _pool is not None:
|
||||||
|
return _pool
|
||||||
|
_pool = await asyncpg.create_pool(
|
||||||
|
settings.orrery_db_url,
|
||||||
|
min_size=2,
|
||||||
|
max_size=5,
|
||||||
|
command_timeout=settings.run_query_timeout,
|
||||||
|
)
|
||||||
|
return _pool
|
||||||
|
|
||||||
|
|
||||||
|
async def close_query_pool() -> None:
|
||||||
|
global _pool
|
||||||
|
if _pool is not None:
|
||||||
|
await _pool.close()
|
||||||
|
_pool = None
|
||||||
|
|
||||||
|
|
||||||
|
_FORBIDDEN_RE = re.compile(
|
||||||
|
r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|GRANT|REVOKE|COPY|EXECUTE)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_sql(sql: str) -> None:
|
||||||
|
"""Reject anything that isn't a SELECT (or WITH ... SELECT)."""
|
||||||
|
stripped = sql.strip().rstrip(";").strip()
|
||||||
|
if not stripped:
|
||||||
|
raise ValueError("Empty SQL statement")
|
||||||
|
|
||||||
|
upper = stripped.upper()
|
||||||
|
if not (upper.startswith("SELECT") or upper.startswith("WITH")):
|
||||||
|
raise ValueError("Only SELECT statements are allowed")
|
||||||
|
|
||||||
|
if _FORBIDDEN_RE.search(stripped):
|
||||||
|
raise ValueError("Statement contains forbidden keywords (INSERT/UPDATE/DELETE/DDL)")
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_row(row: asyncpg.Record) -> dict:
|
||||||
|
"""Convert asyncpg Record to a JSON-serializable dict."""
|
||||||
|
result = {}
|
||||||
|
for key, value in row.items():
|
||||||
|
if hasattr(value, "isoformat"):
|
||||||
|
result[key] = value.isoformat()
|
||||||
|
elif isinstance(value, (int, float, str, bool, type(None))):
|
||||||
|
result[key] = value
|
||||||
|
else:
|
||||||
|
result[key] = str(value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def run_query(sql: str) -> str:
|
||||||
|
"""Execute a read-only pg_orrery SQL query and return the results.
|
||||||
|
|
||||||
|
Runs arbitrary SELECT statements against a PostgreSQL database with
|
||||||
|
pg_orrery installed. Use this for live celestial computations:
|
||||||
|
planet positions, satellite passes, rise/set times, angular distances, etc.
|
||||||
|
|
||||||
|
Safety: Only SELECT is allowed. Statements are wrapped in a read-only
|
||||||
|
transaction with a 5-second timeout. Results capped at 100 rows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql: A SELECT statement using pg_orrery functions, e.g.
|
||||||
|
"SELECT planet_equatorial(5, NOW())" or
|
||||||
|
"SELECT * FROM planet_observe(4, '(40.7128,-74.0060,10)', NOW())"
|
||||||
|
"""
|
||||||
|
_validate_sql(sql)
|
||||||
|
|
||||||
|
pool = await _get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
async with conn.transaction(readonly=True):
|
||||||
|
await conn.execute(
|
||||||
|
f"SET LOCAL statement_timeout = '{int(settings.run_query_timeout * 1000)}'"
|
||||||
|
)
|
||||||
|
rows = await conn.fetch(sql)
|
||||||
|
|
||||||
|
if len(rows) > MAX_ROWS:
|
||||||
|
rows = rows[:MAX_ROWS]
|
||||||
|
|
||||||
|
results = [_serialize_row(r) for r in rows]
|
||||||
|
return json.dumps({
|
||||||
|
"row_count": len(results),
|
||||||
|
"columns": list(rows[0].keys()) if rows else [],
|
||||||
|
"rows": results,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def planet_position(
|
||||||
|
body_id: int,
|
||||||
|
latitude: float,
|
||||||
|
longitude: float,
|
||||||
|
altitude: float = 0.0,
|
||||||
|
time: str = "NOW()",
|
||||||
|
) -> str:
|
||||||
|
"""Get a planet's current position as seen from an observer location.
|
||||||
|
|
||||||
|
Returns azimuth, elevation, range, and range_rate in topocentric coordinates,
|
||||||
|
plus equatorial RA/Dec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body_id: Planet ID (0=Sun, 1=Mercury, 2=Venus, 4=Mars, 5=Jupiter,
|
||||||
|
6=Saturn, 7=Uranus, 8=Neptune, 10=Moon)
|
||||||
|
latitude: Observer latitude in degrees (-90 to 90)
|
||||||
|
longitude: Observer longitude in degrees (-180 to 180)
|
||||||
|
altitude: Observer altitude in meters above WGS-84 (default 0)
|
||||||
|
time: PostgreSQL timestamp expression (default "NOW()")
|
||||||
|
"""
|
||||||
|
obs = f"'({latitude},{longitude},{altitude})'::observer"
|
||||||
|
ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()"
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
|
SELECT
|
||||||
|
(t).azimuth_deg, (t).elevation_deg, (t).range_km, (t).range_rate_km_s,
|
||||||
|
(e).ra_hours, (e).dec_degrees, (e).distance_km
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
planet_observe({body_id}, {obs}, {ts}) AS t,
|
||||||
|
planet_equatorial({body_id}, {ts}) AS e
|
||||||
|
) sub
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await run_query(sql)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def sky_survey(
|
||||||
|
latitude: float,
|
||||||
|
longitude: float,
|
||||||
|
altitude: float = 0.0,
|
||||||
|
time: str = "NOW()",
|
||||||
|
) -> str:
|
||||||
|
"""Survey the entire sky: positions of all planets, Sun, and Moon from an observer.
|
||||||
|
|
||||||
|
Returns a table of all major solar system bodies with their topocentric
|
||||||
|
azimuth/elevation and equatorial RA/Dec coordinates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latitude: Observer latitude in degrees (-90 to 90)
|
||||||
|
longitude: Observer longitude in degrees (-180 to 180)
|
||||||
|
altitude: Observer altitude in meters above WGS-84 (default 0)
|
||||||
|
time: PostgreSQL timestamp expression (default "NOW()")
|
||||||
|
"""
|
||||||
|
obs = f"'({latitude},{longitude},{altitude})'::observer"
|
||||||
|
ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()"
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
|
SELECT
|
||||||
|
body_name,
|
||||||
|
(topo).azimuth_deg AS az,
|
||||||
|
(topo).elevation_deg AS el,
|
||||||
|
(topo).range_km AS range_km,
|
||||||
|
(eq).ra_hours AS ra,
|
||||||
|
(eq).dec_degrees AS dec
|
||||||
|
FROM (
|
||||||
|
SELECT 'Sun' AS body_name,
|
||||||
|
sun_observe({obs}, {ts}) AS topo,
|
||||||
|
sun_equatorial({ts}) AS eq
|
||||||
|
UNION ALL
|
||||||
|
SELECT 'Moon',
|
||||||
|
moon_observe({obs}, {ts}),
|
||||||
|
moon_equatorial({ts})
|
||||||
|
UNION ALL
|
||||||
|
SELECT unnest(ARRAY['Mercury','Venus','Mars','Jupiter','Saturn','Uranus','Neptune']),
|
||||||
|
planet_observe(unnest(ARRAY[1,2,4,5,6,7,8]), {obs}, {ts}),
|
||||||
|
planet_equatorial(unnest(ARRAY[1,2,4,5,6,7,8]), {ts})
|
||||||
|
) survey
|
||||||
|
ORDER BY el DESC
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await run_query(sql)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def satellite_pass(
|
||||||
|
tle_line1: str,
|
||||||
|
tle_line2: str,
|
||||||
|
latitude: float,
|
||||||
|
longitude: float,
|
||||||
|
altitude: float = 0.0,
|
||||||
|
time: str = "NOW()",
|
||||||
|
count: int = 5,
|
||||||
|
) -> str:
|
||||||
|
"""Predict upcoming satellite passes over an observer location.
|
||||||
|
|
||||||
|
Uses SGP4/SDP4 propagation with the provided Two-Line Element set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tle_line1: First line of the TLE (69 chars, starts with "1 ")
|
||||||
|
tle_line2: Second line of the TLE (69 chars, starts with "2 ")
|
||||||
|
latitude: Observer latitude in degrees
|
||||||
|
longitude: Observer longitude in degrees
|
||||||
|
altitude: Observer altitude in meters (default 0)
|
||||||
|
time: Start time for pass search (default "NOW()")
|
||||||
|
count: Number of passes to predict (default 5, max 20)
|
||||||
|
"""
|
||||||
|
count = max(1, min(count, 20))
|
||||||
|
obs = f"'({latitude},{longitude},{altitude})'::observer"
|
||||||
|
ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()"
|
||||||
|
|
||||||
|
# Escape TLE lines for SQL (they contain special chars)
|
||||||
|
safe_l1 = tle_line1.replace("'", "''")
|
||||||
|
safe_l2 = tle_line2.replace("'", "''")
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
|
SELECT
|
||||||
|
(p).aos_time, (p).los_time,
|
||||||
|
(p).max_elevation_deg,
|
||||||
|
(p).aos_azimuth_deg, (p).los_azimuth_deg
|
||||||
|
FROM predict_passes(
|
||||||
|
tle_from_lines('{safe_l1}', '{safe_l2}'),
|
||||||
|
{obs}, {ts}, {count}
|
||||||
|
) AS p
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await run_query(sql)
|
||||||
220
search/src/orrery_search/mcp/tools.py
Normal file
220
search/src/orrery_search/mcp/tools.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
|
||||||
|
from orrery_search.db import async_session
|
||||||
|
from orrery_search.mcp import mcp
|
||||||
|
from orrery_search.models.document import Document
|
||||||
|
from orrery_search.services.search import search_documents
|
||||||
|
|
||||||
|
logger = logging.getLogger("orrery_search")
|
||||||
|
|
||||||
|
ContentType = Literal[
|
||||||
|
"page", "guide", "reference", "architecture",
|
||||||
|
"workflow", "getting_started", "performance",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_docs(
|
||||||
|
query: str,
|
||||||
|
mode: Literal["hybrid", "semantic", "text"] = "hybrid",
|
||||||
|
content_type: ContentType | None = None,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> str:
|
||||||
|
"""Search the pg_orrery documentation.
|
||||||
|
|
||||||
|
Finds pages using semantic similarity, text matching, or both.
|
||||||
|
Returns ranked results with titles, snippets, URLs, and relevance scores.
|
||||||
|
Covers 225 SQL functions across satellites, planets, moons, stars, comets,
|
||||||
|
rise/set prediction, Lagrange points, and more.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Natural language search query (e.g. "rise set prediction observer")
|
||||||
|
mode: Search strategy — "hybrid" combines semantic + text (best recall),
|
||||||
|
"semantic" uses embedding similarity only, "text" uses keyword matching
|
||||||
|
content_type: Filter by document type, or None for all
|
||||||
|
limit: Max results to return (1-50, default 10)
|
||||||
|
"""
|
||||||
|
limit = max(1, min(limit, 50))
|
||||||
|
async with async_session() as db:
|
||||||
|
output = await search_documents(
|
||||||
|
q=query, db=db, mode=mode, content_type=content_type, limit=limit
|
||||||
|
)
|
||||||
|
return json.dumps({
|
||||||
|
"query": output.query,
|
||||||
|
"mode": output.mode,
|
||||||
|
"count": output.count,
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"title": r.title,
|
||||||
|
"slug": r.slug,
|
||||||
|
"section": r.section,
|
||||||
|
"content_type": r.content_type,
|
||||||
|
"snippet": r.snippet,
|
||||||
|
"url": r.url,
|
||||||
|
"score": r.score,
|
||||||
|
"source": r.source,
|
||||||
|
}
|
||||||
|
for r in output.results
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_document(slug: str) -> str:
|
||||||
|
"""Retrieve the full text of a documentation page by its slug.
|
||||||
|
|
||||||
|
Use this after search_docs to read the complete content of a result.
|
||||||
|
Returns title, body text, metadata, and URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slug: Document identifier (from search results, e.g. "guides/tracking-satellites")
|
||||||
|
"""
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Document).where(Document.slug == slug)
|
||||||
|
)
|
||||||
|
doc = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not doc:
|
||||||
|
raise ValueError(f"Document not found: {slug}")
|
||||||
|
|
||||||
|
return json.dumps({
|
||||||
|
"title": doc.title,
|
||||||
|
"slug": doc.slug,
|
||||||
|
"section": doc.section,
|
||||||
|
"content_type": doc.content_type,
|
||||||
|
"description": doc.description,
|
||||||
|
"body": doc.body,
|
||||||
|
"url": doc.url,
|
||||||
|
"word_count": doc.word_count,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def list_content(
|
||||||
|
content_type: ContentType | None = None,
|
||||||
|
section: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> str:
|
||||||
|
"""Browse the documentation contents by type or section.
|
||||||
|
|
||||||
|
Without filters, returns a summary of document counts grouped by section.
|
||||||
|
With filters, returns matching document titles and slugs for further retrieval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_type: Filter by document type, or None
|
||||||
|
section: Filter by section prefix (e.g. "guides", "reference")
|
||||||
|
limit: Max documents to return when filtering (1-500, default 100)
|
||||||
|
"""
|
||||||
|
async with async_session() as db:
|
||||||
|
if not content_type and not section:
|
||||||
|
result = await db.execute(
|
||||||
|
select(
|
||||||
|
Document.section,
|
||||||
|
Document.content_type,
|
||||||
|
func.count().label("count"),
|
||||||
|
)
|
||||||
|
.group_by(Document.section, Document.content_type)
|
||||||
|
.order_by(Document.section)
|
||||||
|
)
|
||||||
|
groups = [
|
||||||
|
{"section": row.section, "content_type": row.content_type, "count": row.count}
|
||||||
|
for row in result
|
||||||
|
]
|
||||||
|
return json.dumps({"summary": groups, "total": sum(g["count"] for g in groups)})
|
||||||
|
|
||||||
|
limit = max(1, min(limit, 500))
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
Document.title, Document.slug, Document.section,
|
||||||
|
Document.content_type, Document.description,
|
||||||
|
)
|
||||||
|
.order_by(Document.section, Document.title)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
if content_type:
|
||||||
|
stmt = stmt.where(Document.content_type == content_type)
|
||||||
|
if section:
|
||||||
|
stmt = stmt.where(Document.section.startswith(section))
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
docs = [
|
||||||
|
{
|
||||||
|
"title": row.title,
|
||||||
|
"slug": row.slug,
|
||||||
|
"section": row.section,
|
||||||
|
"content_type": row.content_type,
|
||||||
|
"description": row.description,
|
||||||
|
}
|
||||||
|
for row in result
|
||||||
|
]
|
||||||
|
return json.dumps({"documents": docs, "count": len(docs)})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.resource("orrery://stats")
|
||||||
|
async def orrery_stats() -> str:
|
||||||
|
"""Current document and embedding statistics for the pg_orrery documentation."""
|
||||||
|
async with async_session() as db:
|
||||||
|
doc_count = await db.scalar(select(func.count()).select_from(Document))
|
||||||
|
type_counts = await db.execute(
|
||||||
|
select(Document.content_type, func.count().label("count"))
|
||||||
|
.group_by(Document.content_type)
|
||||||
|
)
|
||||||
|
section_counts = await db.execute(
|
||||||
|
select(Document.section, func.count().label("count"))
|
||||||
|
.group_by(Document.section)
|
||||||
|
.order_by(func.count().desc())
|
||||||
|
)
|
||||||
|
|
||||||
|
return json.dumps({
|
||||||
|
"total_documents": doc_count,
|
||||||
|
"by_type": {row.content_type: row.count for row in type_counts},
|
||||||
|
"by_section": {row.section: row.count for row in section_counts},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.prompt()
|
||||||
|
def orrery_expert(topic: str = "general") -> str:
|
||||||
|
"""System prompt for exploring pg_orrery documentation and running live queries.
|
||||||
|
|
||||||
|
Configures an assistant persona tuned for celestial mechanics computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: Focus area — "general", "satellites" (SGP4/TLE tracking),
|
||||||
|
"planets" (VSOP87/DE ephemeris), "navigation" (rise/set/constellation),
|
||||||
|
or "transfers" (Lambert/Lagrange)
|
||||||
|
"""
|
||||||
|
base = (
|
||||||
|
"You are an expert assistant for pg_orrery, a PostgreSQL extension for "
|
||||||
|
"celestial mechanics. Use the search_docs and get_document tools to find "
|
||||||
|
"documentation, and the run_query tool to execute live SQL computations. "
|
||||||
|
"Always cite your sources and show the SQL queries you run.\n\n"
|
||||||
|
)
|
||||||
|
topic_context = {
|
||||||
|
"general": (
|
||||||
|
"Help with any pg_orrery topic — satellite tracking, planetary observation, "
|
||||||
|
"rise/set prediction, constellation identification, or Lagrange points."
|
||||||
|
),
|
||||||
|
"satellites": (
|
||||||
|
"Focus on satellite tracking: SGP4/SDP4 propagation, TLE parsing, "
|
||||||
|
"pass prediction, eclipse detection, and coordinate transforms."
|
||||||
|
),
|
||||||
|
"planets": (
|
||||||
|
"Focus on planetary ephemerides: VSOP87 and JPL DE441 providers, "
|
||||||
|
"planet/sun/moon observation, equatorial coordinates, and apparent positions."
|
||||||
|
),
|
||||||
|
"navigation": (
|
||||||
|
"Focus on observational astronomy: rise/set prediction, twilight computation, "
|
||||||
|
"constellation identification, lunar phase, planet magnitude, and refraction."
|
||||||
|
),
|
||||||
|
"transfers": (
|
||||||
|
"Focus on orbital mechanics: Lambert transfer solver, Lagrange equilibrium "
|
||||||
|
"points (CR3BP), Hill radius, and interplanetary trajectory design."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return base + topic_context.get(topic, topic_context["general"])
|
||||||
4
search/src/orrery_search/models/__init__.py
Normal file
4
search/src/orrery_search/models/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from orrery_search.models.base import Base
|
||||||
|
from orrery_search.models.document import Document
|
||||||
|
|
||||||
|
__all__ = ["Base", "Document"]
|
||||||
5
search/src/orrery_search/models/base.py
Normal file
5
search/src/orrery_search/models/base.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
27
search/src/orrery_search/models/document.py
Normal file
27
search/src/orrery_search/models/document.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pgai.sqlalchemy import vectorizer_relationship
|
||||||
|
from sqlalchemy import DateTime, Integer, String, Text, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from orrery_search.models.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Document(Base):
|
||||||
|
__tablename__ = "document"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
content_type: Mapped[str] = mapped_column(String(20), index=True)
|
||||||
|
slug: Mapped[str] = mapped_column(String(300), unique=True)
|
||||||
|
title: Mapped[str] = mapped_column(String(300))
|
||||||
|
section: Mapped[str] = mapped_column(String(200), index=True)
|
||||||
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
body: Mapped[str] = mapped_column(Text)
|
||||||
|
search_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
url: Mapped[str] = mapped_column(String(300))
|
||||||
|
word_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = vectorizer_relationship(dimensions=1024)
|
||||||
0
search/src/orrery_search/routers/__init__.py
Normal file
0
search/src/orrery_search/routers/__init__.py
Normal file
172
search/src/orrery_search/routers/chat.py
Normal file
172
search/src/orrery_search/routers/chat.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from orrery_search.mcp.chat import (
|
||||||
|
_build_context,
|
||||||
|
_chat_completion,
|
||||||
|
_chat_completion_stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("orrery_search")
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
question: str = Field(..., min_length=1, max_length=1000)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSource(BaseModel):
|
||||||
|
title: str
|
||||||
|
slug: str
|
||||||
|
url: str
|
||||||
|
score: float
|
||||||
|
|
||||||
|
@field_validator("url")
|
||||||
|
@classmethod
|
||||||
|
def url_must_be_relative_or_https(cls, v: str) -> str:
|
||||||
|
if v.startswith("/") or v.startswith("https://") or v.startswith("http://"):
|
||||||
|
return v
|
||||||
|
raise ValueError("URL must be relative or http(s)")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
answer: str
|
||||||
|
sources: list[ChatSource]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ChatResponse)
|
||||||
|
async def chat(req: ChatRequest):
|
||||||
|
"""Ask a question about pg_orrery and get an answer grounded in documentation."""
|
||||||
|
try:
|
||||||
|
context, sources = await _build_context(req.question)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Chat context build failed")
|
||||||
|
raise HTTPException(status_code=502, detail="Search service unavailable")
|
||||||
|
|
||||||
|
if not context:
|
||||||
|
return ChatResponse(
|
||||||
|
answer="I couldn't find any relevant documentation for that question. "
|
||||||
|
"Try asking about satellite tracking, planetary observation, "
|
||||||
|
"rise/set prediction, or other pg_orrery functions.",
|
||||||
|
sources=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
answer = await _chat_completion(context, req.question)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Chat completion failed")
|
||||||
|
raise HTTPException(status_code=502, detail="Chat completion unavailable")
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
answer=answer,
|
||||||
|
sources=[
|
||||||
|
ChatSource(title=s["title"], slug=s["slug"], url=s["url"], score=s["score"])
|
||||||
|
for s in sources
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Streaming endpoint ----
|
||||||
|
|
||||||
|
|
||||||
|
class PageContext(BaseModel):
|
||||||
|
title: str = Field("", max_length=200)
|
||||||
|
path: str = Field("", max_length=500)
|
||||||
|
description: str = Field("", max_length=500)
|
||||||
|
|
||||||
|
@field_validator("path")
|
||||||
|
@classmethod
|
||||||
|
def path_must_be_relative(cls, v: str) -> str:
|
||||||
|
if v and not v.startswith("/"):
|
||||||
|
raise ValueError("Path must start with /")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ChatStreamRequest(BaseModel):
|
||||||
|
question: str = Field(..., min_length=1, max_length=1000)
|
||||||
|
page: PageContext | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _sse_event(event: str, data: dict | list) -> str:
|
||||||
|
return f"event: {event}\ndata: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/stream")
|
||||||
|
async def chat_stream(req: ChatStreamRequest):
|
||||||
|
"""SSE streaming endpoint for the chat widget."""
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
question = req.question
|
||||||
|
if req.page and req.page.title:
|
||||||
|
page_context = f'[The user is currently reading: "{req.page.title}" ({req.page.path})'
|
||||||
|
if req.page.description:
|
||||||
|
page_context += f"\nDocument description: {req.page.description}"
|
||||||
|
page_context += "]\n\n"
|
||||||
|
question = page_context + req.question
|
||||||
|
|
||||||
|
yield _sse_event("status", {"text": "Searching the documentation\u2026"})
|
||||||
|
|
||||||
|
try:
|
||||||
|
context, sources = await _build_context(question)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Chat stream context build failed")
|
||||||
|
yield _sse_event("error", {"text": "Search service unavailable"})
|
||||||
|
return
|
||||||
|
|
||||||
|
if not context:
|
||||||
|
yield _sse_event("status", {"text": "No relevant documents found"})
|
||||||
|
yield _sse_event(
|
||||||
|
"token",
|
||||||
|
{
|
||||||
|
"text": "I couldn't find any relevant documentation "
|
||||||
|
"for that question. Try asking about satellite tracking, "
|
||||||
|
"planetary observation, rise/set prediction, or other pg_orrery functions."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _sse_event("done", {})
|
||||||
|
return
|
||||||
|
|
||||||
|
n = len(sources)
|
||||||
|
yield _sse_event(
|
||||||
|
"status",
|
||||||
|
{"text": f"Found {n} relevant page{'s' if n != 1 else ''}, generating answer\u2026"},
|
||||||
|
)
|
||||||
|
yield _sse_event("sources", sources)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for kind, text in _chat_completion_stream(context, question):
|
||||||
|
yield _sse_event(
|
||||||
|
"reasoning" if kind == "reasoning" else "token",
|
||||||
|
{"text": text},
|
||||||
|
)
|
||||||
|
except (
|
||||||
|
httpx.HTTPStatusError,
|
||||||
|
httpx.ConnectError,
|
||||||
|
httpx.ReadTimeout,
|
||||||
|
httpx.PoolTimeout,
|
||||||
|
httpx.ConnectTimeout,
|
||||||
|
) as exc:
|
||||||
|
logger.warning("Chat stream failed: %s", exc)
|
||||||
|
yield _sse_event("error", {"text": "Chat service unavailable"})
|
||||||
|
return
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug("Chat stream cancelled by client disconnect")
|
||||||
|
return
|
||||||
|
|
||||||
|
yield _sse_event("done", {})
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
46
search/src/orrery_search/routers/health.py
Normal file
46
search/src/orrery_search/routers/health.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy import func, select, text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from orrery_search.db import get_db
|
||||||
|
from orrery_search.models.document import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger("orrery_search")
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
async def health(db: AsyncSession = Depends(get_db)):
|
||||||
|
"""Health check with document and embedding counts."""
|
||||||
|
doc_count = await db.scalar(select(func.count(Document.id)))
|
||||||
|
|
||||||
|
embedding_count = 0
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
text("SELECT count(*) FROM document_embedding_store")
|
||||||
|
)
|
||||||
|
embedding_count = result.scalar() or 0
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Embedding store not yet available", exc_info=True)
|
||||||
|
|
||||||
|
# Check if pg_orrery extension is loaded
|
||||||
|
orrery_ok = False
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
text("SELECT extversion FROM pg_extension WHERE extname = 'pg_orrery'")
|
||||||
|
)
|
||||||
|
row = result.first()
|
||||||
|
orrery_ok = row is not None
|
||||||
|
except Exception:
|
||||||
|
logger.debug("pg_orrery extension check failed", exc_info=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"service": "orrery-search",
|
||||||
|
"documents": doc_count or 0,
|
||||||
|
"embeddings": embedding_count,
|
||||||
|
"pg_orrery": orrery_ok,
|
||||||
|
}
|
||||||
74
search/src/orrery_search/routers/search.py
Normal file
74
search/src/orrery_search/routers/search.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from orrery_search.db import get_db
|
||||||
|
from orrery_search.models.document import Document
|
||||||
|
from orrery_search.schemas.search import (
|
||||||
|
SearchResponse,
|
||||||
|
SearchResult,
|
||||||
|
SectionCount,
|
||||||
|
SectionsResponse,
|
||||||
|
)
|
||||||
|
from orrery_search.services.search import search_documents
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=SearchResponse)
|
||||||
|
async def search(
|
||||||
|
q: str = Query(..., min_length=1, max_length=500),
|
||||||
|
content_type: str | None = Query(None, description="Filter by content type"),
|
||||||
|
section: str | None = Query(None, max_length=200, description="Section prefix filter"),
|
||||||
|
limit: int = Query(10, ge=1, le=50),
|
||||||
|
mode: str = Query("hybrid", pattern="^(hybrid|semantic|text)$"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Search pg_orrery documentation using semantic and/or text matching."""
|
||||||
|
output = await search_documents(
|
||||||
|
q=q,
|
||||||
|
db=db,
|
||||||
|
mode=mode,
|
||||||
|
content_type=content_type,
|
||||||
|
section=section,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SearchResponse(
|
||||||
|
query=output.query,
|
||||||
|
results=[
|
||||||
|
SearchResult(
|
||||||
|
title=r.title,
|
||||||
|
slug=r.slug,
|
||||||
|
section=r.section,
|
||||||
|
content_type=r.content_type,
|
||||||
|
description=r.description,
|
||||||
|
snippet=r.snippet,
|
||||||
|
url=r.url,
|
||||||
|
score=r.score,
|
||||||
|
source=r.source,
|
||||||
|
)
|
||||||
|
for r in output.results
|
||||||
|
],
|
||||||
|
count=output.count,
|
||||||
|
mode=output.mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sections", response_model=SectionsResponse)
|
||||||
|
async def list_sections(
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""List all sections with document counts for the filter UI."""
|
||||||
|
stmt = (
|
||||||
|
select(Document.section, func.count(Document.id).label("count"))
|
||||||
|
.group_by(Document.section)
|
||||||
|
.order_by(Document.section)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
sections = [
|
||||||
|
SectionCount(section=row.section, count=row.count)
|
||||||
|
for row in result
|
||||||
|
if row.section
|
||||||
|
]
|
||||||
|
return SectionsResponse(sections=sections)
|
||||||
0
search/src/orrery_search/schemas/__init__.py
Normal file
0
search/src/orrery_search/schemas/__init__.py
Normal file
29
search/src/orrery_search/schemas/search.py
Normal file
29
search/src/orrery_search/schemas/search.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
title: str
|
||||||
|
slug: str
|
||||||
|
section: str
|
||||||
|
content_type: str
|
||||||
|
description: str | None = None
|
||||||
|
snippet: str
|
||||||
|
url: str
|
||||||
|
score: float
|
||||||
|
source: str # "semantic" | "text" | "both"
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResponse(BaseModel):
|
||||||
|
query: str
|
||||||
|
results: list[SearchResult]
|
||||||
|
count: int
|
||||||
|
mode: str # "hybrid" | "semantic" | "text"
|
||||||
|
|
||||||
|
|
||||||
|
class SectionCount(BaseModel):
|
||||||
|
section: str
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
class SectionsResponse(BaseModel):
|
||||||
|
sections: list[SectionCount]
|
||||||
0
search/src/orrery_search/services/__init__.py
Normal file
0
search/src/orrery_search/services/__init__.py
Normal file
55
search/src/orrery_search/services/embedding.py
Normal file
55
search/src/orrery_search/services/embedding.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""Embedding client for semantic search queries.
|
||||||
|
|
||||||
|
Singleton AsyncOpenAI client configured against the GPU embedding gateway.
|
||||||
|
Falls back to text-only search if the client is unavailable or requests fail.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from openai import (
|
||||||
|
APIConnectionError,
|
||||||
|
APITimeoutError,
|
||||||
|
AsyncOpenAI,
|
||||||
|
AuthenticationError,
|
||||||
|
RateLimitError,
|
||||||
|
)
|
||||||
|
|
||||||
|
from orrery_search.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger("orrery_search")
|
||||||
|
|
||||||
|
embedding_client: AsyncOpenAI | None = None
|
||||||
|
if settings.gpu_api_key:
|
||||||
|
embedding_client = AsyncOpenAI(
|
||||||
|
base_url=settings.gpu_base_url,
|
||||||
|
api_key=settings.gpu_api_key,
|
||||||
|
timeout=10.0,
|
||||||
|
max_retries=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embed_query(text: str) -> list[float] | None:
|
||||||
|
"""Embed a search query using mxbai-embed-large (1024 dims).
|
||||||
|
|
||||||
|
Returns None if the embedding client is not configured or the
|
||||||
|
request fails. Callers fall back to text-only search.
|
||||||
|
"""
|
||||||
|
if not embedding_client:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
resp = await embedding_client.embeddings.create(
|
||||||
|
model=settings.gpu_embed_model,
|
||||||
|
input=text,
|
||||||
|
)
|
||||||
|
return resp.data[0].embedding
|
||||||
|
except APITimeoutError:
|
||||||
|
logger.warning("Embedding timed out (10s) for: %s", text[:80])
|
||||||
|
except AuthenticationError:
|
||||||
|
logger.error("Embedding auth failed — check GPU_API_KEY")
|
||||||
|
except RateLimitError:
|
||||||
|
logger.warning("Embedding rate-limited for: %s", text[:80])
|
||||||
|
except APIConnectionError as exc:
|
||||||
|
logger.warning("Embedding connection failed: %s", exc)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Unexpected embedding failure for: %s", text[:80])
|
||||||
|
return None
|
||||||
242
search/src/orrery_search/services/search.py
Normal file
242
search/src/orrery_search/services/search.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
"""Hybrid semantic + text search service.
|
||||||
|
|
||||||
|
Extracts search logic into framework-agnostic functions returning
|
||||||
|
dataclasses. The router converts these to Pydantic models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from sqlalchemy import or_, select
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from orrery_search.models.document import Document
|
||||||
|
from orrery_search.services.embedding import embed_query
|
||||||
|
|
||||||
|
logger = logging.getLogger("orrery_search")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchResult:
|
||||||
|
title: str
|
||||||
|
slug: str
|
||||||
|
section: str
|
||||||
|
content_type: str
|
||||||
|
description: str | None
|
||||||
|
snippet: str
|
||||||
|
url: str
|
||||||
|
score: float
|
||||||
|
source: str # "semantic" | "text" | "both"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchOutput:
|
||||||
|
query: str
|
||||||
|
results: list[SearchResult] = field(default_factory=list)
|
||||||
|
count: int = 0
|
||||||
|
mode: str = "hybrid"
|
||||||
|
|
||||||
|
|
||||||
|
def _escape_ilike(value: str) -> str:
|
||||||
|
return re.sub(r"([%_\\])", r"\\\1", value)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_snippet(body: str, query: str, context_chars: int = 100) -> str:
|
||||||
|
idx = body.lower().find(query.lower())
|
||||||
|
if idx == -1:
|
||||||
|
return body[:200]
|
||||||
|
|
||||||
|
start = max(0, idx - context_chars)
|
||||||
|
end = min(len(body), idx + len(query) + context_chars)
|
||||||
|
snippet = body[start:end]
|
||||||
|
|
||||||
|
if start > 0:
|
||||||
|
snippet = "..." + snippet
|
||||||
|
if end < len(body):
|
||||||
|
snippet = snippet + "..."
|
||||||
|
|
||||||
|
return snippet
|
||||||
|
|
||||||
|
|
||||||
|
async def _semantic_search(
|
||||||
|
query_vec: list[float],
|
||||||
|
limit: int,
|
||||||
|
db: AsyncSession,
|
||||||
|
content_type: str | None,
|
||||||
|
section: str | None,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
results: list[SearchResult] = []
|
||||||
|
try:
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
Document,
|
||||||
|
Document.embeddings.embedding.cosine_distance(query_vec).label("distance"),
|
||||||
|
Document.embeddings.chunk,
|
||||||
|
)
|
||||||
|
.join(Document.embeddings)
|
||||||
|
.order_by("distance")
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
if content_type:
|
||||||
|
stmt = stmt.where(Document.content_type == content_type)
|
||||||
|
if section:
|
||||||
|
stmt = stmt.where(Document.section.startswith(section))
|
||||||
|
|
||||||
|
for page, dist, chunk in await db.execute(stmt):
|
||||||
|
snippet = chunk[:200] if chunk else page.body[:200]
|
||||||
|
results.append(
|
||||||
|
SearchResult(
|
||||||
|
title=page.title,
|
||||||
|
slug=page.slug,
|
||||||
|
section=page.section,
|
||||||
|
content_type=page.content_type,
|
||||||
|
description=page.description,
|
||||||
|
snippet=snippet,
|
||||||
|
url=page.url,
|
||||||
|
score=round(max(0.0, 1.0 - float(dist)), 4),
|
||||||
|
source="semantic",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except SQLAlchemyError:
|
||||||
|
logger.warning("Semantic search failed (database)", exc_info=True)
|
||||||
|
await db.rollback()
|
||||||
|
except Exception:
|
||||||
|
logger.error("Semantic search failed (unexpected)", exc_info=True)
|
||||||
|
await db.rollback()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def _text_search(
|
||||||
|
q: str,
|
||||||
|
limit: int,
|
||||||
|
db: AsyncSession,
|
||||||
|
content_type: str | None,
|
||||||
|
section: str | None,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
results: list[SearchResult] = []
|
||||||
|
if not q:
|
||||||
|
return results
|
||||||
|
|
||||||
|
safe_q = _escape_ilike(q)
|
||||||
|
try:
|
||||||
|
stmt = (
|
||||||
|
select(Document)
|
||||||
|
.where(
|
||||||
|
or_(
|
||||||
|
Document.title.ilike(f"%{safe_q}%"),
|
||||||
|
Document.body.ilike(f"%{safe_q}%"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
if content_type:
|
||||||
|
stmt = stmt.where(Document.content_type == content_type)
|
||||||
|
if section:
|
||||||
|
stmt = stmt.where(Document.section.startswith(section))
|
||||||
|
|
||||||
|
for page in (await db.execute(stmt)).scalars():
|
||||||
|
snippet = _extract_snippet(page.body, q)
|
||||||
|
results.append(
|
||||||
|
SearchResult(
|
||||||
|
title=page.title,
|
||||||
|
slug=page.slug,
|
||||||
|
section=page.section,
|
||||||
|
content_type=page.content_type,
|
||||||
|
description=page.description,
|
||||||
|
snippet=snippet,
|
||||||
|
url=page.url,
|
||||||
|
score=0.5,
|
||||||
|
source="text",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except SQLAlchemyError:
|
||||||
|
logger.warning("Text search failed (database)", exc_info=True)
|
||||||
|
await db.rollback()
|
||||||
|
except Exception:
|
||||||
|
logger.error("Text search failed (unexpected)", exc_info=True)
|
||||||
|
await db.rollback()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_results(
|
||||||
|
semantic: list[SearchResult],
|
||||||
|
text: list[SearchResult],
|
||||||
|
limit: int,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
seen: dict[str, SearchResult] = {}
|
||||||
|
|
||||||
|
for r in semantic:
|
||||||
|
if r.slug not in seen:
|
||||||
|
seen[r.slug] = r
|
||||||
|
|
||||||
|
for r in text:
|
||||||
|
if r.slug in seen:
|
||||||
|
boosted = seen[r.slug]
|
||||||
|
seen[r.slug] = SearchResult(
|
||||||
|
title=boosted.title,
|
||||||
|
slug=boosted.slug,
|
||||||
|
section=boosted.section,
|
||||||
|
content_type=boosted.content_type,
|
||||||
|
description=boosted.description,
|
||||||
|
snippet=boosted.snippet,
|
||||||
|
url=boosted.url,
|
||||||
|
score=min(1.0, boosted.score + 0.1),
|
||||||
|
source="both",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
seen[r.slug] = r
|
||||||
|
|
||||||
|
merged = sorted(seen.values(), key=lambda r: r.score, reverse=True)
|
||||||
|
return merged[:limit]
|
||||||
|
|
||||||
|
|
||||||
|
async def search_documents(
|
||||||
|
q: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
mode: str = "hybrid",
|
||||||
|
content_type: str | None = None,
|
||||||
|
section: str | None = None,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> SearchOutput:
|
||||||
|
"""Search pg_orrery docs using semantic and/or text matching."""
|
||||||
|
semantic_results: list[SearchResult] = []
|
||||||
|
text_results: list[SearchResult] = []
|
||||||
|
actual_mode = mode
|
||||||
|
|
||||||
|
if mode in ("hybrid", "semantic"):
|
||||||
|
query_vec = await embed_query(q)
|
||||||
|
if query_vec:
|
||||||
|
semantic_results = await _semantic_search(
|
||||||
|
query_vec, limit * 3, db, content_type, section
|
||||||
|
)
|
||||||
|
elif mode == "semantic":
|
||||||
|
actual_mode = "text"
|
||||||
|
|
||||||
|
if mode in ("hybrid", "text") or (mode == "semantic" and not semantic_results):
|
||||||
|
text_results = await _text_search(q, limit, db, content_type, section)
|
||||||
|
|
||||||
|
if semantic_results and text_results:
|
||||||
|
final = _merge_results(semantic_results, text_results, limit)
|
||||||
|
elif semantic_results:
|
||||||
|
deduped: dict[str, SearchResult] = {}
|
||||||
|
for r in semantic_results:
|
||||||
|
if r.slug not in deduped:
|
||||||
|
deduped[r.slug] = r
|
||||||
|
final = sorted(deduped.values(), key=lambda r: r.score, reverse=True)[:limit]
|
||||||
|
else:
|
||||||
|
final = text_results[:limit]
|
||||||
|
if mode != "text":
|
||||||
|
actual_mode = "text"
|
||||||
|
|
||||||
|
return SearchOutput(
|
||||||
|
query=q,
|
||||||
|
results=final,
|
||||||
|
count=len(final),
|
||||||
|
mode=actual_mode,
|
||||||
|
)
|
||||||
35
search/src/orrery_search/services/search_text.py
Normal file
35
search/src/orrery_search/services/search_text.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
def build_search_text(
|
||||||
|
title: str,
|
||||||
|
section: str,
|
||||||
|
content_type: str,
|
||||||
|
description: str | None,
|
||||||
|
body: str,
|
||||||
|
) -> str:
|
||||||
|
"""Build enriched search text for vectorizer indexing.
|
||||||
|
|
||||||
|
Concatenates title, section path, content type context, description,
|
||||||
|
and body (capped at 4000 chars) so the vectorizer gets a dense,
|
||||||
|
searchable representation.
|
||||||
|
"""
|
||||||
|
parts = [title]
|
||||||
|
|
||||||
|
if section:
|
||||||
|
parts.append(f"section: {section}")
|
||||||
|
|
||||||
|
type_context = {
|
||||||
|
"guide": "pg_orrery usage guide and tutorial",
|
||||||
|
"reference": "pg_orrery SQL function and type reference",
|
||||||
|
"architecture": "pg_orrery internal architecture and design",
|
||||||
|
"workflow": "workflow translation from other tools to pg_orrery SQL",
|
||||||
|
"getting_started": "pg_orrery installation and getting started",
|
||||||
|
"performance": "pg_orrery performance benchmarks",
|
||||||
|
}
|
||||||
|
ctx = type_context.get(content_type)
|
||||||
|
if ctx:
|
||||||
|
parts.append(ctx)
|
||||||
|
|
||||||
|
if description:
|
||||||
|
parts.append(description)
|
||||||
|
|
||||||
|
parts.append(body[:4000])
|
||||||
|
return " ".join(parts)
|
||||||
Loading…
x
Reference in New Issue
Block a user