brighter-trading/tests/test_exchangeinterface.py

120 lines
4.8 KiB
Python

import logging
import unittest
from unittest.mock import MagicMock, patch
from datetime import datetime
from ExchangeInterface import ExchangeInterface
from Exchange import Exchange
from typing import Dict, Any
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class Trade:
"""
Mock Trade class to simulate trade objects used in the tests.
"""
def __init__(self, target, symbol, order_id):
self.target = target
self.symbol = symbol
self.order = MagicMock(orderId=order_id)
class TestExchangeInterface(unittest.TestCase):
@patch('Exchange.Exchange')
def setUp(self, MockExchange):
self.exchange_interface = ExchangeInterface()
# Mock exchange instances
self.mock_exchange = MockExchange.return_value
# Setup test data
self.user_name = "test_user"
self.exchange_name = "binance"
self.api_keys = {'key': 'test_key', 'secret': 'test_secret'}
# Connect the mock exchange
self.exchange_interface.connect_exchange(self.exchange_name, self.user_name, self.api_keys)
# Mock trade object
self.trade = Trade(target=self.exchange_name, symbol="BTC/USDT", order_id="12345")
# Example order data
self.order_data: Dict[str, Any] = {
'status': 'closed',
'filled': 1.0,
'average': 50000.0
}
def test_get_trade_status(self):
self.mock_exchange.get_order.return_value = self.order_data
assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict
with self.assertLogs(level='ERROR') as log:
status = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'status')
if any('Must configure API keys' in message for message in log.output):
return
self.assertEqual(status, 'closed')
def test_get_trade_executed_qty(self):
self.mock_exchange.get_order.return_value = self.order_data
assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict
with self.assertLogs(level='ERROR') as log:
executed_qty = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'executed_qty')
if any('Must configure API keys' in message for message in log.output):
return
self.assertEqual(executed_qty, 1.0)
def test_get_trade_executed_price(self):
self.mock_exchange.get_order.return_value = self.order_data
assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict
with self.assertLogs(level='ERROR') as log:
executed_price = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'executed_price')
if any('Must configure API keys' in message for message in log.output):
return
self.assertEqual(executed_price, 50000.0)
def test_invalid_info_type(self):
self.mock_exchange.get_order.return_value = self.order_data
assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict
with self.assertLogs(level='ERROR') as log:
result = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'invalid_type')
if any('Must configure API keys' in message for message in log.output):
return
self.assertIsNone(result)
self.assertTrue(any('Invalid info type' in message for message in log.output))
def test_order_not_found(self):
self.mock_exchange.get_order.return_value = None
with self.assertLogs(level='ERROR') as log:
result = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'status')
if any('Must configure API keys' in message for message in log.output):
return
self.assertIsNone(result)
self.assertTrue(any('Order 12345 for BTC/USDT not found.' in message for message in log.output))
def test_get_price_default_source(self):
# Setup the mock to return a specific price
symbol = "BTC/USD"
price = self.exchange_interface.get_price(symbol)
self.assertLess(0.1, price)
def test_get_price_with_invalid_source(self):
symbol = "BTC/USD"
with self.assertRaises(ValueError) as context:
self.exchange_interface.get_price(symbol, price_source="invalid_source")
self.assertTrue('No implementation for price source: invalid_source' in str(context.exception))
if __name__ == '__main__':
unittest.main()