brighter-trading/tests/test_exchangeinterface.py

115 lines
4.6 KiB
Python

import logging
import unittest
from unittest.mock import MagicMock, patch, PropertyMock
from datetime import datetime
from ExchangeInterface import ExchangeInterface
from Exchange import Exchange
from DataCache_v3 import DataCache
from typing import Dict, Any
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class Trade:
"""
Mock Trade class to simulate trade objects used in the tests.
"""
def __init__(self, target, symbol, order_id):
self.target = target
self.symbol = symbol
self.order = MagicMock(orderId=order_id)
class TestExchangeInterface(unittest.TestCase):
def setUp(self):
self.cache_manager = DataCache()
self.exchange_interface = ExchangeInterface(self.cache_manager)
# Setup test data
self.user_name = "test_user"
self.exchange_name = "binance"
self.api_keys = {'key': 'test_key', 'secret': 'test_secret'}
# Mock trade object
self.trade = Trade(target=self.exchange_name, symbol="BTC/USDT", order_id="12345")
# Example order data
self.order_data: Dict[str, Any] = {
'status': 'closed',
'filled': 1.0,
'average': 50000.0
}
def test_get_trade_status(self):
"""Test getting trade status with mocked exchange."""
mock_exchange = MagicMock()
mock_exchange.get_order.return_value = self.order_data
# Mock get_exchange to return our mock
with patch.object(self.exchange_interface, 'get_exchange', return_value=mock_exchange):
status = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'status')
self.assertEqual(status, 'closed')
def test_get_trade_executed_qty(self):
"""Test getting executed quantity with mocked exchange."""
mock_exchange = MagicMock()
mock_exchange.get_order.return_value = self.order_data
with patch.object(self.exchange_interface, 'get_exchange', return_value=mock_exchange):
executed_qty = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'executed_qty')
self.assertEqual(executed_qty, 1.0)
def test_get_trade_executed_price(self):
"""Test getting executed price with mocked exchange."""
mock_exchange = MagicMock()
mock_exchange.get_order.return_value = self.order_data
with patch.object(self.exchange_interface, 'get_exchange', return_value=mock_exchange):
executed_price = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'executed_price')
self.assertEqual(executed_price, 50000.0)
def test_invalid_info_type(self):
"""Test invalid info type returns None."""
mock_exchange = MagicMock()
mock_exchange.get_order.return_value = self.order_data
with patch.object(self.exchange_interface, 'get_exchange', return_value=mock_exchange):
with self.assertLogs(level='ERROR') as log:
result = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'invalid_type')
self.assertIsNone(result)
self.assertTrue(any('Invalid info type' in message for message in log.output))
def test_order_not_found(self):
"""Test order not found returns None."""
mock_exchange = MagicMock()
mock_exchange.get_order.return_value = None
with patch.object(self.exchange_interface, 'get_exchange', return_value=mock_exchange):
with self.assertLogs(level='ERROR') as log:
result = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'status')
self.assertIsNone(result)
self.assertTrue(any('not found' in message for message in log.output))
def test_get_price_default_source(self):
"""Test get_price with default exchange."""
mock_exchange = MagicMock()
mock_exchange.get_price.return_value = 50000.0
with patch.object(self.exchange_interface, 'connect_default_exchange'):
self.exchange_interface.default_exchange = mock_exchange
price = self.exchange_interface.get_price("BTC/USDT")
self.assertEqual(price, 50000.0)
def test_get_price_with_invalid_exchange(self):
"""Test get_price with invalid exchange name returns 0."""
# Unknown exchange should return 0.0
price = self.exchange_interface.get_price("BTC/USDT", exchange_name="invalid_exchange_xyz")
self.assertEqual(price, 0.0)
if __name__ == '__main__':
unittest.main()