"""Tests for WebSocket streaming API.""" import pytest import asyncio from unittest.mock import AsyncMock, MagicMock, patch from exchange_data_manager.api.websocket import WebSocketManager, ClientConnection, Subscription from exchange_data_manager.candles.models import Candle class MockWebSocket: """Mock WebSocket for testing.""" def __init__(self): self.accepted = False self.closed = False self.sent_messages = [] self.client_state = MagicMock() self.client_state.CONNECTED = "CONNECTED" async def accept(self): self.accepted = True async def close(self): self.closed = True async def send_json(self, data): self.sent_messages.append(data) async def receive_text(self): raise Exception("Not implemented in mock") class TestWebSocketManager: """Tests for WebSocketManager.""" @pytest.fixture def manager(self): return WebSocketManager() @pytest.fixture def mock_ws(self): ws = MockWebSocket() # Make it appear connected from starlette.websockets import WebSocketState ws.client_state = WebSocketState.CONNECTED return ws @pytest.mark.asyncio async def test_connect_accepts_websocket(self, manager, mock_ws): """Test that connecting accepts the WebSocket.""" await manager.connect(mock_ws) assert mock_ws.accepted assert mock_ws in manager._clients @pytest.mark.asyncio async def test_disconnect_removes_client(self, manager, mock_ws): """Test that disconnecting removes the client.""" await manager.connect(mock_ws) await manager.disconnect(mock_ws) assert mock_ws not in manager._clients @pytest.mark.asyncio async def test_handle_ping_responds_pong(self, manager, mock_ws): """Test that ping receives pong response.""" await manager.connect(mock_ws) await manager.handle_message(mock_ws, '{"action": "ping"}') assert len(mock_ws.sent_messages) == 1 assert mock_ws.sent_messages[0]["action"] == "pong" @pytest.mark.asyncio async def test_handle_invalid_json(self, manager, mock_ws): """Test handling invalid JSON.""" await manager.connect(mock_ws) await manager.handle_message(mock_ws, "not valid json") assert len(mock_ws.sent_messages) == 1 assert "error" in mock_ws.sent_messages[0] @pytest.mark.asyncio async def test_handle_unknown_action(self, manager, mock_ws): """Test handling unknown action.""" await manager.connect(mock_ws) await manager.handle_message(mock_ws, '{"action": "unknown"}') assert len(mock_ws.sent_messages) == 1 assert "error" in mock_ws.sent_messages[0] assert "valid_actions" in mock_ws.sent_messages[0] @pytest.mark.asyncio async def test_subscribe_missing_fields(self, manager, mock_ws): """Test subscribe with missing fields.""" await manager.connect(mock_ws) await manager.handle_message(mock_ws, '{"action": "subscribe", "exchange": "binance"}') assert len(mock_ws.sent_messages) == 1 assert "error" in mock_ws.sent_messages[0] assert "Missing required fields" in mock_ws.sent_messages[0]["error"] @pytest.mark.asyncio async def test_subscribe_invalid_exchange(self, manager, mock_ws): """Test subscribe to invalid exchange.""" await manager.connect(mock_ws) await manager.handle_message(mock_ws, '''{ "action": "subscribe", "exchange": "not_real_exchange", "symbol": "BTC/USDT", "timeframe": "1m" }''') assert len(mock_ws.sent_messages) == 1 assert "error" in mock_ws.sent_messages[0] assert "not supported" in mock_ws.sent_messages[0]["error"] @pytest.mark.asyncio async def test_subscribe_success(self, manager, mock_ws): """Test successful subscription.""" await manager.connect(mock_ws) # Mock the connector with patch.object(manager, 'get_or_create_connector') as mock_get_connector: mock_connector = MagicMock() mock_connector.subscribe = AsyncMock(return_value="sub_123") mock_get_connector.return_value = mock_connector await manager.handle_message(mock_ws, '''{ "action": "subscribe", "exchange": "binance", "symbol": "BTC/USDT", "timeframe": "1m" }''') assert len(mock_ws.sent_messages) == 1 assert mock_ws.sent_messages[0]["action"] == "subscribed" assert mock_ws.sent_messages[0]["stream"] == "binance:BTC/USDT:1m" # Verify subscription is tracked client = manager._clients[mock_ws] assert "binance:BTC/USDT:1m" in client.subscriptions @pytest.mark.asyncio async def test_unsubscribe_success(self, manager, mock_ws): """Test successful unsubscription.""" await manager.connect(mock_ws) # Subscribe first with patch.object(manager, 'get_or_create_connector') as mock_get_connector: mock_connector = MagicMock() mock_connector.subscribe = AsyncMock(return_value="sub_123") mock_connector.unsubscribe = AsyncMock() mock_get_connector.return_value = mock_connector manager._connectors["binance"] = mock_connector await manager.handle_message(mock_ws, '''{ "action": "subscribe", "exchange": "binance", "symbol": "BTC/USDT", "timeframe": "1m" }''') # Now unsubscribe await manager.handle_message(mock_ws, '''{ "action": "unsubscribe", "exchange": "binance", "symbol": "BTC/USDT", "timeframe": "1m" }''') assert mock_ws.sent_messages[-1]["action"] == "unsubscribed" # Verify subscription is removed client = manager._clients[mock_ws] assert "binance:BTC/USDT:1m" not in client.subscriptions @pytest.mark.asyncio async def test_multiple_clients_same_stream(self, manager): """Test multiple clients subscribing to the same stream.""" ws1 = MockWebSocket() ws2 = MockWebSocket() from starlette.websockets import WebSocketState ws1.client_state = WebSocketState.CONNECTED ws2.client_state = WebSocketState.CONNECTED await manager.connect(ws1) await manager.connect(ws2) with patch.object(manager, 'get_or_create_connector') as mock_get_connector: mock_connector = MagicMock() mock_connector.subscribe = AsyncMock(return_value="sub_123") mock_get_connector.return_value = mock_connector # Both clients subscribe to same stream msg = '''{ "action": "subscribe", "exchange": "binance", "symbol": "BTC/USDT", "timeframe": "1m" }''' await manager.handle_message(ws1, msg) await manager.handle_message(ws2, msg) # Only one polling subscription should be created assert mock_connector.subscribe.call_count == 1 # Both clients should be in subscribers stream_key = "binance:BTC/USDT:1m" assert ws1 in manager._stream_subscribers[stream_key] assert ws2 in manager._stream_subscribers[stream_key] @pytest.mark.asyncio async def test_broadcast_candle(self, manager): """Test broadcasting candle to subscribers.""" ws1 = MockWebSocket() ws2 = MockWebSocket() from starlette.websockets import WebSocketState ws1.client_state = WebSocketState.CONNECTED ws2.client_state = WebSocketState.CONNECTED await manager.connect(ws1) await manager.connect(ws2) # Manually set up subscription stream_key = "binance:BTC/USDT:1m" manager._stream_subscribers[stream_key] = {ws1, ws2} # Broadcast a candle candle = Candle( time=1709337600, open=50000.0, high=50100.0, low=49900.0, close=50050.0, volume=10.0, ) await manager._broadcast_candle(stream_key, candle) # Both should receive the candle assert len(ws1.sent_messages) == 1 assert len(ws2.sent_messages) == 1 assert ws1.sent_messages[0]["action"] == "candle" assert ws1.sent_messages[0]["data"]["close"] == 50050.0 @pytest.mark.asyncio async def test_shutdown_cleans_up(self, manager, mock_ws): """Test that shutdown cleans up all resources.""" await manager.connect(mock_ws) # Add a mock connector mock_connector = MagicMock() mock_connector.unsubscribe = AsyncMock() mock_connector.close = AsyncMock() manager._connectors["binance"] = mock_connector manager._active_streams["binance:BTC/USDT:1m"] = "sub_123" await manager.shutdown() assert len(manager._clients) == 0 assert len(manager._connectors) == 0 assert len(manager._active_streams) == 0 mock_connector.close.assert_called_once() class TestClientConnection: """Tests for ClientConnection dataclass.""" def test_subscription_key(self): """Test subscription key generation.""" mock_ws = MockWebSocket() conn = ClientConnection(websocket=mock_ws) key = conn.subscription_key("binance", "BTC/USDT", "1m") assert key == "binance:BTC/USDT:1m"