diff --git a/src/rentcache/cache.py b/src/rentcache/cache.py index 667496a..6052d59 100644 --- a/src/rentcache/cache.py +++ b/src/rentcache/cache.py @@ -135,7 +135,8 @@ class SQLiteCacheBackend(CacheBackend): headers_json=json.dumps(data.get('headers', {})), expires_at=expires_at, ttl_seconds=ttl, - estimated_cost=data.get('estimated_cost', 0.0) + estimated_cost=data.get('estimated_cost', 0.0), + fetched_by_key_hash=data.get('fetched_by_key_hash'), ) db.add(new_entry) @@ -150,46 +151,53 @@ class SQLiteCacheBackend(CacheBackend): async def delete(self, key: str) -> bool: """Soft delete cache entry.""" - try: - await self._mark_invalid_by_key(key) - await self.db.commit() - logger.debug(f"Cache deleted: {key}") - return True - except Exception as e: - logger.error(f"Error deleting cache entry {key}: {e}") - await self.db.rollback() - return False + async with self.SessionLocal() as db: + try: + stmt = update(CacheEntry).where( + CacheEntry.cache_key == key + ).values(is_valid=False) + await db.execute(stmt) + await db.commit() + logger.debug(f"Cache deleted: {key}") + return True + except Exception as e: + logger.error(f"Error deleting cache entry {key}: {e}") + await db.rollback() + return False async def clear_pattern(self, pattern: str) -> int: """Clear cache entries matching pattern.""" - try: - # Convert pattern to SQL LIKE pattern - sql_pattern = pattern.replace('*', '%') - - stmt = update(CacheEntry).where( - CacheEntry.cache_key.like(sql_pattern) - ).values(is_valid=False) - - result = await self.db.execute(stmt) - await self.db.commit() - - count = result.rowcount - logger.info(f"Cleared {count} cache entries matching pattern: {pattern}") - return count - - except Exception as e: - logger.error(f"Error clearing cache pattern {pattern}: {e}") - await self.db.rollback() - return 0 - + async with self.SessionLocal() as db: + try: + # Escape SQL LIKE special characters first, then convert glob wildcards + escaped = pattern.replace('%', r'\%').replace('_', r'\_') + sql_pattern = escaped.replace('*', '%') + + stmt = update(CacheEntry).where( + CacheEntry.cache_key.like(sql_pattern, escape='\\') + ).values(is_valid=False) + + result = await db.execute(stmt) + await db.commit() + + count = result.rowcount + logger.info(f"Cleared {count} cache entries matching pattern: {pattern}") + return count + + except Exception as e: + logger.error(f"Error clearing cache pattern {pattern}: {e}") + await db.rollback() + return 0 + async def health_check(self) -> bool: """Check SQLite database health.""" - try: - await self.db.execute(select(1)) - return True - except Exception as e: - logger.error(f"SQLite health check failed: {e}") - return False + async with self.SessionLocal() as db: + try: + await db.execute(select(1)) + return True + except Exception as e: + logger.error(f"SQLite health check failed: {e}") + return False async def _mark_invalid(self, entry_id: int, db: AsyncSession): """Mark specific entry as invalid.""" diff --git a/src/rentcache/models.py b/src/rentcache/models.py index d50c768..004da2c 100644 --- a/src/rentcache/models.py +++ b/src/rentcache/models.py @@ -56,7 +56,10 @@ class CacheEntry(Base, TimestampMixin): # Cost tracking (if applicable) estimated_cost = Column(Float, default=0.0) - + + # Attribution - which Rentcast API key originally fetched this data + fetched_by_key_hash = Column(String(64), index=True, nullable=True) + __table_args__ = ( Index('idx_cache_valid_expires', 'is_valid', 'expires_at'), Index('idx_cache_endpoint_method', 'endpoint', 'method'), diff --git a/src/rentcache/server.py b/src/rentcache/server.py index cd59f4f..19bd80f 100644 --- a/src/rentcache/server.py +++ b/src/rentcache/server.py @@ -449,7 +449,8 @@ async def proxy_to_rentcast( response_data = response.json() if response.content else {} response_time = (time.time() - start_time) * 1000 - # Store in cache + # Store in cache with attribution to the key that fetched it + key_hash = hashlib.sha256(rentcast_key.encode()).hexdigest() if rentcast_key else None cache_entry_data = { "data": response_data, "status_code": response.status_code, @@ -458,7 +459,8 @@ async def proxy_to_rentcast( "method": method, "params": cache_data, "params_hash": hashlib.md5(json.dumps(cache_data, sort_keys=True).encode()).hexdigest(), - "estimated_cost": endpoint_config["cost_estimate"] + "estimated_cost": endpoint_config["cost_estimate"], + "fetched_by_key_hash": key_hash, } ttl = ttl_override or endpoint_config["ttl"] diff --git a/tests/conftest.py b/tests/conftest.py index eb1d46b..e3c33a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ Pytest configuration for RentCache tests. import asyncio import pytest import pytest_asyncio -from httpx import AsyncClient +from httpx import AsyncClient, ASGITransport from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.pool import StaticPool @@ -63,7 +63,8 @@ async def test_client(test_session): app.dependency_overrides[get_db] = override_get_db - async with AsyncClient(app=app, base_url="http://test") as client: + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear()