diff --git a/src/exchange_data_manager/api/rest.py b/src/exchange_data_manager/api/rest.py index 906840a..fb9d8b1 100644 --- a/src/exchange_data_manager/api/rest.py +++ b/src/exchange_data_manager/api/rest.py @@ -3,6 +3,7 @@ REST API using FastAPI with WebSocket support for streaming. """ import logging +import time from typing import Optional, List from contextlib import asynccontextmanager @@ -15,7 +16,7 @@ from ..cache.manager import CacheManager from ..candles.models import CandleRequest, Candle from ..exchanges import CONNECTOR_REGISTRY, CCXTConnector from ..sessions import SessionManager -from ..monitoring import HealthChecker, MetricsCollector, CacheSource +from ..monitoring import HealthChecker, MetricsCollector, CacheSource, RequestMetrics from .websocket import ws_manager logger = logging.getLogger(__name__) @@ -142,6 +143,10 @@ async def lifespan(app: FastAPI): if connector: await connector.close() + # Close cache resources (e.g., async DB connection pool) + if cache_manager: + await cache_manager.close() + logger.info("Exchange Data Manager stopped") @@ -279,6 +284,10 @@ async def get_candles( start: Optional[int] = Query(None, description="Start timestamp (Unix seconds)"), end: Optional[int] = Query(None, description="End timestamp (Unix seconds)"), limit: Optional[int] = Query(100, description="Maximum candles to return", le=1000), + session_id: Optional[str] = Query( + None, + description="Optional session ID for per-session exchange credentials", + ), ): """ Get historical candle data. @@ -297,8 +306,43 @@ async def get_candles( limit=limit, ) + request_start = time.perf_counter() + connector_override = None + try: - candles = await cache_manager.get_candles(request) + if session_id is not None: + if session_manager is None: + raise HTTPException(status_code=503, detail="Session manager not initialized") + + connector_override = session_manager.get_session_connector( + session_id=session_id, + exchange=request.exchange, + ) + if connector_override is None: + raise HTTPException( + status_code=404, + detail=f"No active session connector for exchange '{request.exchange}'", + ) + + candles, source_name = await cache_manager.get_candles_with_source( + request, + connector_override=connector_override, + ) + + if metrics_collector is not None: + try: + source = CacheSource(source_name) + except ValueError: + source = CacheSource.EXCHANGE + metrics_collector.record_candle_request( + latency_ms=(time.perf_counter() - request_start) * 1000, + status_code=200, + cache_source=source, + exchange=request.exchange, + symbol=request.symbol, + timeframe=request.timeframe, + candle_count=len(candles), + ) return CandleResponse( exchange=request.exchange, @@ -307,7 +351,33 @@ async def get_candles( candles=[c.to_dict() for c in candles], count=len(candles), ) + except HTTPException as e: + if metrics_collector is not None: + metrics_collector.record( + RequestMetrics( + endpoint="/candles", + method="GET", + status_code=e.status_code, + latency_ms=(time.perf_counter() - request_start) * 1000, + exchange=request.exchange, + symbol=request.symbol, + timeframe=request.timeframe, + ) + ) + raise except Exception as e: + if metrics_collector is not None: + metrics_collector.record( + RequestMetrics( + endpoint="/candles", + method="GET", + status_code=500, + latency_ms=(time.perf_counter() - request_start) * 1000, + exchange=request.exchange, + symbol=request.symbol, + timeframe=request.timeframe, + ) + ) logger.error(f"Error fetching candles: {e}") raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/exchange_data_manager/cache/manager.py b/src/exchange_data_manager/cache/manager.py index e7107b1..ecb0fae 100644 --- a/src/exchange_data_manager/cache/manager.py +++ b/src/exchange_data_manager/cache/manager.py @@ -9,12 +9,12 @@ Implements three-tier caching with: import logging from typing import List, Optional, Tuple, Dict, Union, TYPE_CHECKING -from ..candles.models import Candle, CandleRequest, RequestMode +from ..candles.models import Candle, CandleRequest from ..config import CacheConfig, DatabaseConfig from .memory import MemoryCache from .database import DatabaseCache from .async_database import AsyncDatabaseCache -from .completeness import check_completeness, find_missing_ranges +from .completeness import check_completeness from .gap_filler import fill_gaps if TYPE_CHECKING: @@ -68,7 +68,11 @@ class CacheManager: self.database = database_cache self._use_async_db = isinstance(database_cache, AsyncDatabaseCache) elif use_async_db: - self.database = AsyncDatabaseCache(db_path=self._database_config.path) + self.database = AsyncDatabaseCache( + db_path=self._database_config.path, + pool_size=self._database_config.pool_size, + max_overflow=self._database_config.max_overflow, + ) self._use_async_db = True else: self.database = DatabaseCache(db_path=self._database_config.path) @@ -106,6 +110,25 @@ class CacheManager: Returns: List of candles (sorted by time ascending) """ + candles, _ = await self.get_candles_with_source(request) + return candles + + async def get_candles_with_source( + self, + request: CandleRequest, + connector_override: Optional["BaseExchangeConnector"] = None, + ) -> Tuple[List[Candle], str]: + """ + Get candles and identify which source satisfied the request. + + Args: + request: Candle request parameters + connector_override: Optional per-request connector (session-scoped) + + Returns: + Tuple of (candles, source) where source is one of: + "memory", "database", "exchange" + """ cache_key = request.cache_key all_candles: List[Candle] = [] @@ -118,7 +141,7 @@ class CacheManager: if memory_candles and not memory_gaps: # Complete data in memory logger.debug(f"Memory cache hit: {len(memory_candles)} candles") - return memory_candles + return memory_candles, "memory" all_candles.extend(memory_candles) gaps_to_fill = memory_gaps if memory_gaps else [(request.start, request.end)] @@ -126,7 +149,10 @@ class CacheManager: # Step 2: Handle cold-cache / limit-only requests # When start and end are both None, fetch directly from exchange if not gaps_to_fill or gaps_to_fill == [(None, None)]: - exchange_candles = await self._fetch_limit_only(request) + exchange_candles = await self._fetch_limit_only( + request, + connector_override=connector_override, + ) if exchange_candles: all_candles.extend(exchange_candles) # Sort, dedupe, and return @@ -143,7 +169,11 @@ class CacheManager: if request.limit and len(result) > request.limit: result = result[-request.limit:] - return result + source = "exchange" if exchange_candles else "memory" + return result, source + + db_hit = False + exchange_hit = False # Step 3: Check database for missing ranges for gap_start, gap_end in gaps_to_fill: @@ -173,17 +203,21 @@ class CacheManager: if db_candles: logger.debug(f"Database hit: {len(db_candles)} candles") all_candles.extend(db_candles) + db_hit = True # Store in memory for future requests self.memory.put(cache_key, db_candles) # Step 3: Fetch remaining gaps from exchange if db_gaps: exchange_candles = await self._fetch_from_exchange( - request, db_gaps + request, + db_gaps, + connector_override=connector_override, ) if exchange_candles: logger.debug(f"Exchange fetch: {len(exchange_candles)} candles") all_candles.extend(exchange_candles) + exchange_hit = True # Sort and deduplicate by_time = {c.time: c for c in all_candles} @@ -214,9 +248,21 @@ class CacheManager: if request.limit and len(result) > request.limit: result = result[-request.limit:] - return result + source = "memory" + if exchange_hit: + source = "exchange" + elif db_hit: + source = "database" + elif memory_candles: + source = "memory" - async def _fetch_limit_only(self, request: CandleRequest) -> List[Candle]: + return result, source + + async def _fetch_limit_only( + self, + request: CandleRequest, + connector_override: Optional["BaseExchangeConnector"] = None, + ) -> List[Candle]: """ Fetch candles when only limit is specified (no start/end). @@ -228,7 +274,7 @@ class CacheManager: Returns: List of fetched candles """ - connector = self._exchange_connectors.get(request.exchange) + connector = connector_override or self._exchange_connectors.get(request.exchange) if not connector: logger.warning(f"No connector registered for exchange: {request.exchange}") return [] @@ -274,7 +320,10 @@ class CacheManager: return [] async def _fetch_from_exchange( - self, request: CandleRequest, gaps: List[Tuple[int, int]] + self, + request: CandleRequest, + gaps: List[Tuple[int, int]], + connector_override: Optional["BaseExchangeConnector"] = None, ) -> List[Candle]: """ Fetch missing candles from exchange. @@ -286,7 +335,7 @@ class CacheManager: Returns: List of fetched candles """ - connector = self._exchange_connectors.get(request.exchange) + connector = connector_override or self._exchange_connectors.get(request.exchange) if not connector: logger.warning(f"No connector registered for exchange: {request.exchange}") return [] @@ -387,3 +436,8 @@ class CacheManager: "database": db_stats, "registered_exchanges": list(self._exchange_connectors.keys()), } + + async def close(self): + """Close async resources.""" + if self._use_async_db and hasattr(self.database, "close"): + await self.database.close() diff --git a/tests/test_cache_manager.py b/tests/test_cache_manager.py index 42e7883..56deaf3 100644 --- a/tests/test_cache_manager.py +++ b/tests/test_cache_manager.py @@ -4,6 +4,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock from exchange_data_manager.cache.manager import CacheManager +from exchange_data_manager.config import DatabaseConfig from exchange_data_manager.candles.models import Candle, CandleRequest @@ -116,6 +117,77 @@ class TestCacheManagerStats: assert isinstance(stats["registered_exchanges"], list) +class TestCacheManagerSources: + """Tests for source reporting and per-request connector overrides.""" + + @pytest.mark.asyncio + async def test_get_candles_with_source_reports_memory(self, cache_manager): + """Memory-only responses should report memory source.""" + candles = [ + Candle(time=1709337600, open=50000.0, high=50100.0, low=49900.0, close=50050.0, volume=10.0), + Candle(time=1709337660, open=50050.0, high=50200.0, low=50000.0, close=50150.0, volume=15.0), + ] + cache_manager.memory.put("binance:BTC/USDT:1m", candles) + + request = CandleRequest( + exchange="binance", + symbol="BTC/USDT", + timeframe="1m", + limit=2, + ) + + result, source = await cache_manager.get_candles_with_source(request) + + assert len(result) == 2 + assert source == "memory" + + @pytest.mark.asyncio + async def test_get_candles_with_source_uses_connector_override(self, cache_manager): + """A session-scoped connector override should be used when provided.""" + override_connector = MockConnector() + override_connector.fetch_candles.return_value = [ + Candle(time=1709337600, open=50000.0, high=50100.0, low=49900.0, close=50050.0, volume=10.0), + ] + + request = CandleRequest( + exchange="binance", + symbol="BTC/USDT", + timeframe="1m", + limit=1, + ) + + result, source = await cache_manager.get_candles_with_source( + request, + connector_override=override_connector, + ) + + assert len(result) == 1 + assert source == "exchange" + override_connector.fetch_candles.assert_called_once_with( + symbol="BTC/USDT", + timeframe="1m", + limit=1, + ) + + +class TestCacheManagerDatabaseConfig: + """Tests for database config wiring.""" + + def test_async_database_pool_config_is_wired(self): + """CacheManager should pass pool config into AsyncDatabaseCache.""" + manager = CacheManager( + database_config=DatabaseConfig( + path="./data/test-pool.db", + pool_size=7, + max_overflow=3, + ), + use_async_db=True, + ) + + assert manager.database._pool_size == 7 + assert manager.database._max_overflow == 3 + + class TestCacheManagerExchangeRegistration: """Tests for exchange connector registration.""" diff --git a/tests/test_rest_candles.py b/tests/test_rest_candles.py new file mode 100644 index 0000000..9acbcf2 --- /dev/null +++ b/tests/test_rest_candles.py @@ -0,0 +1,90 @@ +"""Unit tests for REST /candles handler wiring.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock +from fastapi import HTTPException + +from exchange_data_manager.api import rest +from exchange_data_manager.candles.models import Candle +from exchange_data_manager.monitoring import CacheSource + + +def _sample_candles(): + return [ + Candle( + time=1709337600, + open=50000.0, + high=50100.0, + low=49900.0, + close=50050.0, + volume=10.0, + ) + ] + + +@pytest.mark.asyncio +async def test_candles_uses_session_connector_and_records_metrics(monkeypatch): + """Session ID requests should use session connector and record metrics.""" + connector = MagicMock() + + mock_cache_manager = MagicMock() + mock_cache_manager.get_candles_with_source = AsyncMock( + return_value=(_sample_candles(), "exchange") + ) + + mock_session_manager = MagicMock() + mock_session_manager.get_session_connector.return_value = connector + + mock_metrics = MagicMock() + + monkeypatch.setattr(rest, "cache_manager", mock_cache_manager) + monkeypatch.setattr(rest, "session_manager", mock_session_manager) + monkeypatch.setattr(rest, "metrics_collector", mock_metrics) + + response = await rest.get_candles( + exchange="binance", + symbol="BTC/USDT", + timeframe="1m", + start=None, + end=None, + limit=1, + session_id="session-123", + ) + + assert response.count == 1 + mock_session_manager.get_session_connector.assert_called_once_with( + session_id="session-123", + exchange="binance", + ) + mock_cache_manager.get_candles_with_source.assert_awaited_once() + + kwargs = mock_metrics.record_candle_request.call_args.kwargs + assert kwargs["cache_source"] == CacheSource.EXCHANGE + assert kwargs["status_code"] == 200 + + +@pytest.mark.asyncio +async def test_candles_invalid_session_records_http_error_metric(monkeypatch): + """Invalid session connector should return 404 and record error metrics.""" + mock_cache_manager = MagicMock() + mock_session_manager = MagicMock() + mock_session_manager.get_session_connector.return_value = None + mock_metrics = MagicMock() + + monkeypatch.setattr(rest, "cache_manager", mock_cache_manager) + monkeypatch.setattr(rest, "session_manager", mock_session_manager) + monkeypatch.setattr(rest, "metrics_collector", mock_metrics) + + with pytest.raises(HTTPException) as exc: + await rest.get_candles( + exchange="binance", + symbol="BTC/USDT", + timeframe="1m", + start=None, + end=None, + limit=1, + session_id="missing-session", + ) + + assert exc.value.status_code == 404 + mock_metrics.record.assert_called_once()