From 317f74b33b7e42a7929400681e05748b7cbd8d53 Mon Sep 17 00:00:00 2001 From: Ryan Malloy Date: Sun, 1 Mar 2026 15:42:14 -0700 Subject: [PATCH] Add search backend: FastAPI + FastMCP + pgvector for docs Q&A and live SQL Search stack replicates the Hamilton site pattern with pg_orrery-specific additions: - FastAPI REST API (chat SSE streaming, semantic search, health check) - FastMCP server at /mcp with doc search and live SQL query tools - pgvector + pgai vectorizer for 1024-dim document embeddings - Hybrid search (semantic cosine + text ILIKE with pg_trgm GIN) - Dual LLM backend: self-hosted qwen3 via GPU gateway or Anthropic Claude - Live read-only pg_orrery SQL execution with safety guardrails (SELECT-only validation, read-only transaction, 5s timeout, 100-row cap) - Convenience MCP tools: planet_position, sky_survey, satellite_pass - MDX content ingestion from docs/src/content/docs/ (50 pages) - Docker Compose: pg_orrery+pgvector DB, pgai, vectorizer-worker, API - Alembic async migrations, Makefile, .env.example --- search/.env.example | 24 ++ search/Dockerfile | 29 ++ search/Dockerfile.db | 8 + search/Makefile | 57 ++++ search/alembic.ini | 36 ++ search/alembic/env.py | 72 ++++ search/alembic/script.py.mako | 23 ++ search/alembic/versions/001_baseline.py | 83 +++++ search/docker-compose.yml | 134 ++++++++ search/docker/030_install_extensions.sh | 8 + search/pyproject.toml | 50 +++ search/src/orrery_search/__init__.py | 0 search/src/orrery_search/config.py | 36 ++ search/src/orrery_search/db.py | 17 + search/src/orrery_search/ingest/__init__.py | 0 search/src/orrery_search/ingest/__main__.py | 6 + search/src/orrery_search/ingest/mdx_parser.py | 54 +++ search/src/orrery_search/ingest/runner.py | 183 +++++++++++ search/src/orrery_search/main.py | 49 +++ search/src/orrery_search/mcp/__init__.py | 17 + search/src/orrery_search/mcp/chat.py | 307 ++++++++++++++++++ search/src/orrery_search/mcp/query.py | 247 ++++++++++++++ search/src/orrery_search/mcp/tools.py | 220 +++++++++++++ search/src/orrery_search/models/__init__.py | 4 + search/src/orrery_search/models/base.py | 5 + search/src/orrery_search/models/document.py | 27 ++ search/src/orrery_search/routers/__init__.py | 0 search/src/orrery_search/routers/chat.py | 172 ++++++++++ search/src/orrery_search/routers/health.py | 46 +++ search/src/orrery_search/routers/search.py | 74 +++++ search/src/orrery_search/schemas/__init__.py | 0 search/src/orrery_search/schemas/search.py | 29 ++ search/src/orrery_search/services/__init__.py | 0 .../src/orrery_search/services/embedding.py | 55 ++++ search/src/orrery_search/services/search.py | 242 ++++++++++++++ .../src/orrery_search/services/search_text.py | 35 ++ 36 files changed, 2349 insertions(+) create mode 100644 search/.env.example create mode 100644 search/Dockerfile create mode 100644 search/Dockerfile.db create mode 100644 search/Makefile create mode 100644 search/alembic.ini create mode 100644 search/alembic/env.py create mode 100644 search/alembic/script.py.mako create mode 100644 search/alembic/versions/001_baseline.py create mode 100644 search/docker-compose.yml create mode 100755 search/docker/030_install_extensions.sh create mode 100644 search/pyproject.toml create mode 100644 search/src/orrery_search/__init__.py create mode 100644 search/src/orrery_search/config.py create mode 100644 search/src/orrery_search/db.py create mode 100644 search/src/orrery_search/ingest/__init__.py create mode 100644 search/src/orrery_search/ingest/__main__.py create mode 100644 search/src/orrery_search/ingest/mdx_parser.py create mode 100644 search/src/orrery_search/ingest/runner.py create mode 100644 search/src/orrery_search/main.py create mode 100644 search/src/orrery_search/mcp/__init__.py create mode 100644 search/src/orrery_search/mcp/chat.py create mode 100644 search/src/orrery_search/mcp/query.py create mode 100644 search/src/orrery_search/mcp/tools.py create mode 100644 search/src/orrery_search/models/__init__.py create mode 100644 search/src/orrery_search/models/base.py create mode 100644 search/src/orrery_search/models/document.py create mode 100644 search/src/orrery_search/routers/__init__.py create mode 100644 search/src/orrery_search/routers/chat.py create mode 100644 search/src/orrery_search/routers/health.py create mode 100644 search/src/orrery_search/routers/search.py create mode 100644 search/src/orrery_search/schemas/__init__.py create mode 100644 search/src/orrery_search/schemas/search.py create mode 100644 search/src/orrery_search/services/__init__.py create mode 100644 search/src/orrery_search/services/embedding.py create mode 100644 search/src/orrery_search/services/search.py create mode 100644 search/src/orrery_search/services/search_text.py diff --git a/search/.env.example b/search/.env.example new file mode 100644 index 0000000..de51663 --- /dev/null +++ b/search/.env.example @@ -0,0 +1,24 @@ +# pg_orrery Search — Environment Configuration +COMPOSE_PROJECT_NAME=pg-orrery-search + +# Domain for caddy-docker-proxy labels +DOMAIN=pg-orrery.warehack.ing + +# PostgreSQL +POSTGRES_USER=orrery +POSTGRES_PASSWORD=changeme + +# Database URLs (must match POSTGRES_USER/PASSWORD above) +DATABASE_URL=postgresql+asyncpg://orrery:changeme@db:5432/orrery_search +ORRERY_DB_URL=postgresql://orrery:changeme@db:5432/orrery_search + +# GPU Gateway (embedding + chat) +GPU_API_KEY=sk-gpu-lb-master-key-2026 +GPU_BASE_URL=https://orrery-search.gpu.supported.systems/v1 + +# LLM Provider: "gpu" (self-hosted qwen3) or "anthropic" (Claude) +LLM_PROVIDER=gpu + +# Anthropic (only needed if LLM_PROVIDER=anthropic) +ANTHROPIC_API_KEY= +ANTHROPIC_MODEL=claude-sonnet-4-20250514 diff --git a/search/Dockerfile b/search/Dockerfile new file mode 100644 index 0000000..0644f03 --- /dev/null +++ b/search/Dockerfile @@ -0,0 +1,29 @@ +# syntax=docker/dockerfile:1 +FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim AS base +WORKDIR /app + +ENV UV_COMPILE_BYTECODE=1 +ENV UV_LINK_MODE=copy + +# Install dependencies first (cache layer) +COPY pyproject.toml ./ +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r pyproject.toml + +# Copy source +COPY . . + +# --- Development (editable install for hot reload) --- +FROM base AS dev +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system --no-deps -e . +CMD ["uvicorn", "orrery_search.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload", "--proxy-headers", "--forwarded-allow-ips", "*"] + +# --- Production (non-editable, non-root) --- +FROM base AS prod +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system --no-deps . +RUN adduser --disabled-password --gecos "" orrery && \ + mkdir -p /data && chown -R orrery:orrery /data +USER orrery +CMD ["uvicorn", "orrery_search.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2", "--proxy-headers", "--forwarded-allow-ips", "*"] diff --git a/search/Dockerfile.db b/search/Dockerfile.db new file mode 100644 index 0000000..468bb7d --- /dev/null +++ b/search/Dockerfile.db @@ -0,0 +1,8 @@ +# pg_orrery + pgvector: celestial mechanics + vector search in one database +FROM git.supported.systems/warehack.ing/pg_orrery:pg17 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + postgresql-17-pgvector \ + && rm -rf /var/lib/apt/lists/* + +COPY docker/030_install_extensions.sh /docker-entrypoint-initdb.d/ diff --git a/search/Makefile b/search/Makefile new file mode 100644 index 0000000..93f3140 --- /dev/null +++ b/search/Makefile @@ -0,0 +1,57 @@ +.PHONY: dev prod build logs down clean restart migrate ingest shell status + +.DEFAULT_GOAL := dev + +# Development mode with hot reload +dev: + @echo "Starting pg_orrery Search in development mode..." + docker compose --profile dev up -d --build + @sleep 3 + docker compose logs -f + +# Production mode +prod: + @echo "Starting pg_orrery Search in production mode..." + docker compose --profile prod up -d --build + @sleep 3 + docker compose logs -f + +# Build without starting +build: + docker compose build + +# View logs +logs: + docker compose logs -f + +# Stop containers +down: + docker compose down + +# Clean up containers, images, volumes +clean: + docker compose down -v --rmi local + +# Restart +restart: down dev + +# Run database migrations +migrate: + docker compose exec api-dev alembic upgrade head + +# Run content ingestion +ingest: + docker compose exec api-dev python -m orrery_search.ingest + +# Shell into running API container +shell: + docker compose exec api-dev bash + +# Check vectorizer status +status: + docker compose exec db psql -U orrery -d orrery_search -c "SELECT * FROM ai.vectorizer_status;" + +# Test search +test-search: + @echo "Testing search API..." + curl -s "https://$$(grep DOMAIN .env | cut -d= -f2)/api/search?q=satellite+tracking" | python3 -m json.tool diff --git a/search/alembic.ini b/search/alembic.ini new file mode 100644 index 0000000..ddf55eb --- /dev/null +++ b/search/alembic.ini @@ -0,0 +1,36 @@ +[alembic] +script_location = alembic +sqlalchemy.url = postgresql+asyncpg://orrery:orrery@localhost:5432/orrery_search + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/search/alembic/env.py b/search/alembic/env.py new file mode 100644 index 0000000..cae28f8 --- /dev/null +++ b/search/alembic/env.py @@ -0,0 +1,72 @@ +import asyncio +import os +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import async_engine_from_config + +from orrery_search.models import Base + +config = context.config +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +db_url = os.environ.get("DATABASE_URL") +if db_url: + config.set_main_option("sqlalchemy.url", db_url) + +target_metadata = Base.metadata + +PGAI_MANAGED_TABLES = { + "document_embedding_store", +} + + +def include_object(object, name, type_, reflected, compare_to): + if type_ == "table" and name in PGAI_MANAGED_TABLES: + return False + return True + + +def run_migrations_offline(): + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + include_object=include_object, + ) + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection): + context.configure( + connection=connection, + target_metadata=target_metadata, + include_object=include_object, + ) + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations(): + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + await connectable.dispose() + + +def run_migrations_online(): + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/search/alembic/script.py.mako b/search/alembic/script.py.mako new file mode 100644 index 0000000..f72b90c --- /dev/null +++ b/search/alembic/script.py.mako @@ -0,0 +1,23 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} +""" + +import sqlalchemy as sa +from alembic import op +${imports if imports else ""} + +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/search/alembic/versions/001_baseline.py b/search/alembic/versions/001_baseline.py new file mode 100644 index 0000000..92e16a3 --- /dev/null +++ b/search/alembic/versions/001_baseline.py @@ -0,0 +1,83 @@ +"""document table with pgai vectorizer + +Revision ID: 001_baseline +Revises: None +Create Date: 2026-03-01 +""" + +import sqlalchemy as sa +from alembic import op + +revision = "001_baseline" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "document", + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("content_type", sa.String(20), nullable=False, index=True), + sa.Column("slug", sa.String(300), nullable=False, unique=True), + sa.Column("title", sa.String(300), nullable=False), + sa.Column("section", sa.String(200), nullable=False, index=True), + sa.Column("description", sa.Text, nullable=True), + sa.Column("body", sa.Text, nullable=False), + sa.Column("search_text", sa.Text, nullable=True), + sa.Column("url", sa.String(300), nullable=False), + sa.Column("word_count", sa.Integer, nullable=False, server_default="0"), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + ) + + # Enable pg_trgm for fast ILIKE with GIN indexes + op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm") + + op.execute( + "CREATE INDEX ix_document_title_trgm ON document USING gin (title gin_trgm_ops)" + ) + op.execute( + "CREATE INDEX ix_document_body_trgm ON document USING gin (body gin_trgm_ops)" + ) + + # pgai vectorizer — reads search_text, generates 1024-dim embeddings + # Uses mxbai-embed-large via the GPU embedding gateway. + op.execute(""" + SELECT ai.create_vectorizer( + 'document'::regclass, + name => 'document_embedder', + loading => ai.loading_column(column_name => 'search_text'), + embedding => ai.embedding_openai( + model => 'mxbai-embed-large', + dimensions => 1024 + ), + chunking => ai.chunking_recursive_character_text_splitter( + chunk_size => 400, + chunk_overlap => 50, + separators => array[E'\\n\\n', E'\\n', '. ', ' '] + ), + formatting => ai.formatting_python_template( + template => '$chunk' + ) + ) + """) + + +def downgrade(): + op.execute(""" + DO $$ BEGIN + PERFORM ai.drop_vectorizer('document_embedder', drop_all => true); + EXCEPTION WHEN OTHERS THEN + RAISE NOTICE 'document_embedder not found, skipping drop'; + END $$ + """) + + op.execute("DROP INDEX IF EXISTS ix_document_title_trgm") + op.execute("DROP INDEX IF EXISTS ix_document_body_trgm") + + op.drop_table("document") diff --git a/search/docker-compose.yml b/search/docker-compose.yml new file mode 100644 index 0000000..062278f --- /dev/null +++ b/search/docker-compose.yml @@ -0,0 +1,134 @@ +services: + db: + build: + context: . + dockerfile: Dockerfile.db + environment: + POSTGRES_DB: orrery_search + POSTGRES_USER: ${POSTGRES_USER:-orrery} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} + volumes: + - pg-data:/var/lib/postgresql/data + networks: + - internal + healthcheck: + test: ["CMD-SHELL", "pg_isready -U orrery -d orrery_search"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 10s + restart: unless-stopped + + pgai-install: + image: timescale/pgai-vectorizer-worker:latest + entrypoint: ["python", "-m", "pgai", "install", "-d", "postgresql://${POSTGRES_USER:-orrery}:${POSTGRES_PASSWORD}@db:5432/orrery_search"] + networks: + - internal + depends_on: + db: + condition: service_healthy + restart: "no" + + vectorizer-worker: + image: timescale/pgai-vectorizer-worker:latest + environment: + PGAI_VECTORIZER_WORKER_DB_URL: postgresql://${POSTGRES_USER:-orrery}:${POSTGRES_PASSWORD}@db:5432/orrery_search + OPENAI_BASE_URL: ${GPU_BASE_URL:-https://orrery-search.gpu.supported.systems/v1} + OPENAI_API_KEY: ${GPU_API_KEY} + command: ["--poll-interval", "5s"] + networks: + - internal + depends_on: + db: + condition: service_healthy + pgai-install: + condition: service_completed_successfully + restart: unless-stopped + + api-dev: + build: + context: . + dockerfile: Dockerfile + target: dev + profiles: ["dev"] + env_file: .env + volumes: + - ./src:/app/src + - ./alembic:/app/alembic + - ../docs/src/content/docs:/data/content:ro + networks: + - internal + - caddy + depends_on: + db: + condition: service_healthy + pgai-install: + condition: service_completed_successfully + labels: + caddy: ${DOMAIN:-pg-orrery.warehack.ing} + caddy.handle: /api/search* + caddy.handle.0_reverse_proxy: "{{upstreams 8000}}" + caddy.handle_1: /health + caddy.handle_1.0_reverse_proxy: "{{upstreams 8000}}" + caddy.handle_2: /mcp* + caddy.handle_2.0_reverse_proxy: "{{upstreams 8000}}" + caddy.handle_3: /api/chat* + caddy.handle_3.0_reverse_proxy: "{{upstreams 8000}}" + caddy.handle_3.0_reverse_proxy.flush_interval: "-1" + caddy.handle_3.0_reverse_proxy.transport: "http" + caddy.handle_3.0_reverse_proxy.transport.read_timeout: "0" + caddy.handle_3.0_reverse_proxy.transport.write_timeout: "0" + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health')"] + interval: 30s + timeout: 5s + retries: 3 + start_period: 15s + restart: unless-stopped + + api-prod: + build: + context: . + dockerfile: Dockerfile + target: prod + profiles: ["prod"] + env_file: .env + volumes: + - ../docs/src/content/docs:/data/content:ro + networks: + - internal + - caddy + depends_on: + db: + condition: service_healthy + pgai-install: + condition: service_completed_successfully + labels: + caddy: ${DOMAIN:-pg-orrery.warehack.ing} + caddy.handle: /api/search* + caddy.handle.0_reverse_proxy: "{{upstreams 8000}}" + caddy.handle_1: /health + caddy.handle_1.0_reverse_proxy: "{{upstreams 8000}}" + caddy.handle_2: /mcp* + caddy.handle_2.0_reverse_proxy: "{{upstreams 8000}}" + caddy.handle_3: /api/chat* + caddy.handle_3.0_reverse_proxy: "{{upstreams 8000}}" + caddy.handle_3.0_reverse_proxy.flush_interval: "-1" + caddy.handle_3.0_reverse_proxy.transport: "http" + caddy.handle_3.0_reverse_proxy.transport.read_timeout: "0" + caddy.handle_3.0_reverse_proxy.transport.write_timeout: "0" + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health')"] + interval: 30s + timeout: 5s + retries: 3 + start_period: 15s + restart: unless-stopped + +volumes: + pg-data: + +networks: + internal: + caddy: + external: true diff --git a/search/docker/030_install_extensions.sh b/search/docker/030_install_extensions.sh new file mode 100755 index 0000000..8f09a3a --- /dev/null +++ b/search/docker/030_install_extensions.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -e + +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL + CREATE EXTENSION IF NOT EXISTS vector; + CREATE EXTENSION IF NOT EXISTS pg_trgm; + CREATE EXTENSION IF NOT EXISTS pg_orrery; +EOSQL diff --git a/search/pyproject.toml b/search/pyproject.toml new file mode 100644 index 0000000..f4a85ca --- /dev/null +++ b/search/pyproject.toml @@ -0,0 +1,50 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "orrery-search" +version = "2026.03.01" +description = "Semantic search and chat API for the pg_orrery documentation" +license = "MIT" +requires-python = ">=3.12" +authors = [{name = "Ryan Malloy", email = "ryan@supported.systems"}] +dependencies = [ + "fastapi>=0.115.0", + "uvicorn[standard]>=0.32.0", + "sqlalchemy[asyncio]>=2.0.36", + "asyncpg>=0.30.0", + "pgvector>=0.4.2", + "pgai[sqlalchemy]>=0.12.0", + "tiktoken>=0.7.0", + "alembic>=1.14.0", + "pydantic-settings>=2.6.0", + "openai>=1.60.0", + "anthropic>=0.40.0", + "pyyaml>=6.0", + "fastmcp>=3.0.0", + "httpx>=0.28.0", +] + +[project.scripts] +orrery-search = "orrery_search.main:run" + +[tool.ruff] +target-version = "py312" +src = ["src"] + +[tool.ruff.lint] +select = ["E", "F", "I", "UP", "B", "SIM"] +ignore = ["B008"] + +[tool.hatch.build.targets.wheel] +packages = ["src/orrery_search"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" + +[dependency-groups] +dev = [ + "pytest>=9.0.2", + "pytest-asyncio>=1.3.0", +] diff --git a/search/src/orrery_search/__init__.py b/search/src/orrery_search/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/search/src/orrery_search/config.py b/search/src/orrery_search/config.py new file mode 100644 index 0000000..03cfcab --- /dev/null +++ b/search/src/orrery_search/config.py @@ -0,0 +1,36 @@ +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + api_host: str = "0.0.0.0" + api_port: int = 8000 + api_log_level: str = "info" + + database_url: str = "postgresql+asyncpg://orrery:orrery@localhost:5432/orrery_search" + + # Raw asyncpg URL for direct pg_orrery SQL execution (no SQLAlchemy) + orrery_db_url: str = "postgresql://orrery:orrery@localhost:5432/orrery_search" + + gpu_api_key: str = "" + gpu_base_url: str = "https://orrery-search.gpu.supported.systems/v1" + gpu_embed_model: str = "mxbai-embed-large" + gpu_embed_dimensions: int = 1024 + + search_max_results: int = 50 + + # LLM provider: "gpu" for self-hosted (qwen3), "anthropic" for Claude + llm_provider: str = "gpu" + + gpu_chat_model: str = "qwen3" + chat_timeout: float = 30.0 + chat_max_tokens: int = 8192 + + anthropic_api_key: str = "" + anthropic_model: str = "claude-sonnet-4-20250514" + + run_query_timeout: float = 5.0 + + model_config = {"env_prefix": "", "env_file": ".env", "extra": "ignore"} + + +settings = Settings() diff --git a/search/src/orrery_search/db.py b/search/src/orrery_search/db.py new file mode 100644 index 0000000..f8e52d6 --- /dev/null +++ b/search/src/orrery_search/db.py @@ -0,0 +1,17 @@ +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from orrery_search.config import settings + +engine = create_async_engine( + settings.database_url, + echo=False, + pool_pre_ping=True, + pool_size=10, + max_overflow=20, +) +async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + +async def get_db() -> AsyncSession: + async with async_session() as session: + yield session diff --git a/search/src/orrery_search/ingest/__init__.py b/search/src/orrery_search/ingest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/search/src/orrery_search/ingest/__main__.py b/search/src/orrery_search/ingest/__main__.py new file mode 100644 index 0000000..7dc2916 --- /dev/null +++ b/search/src/orrery_search/ingest/__main__.py @@ -0,0 +1,6 @@ +import asyncio + +from orrery_search.ingest.runner import ingest + +if __name__ == "__main__": + asyncio.run(ingest()) diff --git a/search/src/orrery_search/ingest/mdx_parser.py b/search/src/orrery_search/ingest/mdx_parser.py new file mode 100644 index 0000000..2a4a0e9 --- /dev/null +++ b/search/src/orrery_search/ingest/mdx_parser.py @@ -0,0 +1,54 @@ +"""MDX/Markdown frontmatter extraction and content cleaning. + +Strips JSX components, import statements, Starlight admonitions, +and other MDX-specific syntax to produce clean plain text for indexing. +""" + +import contextlib +import html +import re + +import yaml + +_FRONTMATTER_RE = re.compile(r"^---\s*\n(.*?)\n---\s*\n", re.DOTALL) +_IMPORT_RE = re.compile(r"^import\s+.*$", re.MULTILINE) +_MERMAID_BLOCK_RE = re.compile(r"```mermaid\n.*?```", re.DOTALL) +_JSX_SELF_CLOSING_RE = re.compile(r"<\w+[^>]*/\s*>") +_JSX_BLOCK_RE = re.compile( + r"<(Aside|CardGrid|Steps|Tabs|TabItem|Card|LinkCard)[^>]*>(.*?)", + re.DOTALL, +) +_JSX_OPEN_CLOSE_RE = re.compile( + r"]*>" +) +_ADMONITION_FENCE_RE = re.compile(r"^:::.*$", re.MULTILINE) +_MULTI_BLANK_RE = re.compile(r"\n{3,}") + + +def strip_mdx(raw: str) -> tuple[dict, str]: + """Extract frontmatter and clean MDX content to plain text. + + Returns (frontmatter_dict, cleaned_body). + """ + frontmatter: dict = {} + body = raw + + fm_match = _FRONTMATTER_RE.match(raw) + if fm_match: + with contextlib.suppress(yaml.YAMLError): + frontmatter = yaml.safe_load(fm_match.group(1)) or {} + body = raw[fm_match.end():] + + body = _IMPORT_RE.sub("", body) + body = _MERMAID_BLOCK_RE.sub("", body) + + while _JSX_BLOCK_RE.search(body): + body = _JSX_BLOCK_RE.sub(r"\2", body) + + body = _JSX_SELF_CLOSING_RE.sub("", body) + body = _JSX_OPEN_CLOSE_RE.sub("", body) + body = _ADMONITION_FENCE_RE.sub("", body) + body = html.unescape(body) + body = _MULTI_BLANK_RE.sub("\n\n", body).strip() + + return frontmatter, body diff --git a/search/src/orrery_search/ingest/runner.py b/search/src/orrery_search/ingest/runner.py new file mode 100644 index 0000000..f13bd88 --- /dev/null +++ b/search/src/orrery_search/ingest/runner.py @@ -0,0 +1,183 @@ +"""Ingestion orchestrator — walks content directories, classifies, and upserts. + +Run: docker compose exec api-dev python -m orrery_search.ingest +""" + +import asyncio +import sys +from pathlib import Path + +from sqlalchemy import select + +from orrery_search.db import async_session +from orrery_search.ingest.mdx_parser import strip_mdx +from orrery_search.models.document import Document +from orrery_search.services.search_text import build_search_text + +CONTENT_DIR = Path("/data/content") + + +def _resolve_paths() -> Path: + """Return content_dir, preferring Docker mount.""" + if CONTENT_DIR.exists(): + return CONTENT_DIR + + here = Path(__file__).resolve() + project_root = here + for _ in range(10): + project_root = project_root.parent + if (project_root / "docs" / "src").exists(): + break + return project_root / "docs" / "src" / "content" / "docs" + + +def _classify_content_type(rel_path: str) -> str: + """Classify content type from the relative path within content/docs/.""" + parts = rel_path.split("/") + + if parts[0] == "getting-started": + return "getting_started" + if parts[0] == "guides": + return "guide" + if parts[0] == "workflow": + return "workflow" + if parts[0] == "reference": + return "reference" + if parts[0] == "architecture": + return "architecture" + if parts[0] == "performance": + return "performance" + + return "page" + + +def _mdx_path_to_url(rel_path: str) -> str: + """Convert relative .mdx path to Starlight page URL.""" + slug = rel_path.removesuffix(".mdx").removesuffix("/index") + return f"/{slug}/" + + +def _mdx_path_to_section(rel_path: str) -> str: + """Extract section from relative path.""" + parts = Path(rel_path).parts + if len(parts) > 1: + return "/".join(parts[:-1]) + return "" + + +def _mdx_path_to_slug(rel_path: str) -> str: + """Convert to a unique slug for dedup.""" + return rel_path.removesuffix(".mdx").removesuffix("/index") + + +def _collect_mdx_pages(content_dir: Path) -> list[dict]: + """Walk the content directory and parse all .mdx files.""" + pages = [] + for mdx_path in sorted(content_dir.rglob("*.mdx")): + try: + raw = mdx_path.read_text(encoding="utf-8") + except (OSError, UnicodeDecodeError) as exc: + print(f" SKIP {mdx_path}: {exc}", file=sys.stderr) + continue + + frontmatter, body = strip_mdx(raw) + + rel_path = str(mdx_path.relative_to(content_dir)) + title = frontmatter.get("title", mdx_path.stem.replace("-", " ").title()) + description = frontmatter.get("description") + content_type = _classify_content_type(rel_path) + word_count = len(body.split()) + + pages.append({ + "content_type": content_type, + "slug": _mdx_path_to_slug(rel_path), + "title": title, + "section": _mdx_path_to_section(rel_path), + "description": description, + "body": body, + "url": _mdx_path_to_url(rel_path), + "word_count": word_count, + }) + + return pages + + +async def ingest(): + """Main ingestion: read docs content, upsert into document table.""" + content_dir = _resolve_paths() + print(f"Content dir: {content_dir}", file=sys.stderr) + + if not content_dir.exists(): + print(f"Content directory not found: {content_dir}", file=sys.stderr) + sys.exit(1) + + pages = _collect_mdx_pages(content_dir) + print(f"Found {len(pages)} published pages", file=sys.stderr) + + async with async_session() as db: + inserted = 0 + updated = 0 + errors = 0 + + for i, page_data in enumerate(pages): + try: + search_text = build_search_text( + title=page_data["title"], + section=page_data["section"], + content_type=page_data["content_type"], + description=page_data["description"], + body=page_data["body"], + ) + + async with db.begin_nested(): + stmt = select(Document).where( + Document.slug == page_data["slug"] + ) + result = await db.execute(stmt) + existing = result.scalar_one_or_none() + + if existing: + existing.title = page_data["title"] + existing.section = page_data["section"] + existing.description = page_data["description"] + existing.body = page_data["body"] + existing.search_text = search_text + existing.url = page_data["url"] + existing.content_type = page_data["content_type"] + existing.word_count = page_data["word_count"] + updated += 1 + else: + db.add(Document( + content_type=page_data["content_type"], + slug=page_data["slug"], + title=page_data["title"], + section=page_data["section"], + description=page_data["description"], + body=page_data["body"], + search_text=search_text, + url=page_data["url"], + word_count=page_data["word_count"], + )) + inserted += 1 + + if (i + 1) % 50 == 0: + await db.commit() + print( + f" progress: {i + 1}/{len(pages)}", + file=sys.stderr, + ) + + except Exception as exc: + print( + f" ERROR on {page_data['slug']}: {exc}", + file=sys.stderr, + ) + errors += 1 + + await db.commit() + + print( + f"Ingestion complete: {inserted} inserted, {updated} updated, " + f"{errors} errors ({inserted + updated} total)", + file=sys.stderr, + ) diff --git a/search/src/orrery_search/main.py b/search/src/orrery_search/main.py new file mode 100644 index 0000000..4c3a3cd --- /dev/null +++ b/search/src/orrery_search/main.py @@ -0,0 +1,49 @@ +import logging +from contextlib import asynccontextmanager + +import uvicorn +from fastapi import FastAPI +from fastmcp.utilities.lifespan import combine_lifespans + +from orrery_search.config import settings +from orrery_search.db import engine +from orrery_search.mcp import mcp +from orrery_search.mcp.chat import close_chat_client +from orrery_search.mcp.query import close_query_pool +from orrery_search.routers import chat, health, search + +logging.basicConfig(level=logging.WARNING, format="%(name)s: %(message)s") +logging.getLogger("orrery_search").setLevel(logging.INFO) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + yield + await close_chat_client() + await close_query_pool() + await engine.dispose() + + +mcp_app = mcp.http_app(path="/", stateless_http=True) + +app = FastAPI( + title="pg_orrery Search", + description="Semantic search and chat for the pg_orrery documentation", + version="2026.03.01", + lifespan=combine_lifespans(lifespan, mcp_app.lifespan), +) + +app.include_router(search.router, prefix="/api/search", tags=["search"]) +app.include_router(chat.router, prefix="/api/chat", tags=["chat"]) +app.include_router(health.router, tags=["health"]) +app.mount("/mcp", mcp_app) + + +def run(): + uvicorn.run( + "orrery_search.main:app", + host=settings.api_host, + port=settings.api_port, + log_level=settings.api_log_level, + reload=True, + ) diff --git a/search/src/orrery_search/mcp/__init__.py b/search/src/orrery_search/mcp/__init__.py new file mode 100644 index 0000000..646743d --- /dev/null +++ b/search/src/orrery_search/mcp/__init__.py @@ -0,0 +1,17 @@ +from fastmcp import FastMCP + +mcp = FastMCP( + "pg_orrery", + instructions=( + "MCP server for the pg_orrery documentation and live celestial mechanics. " + "Search 50+ pages of docs covering 225 SQL functions for satellite tracking, " + "planetary ephemerides, rise/set prediction, Lagrange points, and more. " + "Use ask_orrery for natural-language Q&A with source citations. " + "Use run_query to execute live pg_orrery SQL against a PostgreSQL database." + ), +) + +# Register tools, resources, and prompts +import orrery_search.mcp.chat # noqa: E402, F401 +import orrery_search.mcp.query # noqa: E402, F401 +import orrery_search.mcp.tools # noqa: E402, F401 diff --git a/search/src/orrery_search/mcp/chat.py b/search/src/orrery_search/mcp/chat.py new file mode 100644 index 0000000..185b579 --- /dev/null +++ b/search/src/orrery_search/mcp/chat.py @@ -0,0 +1,307 @@ +import asyncio +import collections.abc +import json +import logging + +import httpx +from sqlalchemy import select + +from orrery_search.config import settings +from orrery_search.db import async_session +from orrery_search.mcp import mcp +from orrery_search.models.document import Document +from orrery_search.services.search import search_documents + +logger = logging.getLogger("orrery_search") + +SYSTEM_PROMPT = ( + "/no_think\n" + "You are a knowledgeable assistant for pg_orrery, a PostgreSQL extension providing " + "celestial mechanics types and functions. Answer questions using the provided context " + "from the documentation.\n\n" + "Key domain knowledge:\n" + "- Body IDs: 0=Sun, 1=Mercury, 2=Venus, 3=Earth, 4=Mars, 5=Jupiter, 6=Saturn, " + "7=Uranus, 8=Neptune, 10=Moon\n" + "- Observer format: '(lat_deg, lon_deg, alt_m)' e.g. '(40.7128, -74.0060, 10)'\n" + "- Key functions: planet_observe(body_id, observer, timestamp), " + "planet_equatorial(body_id, timestamp), sun_observe(), moon_observe(), " + "predict_passes(tle, observer, timestamp, count), satellite_is_eclipsed()\n" + "- Coordinate systems: topocentric (az/el/range), equatorial (RA/Dec), " + "heliocentric (x,y,z AU), geodetic (lat/lon/alt)\n" + "- All functions are PARALLEL SAFE. VSOP87 functions are IMMUTABLE.\n\n" + "If the context doesn't contain enough information, say so clearly. " + "Cite specific documents by title when referencing information. " + "Be precise and factual — never fabricate claims." +) + +MAX_CONTEXT_CHARS = 2_000 + +_chat_client: httpx.AsyncClient | None = None +_client_lock = asyncio.Lock() + + +async def _get_chat_client() -> httpx.AsyncClient: + global _chat_client + if _chat_client is not None and not _chat_client.is_closed: + return _chat_client + async with _client_lock: + if _chat_client is not None and not _chat_client.is_closed: + return _chat_client + _chat_client = httpx.AsyncClient( + timeout=httpx.Timeout( + connect=10.0, + read=120.0, + write=10.0, + pool=10.0, + ), + limits=httpx.Limits( + max_connections=20, + max_keepalive_connections=10, + ), + ) + return _chat_client + + +async def close_chat_client() -> None: + global _chat_client + if _chat_client is not None and not _chat_client.is_closed: + await _chat_client.aclose() + _chat_client = None + + +async def _build_context(query: str) -> tuple[str, list[dict]]: + """Search docs and batch-fetch full document bodies for RAG context.""" + async with async_session() as db: + output = await search_documents(q=query, db=db, mode="hybrid", limit=5) + + if not output.results: + return "", [] + + slugs = [r.slug for r in output.results] + score_by_slug = {r.slug: r.score for r in output.results} + + docs_result = await db.execute( + select(Document).where(Document.slug.in_(slugs)) + ) + docs_by_slug = {doc.slug: doc for doc in docs_result.scalars()} + + sources = [] + context_parts = [] + chars_used = 0 + + for slug in slugs: + doc = docs_by_slug.get(slug) + if not doc: + continue + + remaining = MAX_CONTEXT_CHARS - chars_used + if remaining <= 0: + break + + body = doc.body + if len(body) > remaining: + body = body[:remaining] + "..." + + context_parts.append( + f"--- {doc.title} (/{doc.slug}) ---\n{body}" + ) + sources.append({ + "title": doc.title, + "slug": doc.slug, + "url": doc.url, + "score": score_by_slug.get(slug, 0.0), + }) + chars_used += len(body) + + context_text = "\n\n".join(context_parts) + return context_text, sources + + +async def _chat_completion_stream_gpu( + context: str, question: str +) -> collections.abc.AsyncIterator[tuple[str, str]]: + """Stream chat completion tokens from the GPU gateway (OpenAI-compatible).""" + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + { + "role": "user", + "content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}", + }, + ] + + client = await _get_chat_client() + async with client.stream( + "POST", + f"{settings.gpu_base_url}/chat/completions", + headers={ + "Authorization": f"Bearer {settings.gpu_api_key}", + "Content-Type": "application/json", + }, + json={ + "model": settings.gpu_chat_model, + "messages": messages, + "temperature": 0.3, + "max_tokens": settings.chat_max_tokens, + "num_ctx": settings.chat_max_tokens, + "stream": True, + }, + ) as resp: + if resp.status_code >= 400: + body = await resp.aread() + error_text = body[:500].decode("utf-8", errors="replace") + logger.error("GPU gateway returned %d: %s", resp.status_code, error_text) + raise httpx.HTTPStatusError( + f"GPU gateway error {resp.status_code}", + request=resp.request, response=resp, + ) + reasoning_count = 0 + content_count = 0 + try: + async for line in resp.aiter_lines(): + if not line.startswith("data: "): + continue + payload = line[6:] + if payload.strip() == "[DONE]": + return + try: + chunk = json.loads(payload) + delta = chunk["choices"][0]["delta"] + content = delta.get("content") + if content: + content_count += 1 + yield ("content", content) + elif delta.get("reasoning_content"): + reasoning_count += 1 + yield ("reasoning", delta["reasoning_content"]) + except (json.JSONDecodeError, KeyError, IndexError): + continue + except GeneratorExit: + logger.debug("Chat stream cancelled after %d content chunks", content_count) + await resp.aclose() + return + + +async def _chat_completion_stream_anthropic( + context: str, question: str +) -> collections.abc.AsyncIterator[tuple[str, str]]: + """Stream chat completion tokens via the Anthropic API.""" + import anthropic + + client = anthropic.AsyncAnthropic(api_key=settings.anthropic_api_key) + try: + async with client.messages.stream( + model=settings.anthropic_model, + max_tokens=settings.chat_max_tokens, + system=SYSTEM_PROMPT, + messages=[ + { + "role": "user", + "content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}", + }, + ], + ) as stream: + async for text in stream.text_stream: + yield ("content", text) + except anthropic.APIError as exc: + logger.error("Anthropic API error: %s", exc) + raise + + +async def _chat_completion_stream( + context: str, question: str +) -> collections.abc.AsyncIterator[tuple[str, str]]: + """Dispatch to the configured LLM provider.""" + if settings.llm_provider == "anthropic" and settings.anthropic_api_key: + async for item in _chat_completion_stream_anthropic(context, question): + yield item + else: + async for item in _chat_completion_stream_gpu(context, question): + yield item + + +async def _chat_completion(context: str, question: str) -> str: + """Non-streaming chat completion for MCP tool use.""" + if settings.llm_provider == "anthropic" and settings.anthropic_api_key: + import anthropic + + client = anthropic.AsyncAnthropic(api_key=settings.anthropic_api_key) + message = await client.messages.create( + model=settings.anthropic_model, + max_tokens=settings.chat_max_tokens, + system=SYSTEM_PROMPT, + messages=[ + { + "role": "user", + "content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}", + }, + ], + ) + return message.content[0].text + + client = await _get_chat_client() + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + { + "role": "user", + "content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}", + }, + ] + resp = await client.post( + f"{settings.gpu_base_url}/chat/completions", + headers={ + "Authorization": f"Bearer {settings.gpu_api_key}", + "Content-Type": "application/json", + }, + json={ + "model": settings.gpu_chat_model, + "messages": messages, + "temperature": 0.3, + "max_tokens": settings.chat_max_tokens, + "num_ctx": settings.chat_max_tokens, + }, + ) + resp.raise_for_status() + data = resp.json() + return data["choices"][0]["message"]["content"] + + +@mcp.tool() +async def ask_orrery( + question: str, + include_sources: bool = True, +) -> str: + """Ask a question about pg_orrery and get an answer grounded in documentation. + + Uses retrieval-augmented generation: searches the docs for relevant pages, + then synthesizes an answer citing specific sources. Best for conceptual questions + about celestial mechanics functions, types, coordinate systems, and workflows. + + Args: + question: Natural language question (e.g. "How do I predict satellite passes?") + include_sources: Whether to include source document references in the response + """ + try: + context, sources = await _build_context(question) + except Exception: + logger.exception("Failed to build RAG context") + raise RuntimeError("Search failed while building context") + + if not context: + return json.dumps({ + "answer": "No relevant documents found in the pg_orrery docs for this question.", + "sources": [], + }) + + try: + answer = await _chat_completion(context, question) + except httpx.HTTPStatusError as exc: + logger.error("Chat completion failed: HTTP %s", exc.response.status_code) + raise RuntimeError(f"Chat model returned HTTP {exc.response.status_code}") + except httpx.TimeoutException: + logger.warning("Chat completion timed out (limit: %.0fs)", settings.chat_timeout) + raise RuntimeError("Chat model timed out — try a simpler question") + + result: dict = {"answer": answer} + if include_sources: + result["sources"] = sources + return json.dumps(result) diff --git a/search/src/orrery_search/mcp/query.py b/search/src/orrery_search/mcp/query.py new file mode 100644 index 0000000..18851bf --- /dev/null +++ b/search/src/orrery_search/mcp/query.py @@ -0,0 +1,247 @@ +"""Live pg_orrery SQL query execution via MCP tools. + +Provides read-only access to pg_orrery functions running inside the same +PostgreSQL instance that stores document embeddings. Safety guardrails +ensure only SELECT statements execute, with statement timeouts and row limits. +""" + +import asyncio +import json +import logging +import re + +import asyncpg + +from orrery_search.config import settings +from orrery_search.mcp import mcp + +logger = logging.getLogger("orrery_search") + +_pool: asyncpg.Pool | None = None +_pool_lock = asyncio.Lock() + +MAX_ROWS = 100 + + +async def _get_pool() -> asyncpg.Pool: + global _pool + if _pool is not None: + return _pool + async with _pool_lock: + if _pool is not None: + return _pool + _pool = await asyncpg.create_pool( + settings.orrery_db_url, + min_size=2, + max_size=5, + command_timeout=settings.run_query_timeout, + ) + return _pool + + +async def close_query_pool() -> None: + global _pool + if _pool is not None: + await _pool.close() + _pool = None + + +_FORBIDDEN_RE = re.compile( + r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|GRANT|REVOKE|COPY|EXECUTE)\b", + re.IGNORECASE, +) + + +def _validate_sql(sql: str) -> None: + """Reject anything that isn't a SELECT (or WITH ... SELECT).""" + stripped = sql.strip().rstrip(";").strip() + if not stripped: + raise ValueError("Empty SQL statement") + + upper = stripped.upper() + if not (upper.startswith("SELECT") or upper.startswith("WITH")): + raise ValueError("Only SELECT statements are allowed") + + if _FORBIDDEN_RE.search(stripped): + raise ValueError("Statement contains forbidden keywords (INSERT/UPDATE/DELETE/DDL)") + + +def _serialize_row(row: asyncpg.Record) -> dict: + """Convert asyncpg Record to a JSON-serializable dict.""" + result = {} + for key, value in row.items(): + if hasattr(value, "isoformat"): + result[key] = value.isoformat() + elif isinstance(value, (int, float, str, bool, type(None))): + result[key] = value + else: + result[key] = str(value) + return result + + +@mcp.tool() +async def run_query(sql: str) -> str: + """Execute a read-only pg_orrery SQL query and return the results. + + Runs arbitrary SELECT statements against a PostgreSQL database with + pg_orrery installed. Use this for live celestial computations: + planet positions, satellite passes, rise/set times, angular distances, etc. + + Safety: Only SELECT is allowed. Statements are wrapped in a read-only + transaction with a 5-second timeout. Results capped at 100 rows. + + Args: + sql: A SELECT statement using pg_orrery functions, e.g. + "SELECT planet_equatorial(5, NOW())" or + "SELECT * FROM planet_observe(4, '(40.7128,-74.0060,10)', NOW())" + """ + _validate_sql(sql) + + pool = await _get_pool() + async with pool.acquire() as conn: + async with conn.transaction(readonly=True): + await conn.execute( + f"SET LOCAL statement_timeout = '{int(settings.run_query_timeout * 1000)}'" + ) + rows = await conn.fetch(sql) + + if len(rows) > MAX_ROWS: + rows = rows[:MAX_ROWS] + + results = [_serialize_row(r) for r in rows] + return json.dumps({ + "row_count": len(results), + "columns": list(rows[0].keys()) if rows else [], + "rows": results, + }) + + +@mcp.tool() +async def planet_position( + body_id: int, + latitude: float, + longitude: float, + altitude: float = 0.0, + time: str = "NOW()", +) -> str: + """Get a planet's current position as seen from an observer location. + + Returns azimuth, elevation, range, and range_rate in topocentric coordinates, + plus equatorial RA/Dec. + + Args: + body_id: Planet ID (0=Sun, 1=Mercury, 2=Venus, 4=Mars, 5=Jupiter, + 6=Saturn, 7=Uranus, 8=Neptune, 10=Moon) + latitude: Observer latitude in degrees (-90 to 90) + longitude: Observer longitude in degrees (-180 to 180) + altitude: Observer altitude in meters above WGS-84 (default 0) + time: PostgreSQL timestamp expression (default "NOW()") + """ + obs = f"'({latitude},{longitude},{altitude})'::observer" + ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()" + + sql = f""" + SELECT + (t).azimuth_deg, (t).elevation_deg, (t).range_km, (t).range_rate_km_s, + (e).ra_hours, (e).dec_degrees, (e).distance_km + FROM ( + SELECT + planet_observe({body_id}, {obs}, {ts}) AS t, + planet_equatorial({body_id}, {ts}) AS e + ) sub + """ + + return await run_query(sql) + + +@mcp.tool() +async def sky_survey( + latitude: float, + longitude: float, + altitude: float = 0.0, + time: str = "NOW()", +) -> str: + """Survey the entire sky: positions of all planets, Sun, and Moon from an observer. + + Returns a table of all major solar system bodies with their topocentric + azimuth/elevation and equatorial RA/Dec coordinates. + + Args: + latitude: Observer latitude in degrees (-90 to 90) + longitude: Observer longitude in degrees (-180 to 180) + altitude: Observer altitude in meters above WGS-84 (default 0) + time: PostgreSQL timestamp expression (default "NOW()") + """ + obs = f"'({latitude},{longitude},{altitude})'::observer" + ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()" + + sql = f""" + SELECT + body_name, + (topo).azimuth_deg AS az, + (topo).elevation_deg AS el, + (topo).range_km AS range_km, + (eq).ra_hours AS ra, + (eq).dec_degrees AS dec + FROM ( + SELECT 'Sun' AS body_name, + sun_observe({obs}, {ts}) AS topo, + sun_equatorial({ts}) AS eq + UNION ALL + SELECT 'Moon', + moon_observe({obs}, {ts}), + moon_equatorial({ts}) + UNION ALL + SELECT unnest(ARRAY['Mercury','Venus','Mars','Jupiter','Saturn','Uranus','Neptune']), + planet_observe(unnest(ARRAY[1,2,4,5,6,7,8]), {obs}, {ts}), + planet_equatorial(unnest(ARRAY[1,2,4,5,6,7,8]), {ts}) + ) survey + ORDER BY el DESC + """ + + return await run_query(sql) + + +@mcp.tool() +async def satellite_pass( + tle_line1: str, + tle_line2: str, + latitude: float, + longitude: float, + altitude: float = 0.0, + time: str = "NOW()", + count: int = 5, +) -> str: + """Predict upcoming satellite passes over an observer location. + + Uses SGP4/SDP4 propagation with the provided Two-Line Element set. + + Args: + tle_line1: First line of the TLE (69 chars, starts with "1 ") + tle_line2: Second line of the TLE (69 chars, starts with "2 ") + latitude: Observer latitude in degrees + longitude: Observer longitude in degrees + altitude: Observer altitude in meters (default 0) + time: Start time for pass search (default "NOW()") + count: Number of passes to predict (default 5, max 20) + """ + count = max(1, min(count, 20)) + obs = f"'({latitude},{longitude},{altitude})'::observer" + ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()" + + # Escape TLE lines for SQL (they contain special chars) + safe_l1 = tle_line1.replace("'", "''") + safe_l2 = tle_line2.replace("'", "''") + + sql = f""" + SELECT + (p).aos_time, (p).los_time, + (p).max_elevation_deg, + (p).aos_azimuth_deg, (p).los_azimuth_deg + FROM predict_passes( + tle_from_lines('{safe_l1}', '{safe_l2}'), + {obs}, {ts}, {count} + ) AS p + """ + + return await run_query(sql) diff --git a/search/src/orrery_search/mcp/tools.py b/search/src/orrery_search/mcp/tools.py new file mode 100644 index 0000000..e277555 --- /dev/null +++ b/search/src/orrery_search/mcp/tools.py @@ -0,0 +1,220 @@ +import json +import logging +from typing import Literal + +from sqlalchemy import func, select + +from orrery_search.db import async_session +from orrery_search.mcp import mcp +from orrery_search.models.document import Document +from orrery_search.services.search import search_documents + +logger = logging.getLogger("orrery_search") + +ContentType = Literal[ + "page", "guide", "reference", "architecture", + "workflow", "getting_started", "performance", +] + + +@mcp.tool() +async def search_docs( + query: str, + mode: Literal["hybrid", "semantic", "text"] = "hybrid", + content_type: ContentType | None = None, + limit: int = 10, +) -> str: + """Search the pg_orrery documentation. + + Finds pages using semantic similarity, text matching, or both. + Returns ranked results with titles, snippets, URLs, and relevance scores. + Covers 225 SQL functions across satellites, planets, moons, stars, comets, + rise/set prediction, Lagrange points, and more. + + Args: + query: Natural language search query (e.g. "rise set prediction observer") + mode: Search strategy — "hybrid" combines semantic + text (best recall), + "semantic" uses embedding similarity only, "text" uses keyword matching + content_type: Filter by document type, or None for all + limit: Max results to return (1-50, default 10) + """ + limit = max(1, min(limit, 50)) + async with async_session() as db: + output = await search_documents( + q=query, db=db, mode=mode, content_type=content_type, limit=limit + ) + return json.dumps({ + "query": output.query, + "mode": output.mode, + "count": output.count, + "results": [ + { + "title": r.title, + "slug": r.slug, + "section": r.section, + "content_type": r.content_type, + "snippet": r.snippet, + "url": r.url, + "score": r.score, + "source": r.source, + } + for r in output.results + ], + }) + + +@mcp.tool() +async def get_document(slug: str) -> str: + """Retrieve the full text of a documentation page by its slug. + + Use this after search_docs to read the complete content of a result. + Returns title, body text, metadata, and URL. + + Args: + slug: Document identifier (from search results, e.g. "guides/tracking-satellites") + """ + async with async_session() as db: + result = await db.execute( + select(Document).where(Document.slug == slug) + ) + doc = result.scalar_one_or_none() + + if not doc: + raise ValueError(f"Document not found: {slug}") + + return json.dumps({ + "title": doc.title, + "slug": doc.slug, + "section": doc.section, + "content_type": doc.content_type, + "description": doc.description, + "body": doc.body, + "url": doc.url, + "word_count": doc.word_count, + }) + + +@mcp.tool() +async def list_content( + content_type: ContentType | None = None, + section: str | None = None, + limit: int = 100, +) -> str: + """Browse the documentation contents by type or section. + + Without filters, returns a summary of document counts grouped by section. + With filters, returns matching document titles and slugs for further retrieval. + + Args: + content_type: Filter by document type, or None + section: Filter by section prefix (e.g. "guides", "reference") + limit: Max documents to return when filtering (1-500, default 100) + """ + async with async_session() as db: + if not content_type and not section: + result = await db.execute( + select( + Document.section, + Document.content_type, + func.count().label("count"), + ) + .group_by(Document.section, Document.content_type) + .order_by(Document.section) + ) + groups = [ + {"section": row.section, "content_type": row.content_type, "count": row.count} + for row in result + ] + return json.dumps({"summary": groups, "total": sum(g["count"] for g in groups)}) + + limit = max(1, min(limit, 500)) + stmt = ( + select( + Document.title, Document.slug, Document.section, + Document.content_type, Document.description, + ) + .order_by(Document.section, Document.title) + .limit(limit) + ) + + if content_type: + stmt = stmt.where(Document.content_type == content_type) + if section: + stmt = stmt.where(Document.section.startswith(section)) + + result = await db.execute(stmt) + docs = [ + { + "title": row.title, + "slug": row.slug, + "section": row.section, + "content_type": row.content_type, + "description": row.description, + } + for row in result + ] + return json.dumps({"documents": docs, "count": len(docs)}) + + +@mcp.resource("orrery://stats") +async def orrery_stats() -> str: + """Current document and embedding statistics for the pg_orrery documentation.""" + async with async_session() as db: + doc_count = await db.scalar(select(func.count()).select_from(Document)) + type_counts = await db.execute( + select(Document.content_type, func.count().label("count")) + .group_by(Document.content_type) + ) + section_counts = await db.execute( + select(Document.section, func.count().label("count")) + .group_by(Document.section) + .order_by(func.count().desc()) + ) + + return json.dumps({ + "total_documents": doc_count, + "by_type": {row.content_type: row.count for row in type_counts}, + "by_section": {row.section: row.count for row in section_counts}, + }) + + +@mcp.prompt() +def orrery_expert(topic: str = "general") -> str: + """System prompt for exploring pg_orrery documentation and running live queries. + + Configures an assistant persona tuned for celestial mechanics computation. + + Args: + topic: Focus area — "general", "satellites" (SGP4/TLE tracking), + "planets" (VSOP87/DE ephemeris), "navigation" (rise/set/constellation), + or "transfers" (Lambert/Lagrange) + """ + base = ( + "You are an expert assistant for pg_orrery, a PostgreSQL extension for " + "celestial mechanics. Use the search_docs and get_document tools to find " + "documentation, and the run_query tool to execute live SQL computations. " + "Always cite your sources and show the SQL queries you run.\n\n" + ) + topic_context = { + "general": ( + "Help with any pg_orrery topic — satellite tracking, planetary observation, " + "rise/set prediction, constellation identification, or Lagrange points." + ), + "satellites": ( + "Focus on satellite tracking: SGP4/SDP4 propagation, TLE parsing, " + "pass prediction, eclipse detection, and coordinate transforms." + ), + "planets": ( + "Focus on planetary ephemerides: VSOP87 and JPL DE441 providers, " + "planet/sun/moon observation, equatorial coordinates, and apparent positions." + ), + "navigation": ( + "Focus on observational astronomy: rise/set prediction, twilight computation, " + "constellation identification, lunar phase, planet magnitude, and refraction." + ), + "transfers": ( + "Focus on orbital mechanics: Lambert transfer solver, Lagrange equilibrium " + "points (CR3BP), Hill radius, and interplanetary trajectory design." + ), + } + return base + topic_context.get(topic, topic_context["general"]) diff --git a/search/src/orrery_search/models/__init__.py b/search/src/orrery_search/models/__init__.py new file mode 100644 index 0000000..aba7f0a --- /dev/null +++ b/search/src/orrery_search/models/__init__.py @@ -0,0 +1,4 @@ +from orrery_search.models.base import Base +from orrery_search.models.document import Document + +__all__ = ["Base", "Document"] diff --git a/search/src/orrery_search/models/base.py b/search/src/orrery_search/models/base.py new file mode 100644 index 0000000..fa2b68a --- /dev/null +++ b/search/src/orrery_search/models/base.py @@ -0,0 +1,5 @@ +from sqlalchemy.orm import DeclarativeBase + + +class Base(DeclarativeBase): + pass diff --git a/search/src/orrery_search/models/document.py b/search/src/orrery_search/models/document.py new file mode 100644 index 0000000..3386b33 --- /dev/null +++ b/search/src/orrery_search/models/document.py @@ -0,0 +1,27 @@ +from datetime import datetime + +from pgai.sqlalchemy import vectorizer_relationship +from sqlalchemy import DateTime, Integer, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column + +from orrery_search.models.base import Base + + +class Document(Base): + __tablename__ = "document" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + content_type: Mapped[str] = mapped_column(String(20), index=True) + slug: Mapped[str] = mapped_column(String(300), unique=True) + title: Mapped[str] = mapped_column(String(300)) + section: Mapped[str] = mapped_column(String(200), index=True) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + body: Mapped[str] = mapped_column(Text) + search_text: Mapped[str | None] = mapped_column(Text, nullable=True) + url: Mapped[str] = mapped_column(String(300)) + word_count: Mapped[int] = mapped_column(Integer, default=0) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + embeddings = vectorizer_relationship(dimensions=1024) diff --git a/search/src/orrery_search/routers/__init__.py b/search/src/orrery_search/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/search/src/orrery_search/routers/chat.py b/search/src/orrery_search/routers/chat.py new file mode 100644 index 0000000..fe11eae --- /dev/null +++ b/search/src/orrery_search/routers/chat.py @@ -0,0 +1,172 @@ +import asyncio +import json +import logging + +import httpx +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field, field_validator + +from orrery_search.mcp.chat import ( + _build_context, + _chat_completion, + _chat_completion_stream, +) + +logger = logging.getLogger("orrery_search") + +router = APIRouter() + + +class ChatRequest(BaseModel): + question: str = Field(..., min_length=1, max_length=1000) + + +class ChatSource(BaseModel): + title: str + slug: str + url: str + score: float + + @field_validator("url") + @classmethod + def url_must_be_relative_or_https(cls, v: str) -> str: + if v.startswith("/") or v.startswith("https://") or v.startswith("http://"): + return v + raise ValueError("URL must be relative or http(s)") + + +class ChatResponse(BaseModel): + answer: str + sources: list[ChatSource] + + +@router.post("", response_model=ChatResponse) +async def chat(req: ChatRequest): + """Ask a question about pg_orrery and get an answer grounded in documentation.""" + try: + context, sources = await _build_context(req.question) + except Exception: + logger.exception("Chat context build failed") + raise HTTPException(status_code=502, detail="Search service unavailable") + + if not context: + return ChatResponse( + answer="I couldn't find any relevant documentation for that question. " + "Try asking about satellite tracking, planetary observation, " + "rise/set prediction, or other pg_orrery functions.", + sources=[], + ) + + try: + answer = await _chat_completion(context, req.question) + except Exception: + logger.exception("Chat completion failed") + raise HTTPException(status_code=502, detail="Chat completion unavailable") + + return ChatResponse( + answer=answer, + sources=[ + ChatSource(title=s["title"], slug=s["slug"], url=s["url"], score=s["score"]) + for s in sources + ], + ) + + +# ---- Streaming endpoint ---- + + +class PageContext(BaseModel): + title: str = Field("", max_length=200) + path: str = Field("", max_length=500) + description: str = Field("", max_length=500) + + @field_validator("path") + @classmethod + def path_must_be_relative(cls, v: str) -> str: + if v and not v.startswith("/"): + raise ValueError("Path must start with /") + return v + + +class ChatStreamRequest(BaseModel): + question: str = Field(..., min_length=1, max_length=1000) + page: PageContext | None = None + + +def _sse_event(event: str, data: dict | list) -> str: + return f"event: {event}\ndata: {json.dumps(data)}\n\n" + + +@router.post("/stream") +async def chat_stream(req: ChatStreamRequest): + """SSE streaming endpoint for the chat widget.""" + + async def generate(): + question = req.question + if req.page and req.page.title: + page_context = f'[The user is currently reading: "{req.page.title}" ({req.page.path})' + if req.page.description: + page_context += f"\nDocument description: {req.page.description}" + page_context += "]\n\n" + question = page_context + req.question + + yield _sse_event("status", {"text": "Searching the documentation\u2026"}) + + try: + context, sources = await _build_context(question) + except Exception: + logger.exception("Chat stream context build failed") + yield _sse_event("error", {"text": "Search service unavailable"}) + return + + if not context: + yield _sse_event("status", {"text": "No relevant documents found"}) + yield _sse_event( + "token", + { + "text": "I couldn't find any relevant documentation " + "for that question. Try asking about satellite tracking, " + "planetary observation, rise/set prediction, or other pg_orrery functions." + }, + ) + yield _sse_event("done", {}) + return + + n = len(sources) + yield _sse_event( + "status", + {"text": f"Found {n} relevant page{'s' if n != 1 else ''}, generating answer\u2026"}, + ) + yield _sse_event("sources", sources) + + try: + async for kind, text in _chat_completion_stream(context, question): + yield _sse_event( + "reasoning" if kind == "reasoning" else "token", + {"text": text}, + ) + except ( + httpx.HTTPStatusError, + httpx.ConnectError, + httpx.ReadTimeout, + httpx.PoolTimeout, + httpx.ConnectTimeout, + ) as exc: + logger.warning("Chat stream failed: %s", exc) + yield _sse_event("error", {"text": "Chat service unavailable"}) + return + except asyncio.CancelledError: + logger.debug("Chat stream cancelled by client disconnect") + return + + yield _sse_event("done", {}) + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) diff --git a/search/src/orrery_search/routers/health.py b/search/src/orrery_search/routers/health.py new file mode 100644 index 0000000..a479c81 --- /dev/null +++ b/search/src/orrery_search/routers/health.py @@ -0,0 +1,46 @@ +import logging + +from fastapi import APIRouter, Depends +from sqlalchemy import func, select, text +from sqlalchemy.ext.asyncio import AsyncSession + +from orrery_search.db import get_db +from orrery_search.models.document import Document + +logger = logging.getLogger("orrery_search") + +router = APIRouter() + + +@router.get("/health") +async def health(db: AsyncSession = Depends(get_db)): + """Health check with document and embedding counts.""" + doc_count = await db.scalar(select(func.count(Document.id))) + + embedding_count = 0 + try: + result = await db.execute( + text("SELECT count(*) FROM document_embedding_store") + ) + embedding_count = result.scalar() or 0 + except Exception: + logger.debug("Embedding store not yet available", exc_info=True) + + # Check if pg_orrery extension is loaded + orrery_ok = False + try: + result = await db.execute( + text("SELECT extversion FROM pg_extension WHERE extname = 'pg_orrery'") + ) + row = result.first() + orrery_ok = row is not None + except Exception: + logger.debug("pg_orrery extension check failed", exc_info=True) + + return { + "status": "ok", + "service": "orrery-search", + "documents": doc_count or 0, + "embeddings": embedding_count, + "pg_orrery": orrery_ok, + } diff --git a/search/src/orrery_search/routers/search.py b/search/src/orrery_search/routers/search.py new file mode 100644 index 0000000..088b966 --- /dev/null +++ b/search/src/orrery_search/routers/search.py @@ -0,0 +1,74 @@ +from fastapi import APIRouter, Depends, Query +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from orrery_search.db import get_db +from orrery_search.models.document import Document +from orrery_search.schemas.search import ( + SearchResponse, + SearchResult, + SectionCount, + SectionsResponse, +) +from orrery_search.services.search import search_documents + +router = APIRouter() + + +@router.get("", response_model=SearchResponse) +async def search( + q: str = Query(..., min_length=1, max_length=500), + content_type: str | None = Query(None, description="Filter by content type"), + section: str | None = Query(None, max_length=200, description="Section prefix filter"), + limit: int = Query(10, ge=1, le=50), + mode: str = Query("hybrid", pattern="^(hybrid|semantic|text)$"), + db: AsyncSession = Depends(get_db), +): + """Search pg_orrery documentation using semantic and/or text matching.""" + output = await search_documents( + q=q, + db=db, + mode=mode, + content_type=content_type, + section=section, + limit=limit, + ) + + return SearchResponse( + query=output.query, + results=[ + SearchResult( + title=r.title, + slug=r.slug, + section=r.section, + content_type=r.content_type, + description=r.description, + snippet=r.snippet, + url=r.url, + score=r.score, + source=r.source, + ) + for r in output.results + ], + count=output.count, + mode=output.mode, + ) + + +@router.get("/sections", response_model=SectionsResponse) +async def list_sections( + db: AsyncSession = Depends(get_db), +): + """List all sections with document counts for the filter UI.""" + stmt = ( + select(Document.section, func.count(Document.id).label("count")) + .group_by(Document.section) + .order_by(Document.section) + ) + result = await db.execute(stmt) + sections = [ + SectionCount(section=row.section, count=row.count) + for row in result + if row.section + ] + return SectionsResponse(sections=sections) diff --git a/search/src/orrery_search/schemas/__init__.py b/search/src/orrery_search/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/search/src/orrery_search/schemas/search.py b/search/src/orrery_search/schemas/search.py new file mode 100644 index 0000000..507971d --- /dev/null +++ b/search/src/orrery_search/schemas/search.py @@ -0,0 +1,29 @@ +from pydantic import BaseModel + + +class SearchResult(BaseModel): + title: str + slug: str + section: str + content_type: str + description: str | None = None + snippet: str + url: str + score: float + source: str # "semantic" | "text" | "both" + + +class SearchResponse(BaseModel): + query: str + results: list[SearchResult] + count: int + mode: str # "hybrid" | "semantic" | "text" + + +class SectionCount(BaseModel): + section: str + count: int + + +class SectionsResponse(BaseModel): + sections: list[SectionCount] diff --git a/search/src/orrery_search/services/__init__.py b/search/src/orrery_search/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/search/src/orrery_search/services/embedding.py b/search/src/orrery_search/services/embedding.py new file mode 100644 index 0000000..c9d66ab --- /dev/null +++ b/search/src/orrery_search/services/embedding.py @@ -0,0 +1,55 @@ +"""Embedding client for semantic search queries. + +Singleton AsyncOpenAI client configured against the GPU embedding gateway. +Falls back to text-only search if the client is unavailable or requests fail. +""" + +import logging + +from openai import ( + APIConnectionError, + APITimeoutError, + AsyncOpenAI, + AuthenticationError, + RateLimitError, +) + +from orrery_search.config import settings + +logger = logging.getLogger("orrery_search") + +embedding_client: AsyncOpenAI | None = None +if settings.gpu_api_key: + embedding_client = AsyncOpenAI( + base_url=settings.gpu_base_url, + api_key=settings.gpu_api_key, + timeout=10.0, + max_retries=0, + ) + + +async def embed_query(text: str) -> list[float] | None: + """Embed a search query using mxbai-embed-large (1024 dims). + + Returns None if the embedding client is not configured or the + request fails. Callers fall back to text-only search. + """ + if not embedding_client: + return None + try: + resp = await embedding_client.embeddings.create( + model=settings.gpu_embed_model, + input=text, + ) + return resp.data[0].embedding + except APITimeoutError: + logger.warning("Embedding timed out (10s) for: %s", text[:80]) + except AuthenticationError: + logger.error("Embedding auth failed — check GPU_API_KEY") + except RateLimitError: + logger.warning("Embedding rate-limited for: %s", text[:80]) + except APIConnectionError as exc: + logger.warning("Embedding connection failed: %s", exc) + except Exception: + logger.exception("Unexpected embedding failure for: %s", text[:80]) + return None diff --git a/search/src/orrery_search/services/search.py b/search/src/orrery_search/services/search.py new file mode 100644 index 0000000..6e2993f --- /dev/null +++ b/search/src/orrery_search/services/search.py @@ -0,0 +1,242 @@ +"""Hybrid semantic + text search service. + +Extracts search logic into framework-agnostic functions returning +dataclasses. The router converts these to Pydantic models. +""" + +import logging +import re +from dataclasses import dataclass, field + +from sqlalchemy import or_, select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession + +from orrery_search.models.document import Document +from orrery_search.services.embedding import embed_query + +logger = logging.getLogger("orrery_search") + + +@dataclass +class SearchResult: + title: str + slug: str + section: str + content_type: str + description: str | None + snippet: str + url: str + score: float + source: str # "semantic" | "text" | "both" + + +@dataclass +class SearchOutput: + query: str + results: list[SearchResult] = field(default_factory=list) + count: int = 0 + mode: str = "hybrid" + + +def _escape_ilike(value: str) -> str: + return re.sub(r"([%_\\])", r"\\\1", value) + + +def _extract_snippet(body: str, query: str, context_chars: int = 100) -> str: + idx = body.lower().find(query.lower()) + if idx == -1: + return body[:200] + + start = max(0, idx - context_chars) + end = min(len(body), idx + len(query) + context_chars) + snippet = body[start:end] + + if start > 0: + snippet = "..." + snippet + if end < len(body): + snippet = snippet + "..." + + return snippet + + +async def _semantic_search( + query_vec: list[float], + limit: int, + db: AsyncSession, + content_type: str | None, + section: str | None, +) -> list[SearchResult]: + results: list[SearchResult] = [] + try: + stmt = ( + select( + Document, + Document.embeddings.embedding.cosine_distance(query_vec).label("distance"), + Document.embeddings.chunk, + ) + .join(Document.embeddings) + .order_by("distance") + .limit(limit) + ) + + if content_type: + stmt = stmt.where(Document.content_type == content_type) + if section: + stmt = stmt.where(Document.section.startswith(section)) + + for page, dist, chunk in await db.execute(stmt): + snippet = chunk[:200] if chunk else page.body[:200] + results.append( + SearchResult( + title=page.title, + slug=page.slug, + section=page.section, + content_type=page.content_type, + description=page.description, + snippet=snippet, + url=page.url, + score=round(max(0.0, 1.0 - float(dist)), 4), + source="semantic", + ) + ) + except SQLAlchemyError: + logger.warning("Semantic search failed (database)", exc_info=True) + await db.rollback() + except Exception: + logger.error("Semantic search failed (unexpected)", exc_info=True) + await db.rollback() + + return results + + +async def _text_search( + q: str, + limit: int, + db: AsyncSession, + content_type: str | None, + section: str | None, +) -> list[SearchResult]: + results: list[SearchResult] = [] + if not q: + return results + + safe_q = _escape_ilike(q) + try: + stmt = ( + select(Document) + .where( + or_( + Document.title.ilike(f"%{safe_q}%"), + Document.body.ilike(f"%{safe_q}%"), + ) + ) + .limit(limit) + ) + + if content_type: + stmt = stmt.where(Document.content_type == content_type) + if section: + stmt = stmt.where(Document.section.startswith(section)) + + for page in (await db.execute(stmt)).scalars(): + snippet = _extract_snippet(page.body, q) + results.append( + SearchResult( + title=page.title, + slug=page.slug, + section=page.section, + content_type=page.content_type, + description=page.description, + snippet=snippet, + url=page.url, + score=0.5, + source="text", + ) + ) + except SQLAlchemyError: + logger.warning("Text search failed (database)", exc_info=True) + await db.rollback() + except Exception: + logger.error("Text search failed (unexpected)", exc_info=True) + await db.rollback() + + return results + + +def _merge_results( + semantic: list[SearchResult], + text: list[SearchResult], + limit: int, +) -> list[SearchResult]: + seen: dict[str, SearchResult] = {} + + for r in semantic: + if r.slug not in seen: + seen[r.slug] = r + + for r in text: + if r.slug in seen: + boosted = seen[r.slug] + seen[r.slug] = SearchResult( + title=boosted.title, + slug=boosted.slug, + section=boosted.section, + content_type=boosted.content_type, + description=boosted.description, + snippet=boosted.snippet, + url=boosted.url, + score=min(1.0, boosted.score + 0.1), + source="both", + ) + else: + seen[r.slug] = r + + merged = sorted(seen.values(), key=lambda r: r.score, reverse=True) + return merged[:limit] + + +async def search_documents( + q: str, + db: AsyncSession, + mode: str = "hybrid", + content_type: str | None = None, + section: str | None = None, + limit: int = 10, +) -> SearchOutput: + """Search pg_orrery docs using semantic and/or text matching.""" + semantic_results: list[SearchResult] = [] + text_results: list[SearchResult] = [] + actual_mode = mode + + if mode in ("hybrid", "semantic"): + query_vec = await embed_query(q) + if query_vec: + semantic_results = await _semantic_search( + query_vec, limit * 3, db, content_type, section + ) + elif mode == "semantic": + actual_mode = "text" + + if mode in ("hybrid", "text") or (mode == "semantic" and not semantic_results): + text_results = await _text_search(q, limit, db, content_type, section) + + if semantic_results and text_results: + final = _merge_results(semantic_results, text_results, limit) + elif semantic_results: + deduped: dict[str, SearchResult] = {} + for r in semantic_results: + if r.slug not in deduped: + deduped[r.slug] = r + final = sorted(deduped.values(), key=lambda r: r.score, reverse=True)[:limit] + else: + final = text_results[:limit] + if mode != "text": + actual_mode = "text" + + return SearchOutput( + query=q, + results=final, + count=len(final), + mode=actual_mode, + ) diff --git a/search/src/orrery_search/services/search_text.py b/search/src/orrery_search/services/search_text.py new file mode 100644 index 0000000..2ca7dd6 --- /dev/null +++ b/search/src/orrery_search/services/search_text.py @@ -0,0 +1,35 @@ +def build_search_text( + title: str, + section: str, + content_type: str, + description: str | None, + body: str, +) -> str: + """Build enriched search text for vectorizer indexing. + + Concatenates title, section path, content type context, description, + and body (capped at 4000 chars) so the vectorizer gets a dense, + searchable representation. + """ + parts = [title] + + if section: + parts.append(f"section: {section}") + + type_context = { + "guide": "pg_orrery usage guide and tutorial", + "reference": "pg_orrery SQL function and type reference", + "architecture": "pg_orrery internal architecture and design", + "workflow": "workflow translation from other tools to pg_orrery SQL", + "getting_started": "pg_orrery installation and getting started", + "performance": "pg_orrery performance benchmarks", + } + ctx = type_context.get(content_type) + if ctx: + parts.append(ctx) + + if description: + parts.append(description) + + parts.append(body[:4000]) + return " ".join(parts)