519 lines
18 KiB
Python
519 lines
18 KiB
Python
"""
|
|
Tests for paper trading state persistence.
|
|
|
|
These tests verify that:
|
|
1. PaperBroker can serialize/deserialize state
|
|
2. State persists via data_cache
|
|
3. PaperStrategyInstance restores state on restart
|
|
"""
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
import json
|
|
import pandas as pd
|
|
import uuid
|
|
|
|
|
|
class TestPaperBrokerSerialization:
|
|
"""Tests for PaperBroker state serialization."""
|
|
|
|
@pytest.fixture
|
|
def paper_broker(self):
|
|
"""Create a PaperBroker instance."""
|
|
# Import directly - eventlet is only needed when running with Flask/SocketIO
|
|
from brokers import PaperBroker
|
|
|
|
broker = PaperBroker(
|
|
initial_balance=10000.0,
|
|
commission=0.001,
|
|
slippage=0.0005
|
|
)
|
|
return broker
|
|
|
|
def test_to_state_dict_empty_broker(self, paper_broker):
|
|
"""Test serialization of empty broker."""
|
|
state = paper_broker.to_state_dict()
|
|
|
|
assert state['cash'] == 10000.0
|
|
assert state['locked_balance'] == 0.0
|
|
assert state['orders'] == {}
|
|
assert state['positions'] == {}
|
|
assert state['trade_history'] == []
|
|
|
|
def test_to_state_dict_with_positions(self, paper_broker):
|
|
"""Test serialization with open positions."""
|
|
from brokers import OrderSide, OrderType
|
|
|
|
# Place a buy order that fills immediately
|
|
paper_broker.update_price('BTC/USDT', 50000.0)
|
|
result = paper_broker.place_order(
|
|
symbol='BTC/USDT',
|
|
side=OrderSide.BUY,
|
|
order_type=OrderType.MARKET,
|
|
size=0.1
|
|
)
|
|
|
|
# Verify order succeeded
|
|
assert result.success, f"Order failed: {result.message}"
|
|
assert result.filled_qty == 0.1
|
|
|
|
state = paper_broker.to_state_dict()
|
|
|
|
# Should have position
|
|
assert 'BTC/USDT' in state['positions'], f"No position created. Orders: {state['orders']}, Cash: {state['cash']}"
|
|
position = state['positions']['BTC/USDT']
|
|
assert position['size'] == 0.1
|
|
assert position['entry_price'] > 0
|
|
|
|
# Should have trade history
|
|
assert len(state['trade_history']) == 1
|
|
|
|
# Cash should be reduced
|
|
assert state['cash'] < 10000.0
|
|
|
|
def test_to_state_dict_with_pending_orders(self, paper_broker):
|
|
"""Test serialization with pending limit orders."""
|
|
from brokers import OrderSide, OrderType
|
|
|
|
paper_broker.update_price('BTC/USDT', 50000.0)
|
|
|
|
# Place a limit order that won't fill immediately
|
|
result = paper_broker.place_order(
|
|
symbol='BTC/USDT',
|
|
side=OrderSide.BUY,
|
|
order_type=OrderType.LIMIT,
|
|
size=0.1,
|
|
price=45000.0 # Below market, won't fill
|
|
)
|
|
|
|
state = paper_broker.to_state_dict()
|
|
|
|
# Should have open order
|
|
assert len(state['orders']) == 1
|
|
order = list(state['orders'].values())[0]
|
|
assert order['status'] == 'open'
|
|
assert order['price'] == 45000.0
|
|
|
|
# Funds should be locked
|
|
assert state['locked_balance'] > 0
|
|
|
|
def test_from_state_dict_restores_balances(self, paper_broker):
|
|
"""Test that from_state_dict restores balance correctly."""
|
|
state = {
|
|
'cash': 8500.0,
|
|
'locked_balance': 500.0,
|
|
'initial_balance': 10000.0,
|
|
'commission': 0.001,
|
|
'slippage': 0.0005,
|
|
'orders': {},
|
|
'positions': {},
|
|
'trade_history': [],
|
|
'current_prices': {}
|
|
}
|
|
|
|
paper_broker.from_state_dict(state)
|
|
|
|
assert paper_broker._cash == 8500.0
|
|
assert paper_broker._locked_balance == 500.0
|
|
|
|
def test_from_state_dict_restores_positions(self, paper_broker):
|
|
"""Test that from_state_dict restores positions."""
|
|
state = {
|
|
'cash': 5000.0,
|
|
'locked_balance': 0.0,
|
|
'orders': {},
|
|
'positions': {
|
|
'BTC/USDT': {
|
|
'symbol': 'BTC/USDT',
|
|
'size': 0.1,
|
|
'entry_price': 50000.0,
|
|
'current_price': 52000.0,
|
|
'unrealized_pnl': 200.0,
|
|
'realized_pnl': 50.0
|
|
}
|
|
},
|
|
'trade_history': [
|
|
{'order_id': 'test-1', 'symbol': 'BTC/USDT', 'side': 'buy', 'size': 0.1}
|
|
],
|
|
'current_prices': {'BTC/USDT': 52000.0}
|
|
}
|
|
|
|
paper_broker.from_state_dict(state)
|
|
|
|
assert len(paper_broker._positions) == 1
|
|
position = paper_broker.get_position('BTC/USDT')
|
|
assert position is not None
|
|
assert position.size == 0.1
|
|
assert position.entry_price == 50000.0
|
|
assert position.realized_pnl == 50.0
|
|
|
|
assert len(paper_broker._trade_history) == 1
|
|
assert paper_broker._current_prices['BTC/USDT'] == 52000.0
|
|
|
|
def test_from_state_dict_restores_orders(self, paper_broker):
|
|
"""Test that from_state_dict restores pending orders."""
|
|
state = {
|
|
'cash': 5000.0,
|
|
'locked_balance': 4500.0,
|
|
'orders': {
|
|
'order-1': {
|
|
'order_id': 'order-1',
|
|
'symbol': 'BTC/USDT',
|
|
'side': 'buy',
|
|
'order_type': 'limit',
|
|
'size': 0.1,
|
|
'price': 45000.0,
|
|
'status': 'open',
|
|
'filled_qty': 0.0,
|
|
'filled_price': 0.0,
|
|
'commission': 0.0,
|
|
'created_at': '2024-01-01T00:00:00+00:00',
|
|
'filled_at': None
|
|
}
|
|
},
|
|
'positions': {},
|
|
'trade_history': [],
|
|
'current_prices': {}
|
|
}
|
|
|
|
paper_broker.from_state_dict(state)
|
|
|
|
assert len(paper_broker._orders) == 1
|
|
order = paper_broker._orders.get('order-1')
|
|
assert order is not None
|
|
assert order.symbol == 'BTC/USDT'
|
|
assert order.price == 45000.0
|
|
assert order.status.value == 'open'
|
|
|
|
def test_roundtrip_serialization(self, paper_broker):
|
|
"""Test that serialization followed by deserialization preserves state."""
|
|
from brokers import OrderSide, OrderType
|
|
|
|
# Set up some state
|
|
paper_broker.update_price('BTC/USDT', 50000.0)
|
|
paper_broker.place_order(
|
|
symbol='BTC/USDT',
|
|
side=OrderSide.BUY,
|
|
order_type=OrderType.MARKET,
|
|
size=0.1
|
|
)
|
|
|
|
# Serialize
|
|
state = paper_broker.to_state_dict()
|
|
|
|
# Create new broker and restore
|
|
new_broker = type(paper_broker)(
|
|
initial_balance=10000.0,
|
|
commission=0.001,
|
|
slippage=0.0005
|
|
)
|
|
new_broker.from_state_dict(state)
|
|
|
|
# Verify state matches
|
|
assert new_broker._cash == paper_broker._cash
|
|
assert len(new_broker._positions) == len(paper_broker._positions)
|
|
assert len(new_broker._trade_history) == len(paper_broker._trade_history)
|
|
|
|
|
|
class TestPaperBrokerCachePersistence:
|
|
"""Tests for PaperBroker data_cache persistence."""
|
|
|
|
@pytest.fixture
|
|
def mock_data_cache(self):
|
|
"""Create a mock data cache."""
|
|
cache = MagicMock()
|
|
cache.create_cache = MagicMock()
|
|
cache.get_rows_from_datacache = MagicMock(return_value=pd.DataFrame())
|
|
cache.insert_row_into_datacache = MagicMock()
|
|
cache.modify_datacache_item = MagicMock()
|
|
return cache
|
|
|
|
@pytest.fixture
|
|
def paper_broker_with_cache(self, mock_data_cache):
|
|
"""Create a PaperBroker with mock data cache."""
|
|
with patch.dict('sys.modules', {'eventlet': MagicMock()}):
|
|
from brokers import PaperBroker
|
|
|
|
broker = PaperBroker(
|
|
initial_balance=10000.0,
|
|
data_cache=mock_data_cache
|
|
)
|
|
return broker
|
|
|
|
def test_save_state_inserts_new(self, paper_broker_with_cache, mock_data_cache):
|
|
"""Test save_state inserts when no existing state."""
|
|
mock_data_cache.get_rows_from_datacache.return_value = pd.DataFrame()
|
|
|
|
result = paper_broker_with_cache.save_state('test-instance-1')
|
|
|
|
assert result is True
|
|
mock_data_cache.insert_row_into_datacache.assert_called_once()
|
|
call_args = mock_data_cache.insert_row_into_datacache.call_args
|
|
|
|
assert call_args[1]['cache_name'] == 'paper_broker_states'
|
|
assert 'test-instance-1' in call_args[1]['values']
|
|
|
|
def test_save_state_updates_existing(self, paper_broker_with_cache, mock_data_cache):
|
|
"""Test save_state updates when state exists."""
|
|
# Return existing row
|
|
mock_data_cache.get_rows_from_datacache.return_value = pd.DataFrame([{
|
|
'strategy_instance_id': 'test-instance-1',
|
|
'broker_state': '{}',
|
|
'updated_at': '2024-01-01T00:00:00'
|
|
}])
|
|
|
|
result = paper_broker_with_cache.save_state('test-instance-1')
|
|
|
|
assert result is True
|
|
mock_data_cache.modify_datacache_item.assert_called_once()
|
|
call_args = mock_data_cache.modify_datacache_item.call_args
|
|
|
|
assert call_args[1]['cache_name'] == 'paper_broker_states'
|
|
|
|
def test_load_state_returns_false_when_empty(self, paper_broker_with_cache, mock_data_cache):
|
|
"""Test load_state returns False when no saved state."""
|
|
mock_data_cache.get_rows_from_datacache.return_value = pd.DataFrame()
|
|
|
|
result = paper_broker_with_cache.load_state('nonexistent')
|
|
|
|
assert result is False
|
|
# Broker should still have initial balance
|
|
assert paper_broker_with_cache._cash == 10000.0
|
|
|
|
def test_load_state_restores_from_cache(self, paper_broker_with_cache, mock_data_cache):
|
|
"""Test load_state restores state from cache."""
|
|
saved_state = {
|
|
'cash': 8000.0,
|
|
'locked_balance': 0.0,
|
|
'orders': {},
|
|
'positions': {
|
|
'ETH/USDT': {
|
|
'symbol': 'ETH/USDT',
|
|
'size': 1.0,
|
|
'entry_price': 3000.0,
|
|
'current_price': 3200.0,
|
|
'unrealized_pnl': 200.0,
|
|
'realized_pnl': 0.0
|
|
}
|
|
},
|
|
'trade_history': [{'order_id': 'order-123'}],
|
|
'current_prices': {'ETH/USDT': 3200.0}
|
|
}
|
|
|
|
mock_data_cache.get_rows_from_datacache.return_value = pd.DataFrame([{
|
|
'strategy_instance_id': 'test-instance',
|
|
'broker_state': json.dumps(saved_state),
|
|
'updated_at': '2024-01-01T00:00:00'
|
|
}])
|
|
|
|
result = paper_broker_with_cache.load_state('test-instance')
|
|
|
|
assert result is True
|
|
assert paper_broker_with_cache._cash == 8000.0
|
|
assert len(paper_broker_with_cache._positions) == 1
|
|
assert paper_broker_with_cache.get_position('ETH/USDT') is not None
|
|
|
|
def test_save_state_without_cache_returns_false(self):
|
|
"""Test save_state returns False when no data cache."""
|
|
with patch.dict('sys.modules', {'eventlet': MagicMock()}):
|
|
from brokers import PaperBroker
|
|
|
|
broker = PaperBroker(initial_balance=10000.0, data_cache=None)
|
|
result = broker.save_state('test')
|
|
|
|
assert result is False
|
|
|
|
def test_save_and_load_state_with_real_datacache(self):
|
|
"""Test persistence with the real DataCache implementation."""
|
|
with patch.dict('sys.modules', {'eventlet': MagicMock()}):
|
|
from DataCache_v3 import DataCache
|
|
from brokers import PaperBroker, OrderSide, OrderType
|
|
|
|
cache = DataCache()
|
|
strategy_instance_id = f"persist-{uuid.uuid4()}"
|
|
|
|
broker = PaperBroker(initial_balance=10000.0, data_cache=cache)
|
|
broker.update_price('BTC/USDT', 50000.0)
|
|
buy_result = broker.place_order(
|
|
symbol='BTC/USDT',
|
|
side=OrderSide.BUY,
|
|
order_type=OrderType.MARKET,
|
|
size=0.1
|
|
)
|
|
assert buy_result.success is True
|
|
assert broker.save_state(strategy_instance_id) is True
|
|
|
|
restored = PaperBroker(initial_balance=10000.0, data_cache=cache)
|
|
assert restored.load_state(strategy_instance_id) is True
|
|
restored_position = restored.get_position('BTC/USDT')
|
|
assert restored_position is not None
|
|
assert restored_position.size > 0
|
|
|
|
|
|
class TestPaperStrategyInstancePersistence:
|
|
"""Tests for PaperStrategyInstance state persistence."""
|
|
|
|
@pytest.fixture
|
|
def mock_data_cache(self):
|
|
"""Create a mock data cache."""
|
|
cache = MagicMock()
|
|
cache.create_cache = MagicMock()
|
|
cache.get_rows_from_datacache = MagicMock(return_value=pd.DataFrame())
|
|
cache.insert_row_into_datacache = MagicMock()
|
|
cache.modify_datacache_item = MagicMock()
|
|
return cache
|
|
|
|
def test_save_context_saves_broker_state(self, mock_data_cache):
|
|
"""Test that save_context saves paper broker state."""
|
|
with patch.dict('sys.modules', {'eventlet': MagicMock()}):
|
|
from paper_strategy_instance import PaperStrategyInstance
|
|
|
|
with patch.object(PaperStrategyInstance, '__init__', lambda x: None):
|
|
instance = PaperStrategyInstance()
|
|
|
|
# Set up minimal attributes
|
|
instance.strategy_instance_id = 'test-instance'
|
|
instance.strategy_id = 'test-strategy'
|
|
instance.flags = {}
|
|
instance.variables = {}
|
|
instance.profit_loss = 0.0
|
|
instance.active = True
|
|
instance.paused = False
|
|
instance.exit = False
|
|
instance.exit_method = 'all'
|
|
instance.start_time = MagicMock()
|
|
instance.start_time.isoformat.return_value = '2024-01-01T00:00:00'
|
|
instance.starting_balance = 10000.0
|
|
instance.current_balance = 9500.0
|
|
instance.available_balance = 9500.0
|
|
instance.available_strategy_balance = 9500.0
|
|
|
|
# Mock paper broker
|
|
instance.paper_broker = MagicMock()
|
|
instance.paper_broker.get_balance.return_value = 9500.0
|
|
instance.paper_broker.get_available_balance.return_value = 9500.0
|
|
|
|
# Mock data cache
|
|
instance.data_cache = mock_data_cache
|
|
|
|
# Mock exec_context
|
|
instance.exec_context = {}
|
|
|
|
# Call save_context
|
|
instance.save_context()
|
|
|
|
# Verify broker state was saved
|
|
instance.paper_broker.save_state.assert_called_once_with('test-instance')
|
|
|
|
def test_init_loads_broker_state_if_exists(self, mock_data_cache):
|
|
"""Test that __init__ attempts to load broker state."""
|
|
# Set up cache to return saved state
|
|
saved_broker_state = {
|
|
'cash': 8500.0,
|
|
'locked_balance': 0.0,
|
|
'orders': {},
|
|
'positions': {},
|
|
'trade_history': [],
|
|
'current_prices': {}
|
|
}
|
|
|
|
def mock_get_rows(cache_name, filter_vals=None):
|
|
if cache_name == 'paper_broker_states':
|
|
return pd.DataFrame([{
|
|
'strategy_instance_id': 'test-instance',
|
|
'broker_state': json.dumps(saved_broker_state),
|
|
'updated_at': '2024-01-01T00:00:00'
|
|
}])
|
|
elif cache_name == 'strategy_contexts':
|
|
return pd.DataFrame() # No strategy context
|
|
return pd.DataFrame()
|
|
|
|
mock_data_cache.get_rows_from_datacache.side_effect = mock_get_rows
|
|
|
|
with patch.dict('sys.modules', {'eventlet': MagicMock()}):
|
|
from paper_strategy_instance import PaperStrategyInstance
|
|
|
|
instance = PaperStrategyInstance(
|
|
strategy_instance_id='test-instance',
|
|
strategy_id='test-strategy',
|
|
strategy_name='Test Strategy',
|
|
user_id=1,
|
|
generated_code='def next(): pass',
|
|
data_cache=mock_data_cache,
|
|
indicators=MagicMock(),
|
|
trades=MagicMock(),
|
|
initial_balance=10000.0
|
|
)
|
|
|
|
# Balance should reflect loaded state
|
|
assert instance.current_balance == 8500.0
|
|
|
|
|
|
class TestPositionSerialization:
|
|
"""Tests for Position to_dict/from_dict."""
|
|
|
|
def test_position_to_dict(self):
|
|
"""Test Position.to_dict()."""
|
|
from brokers import Position
|
|
|
|
position = Position(
|
|
symbol='BTC/USDT',
|
|
size=0.5,
|
|
entry_price=50000.0,
|
|
current_price=52000.0,
|
|
unrealized_pnl=1000.0,
|
|
realized_pnl=250.0
|
|
)
|
|
|
|
d = position.to_dict()
|
|
|
|
assert d['symbol'] == 'BTC/USDT'
|
|
assert d['size'] == 0.5
|
|
assert d['entry_price'] == 50000.0
|
|
assert d['current_price'] == 52000.0
|
|
assert d['unrealized_pnl'] == 1000.0
|
|
assert d['realized_pnl'] == 250.0
|
|
|
|
def test_position_from_dict(self):
|
|
"""Test Position.from_dict()."""
|
|
from brokers import Position
|
|
|
|
data = {
|
|
'symbol': 'ETH/USDT',
|
|
'size': 2.0,
|
|
'entry_price': 3000.0,
|
|
'current_price': 3100.0,
|
|
'unrealized_pnl': 200.0,
|
|
'realized_pnl': 100.0
|
|
}
|
|
|
|
position = Position.from_dict(data)
|
|
|
|
assert position.symbol == 'ETH/USDT'
|
|
assert position.size == 2.0
|
|
assert position.entry_price == 3000.0
|
|
assert position.current_price == 3100.0
|
|
assert position.unrealized_pnl == 200.0
|
|
assert position.realized_pnl == 100.0
|
|
|
|
def test_position_roundtrip(self):
|
|
"""Test Position roundtrip serialization."""
|
|
from brokers import Position
|
|
|
|
original = Position(
|
|
symbol='SOL/USDT',
|
|
size=10.0,
|
|
entry_price=100.0,
|
|
current_price=110.0,
|
|
unrealized_pnl=100.0,
|
|
realized_pnl=0.0
|
|
)
|
|
|
|
restored = Position.from_dict(original.to_dict())
|
|
|
|
assert restored.symbol == original.symbol
|
|
assert restored.size == original.size
|
|
assert restored.entry_price == original.entry_price
|
|
assert restored.current_price == original.current_price
|
|
assert restored.unrealized_pnl == original.unrealized_pnl
|
|
assert restored.realized_pnl == original.realized_pnl
|