exchange-data-manager/tests/test_websocket.py

276 lines
9.5 KiB
Python

"""Tests for WebSocket streaming API."""
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from exchange_data_manager.api.websocket import WebSocketManager, ClientConnection, Subscription
from exchange_data_manager.candles.models import Candle
class MockWebSocket:
"""Mock WebSocket for testing."""
def __init__(self):
self.accepted = False
self.closed = False
self.sent_messages = []
self.client_state = MagicMock()
self.client_state.CONNECTED = "CONNECTED"
async def accept(self):
self.accepted = True
async def close(self):
self.closed = True
async def send_json(self, data):
self.sent_messages.append(data)
async def receive_text(self):
raise Exception("Not implemented in mock")
class TestWebSocketManager:
"""Tests for WebSocketManager."""
@pytest.fixture
def manager(self):
return WebSocketManager()
@pytest.fixture
def mock_ws(self):
ws = MockWebSocket()
# Make it appear connected
from starlette.websockets import WebSocketState
ws.client_state = WebSocketState.CONNECTED
return ws
@pytest.mark.asyncio
async def test_connect_accepts_websocket(self, manager, mock_ws):
"""Test that connecting accepts the WebSocket."""
await manager.connect(mock_ws)
assert mock_ws.accepted
assert mock_ws in manager._clients
@pytest.mark.asyncio
async def test_disconnect_removes_client(self, manager, mock_ws):
"""Test that disconnecting removes the client."""
await manager.connect(mock_ws)
await manager.disconnect(mock_ws)
assert mock_ws not in manager._clients
@pytest.mark.asyncio
async def test_handle_ping_responds_pong(self, manager, mock_ws):
"""Test that ping receives pong response."""
await manager.connect(mock_ws)
await manager.handle_message(mock_ws, '{"action": "ping"}')
assert len(mock_ws.sent_messages) == 1
assert mock_ws.sent_messages[0]["action"] == "pong"
@pytest.mark.asyncio
async def test_handle_invalid_json(self, manager, mock_ws):
"""Test handling invalid JSON."""
await manager.connect(mock_ws)
await manager.handle_message(mock_ws, "not valid json")
assert len(mock_ws.sent_messages) == 1
assert "error" in mock_ws.sent_messages[0]
@pytest.mark.asyncio
async def test_handle_unknown_action(self, manager, mock_ws):
"""Test handling unknown action."""
await manager.connect(mock_ws)
await manager.handle_message(mock_ws, '{"action": "unknown"}')
assert len(mock_ws.sent_messages) == 1
assert "error" in mock_ws.sent_messages[0]
assert "valid_actions" in mock_ws.sent_messages[0]
@pytest.mark.asyncio
async def test_subscribe_missing_fields(self, manager, mock_ws):
"""Test subscribe with missing fields."""
await manager.connect(mock_ws)
await manager.handle_message(mock_ws, '{"action": "subscribe", "exchange": "binance"}')
assert len(mock_ws.sent_messages) == 1
assert "error" in mock_ws.sent_messages[0]
assert "Missing required fields" in mock_ws.sent_messages[0]["error"]
@pytest.mark.asyncio
async def test_subscribe_invalid_exchange(self, manager, mock_ws):
"""Test subscribe to invalid exchange."""
await manager.connect(mock_ws)
await manager.handle_message(mock_ws, '''{
"action": "subscribe",
"exchange": "not_real_exchange",
"symbol": "BTC/USDT",
"timeframe": "1m"
}''')
assert len(mock_ws.sent_messages) == 1
assert "error" in mock_ws.sent_messages[0]
assert "not supported" in mock_ws.sent_messages[0]["error"]
@pytest.mark.asyncio
async def test_subscribe_success(self, manager, mock_ws):
"""Test successful subscription."""
await manager.connect(mock_ws)
# Mock the connector
with patch.object(manager, 'get_or_create_connector') as mock_get_connector:
mock_connector = MagicMock()
mock_connector.subscribe = AsyncMock(return_value="sub_123")
mock_get_connector.return_value = mock_connector
await manager.handle_message(mock_ws, '''{
"action": "subscribe",
"exchange": "binance",
"symbol": "BTC/USDT",
"timeframe": "1m"
}''')
assert len(mock_ws.sent_messages) == 1
assert mock_ws.sent_messages[0]["action"] == "subscribed"
assert mock_ws.sent_messages[0]["stream"] == "binance:BTC/USDT:1m"
# Verify subscription is tracked
client = manager._clients[mock_ws]
assert "binance:BTC/USDT:1m" in client.subscriptions
@pytest.mark.asyncio
async def test_unsubscribe_success(self, manager, mock_ws):
"""Test successful unsubscription."""
await manager.connect(mock_ws)
# Subscribe first
with patch.object(manager, 'get_or_create_connector') as mock_get_connector:
mock_connector = MagicMock()
mock_connector.subscribe = AsyncMock(return_value="sub_123")
mock_connector.unsubscribe = AsyncMock()
mock_get_connector.return_value = mock_connector
manager._connectors["binance"] = mock_connector
await manager.handle_message(mock_ws, '''{
"action": "subscribe",
"exchange": "binance",
"symbol": "BTC/USDT",
"timeframe": "1m"
}''')
# Now unsubscribe
await manager.handle_message(mock_ws, '''{
"action": "unsubscribe",
"exchange": "binance",
"symbol": "BTC/USDT",
"timeframe": "1m"
}''')
assert mock_ws.sent_messages[-1]["action"] == "unsubscribed"
# Verify subscription is removed
client = manager._clients[mock_ws]
assert "binance:BTC/USDT:1m" not in client.subscriptions
@pytest.mark.asyncio
async def test_multiple_clients_same_stream(self, manager):
"""Test multiple clients subscribing to the same stream."""
ws1 = MockWebSocket()
ws2 = MockWebSocket()
from starlette.websockets import WebSocketState
ws1.client_state = WebSocketState.CONNECTED
ws2.client_state = WebSocketState.CONNECTED
await manager.connect(ws1)
await manager.connect(ws2)
with patch.object(manager, 'get_or_create_connector') as mock_get_connector:
mock_connector = MagicMock()
mock_connector.subscribe = AsyncMock(return_value="sub_123")
mock_get_connector.return_value = mock_connector
# Both clients subscribe to same stream
msg = '''{
"action": "subscribe",
"exchange": "binance",
"symbol": "BTC/USDT",
"timeframe": "1m"
}'''
await manager.handle_message(ws1, msg)
await manager.handle_message(ws2, msg)
# Only one polling subscription should be created
assert mock_connector.subscribe.call_count == 1
# Both clients should be in subscribers
stream_key = "binance:BTC/USDT:1m"
assert ws1 in manager._stream_subscribers[stream_key]
assert ws2 in manager._stream_subscribers[stream_key]
@pytest.mark.asyncio
async def test_broadcast_candle(self, manager):
"""Test broadcasting candle to subscribers."""
ws1 = MockWebSocket()
ws2 = MockWebSocket()
from starlette.websockets import WebSocketState
ws1.client_state = WebSocketState.CONNECTED
ws2.client_state = WebSocketState.CONNECTED
await manager.connect(ws1)
await manager.connect(ws2)
# Manually set up subscription
stream_key = "binance:BTC/USDT:1m"
manager._stream_subscribers[stream_key] = {ws1, ws2}
# Broadcast a candle
candle = Candle(
time=1709337600,
open=50000.0,
high=50100.0,
low=49900.0,
close=50050.0,
volume=10.0,
)
await manager._broadcast_candle(stream_key, candle)
# Both should receive the candle
assert len(ws1.sent_messages) == 1
assert len(ws2.sent_messages) == 1
assert ws1.sent_messages[0]["action"] == "candle"
assert ws1.sent_messages[0]["data"]["close"] == 50050.0
@pytest.mark.asyncio
async def test_shutdown_cleans_up(self, manager, mock_ws):
"""Test that shutdown cleans up all resources."""
await manager.connect(mock_ws)
# Add a mock connector
mock_connector = MagicMock()
mock_connector.unsubscribe = AsyncMock()
mock_connector.close = AsyncMock()
manager._connectors["binance"] = mock_connector
manager._active_streams["binance:BTC/USDT:1m"] = "sub_123"
await manager.shutdown()
assert len(manager._clients) == 0
assert len(manager._connectors) == 0
assert len(manager._active_streams) == 0
mock_connector.close.assert_called_once()
class TestClientConnection:
"""Tests for ClientConnection dataclass."""
def test_subscription_key(self):
"""Test subscription key generation."""
mock_ws = MockWebSocket()
conn = ClientConnection(websocket=mock_ws)
key = conn.subscription_key("binance", "BTC/USDT", "1m")
assert key == "binance:BTC/USDT:1m"