Fix ruff lint, Pyright type ambiguity, and MCP tool SQL queries

- Fix 28 ruff errors: E501 line length, B904 raise-from, F401 unused import
- Fix SQLAlchemy Row.count() ambiguity with tuple indexing (Pyright)
- Replace composite column notation with accessor functions in MCP tools
  (topocentric/equatorial/pass_event are C-level base types, not composites)
- Fix satellite_pass: use time window (start + end) not count parameter
  to match predict_passes(tle, observer, start_ts, end_ts, min_el) signature
This commit is contained in:
Ryan Malloy 2026-03-01 16:41:07 -07:00
parent a40ae9437d
commit 33787e03da
6 changed files with 128 additions and 51 deletions

View File

@ -3,7 +3,6 @@
Run: docker compose exec api-dev python -m orrery_search.ingest Run: docker compose exec api-dev python -m orrery_search.ingest
""" """
import asyncio
import sys import sys
from pathlib import Path from pathlib import Path

View File

@ -17,7 +17,8 @@ logger = logging.getLogger("orrery_search")
SYSTEM_PROMPT = ( SYSTEM_PROMPT = (
"/no_think\n" "/no_think\n"
"You are a knowledgeable assistant for pg_orrery, a PostgreSQL extension providing " "You are a knowledgeable assistant for pg_orrery, a PostgreSQL extension providing "
"celestial mechanics types and functions. Answer questions using the provided context " "celestial mechanics types and functions. "
"Answer questions using the provided context "
"from the documentation.\n\n" "from the documentation.\n\n"
"Key domain knowledge:\n" "Key domain knowledge:\n"
"- Body IDs: 0=Sun, 1=Mercury, 2=Venus, 3=Earth, 4=Mars, 5=Jupiter, 6=Saturn, " "- Body IDs: 0=Sun, 1=Mercury, 2=Venus, 3=Earth, 4=Mars, 5=Jupiter, 6=Saturn, "
@ -125,7 +126,10 @@ async def _chat_completion_stream_gpu(
{"role": "system", "content": SYSTEM_PROMPT}, {"role": "system", "content": SYSTEM_PROMPT},
{ {
"role": "user", "role": "user",
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}", "content": (
f"Documentation context:\n\n{context}"
f"\n\n---\n\nQuestion: {question}"
),
}, },
] ]
@ -196,7 +200,10 @@ async def _chat_completion_stream_anthropic(
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}", "content": (
f"Documentation context:\n\n{context}"
f"\n\n---\n\nQuestion: {question}"
),
}, },
], ],
) as stream: ) as stream:
@ -232,7 +239,10 @@ async def _chat_completion(context: str, question: str) -> str:
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}", "content": (
f"Documentation context:\n\n{context}"
f"\n\n---\n\nQuestion: {question}"
),
}, },
], ],
) )
@ -243,7 +253,10 @@ async def _chat_completion(context: str, question: str) -> str:
{"role": "system", "content": SYSTEM_PROMPT}, {"role": "system", "content": SYSTEM_PROMPT},
{ {
"role": "user", "role": "user",
"content": f"Documentation context:\n\n{context}\n\n---\n\nQuestion: {question}", "content": (
f"Documentation context:\n\n{context}"
f"\n\n---\n\nQuestion: {question}"
),
}, },
] ]
resp = await client.post( resp = await client.post(
@ -284,11 +297,16 @@ async def ask_orrery(
context, sources = await _build_context(question) context, sources = await _build_context(question)
except Exception: except Exception:
logger.exception("Failed to build RAG context") logger.exception("Failed to build RAG context")
raise RuntimeError("Search failed while building context") raise RuntimeError(
"Search failed while building context"
) from None
if not context: if not context:
return json.dumps({ return json.dumps({
"answer": "No relevant documents found in the pg_orrery docs for this question.", "answer": (
"No relevant documents found in the "
"pg_orrery docs for this question."
),
"sources": [], "sources": [],
}) })
@ -296,10 +314,18 @@ async def ask_orrery(
answer = await _chat_completion(context, question) answer = await _chat_completion(context, question)
except httpx.HTTPStatusError as exc: except httpx.HTTPStatusError as exc:
logger.error("Chat completion failed: HTTP %s", exc.response.status_code) logger.error("Chat completion failed: HTTP %s", exc.response.status_code)
raise RuntimeError(f"Chat model returned HTTP {exc.response.status_code}") status = exc.response.status_code
raise RuntimeError(
f"Chat model returned HTTP {status}"
) from None
except httpx.TimeoutException: except httpx.TimeoutException:
logger.warning("Chat completion timed out (limit: %.0fs)", settings.chat_timeout) logger.warning(
raise RuntimeError("Chat model timed out — try a simpler question") "Chat completion timed out (limit: %.0fs)",
settings.chat_timeout,
)
raise RuntimeError(
"Chat model timed out — try a simpler question"
) from None
result: dict = {"answer": answer} result: dict = {"answer": answer}
if include_sources: if include_sources:

View File

@ -63,7 +63,10 @@ def _validate_sql(sql: str) -> None:
raise ValueError("Only SELECT statements are allowed") raise ValueError("Only SELECT statements are allowed")
if _FORBIDDEN_RE.search(stripped): if _FORBIDDEN_RE.search(stripped):
raise ValueError("Statement contains forbidden keywords (INSERT/UPDATE/DELETE/DDL)") raise ValueError(
"Statement contains forbidden keywords"
" (INSERT/UPDATE/DELETE/DDL)"
)
def _serialize_row(row: asyncpg.Record) -> dict: def _serialize_row(row: asyncpg.Record) -> dict:
@ -98,10 +101,10 @@ async def run_query(sql: str) -> str:
_validate_sql(sql) _validate_sql(sql)
pool = await _get_pool() pool = await _get_pool()
async with pool.acquire() as conn: async with pool.acquire() as conn, conn.transaction(readonly=True):
async with conn.transaction(readonly=True): timeout_ms = int(settings.run_query_timeout * 1000)
await conn.execute( await conn.execute(
f"SET LOCAL statement_timeout = '{int(settings.run_query_timeout * 1000)}'" f"SET LOCAL statement_timeout = '{timeout_ms}'"
) )
rows = await conn.fetch(sql) rows = await conn.fetch(sql)
@ -142,8 +145,13 @@ async def planet_position(
sql = f""" sql = f"""
SELECT SELECT
(t).azimuth_deg, (t).elevation_deg, (t).range_km, (t).range_rate_km_s, topo_azimuth(t) AS az_deg,
(e).ra_hours, (e).dec_degrees, (e).distance_km topo_elevation(t) AS el_deg,
topo_range(t) AS range_km,
topo_range_rate(t) AS range_rate_km_s,
eq_ra(e) AS ra_hours,
eq_dec(e) AS dec_deg,
eq_distance(e) AS distance_km
FROM ( FROM (
SELECT SELECT
planet_observe({body_id}, {obs}, {ts}) AS t, planet_observe({body_id}, {obs}, {ts}) AS t,
@ -178,11 +186,11 @@ async def sky_survey(
sql = f""" sql = f"""
SELECT SELECT
body_name, body_name,
(topo).azimuth_deg AS az, topo_azimuth(topo) AS az,
(topo).elevation_deg AS el, topo_elevation(topo) AS el,
(topo).range_km AS range_km, topo_range(topo) AS range_km,
(eq).ra_hours AS ra, eq_ra(eq) AS ra,
(eq).dec_degrees AS dec eq_dec(eq) AS dec
FROM ( FROM (
SELECT 'Sun' AS body_name, SELECT 'Sun' AS body_name,
sun_observe({obs}, {ts}) AS topo, sun_observe({obs}, {ts}) AS topo,
@ -192,9 +200,14 @@ async def sky_survey(
moon_observe({obs}, {ts}), moon_observe({obs}, {ts}),
moon_equatorial({ts}) moon_equatorial({ts})
UNION ALL UNION ALL
SELECT unnest(ARRAY['Mercury','Venus','Mars','Jupiter','Saturn','Uranus','Neptune']), SELECT unnest(ARRAY[
planet_observe(unnest(ARRAY[1,2,4,5,6,7,8]), {obs}, {ts}), 'Mercury','Venus','Mars','Jupiter',
planet_equatorial(unnest(ARRAY[1,2,4,5,6,7,8]), {ts}) 'Saturn','Uranus','Neptune']),
planet_observe(
unnest(ARRAY[1,2,4,5,6,7,8]),
{obs}, {ts}),
planet_equatorial(
unnest(ARRAY[1,2,4,5,6,7,8]), {ts})
) survey ) survey
ORDER BY el DESC ORDER BY el DESC
""" """
@ -210,11 +223,13 @@ async def satellite_pass(
longitude: float, longitude: float,
altitude: float = 0.0, altitude: float = 0.0,
time: str = "NOW()", time: str = "NOW()",
count: int = 5, hours: int = 48,
min_elevation: float = 0.0,
) -> str: ) -> str:
"""Predict upcoming satellite passes over an observer location. """Predict upcoming satellite passes over an observer location.
Uses SGP4/SDP4 propagation with the provided Two-Line Element set. Uses SGP4/SDP4 propagation with the provided Two-Line Element set.
Searches a time window from the start time forward.
Args: Args:
tle_line1: First line of the TLE (69 chars, starts with "1 ") tle_line1: First line of the TLE (69 chars, starts with "1 ")
@ -223,9 +238,10 @@ async def satellite_pass(
longitude: Observer longitude in degrees longitude: Observer longitude in degrees
altitude: Observer altitude in meters (default 0) altitude: Observer altitude in meters (default 0)
time: Start time for pass search (default "NOW()") time: Start time for pass search (default "NOW()")
count: Number of passes to predict (default 5, max 20) hours: Hours to search forward from start time (default 48)
min_elevation: Minimum peak elevation in degrees (default 0)
""" """
count = max(1, min(count, 20)) hours = max(1, min(hours, 168))
obs = f"'({latitude},{longitude},{altitude})'::observer" obs = f"'({latitude},{longitude},{altitude})'::observer"
ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()" ts = f"'{time}'::timestamptz" if time != "NOW()" else "NOW()"
@ -235,12 +251,17 @@ async def satellite_pass(
sql = f""" sql = f"""
SELECT SELECT
(p).aos_time, (p).los_time, pass_aos_time(p) AS aos_time,
(p).max_elevation_deg, pass_los_time(p) AS los_time,
(p).aos_azimuth_deg, (p).los_azimuth_deg pass_max_elevation(p) AS max_el_deg,
pass_aos_azimuth(p) AS aos_az_deg,
pass_los_azimuth(p) AS los_az_deg,
pass_duration(p) AS duration
FROM predict_passes( FROM predict_passes(
tle_from_lines('{safe_l1}', '{safe_l2}'), tle_from_lines('{safe_l1}', '{safe_l2}'),
{obs}, {ts}, {count} {obs}, {ts},
{ts} + interval '{hours} hours',
{min_elevation}
) AS p ) AS p
""" """

View File

@ -71,7 +71,8 @@ async def get_document(slug: str) -> str:
Returns title, body text, metadata, and URL. Returns title, body text, metadata, and URL.
Args: Args:
slug: Document identifier (from search results, e.g. "guides/tracking-satellites") slug: Document identifier from search results
(e.g. "guides/tracking-satellites")
""" """
async with async_session() as db: async with async_session() as db:
result = await db.execute( result = await db.execute(
@ -122,10 +123,17 @@ async def list_content(
.order_by(Document.section) .order_by(Document.section)
) )
groups = [ groups = [
{"section": row.section, "content_type": row.content_type, "count": row.count} {
"section": row.section,
"content_type": row.content_type,
"count": row.count,
}
for row in result for row in result
] ]
return json.dumps({"summary": groups, "total": sum(g["count"] for g in groups)}) total = sum(g["count"] for g in groups)
return json.dumps(
{"summary": groups, "total": total}
)
limit = max(1, min(limit, 500)) limit = max(1, min(limit, 500))
stmt = ( stmt = (
@ -197,8 +205,9 @@ def orrery_expert(topic: str = "general") -> str:
) )
topic_context = { topic_context = {
"general": ( "general": (
"Help with any pg_orrery topic — satellite tracking, planetary observation, " "Help with any pg_orrery topic — satellite tracking, "
"rise/set prediction, constellation identification, or Lagrange points." "planetary observation, rise/set prediction, "
"constellation identification, or Lagrange points."
), ),
"satellites": ( "satellites": (
"Focus on satellite tracking: SGP4/SDP4 propagation, TLE parsing, " "Focus on satellite tracking: SGP4/SDP4 propagation, TLE parsing, "
@ -206,11 +215,13 @@ def orrery_expert(topic: str = "general") -> str:
), ),
"planets": ( "planets": (
"Focus on planetary ephemerides: VSOP87 and JPL DE441 providers, " "Focus on planetary ephemerides: VSOP87 and JPL DE441 providers, "
"planet/sun/moon observation, equatorial coordinates, and apparent positions." "planet/sun/moon observation, equatorial coordinates, "
"and apparent positions."
), ),
"navigation": ( "navigation": (
"Focus on observational astronomy: rise/set prediction, twilight computation, " "Focus on observational astronomy: rise/set prediction, "
"constellation identification, lunar phase, planet magnitude, and refraction." "twilight computation, constellation identification, "
"lunar phase, planet magnitude, and refraction."
), ),
"transfers": ( "transfers": (
"Focus on orbital mechanics: Lambert transfer solver, Lagrange equilibrium " "Focus on orbital mechanics: Lambert transfer solver, Lagrange equilibrium "

View File

@ -48,7 +48,10 @@ async def chat(req: ChatRequest):
context, sources = await _build_context(req.question) context, sources = await _build_context(req.question)
except Exception: except Exception:
logger.exception("Chat context build failed") logger.exception("Chat context build failed")
raise HTTPException(status_code=502, detail="Search service unavailable") raise HTTPException(
status_code=502,
detail="Search service unavailable",
) from None
if not context: if not context:
return ChatResponse( return ChatResponse(
@ -62,7 +65,10 @@ async def chat(req: ChatRequest):
answer = await _chat_completion(context, req.question) answer = await _chat_completion(context, req.question)
except Exception: except Exception:
logger.exception("Chat completion failed") logger.exception("Chat completion failed")
raise HTTPException(status_code=502, detail="Chat completion unavailable") raise HTTPException(
status_code=502,
detail="Chat completion unavailable",
) from None
return ChatResponse( return ChatResponse(
answer=answer, answer=answer,
@ -105,7 +111,12 @@ async def chat_stream(req: ChatStreamRequest):
async def generate(): async def generate():
question = req.question question = req.question
if req.page and req.page.title: if req.page and req.page.title:
page_context = f'[The user is currently reading: "{req.page.title}" ({req.page.path})' title = req.page.title
path = req.page.path
page_context = (
f'[The user is currently reading: '
f'"{title}" ({path})'
)
if req.page.description: if req.page.description:
page_context += f"\nDocument description: {req.page.description}" page_context += f"\nDocument description: {req.page.description}"
page_context += "]\n\n" page_context += "]\n\n"
@ -127,7 +138,8 @@ async def chat_stream(req: ChatStreamRequest):
{ {
"text": "I couldn't find any relevant documentation " "text": "I couldn't find any relevant documentation "
"for that question. Try asking about satellite tracking, " "for that question. Try asking about satellite tracking, "
"planetary observation, rise/set prediction, or other pg_orrery functions." "planetary observation, rise/set prediction, "
"or other pg_orrery functions."
}, },
) )
yield _sse_event("done", {}) yield _sse_event("done", {})
@ -136,7 +148,13 @@ async def chat_stream(req: ChatStreamRequest):
n = len(sources) n = len(sources)
yield _sse_event( yield _sse_event(
"status", "status",
{"text": f"Found {n} relevant page{'s' if n != 1 else ''}, generating answer\u2026"}, {
"text": (
f"Found {n} relevant "
f"page{'s' if n != 1 else ''}, "
"generating answer\u2026"
),
},
) )
yield _sse_event("sources", sources) yield _sse_event("sources", sources)

View File

@ -19,7 +19,9 @@ router = APIRouter()
async def search( async def search(
q: str = Query(..., min_length=1, max_length=500), q: str = Query(..., min_length=1, max_length=500),
content_type: str | None = Query(None, description="Filter by content type"), content_type: str | None = Query(None, description="Filter by content type"),
section: str | None = Query(None, max_length=200, description="Section prefix filter"), section: str | None = Query(
None, max_length=200, description="Section prefix filter"
),
limit: int = Query(10, ge=1, le=50), limit: int = Query(10, ge=1, le=50),
mode: str = Query("hybrid", pattern="^(hybrid|semantic|text)$"), mode: str = Query("hybrid", pattern="^(hybrid|semantic|text)$"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
@ -67,8 +69,8 @@ async def list_sections(
) )
result = await db.execute(stmt) result = await db.execute(stmt)
sections = [ sections = [
SectionCount(section=row.section, count=row.count) SectionCount(section=row[0], count=row[1])
for row in result for row in result
if row.section if row[0]
] ]
return SectionsResponse(sections=sections) return SectionsResponse(sections=sections)