276 lines
9.5 KiB
Python
276 lines
9.5 KiB
Python
"""Tests for WebSocket streaming API."""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from exchange_data_manager.api.websocket import WebSocketManager, ClientConnection, Subscription
|
|
from exchange_data_manager.candles.models import Candle
|
|
|
|
|
|
class MockWebSocket:
|
|
"""Mock WebSocket for testing."""
|
|
|
|
def __init__(self):
|
|
self.accepted = False
|
|
self.closed = False
|
|
self.sent_messages = []
|
|
self.client_state = MagicMock()
|
|
self.client_state.CONNECTED = "CONNECTED"
|
|
|
|
async def accept(self):
|
|
self.accepted = True
|
|
|
|
async def close(self):
|
|
self.closed = True
|
|
|
|
async def send_json(self, data):
|
|
self.sent_messages.append(data)
|
|
|
|
async def receive_text(self):
|
|
raise Exception("Not implemented in mock")
|
|
|
|
|
|
class TestWebSocketManager:
|
|
"""Tests for WebSocketManager."""
|
|
|
|
@pytest.fixture
|
|
def manager(self):
|
|
return WebSocketManager()
|
|
|
|
@pytest.fixture
|
|
def mock_ws(self):
|
|
ws = MockWebSocket()
|
|
# Make it appear connected
|
|
from starlette.websockets import WebSocketState
|
|
ws.client_state = WebSocketState.CONNECTED
|
|
return ws
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_accepts_websocket(self, manager, mock_ws):
|
|
"""Test that connecting accepts the WebSocket."""
|
|
await manager.connect(mock_ws)
|
|
|
|
assert mock_ws.accepted
|
|
assert mock_ws in manager._clients
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_removes_client(self, manager, mock_ws):
|
|
"""Test that disconnecting removes the client."""
|
|
await manager.connect(mock_ws)
|
|
await manager.disconnect(mock_ws)
|
|
|
|
assert mock_ws not in manager._clients
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_ping_responds_pong(self, manager, mock_ws):
|
|
"""Test that ping receives pong response."""
|
|
await manager.connect(mock_ws)
|
|
await manager.handle_message(mock_ws, '{"action": "ping"}')
|
|
|
|
assert len(mock_ws.sent_messages) == 1
|
|
assert mock_ws.sent_messages[0]["action"] == "pong"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_invalid_json(self, manager, mock_ws):
|
|
"""Test handling invalid JSON."""
|
|
await manager.connect(mock_ws)
|
|
await manager.handle_message(mock_ws, "not valid json")
|
|
|
|
assert len(mock_ws.sent_messages) == 1
|
|
assert "error" in mock_ws.sent_messages[0]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_unknown_action(self, manager, mock_ws):
|
|
"""Test handling unknown action."""
|
|
await manager.connect(mock_ws)
|
|
await manager.handle_message(mock_ws, '{"action": "unknown"}')
|
|
|
|
assert len(mock_ws.sent_messages) == 1
|
|
assert "error" in mock_ws.sent_messages[0]
|
|
assert "valid_actions" in mock_ws.sent_messages[0]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_missing_fields(self, manager, mock_ws):
|
|
"""Test subscribe with missing fields."""
|
|
await manager.connect(mock_ws)
|
|
await manager.handle_message(mock_ws, '{"action": "subscribe", "exchange": "binance"}')
|
|
|
|
assert len(mock_ws.sent_messages) == 1
|
|
assert "error" in mock_ws.sent_messages[0]
|
|
assert "Missing required fields" in mock_ws.sent_messages[0]["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_invalid_exchange(self, manager, mock_ws):
|
|
"""Test subscribe to invalid exchange."""
|
|
await manager.connect(mock_ws)
|
|
await manager.handle_message(mock_ws, '''{
|
|
"action": "subscribe",
|
|
"exchange": "not_real_exchange",
|
|
"symbol": "BTC/USDT",
|
|
"timeframe": "1m"
|
|
}''')
|
|
|
|
assert len(mock_ws.sent_messages) == 1
|
|
assert "error" in mock_ws.sent_messages[0]
|
|
assert "not supported" in mock_ws.sent_messages[0]["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_success(self, manager, mock_ws):
|
|
"""Test successful subscription."""
|
|
await manager.connect(mock_ws)
|
|
|
|
# Mock the connector
|
|
with patch.object(manager, 'get_or_create_connector') as mock_get_connector:
|
|
mock_connector = MagicMock()
|
|
mock_connector.subscribe = AsyncMock(return_value="sub_123")
|
|
mock_get_connector.return_value = mock_connector
|
|
|
|
await manager.handle_message(mock_ws, '''{
|
|
"action": "subscribe",
|
|
"exchange": "binance",
|
|
"symbol": "BTC/USDT",
|
|
"timeframe": "1m"
|
|
}''')
|
|
|
|
assert len(mock_ws.sent_messages) == 1
|
|
assert mock_ws.sent_messages[0]["action"] == "subscribed"
|
|
assert mock_ws.sent_messages[0]["stream"] == "binance:BTC/USDT:1m"
|
|
|
|
# Verify subscription is tracked
|
|
client = manager._clients[mock_ws]
|
|
assert "binance:BTC/USDT:1m" in client.subscriptions
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unsubscribe_success(self, manager, mock_ws):
|
|
"""Test successful unsubscription."""
|
|
await manager.connect(mock_ws)
|
|
|
|
# Subscribe first
|
|
with patch.object(manager, 'get_or_create_connector') as mock_get_connector:
|
|
mock_connector = MagicMock()
|
|
mock_connector.subscribe = AsyncMock(return_value="sub_123")
|
|
mock_connector.unsubscribe = AsyncMock()
|
|
mock_get_connector.return_value = mock_connector
|
|
manager._connectors["binance"] = mock_connector
|
|
|
|
await manager.handle_message(mock_ws, '''{
|
|
"action": "subscribe",
|
|
"exchange": "binance",
|
|
"symbol": "BTC/USDT",
|
|
"timeframe": "1m"
|
|
}''')
|
|
|
|
# Now unsubscribe
|
|
await manager.handle_message(mock_ws, '''{
|
|
"action": "unsubscribe",
|
|
"exchange": "binance",
|
|
"symbol": "BTC/USDT",
|
|
"timeframe": "1m"
|
|
}''')
|
|
|
|
assert mock_ws.sent_messages[-1]["action"] == "unsubscribed"
|
|
|
|
# Verify subscription is removed
|
|
client = manager._clients[mock_ws]
|
|
assert "binance:BTC/USDT:1m" not in client.subscriptions
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_clients_same_stream(self, manager):
|
|
"""Test multiple clients subscribing to the same stream."""
|
|
ws1 = MockWebSocket()
|
|
ws2 = MockWebSocket()
|
|
from starlette.websockets import WebSocketState
|
|
ws1.client_state = WebSocketState.CONNECTED
|
|
ws2.client_state = WebSocketState.CONNECTED
|
|
|
|
await manager.connect(ws1)
|
|
await manager.connect(ws2)
|
|
|
|
with patch.object(manager, 'get_or_create_connector') as mock_get_connector:
|
|
mock_connector = MagicMock()
|
|
mock_connector.subscribe = AsyncMock(return_value="sub_123")
|
|
mock_get_connector.return_value = mock_connector
|
|
|
|
# Both clients subscribe to same stream
|
|
msg = '''{
|
|
"action": "subscribe",
|
|
"exchange": "binance",
|
|
"symbol": "BTC/USDT",
|
|
"timeframe": "1m"
|
|
}'''
|
|
await manager.handle_message(ws1, msg)
|
|
await manager.handle_message(ws2, msg)
|
|
|
|
# Only one polling subscription should be created
|
|
assert mock_connector.subscribe.call_count == 1
|
|
|
|
# Both clients should be in subscribers
|
|
stream_key = "binance:BTC/USDT:1m"
|
|
assert ws1 in manager._stream_subscribers[stream_key]
|
|
assert ws2 in manager._stream_subscribers[stream_key]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_candle(self, manager):
|
|
"""Test broadcasting candle to subscribers."""
|
|
ws1 = MockWebSocket()
|
|
ws2 = MockWebSocket()
|
|
from starlette.websockets import WebSocketState
|
|
ws1.client_state = WebSocketState.CONNECTED
|
|
ws2.client_state = WebSocketState.CONNECTED
|
|
|
|
await manager.connect(ws1)
|
|
await manager.connect(ws2)
|
|
|
|
# Manually set up subscription
|
|
stream_key = "binance:BTC/USDT:1m"
|
|
manager._stream_subscribers[stream_key] = {ws1, ws2}
|
|
|
|
# Broadcast a candle
|
|
candle = Candle(
|
|
time=1709337600,
|
|
open=50000.0,
|
|
high=50100.0,
|
|
low=49900.0,
|
|
close=50050.0,
|
|
volume=10.0,
|
|
)
|
|
await manager._broadcast_candle(stream_key, candle)
|
|
|
|
# Both should receive the candle
|
|
assert len(ws1.sent_messages) == 1
|
|
assert len(ws2.sent_messages) == 1
|
|
assert ws1.sent_messages[0]["action"] == "candle"
|
|
assert ws1.sent_messages[0]["data"]["close"] == 50050.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_shutdown_cleans_up(self, manager, mock_ws):
|
|
"""Test that shutdown cleans up all resources."""
|
|
await manager.connect(mock_ws)
|
|
|
|
# Add a mock connector
|
|
mock_connector = MagicMock()
|
|
mock_connector.unsubscribe = AsyncMock()
|
|
mock_connector.close = AsyncMock()
|
|
manager._connectors["binance"] = mock_connector
|
|
manager._active_streams["binance:BTC/USDT:1m"] = "sub_123"
|
|
|
|
await manager.shutdown()
|
|
|
|
assert len(manager._clients) == 0
|
|
assert len(manager._connectors) == 0
|
|
assert len(manager._active_streams) == 0
|
|
mock_connector.close.assert_called_once()
|
|
|
|
|
|
class TestClientConnection:
|
|
"""Tests for ClientConnection dataclass."""
|
|
|
|
def test_subscription_key(self):
|
|
"""Test subscription key generation."""
|
|
mock_ws = MockWebSocket()
|
|
conn = ClientConnection(websocket=mock_ws)
|
|
|
|
key = conn.subscription_key("binance", "BTC/USDT", "1m")
|
|
assert key == "binance:BTC/USDT:1m"
|