Add cache attribution and fix broken DB session methods

- Add fetched_by_key_hash column to CacheEntry model for tracking
  which Rentcast API key originally fetched each cached response
- Fix broken self.db references in SQLiteCacheBackend methods:
  - delete() now properly creates session from SessionLocal
  - clear_pattern() fixed with proper session management
  - health_check() fixed with proper session management
- Add SQL injection protection in clear_pattern() by escaping
  LIKE wildcards (%, _) before pattern conversion
- Update proxy_to_rentcast to hash and store the client's
  Rentcast API key on cache miss
- Fix httpx test client to use ASGITransport (required for httpx 0.27+)
This commit is contained in:
Ryan Malloy 2026-01-18 18:09:47 -07:00
parent 26fee9ae74
commit dde8bd04de
4 changed files with 55 additions and 41 deletions

View File

@ -135,7 +135,8 @@ class SQLiteCacheBackend(CacheBackend):
headers_json=json.dumps(data.get('headers', {})),
expires_at=expires_at,
ttl_seconds=ttl,
estimated_cost=data.get('estimated_cost', 0.0)
estimated_cost=data.get('estimated_cost', 0.0),
fetched_by_key_hash=data.get('fetched_by_key_hash'),
)
db.add(new_entry)
@ -150,46 +151,53 @@ class SQLiteCacheBackend(CacheBackend):
async def delete(self, key: str) -> bool:
"""Soft delete cache entry."""
try:
await self._mark_invalid_by_key(key)
await self.db.commit()
logger.debug(f"Cache deleted: {key}")
return True
except Exception as e:
logger.error(f"Error deleting cache entry {key}: {e}")
await self.db.rollback()
return False
async with self.SessionLocal() as db:
try:
stmt = update(CacheEntry).where(
CacheEntry.cache_key == key
).values(is_valid=False)
await db.execute(stmt)
await db.commit()
logger.debug(f"Cache deleted: {key}")
return True
except Exception as e:
logger.error(f"Error deleting cache entry {key}: {e}")
await db.rollback()
return False
async def clear_pattern(self, pattern: str) -> int:
"""Clear cache entries matching pattern."""
try:
# Convert pattern to SQL LIKE pattern
sql_pattern = pattern.replace('*', '%')
async with self.SessionLocal() as db:
try:
# Escape SQL LIKE special characters first, then convert glob wildcards
escaped = pattern.replace('%', r'\%').replace('_', r'\_')
sql_pattern = escaped.replace('*', '%')
stmt = update(CacheEntry).where(
CacheEntry.cache_key.like(sql_pattern)
).values(is_valid=False)
stmt = update(CacheEntry).where(
CacheEntry.cache_key.like(sql_pattern, escape='\\')
).values(is_valid=False)
result = await self.db.execute(stmt)
await self.db.commit()
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
logger.info(f"Cleared {count} cache entries matching pattern: {pattern}")
return count
count = result.rowcount
logger.info(f"Cleared {count} cache entries matching pattern: {pattern}")
return count
except Exception as e:
logger.error(f"Error clearing cache pattern {pattern}: {e}")
await self.db.rollback()
return 0
except Exception as e:
logger.error(f"Error clearing cache pattern {pattern}: {e}")
await db.rollback()
return 0
async def health_check(self) -> bool:
"""Check SQLite database health."""
try:
await self.db.execute(select(1))
return True
except Exception as e:
logger.error(f"SQLite health check failed: {e}")
return False
async with self.SessionLocal() as db:
try:
await db.execute(select(1))
return True
except Exception as e:
logger.error(f"SQLite health check failed: {e}")
return False
async def _mark_invalid(self, entry_id: int, db: AsyncSession):
"""Mark specific entry as invalid."""

View File

@ -57,6 +57,9 @@ class CacheEntry(Base, TimestampMixin):
# Cost tracking (if applicable)
estimated_cost = Column(Float, default=0.0)
# Attribution - which Rentcast API key originally fetched this data
fetched_by_key_hash = Column(String(64), index=True, nullable=True)
__table_args__ = (
Index('idx_cache_valid_expires', 'is_valid', 'expires_at'),
Index('idx_cache_endpoint_method', 'endpoint', 'method'),

View File

@ -449,7 +449,8 @@ async def proxy_to_rentcast(
response_data = response.json() if response.content else {}
response_time = (time.time() - start_time) * 1000
# Store in cache
# Store in cache with attribution to the key that fetched it
key_hash = hashlib.sha256(rentcast_key.encode()).hexdigest() if rentcast_key else None
cache_entry_data = {
"data": response_data,
"status_code": response.status_code,
@ -458,7 +459,8 @@ async def proxy_to_rentcast(
"method": method,
"params": cache_data,
"params_hash": hashlib.md5(json.dumps(cache_data, sort_keys=True).encode()).hexdigest(),
"estimated_cost": endpoint_config["cost_estimate"]
"estimated_cost": endpoint_config["cost_estimate"],
"fetched_by_key_hash": key_hash,
}
ttl = ttl_override or endpoint_config["ttl"]

View File

@ -4,7 +4,7 @@ Pytest configuration for RentCache tests.
import asyncio
import pytest
import pytest_asyncio
from httpx import AsyncClient
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.pool import StaticPool
@ -63,7 +63,8 @@ async def test_client(test_session):
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(app=app, base_url="http://test") as client:
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()