120 lines
4.8 KiB
Python
120 lines
4.8 KiB
Python
import logging
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
from datetime import datetime
|
|
from ExchangeInterface import ExchangeInterface
|
|
from Exchange import Exchange
|
|
from typing import Dict, Any
|
|
|
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Trade:
|
|
"""
|
|
Mock Trade class to simulate trade objects used in the tests.
|
|
"""
|
|
|
|
def __init__(self, target, symbol, order_id):
|
|
self.target = target
|
|
self.symbol = symbol
|
|
self.order = MagicMock(orderId=order_id)
|
|
|
|
|
|
class TestExchangeInterface(unittest.TestCase):
|
|
|
|
@patch('Exchange.Exchange')
|
|
def setUp(self, MockExchange):
|
|
|
|
self.exchange_interface = ExchangeInterface()
|
|
|
|
# Mock exchange instances
|
|
self.mock_exchange = MockExchange.return_value
|
|
|
|
# Setup test data
|
|
self.user_name = "test_user"
|
|
self.exchange_name = "binance"
|
|
self.api_keys = {'key': 'test_key', 'secret': 'test_secret'}
|
|
|
|
# Connect the mock exchange
|
|
self.exchange_interface.connect_exchange(self.exchange_name, self.user_name, self.api_keys)
|
|
|
|
# Mock trade object
|
|
self.trade = Trade(target=self.exchange_name, symbol="BTC/USDT", order_id="12345")
|
|
|
|
# Example order data
|
|
self.order_data: Dict[str, Any] = {
|
|
'status': 'closed',
|
|
'filled': 1.0,
|
|
'average': 50000.0
|
|
}
|
|
|
|
def test_get_trade_status(self):
|
|
self.mock_exchange.get_order.return_value = self.order_data
|
|
assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict
|
|
|
|
with self.assertLogs(level='ERROR') as log:
|
|
status = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'status')
|
|
if any('Must configure API keys' in message for message in log.output):
|
|
return
|
|
self.assertEqual(status, 'closed')
|
|
|
|
def test_get_trade_executed_qty(self):
|
|
self.mock_exchange.get_order.return_value = self.order_data
|
|
assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict
|
|
|
|
with self.assertLogs(level='ERROR') as log:
|
|
executed_qty = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'executed_qty')
|
|
if any('Must configure API keys' in message for message in log.output):
|
|
return
|
|
self.assertEqual(executed_qty, 1.0)
|
|
|
|
def test_get_trade_executed_price(self):
|
|
self.mock_exchange.get_order.return_value = self.order_data
|
|
assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict
|
|
|
|
with self.assertLogs(level='ERROR') as log:
|
|
executed_price = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'executed_price')
|
|
if any('Must configure API keys' in message for message in log.output):
|
|
return
|
|
self.assertEqual(executed_price, 50000.0)
|
|
|
|
def test_invalid_info_type(self):
|
|
self.mock_exchange.get_order.return_value = self.order_data
|
|
assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict
|
|
|
|
with self.assertLogs(level='ERROR') as log:
|
|
result = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'invalid_type')
|
|
if any('Must configure API keys' in message for message in log.output):
|
|
return
|
|
self.assertIsNone(result)
|
|
self.assertTrue(any('Invalid info type' in message for message in log.output))
|
|
|
|
def test_order_not_found(self):
|
|
self.mock_exchange.get_order.return_value = None
|
|
|
|
with self.assertLogs(level='ERROR') as log:
|
|
result = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'status')
|
|
if any('Must configure API keys' in message for message in log.output):
|
|
return
|
|
self.assertIsNone(result)
|
|
self.assertTrue(any('Order 12345 for BTC/USDT not found.' in message for message in log.output))
|
|
|
|
def test_get_price_default_source(self):
|
|
# Setup the mock to return a specific price
|
|
symbol = "BTC/USD"
|
|
price = self.exchange_interface.get_price(symbol)
|
|
|
|
self.assertLess(0.1, price)
|
|
|
|
def test_get_price_with_invalid_source(self):
|
|
symbol = "BTC/USD"
|
|
with self.assertRaises(ValueError) as context:
|
|
self.exchange_interface.get_price(symbol, price_source="invalid_source")
|
|
|
|
self.assertTrue('No implementation for price source: invalid_source' in str(context.exception))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|