115 lines
4.6 KiB
Python
115 lines
4.6 KiB
Python
import logging
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch, PropertyMock
|
|
from datetime import datetime
|
|
from ExchangeInterface import ExchangeInterface
|
|
from Exchange import Exchange
|
|
from DataCache_v3 import DataCache
|
|
from typing import Dict, Any
|
|
|
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Trade:
|
|
"""
|
|
Mock Trade class to simulate trade objects used in the tests.
|
|
"""
|
|
|
|
def __init__(self, target, symbol, order_id):
|
|
self.target = target
|
|
self.symbol = symbol
|
|
self.order = MagicMock(orderId=order_id)
|
|
|
|
|
|
class TestExchangeInterface(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.cache_manager = DataCache()
|
|
self.exchange_interface = ExchangeInterface(self.cache_manager)
|
|
|
|
# Setup test data
|
|
self.user_name = "test_user"
|
|
self.exchange_name = "binance"
|
|
self.api_keys = {'key': 'test_key', 'secret': 'test_secret'}
|
|
|
|
# Mock trade object
|
|
self.trade = Trade(target=self.exchange_name, symbol="BTC/USDT", order_id="12345")
|
|
|
|
# Example order data
|
|
self.order_data: Dict[str, Any] = {
|
|
'status': 'closed',
|
|
'filled': 1.0,
|
|
'average': 50000.0
|
|
}
|
|
|
|
def test_get_trade_status(self):
|
|
"""Test getting trade status with mocked exchange."""
|
|
mock_exchange = MagicMock()
|
|
mock_exchange.get_order.return_value = self.order_data
|
|
|
|
# Mock get_exchange to return our mock
|
|
with patch.object(self.exchange_interface, 'get_exchange', return_value=mock_exchange):
|
|
status = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'status')
|
|
self.assertEqual(status, 'closed')
|
|
|
|
def test_get_trade_executed_qty(self):
|
|
"""Test getting executed quantity with mocked exchange."""
|
|
mock_exchange = MagicMock()
|
|
mock_exchange.get_order.return_value = self.order_data
|
|
|
|
with patch.object(self.exchange_interface, 'get_exchange', return_value=mock_exchange):
|
|
executed_qty = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'executed_qty')
|
|
self.assertEqual(executed_qty, 1.0)
|
|
|
|
def test_get_trade_executed_price(self):
|
|
"""Test getting executed price with mocked exchange."""
|
|
mock_exchange = MagicMock()
|
|
mock_exchange.get_order.return_value = self.order_data
|
|
|
|
with patch.object(self.exchange_interface, 'get_exchange', return_value=mock_exchange):
|
|
executed_price = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'executed_price')
|
|
self.assertEqual(executed_price, 50000.0)
|
|
|
|
def test_invalid_info_type(self):
|
|
"""Test invalid info type returns None."""
|
|
mock_exchange = MagicMock()
|
|
mock_exchange.get_order.return_value = self.order_data
|
|
|
|
with patch.object(self.exchange_interface, 'get_exchange', return_value=mock_exchange):
|
|
with self.assertLogs(level='ERROR') as log:
|
|
result = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'invalid_type')
|
|
self.assertIsNone(result)
|
|
self.assertTrue(any('Invalid info type' in message for message in log.output))
|
|
|
|
def test_order_not_found(self):
|
|
"""Test order not found returns None."""
|
|
mock_exchange = MagicMock()
|
|
mock_exchange.get_order.return_value = None
|
|
|
|
with patch.object(self.exchange_interface, 'get_exchange', return_value=mock_exchange):
|
|
with self.assertLogs(level='ERROR') as log:
|
|
result = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'status')
|
|
self.assertIsNone(result)
|
|
self.assertTrue(any('not found' in message for message in log.output))
|
|
|
|
def test_get_price_default_source(self):
|
|
"""Test get_price with default exchange."""
|
|
mock_exchange = MagicMock()
|
|
mock_exchange.get_price.return_value = 50000.0
|
|
|
|
with patch.object(self.exchange_interface, 'connect_default_exchange'):
|
|
self.exchange_interface.default_exchange = mock_exchange
|
|
price = self.exchange_interface.get_price("BTC/USDT")
|
|
self.assertEqual(price, 50000.0)
|
|
|
|
def test_get_price_with_invalid_exchange(self):
|
|
"""Test get_price with invalid exchange name returns 0."""
|
|
# Unknown exchange should return 0.0
|
|
price = self.exchange_interface.get_price("BTC/USDT", exchange_name="invalid_exchange_xyz")
|
|
self.assertEqual(price, 0.0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|