Completed unittests for Exchange.
This commit is contained in:
parent
0b1ad39476
commit
e601f8c23e
|
|
@ -115,7 +115,7 @@ class DataCache:
|
|||
# Check if the records in the cache go far enough back to satisfy the query.
|
||||
first_timestamp = query_satisfied(start_datetime=start_datetime,
|
||||
records=records,
|
||||
r_length=record_length)
|
||||
r_length_min=record_length)
|
||||
if first_timestamp:
|
||||
# The records didn't go far enough back if a timestamp was returned.
|
||||
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
||||
|
|
@ -129,7 +129,7 @@ class DataCache:
|
|||
self.update_candle_cache(additional_records, key)
|
||||
|
||||
# Check if the records received are up-to-date.
|
||||
last_timestamp = query_uptodate(records=records, r_length=record_length)
|
||||
last_timestamp = query_uptodate(records=records, r_length_min=record_length)
|
||||
|
||||
if last_timestamp:
|
||||
# The query was not up-to-date if a timestamp was returned.
|
||||
|
|
|
|||
160
src/Exchange.py
160
src/Exchange.py
|
|
@ -3,80 +3,58 @@ import pandas as pd
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Tuple, Dict, List, Union
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Exchange:
|
||||
# Class attribute for caching market data
|
||||
_market_cache = {}
|
||||
|
||||
def __init__(self, name: str, api_keys: Dict[str, str], exchange_id: str):
|
||||
self.name = name
|
||||
|
||||
# The API key for the exchange.
|
||||
self.api_key = api_keys['key'] if api_keys else None
|
||||
|
||||
# The API secret key for the exchange.
|
||||
self.api_key_secret = api_keys['secret'] if api_keys else None
|
||||
|
||||
# The exchange id for the exchange.
|
||||
self.exchange_id = exchange_id
|
||||
|
||||
# The connection to the exchange_interface.
|
||||
self.client: ccxt.Exchange = self._connect_exchange()
|
||||
|
||||
# Info on all symbols and exchange_interface rules.
|
||||
self.exchange_info = self._set_exchange_info()
|
||||
|
||||
# List of time intervals available for trading.
|
||||
self.intervals = self._set_avail_intervals()
|
||||
|
||||
# All symbols available for trading.
|
||||
self.symbols = self._set_symbols()
|
||||
|
||||
# Any non-zero balance-info for all assets.
|
||||
self.balances = self._set_balances()
|
||||
|
||||
# Dictionary of places after the decimal requires by the exchange_interface indexed by symbol.
|
||||
self.symbols_n_precision = {}
|
||||
|
||||
def _connect_exchange(self) -> ccxt.Exchange:
|
||||
"""
|
||||
Connects to the exchange_interface and sets a reference the client.
|
||||
Handles cases where api_keys might be None.
|
||||
"""
|
||||
exchange_class = getattr(ccxt, self.exchange_id)
|
||||
if not exchange_class:
|
||||
logger.error(f"Exchange {self.exchange_id} is not supported by CCXT.")
|
||||
raise ValueError(f"Exchange {self.exchange_id} is not supported by CCXT.")
|
||||
|
||||
logger.info(f"Connecting to exchange {self.exchange_id}.")
|
||||
if self.api_key and self.api_key_secret:
|
||||
return exchange_class({
|
||||
'apiKey': self.api_key,
|
||||
'secret': self.api_key_secret,
|
||||
'enableRateLimit': True,
|
||||
'verbose': False # Disable verbose debugging output
|
||||
'verbose': False
|
||||
})
|
||||
else:
|
||||
return exchange_class({
|
||||
'enableRateLimit': True,
|
||||
'verbose': False # Disable verbose debugging output
|
||||
'verbose': False
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def datetime_to_unix_millis(dt: datetime) -> int:
|
||||
"""
|
||||
Convert a datetime object to Unix timestamp in milliseconds.
|
||||
|
||||
:param dt: datetime - The datetime object to convert.
|
||||
:return: int - The Unix timestamp in milliseconds.
|
||||
"""
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
def _fetch_historical_klines(self, symbol: str, interval: str, start_dt: datetime,
|
||||
end_dt: datetime = None) -> pd.DataFrame:
|
||||
def _fetch_historical_klines(self, symbol: str, interval: str,
|
||||
start_dt: datetime, end_dt: datetime = None) -> pd.DataFrame:
|
||||
if end_dt is None:
|
||||
end_dt = datetime.utcnow()
|
||||
|
||||
max_interval = timedelta(days=200) # Binance's maximum interval
|
||||
max_interval = timedelta(days=200)
|
||||
data_frames = []
|
||||
current_start = start_dt
|
||||
|
||||
|
|
@ -86,9 +64,11 @@ class Exchange:
|
|||
end_str = self.datetime_to_unix_millis(current_end)
|
||||
|
||||
try:
|
||||
candles = self.client.fetch_ohlcv(symbol=symbol, timeframe=interval, since=start_str,
|
||||
params={'endTime': end_str})
|
||||
logger.info(f"Fetching OHLCV data for {symbol} from {current_start} to {current_end}.")
|
||||
candles = self.client.fetch_ohlcv(symbol=symbol, timeframe=interval,
|
||||
since=start_str, params={'endTime': end_str})
|
||||
if not candles:
|
||||
logger.warning(f"No OHLCV data returned for {symbol} from {current_start} to {current_end}.")
|
||||
break
|
||||
|
||||
df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume']
|
||||
|
|
@ -96,41 +76,60 @@ class Exchange:
|
|||
candles_df['open_time'] = candles_df['open_time'] // 1000
|
||||
data_frames.append(candles_df)
|
||||
|
||||
# Move the start to the end of the current chunk to get the next chunk
|
||||
current_start = current_end
|
||||
|
||||
# Sleep for 1 second to avoid hitting rate limits
|
||||
time.sleep(1)
|
||||
|
||||
except ccxt.BaseError as e:
|
||||
print(f"Error fetching OHLCV data: {str(e)}")
|
||||
logger.error(f"Error fetching OHLCV data for {symbol}: {str(e)}")
|
||||
break
|
||||
|
||||
if data_frames:
|
||||
result_df = pd.concat(data_frames)
|
||||
logger.info(f"Successfully fetched OHLCV data for {symbol}.")
|
||||
return result_df
|
||||
else:
|
||||
logger.warning(f"No OHLCV data fetched for {symbol}.")
|
||||
return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume'])
|
||||
|
||||
def _fetch_price(self, symbol: str) -> float:
|
||||
ticker = self.client.fetch_ticker(symbol)
|
||||
return float(ticker['last'])
|
||||
try:
|
||||
ticker = self.client.fetch_ticker(symbol)
|
||||
return float(ticker['last'])
|
||||
except ccxt.BaseError as e:
|
||||
logger.error(f"Error fetching price for {symbol}: {str(e)}")
|
||||
return 0.0
|
||||
|
||||
def _fetch_min_qty(self, symbol: str) -> float:
|
||||
market_data = self.exchange_info[symbol]
|
||||
return float(market_data['limits']['amount']['min'])
|
||||
try:
|
||||
market_data = self.exchange_info[symbol]
|
||||
return float(market_data['limits']['amount']['min'])
|
||||
except KeyError as e:
|
||||
logger.error(f"Error fetching minimum quantity for {symbol}: {str(e)}")
|
||||
return 0.0
|
||||
|
||||
def _fetch_min_notional_qty(self, symbol: str) -> float:
|
||||
market_data = self.exchange_info[symbol]
|
||||
return float(market_data['limits']['cost']['min'])
|
||||
try:
|
||||
market_data = self.exchange_info[symbol]
|
||||
return float(market_data['limits']['cost']['min'])
|
||||
except KeyError as e:
|
||||
logger.error(f"Error fetching minimum notional quantity for {symbol}: {str(e)}")
|
||||
return 0.0
|
||||
|
||||
def _fetch_order(self, symbol: str, order_id: str) -> object:
|
||||
return self.client.fetch_order(order_id, symbol)
|
||||
try:
|
||||
return self.client.fetch_order(order_id, symbol)
|
||||
except ccxt.BaseError as e:
|
||||
logger.error(f"Error fetching order {order_id} for {symbol}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _set_symbols(self) -> List[str]:
|
||||
markets = self.client.fetch_markets()
|
||||
symbols = [market['symbol'] for market in markets if market['active']]
|
||||
return symbols
|
||||
try:
|
||||
markets = self.client.fetch_markets()
|
||||
symbols = [market['symbol'] for market in markets if market['active']]
|
||||
return symbols
|
||||
except ccxt.BaseError as e:
|
||||
logger.error(f"Error fetching symbols: {str(e)}")
|
||||
return []
|
||||
|
||||
def _set_balances(self) -> List[Dict[str, Union[str, float]]]:
|
||||
if self.api_key and self.api_key_secret:
|
||||
|
|
@ -142,18 +141,13 @@ class Exchange:
|
|||
if asset_balance > 0:
|
||||
balances.append({'asset': asset, 'balance': asset_balance, 'pnl': 0})
|
||||
return balances
|
||||
except NotImplementedError:
|
||||
# Handle the case where fetch_balance is not supported
|
||||
except ccxt.BaseError as e:
|
||||
logger.error(f"Error fetching balances: {str(e)}")
|
||||
return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}]
|
||||
else:
|
||||
return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}]
|
||||
|
||||
def _set_exchange_info(self) -> dict:
|
||||
"""
|
||||
Fetches market data for all symbols from the exchange.
|
||||
This includes details on trading pairs, limits, fees, and other relevant information.
|
||||
Caches the market info to speed up subsequent calls.
|
||||
"""
|
||||
if self.exchange_id in Exchange._market_cache:
|
||||
return Exchange._market_cache[self.exchange_id]
|
||||
|
||||
|
|
@ -162,40 +156,35 @@ class Exchange:
|
|||
Exchange._market_cache[self.exchange_id] = markets_info
|
||||
return markets_info
|
||||
except ccxt.BaseError as e:
|
||||
print(f"Error fetching market info: {str(e)}")
|
||||
logger.error(f"Error fetching market info: {str(e)}")
|
||||
return {}
|
||||
|
||||
def get_client(self) -> object:
|
||||
""" Return a reference to the exchange_interface client."""
|
||||
return self.client
|
||||
|
||||
def get_avail_intervals(self) -> Tuple[str, ...]:
|
||||
"""Returns a list of time intervals available for trading."""
|
||||
return self.intervals
|
||||
|
||||
def get_exchange_info(self) -> dict:
|
||||
"""Returns Info on all symbols."""
|
||||
return self.exchange_info
|
||||
|
||||
def get_symbols(self) -> List[str]:
|
||||
"""Returns all symbols available for trading."""
|
||||
return self.symbols
|
||||
|
||||
def get_balances(self) -> List[Dict[str, Union[str, float]]]:
|
||||
"""Returns any non-zero balance-info for all assets."""
|
||||
return self.balances
|
||||
|
||||
def get_symbol_precision_rule(self, symbol: str) -> int:
|
||||
"""Returns number of places after the decimal required by the exchange_interface for a specific symbol."""
|
||||
r_value = self.symbols_n_precision.get(symbol)
|
||||
if r_value is None:
|
||||
self._set_precision_rule(symbol)
|
||||
r_value = self.symbols_n_precision.get(symbol)
|
||||
return r_value
|
||||
|
||||
def get_historical_klines(self, symbol: str, interval: str, start_dt: datetime,
|
||||
end_dt: datetime = None) -> pd.DataFrame:
|
||||
return self._fetch_historical_klines(symbol=symbol, interval=interval, start_dt=start_dt, end_dt=end_dt)
|
||||
def get_historical_klines(self, symbol: str, interval: str,
|
||||
start_dt: datetime, end_dt: datetime = None) -> pd.DataFrame:
|
||||
return self._fetch_historical_klines(symbol=symbol, interval=interval,
|
||||
start_dt=start_dt, end_dt=end_dt)
|
||||
|
||||
def get_price(self, symbol: str) -> float:
|
||||
return self._fetch_price(symbol)
|
||||
|
|
@ -209,26 +198,21 @@ class Exchange:
|
|||
def get_order(self, symbol: str, order_id: str) -> object:
|
||||
return self._fetch_order(symbol, order_id)
|
||||
|
||||
def place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> \
|
||||
Tuple[str, object]:
|
||||
result, msg = self._place_order(symbol=symbol, side=side, type=type, timeInForce=timeInForce, quantity=quantity,
|
||||
price=price)
|
||||
def place_order(self, symbol: str, side: str, type: str, timeInForce: str,
|
||||
quantity: float, price: float = None) -> Tuple[str, object]:
|
||||
result, msg = self._place_order(symbol=symbol, side=side, type=type,
|
||||
timeInForce=timeInForce, quantity=quantity, price=price)
|
||||
return result, msg
|
||||
|
||||
def _set_avail_intervals(self) -> Tuple[str, ...]:
|
||||
"""
|
||||
Sets a list of time intervals available for trading on the exchange_interface.
|
||||
"""
|
||||
return tuple(self.client.timeframes.keys())
|
||||
|
||||
def _set_precision_rule(self, symbol: str) -> None:
|
||||
market_data = self.exchange_info[symbol]
|
||||
precision = market_data['precision']['amount']
|
||||
self.symbols_n_precision[symbol] = precision
|
||||
return
|
||||
|
||||
def _place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> \
|
||||
Tuple[str, object]:
|
||||
def _place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> Tuple[str, object]:
|
||||
def format_arg(value: float) -> float:
|
||||
precision = self.symbols_n_precision.get(symbol, 8)
|
||||
return float(f"{value:.{precision}f}")
|
||||
|
|
@ -250,13 +234,14 @@ class Exchange:
|
|||
if price is not None:
|
||||
order_params['price'] = price
|
||||
|
||||
order = self.client.create_order(**order_params)
|
||||
return 'Success', order
|
||||
try:
|
||||
order = self.client.create_order(**order_params)
|
||||
return 'Success', order
|
||||
except ccxt.BaseError as e:
|
||||
logger.error(f"Error placing order for {symbol}: {str(e)}")
|
||||
return 'Failure', None
|
||||
|
||||
def get_active_trades(self) -> List[Dict[str, Union[str, float]]]:
|
||||
"""
|
||||
Get the active trades (open positions).
|
||||
"""
|
||||
if self.api_key and self.api_key_secret:
|
||||
try:
|
||||
positions = self.client.fetch_positions()
|
||||
|
|
@ -270,16 +255,13 @@ class Exchange:
|
|||
}
|
||||
formatted_trades.append(active_trade)
|
||||
return formatted_trades
|
||||
except NotImplementedError:
|
||||
# Handle the case where fetch_positions is not supported
|
||||
except ccxt.BaseError as e:
|
||||
logger.error(f"Error fetching active trades: {str(e)}")
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_open_orders(self) -> List[Dict[str, Union[str, float]]]:
|
||||
"""
|
||||
Get the open orders.
|
||||
"""
|
||||
if self.api_key and self.api_key_secret:
|
||||
try:
|
||||
open_orders = self.client.fetch_open_orders()
|
||||
|
|
@ -293,8 +275,8 @@ class Exchange:
|
|||
}
|
||||
formatted_orders.append(open_order)
|
||||
return formatted_orders
|
||||
except NotImplementedError:
|
||||
# Handle the case where fetch_balance is not supported
|
||||
except ccxt.BaseError as e:
|
||||
logger.error(f"Error fetching open orders: {str(e)}")
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -344,7 +344,7 @@ class Database:
|
|||
print(f'Got {len(records.index)} records from {ex_details[2]}')
|
||||
|
||||
# Check if the records in the db go far enough back to satisfy the query.
|
||||
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length=rl)
|
||||
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
|
||||
if first_timestamp:
|
||||
# The records didn't go far enough back if a timestamp was returned.
|
||||
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
|
||||
|
|
@ -355,7 +355,7 @@ class Database:
|
|||
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
|
||||
|
||||
# Check if the records received are up-to-date.
|
||||
last_timestamp = query_uptodate(records=records, r_length=rl)
|
||||
last_timestamp = query_uptodate(records=records, r_length_min=rl)
|
||||
if last_timestamp:
|
||||
# The query was not up-to-date if a timestamp was returned.
|
||||
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
|
||||
|
|
|
|||
|
|
@ -0,0 +1,193 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import ccxt
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
from Exchange import Exchange
|
||||
|
||||
|
||||
class TestExchange(unittest.TestCase):
|
||||
|
||||
@patch('ccxt.binance')
|
||||
def setUp(self, mock_exchange):
|
||||
self.mock_client = MagicMock()
|
||||
mock_exchange.return_value = self.mock_client
|
||||
|
||||
self.api_keys = {'key': 'test_key', 'secret': 'test_secret'}
|
||||
|
||||
self.mock_client.fetch_markets.return_value = [
|
||||
{'symbol': 'BTC/USDT', 'active': True},
|
||||
{'symbol': 'ETH/USDT', 'active': True},
|
||||
]
|
||||
self.mock_client.load_markets.return_value = {
|
||||
'BTC/USDT': {
|
||||
'limits': {
|
||||
'amount': {'min': 0.001},
|
||||
'cost': {'min': 10.0} # Ensure the cost limit is included
|
||||
},
|
||||
'precision': {
|
||||
'amount': 8,
|
||||
'price': 8
|
||||
},
|
||||
'active': True,
|
||||
'symbol': 'BTC/USDT'
|
||||
},
|
||||
'ETH/USDT': {
|
||||
'limits': {
|
||||
'amount': {'min': 0.01},
|
||||
'cost': {'min': 20.0} # Ensure the cost limit is included
|
||||
},
|
||||
'precision': {
|
||||
'amount': 8,
|
||||
'price': 8
|
||||
},
|
||||
'active': True,
|
||||
'symbol': 'ETH/USDT'
|
||||
}
|
||||
}
|
||||
self.mock_client.fetch_ticker.return_value = {'last': 30000.0}
|
||||
self.mock_client.fetch_ohlcv.return_value = [
|
||||
[1609459200000, 29000.0, 29500.0, 28800.0, 29400.0, 1000]
|
||||
]
|
||||
self.mock_client.fetch_balance.return_value = {
|
||||
'total': {'BTC': 1.0, 'USDT': 1000.0}
|
||||
}
|
||||
self.mock_client.create_order.return_value = {
|
||||
'id': 'test_order_id',
|
||||
'symbol': 'BTC/USDT',
|
||||
'type': 'limit',
|
||||
'side': 'buy',
|
||||
'price': 30000.0,
|
||||
'amount': 1.0,
|
||||
'status': 'open'
|
||||
}
|
||||
self.mock_client.fetch_open_orders.return_value = [
|
||||
{'id': 'test_order_id', 'symbol': 'BTC/USDT', 'side': 'buy', 'amount': 1.0, 'price': 30000.0}
|
||||
]
|
||||
self.mock_client.fetch_positions.return_value = [
|
||||
{'symbol': 'BTC/USDT', 'quantity': 1.0, 'entry_price': 29000.0}
|
||||
]
|
||||
self.mock_client.fetch_order.return_value = {
|
||||
'id': 'test_order_id',
|
||||
'symbol': 'BTC/USDT',
|
||||
'side': 'buy',
|
||||
'amount': 1.0,
|
||||
'price': 30000.0,
|
||||
'status': 'open'
|
||||
}
|
||||
self.mock_client.timeframes = {
|
||||
'1m': '1m',
|
||||
'5m': '5m',
|
||||
'1h': '1h',
|
||||
'1d': '1d',
|
||||
}
|
||||
|
||||
self.exchange = Exchange(name='Binance', api_keys=self.api_keys, exchange_id='binance')
|
||||
|
||||
def test_connect_exchange(self):
|
||||
self.assertIsInstance(self.exchange.client, MagicMock)
|
||||
self.mock_client.load_markets.assert_called_once()
|
||||
self.assertEqual(self.exchange.symbols, ['BTC/USDT', 'ETH/USDT'])
|
||||
self.assertIn('BTC/USDT', self.exchange.exchange_info)
|
||||
self.assertIn('ETH/USDT', self.exchange.exchange_info)
|
||||
self.assertEqual(self.exchange.balances, [
|
||||
{'asset': 'BTC', 'balance': 1.0, 'pnl': 0},
|
||||
{'asset': 'USDT', 'balance': 1000.0, 'pnl': 0}
|
||||
])
|
||||
|
||||
def test_get_symbols(self):
|
||||
symbols = self.exchange.get_symbols()
|
||||
self.assertEqual(symbols, ['BTC/USDT', 'ETH/USDT'])
|
||||
|
||||
def test_get_price(self):
|
||||
price = self.exchange.get_price('BTC/USDT')
|
||||
self.assertEqual(price, 30000.0)
|
||||
self.mock_client.fetch_ticker.assert_called_with('BTC/USDT')
|
||||
|
||||
def test_get_balances(self):
|
||||
balances = self.exchange.get_balances()
|
||||
self.assertEqual(balances, [
|
||||
{'asset': 'BTC', 'balance': 1.0, 'pnl': 0},
|
||||
{'asset': 'USDT', 'balance': 1000.0, 'pnl': 0}
|
||||
])
|
||||
self.mock_client.fetch_balance.assert_called_once()
|
||||
|
||||
def test_get_historical_klines(self):
|
||||
start_dt = datetime(2021, 1, 1)
|
||||
end_dt = datetime(2021, 1, 2)
|
||||
klines = self.exchange.get_historical_klines('BTC/USDT', '1d', start_dt, end_dt)
|
||||
expected_df = pd.DataFrame([
|
||||
{'open_time': 1609459200, 'open': 29000.0, 'high': 29500.0, 'low': 28800.0, 'close': 29400.0,
|
||||
'volume': 1000}
|
||||
])
|
||||
pd.testing.assert_frame_equal(klines, expected_df)
|
||||
self.mock_client.fetch_ohlcv.assert_called()
|
||||
|
||||
def test_get_min_qty(self):
|
||||
min_qty = self.exchange.get_min_qty('BTC/USDT')
|
||||
self.assertEqual(min_qty, 0.001)
|
||||
|
||||
def test_get_min_notional_qty(self):
|
||||
min_notional_qty = self.exchange.get_min_notional_qty('BTC/USDT')
|
||||
self.assertEqual(min_notional_qty, 10.0)
|
||||
|
||||
def test_get_order(self):
|
||||
order = self.exchange.get_order('BTC/USDT', 'test_order_id')
|
||||
self.assertIsInstance(order, dict)
|
||||
self.assertEqual(order['id'], 'test_order_id')
|
||||
self.mock_client.fetch_order.assert_called_with('test_order_id', 'BTC/USDT')
|
||||
|
||||
def test_place_order(self):
|
||||
result, order = self.exchange.place_order(
|
||||
symbol='BTC/USDT',
|
||||
side='buy',
|
||||
type='limit',
|
||||
timeInForce='GTC',
|
||||
quantity=1.0,
|
||||
price=30000.0
|
||||
)
|
||||
self.assertEqual(result, 'Success')
|
||||
self.assertIsInstance(order, dict)
|
||||
self.assertEqual(order.get('id'), 'test_order_id')
|
||||
self.mock_client.create_order.assert_called_once()
|
||||
|
||||
def test_get_open_orders(self):
|
||||
open_orders = self.exchange.get_open_orders()
|
||||
self.assertEqual(open_orders, [
|
||||
{'symbol': 'BTC/USDT', 'side': 'buy', 'quantity': 1.0, 'price': 30000.0}
|
||||
])
|
||||
self.mock_client.fetch_open_orders.assert_called_once()
|
||||
|
||||
def test_get_active_trades(self):
|
||||
active_trades = self.exchange.get_active_trades()
|
||||
self.assertEqual(active_trades, [
|
||||
{'symbol': 'BTC/USDT', 'side': 'buy', 'quantity': 1.0, 'price': 29000.0}
|
||||
])
|
||||
self.mock_client.fetch_positions.assert_called_once()
|
||||
|
||||
@patch('ccxt.binance')
|
||||
def test_fetch_ohlcv_network_failure(self, mock_exchange):
|
||||
self.mock_client.fetch_ohlcv.side_effect = ccxt.NetworkError('Network error')
|
||||
start_dt = datetime(2021, 1, 1)
|
||||
end_dt = datetime(2021, 1, 2)
|
||||
klines = self.exchange.get_historical_klines('BTC/USDT', '1d', start_dt, end_dt)
|
||||
self.assertTrue(klines.empty)
|
||||
self.mock_client.fetch_ohlcv.assert_called()
|
||||
|
||||
@patch('ccxt.binance')
|
||||
def test_fetch_ticker_invalid_response(self, mock_exchange):
|
||||
self.mock_client.fetch_ticker.side_effect = ccxt.ExchangeError('Invalid response')
|
||||
price = self.exchange.get_price('BTC/USDT')
|
||||
self.assertEqual(price, 0.0)
|
||||
self.mock_client.fetch_ticker.assert_called_with('BTC/USDT')
|
||||
|
||||
@patch('ccxt.binance')
|
||||
def test_fetch_order_invalid_response(self, mock_exchange):
|
||||
self.mock_client.fetch_order.side_effect = ccxt.ExchangeError('Invalid response')
|
||||
order = self.exchange.get_order('BTC/USDT', 'invalid_order_id')
|
||||
self.assertIsNone(order)
|
||||
self.mock_client.fetch_order.assert_called_with('invalid_order_id', 'BTC/USDT')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
import unittest
|
||||
import ccxt
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
from Exchange import Exchange
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class TestExchange(unittest.TestCase):
|
||||
api_keys: Optional[Dict[str, str]]
|
||||
exchange: Exchange
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
exchange_name = 'binance'
|
||||
cls.api_keys = None
|
||||
"""Uncomment and Provide api keys to connect to exchange."""
|
||||
# cls.api_keys = {'key': 'EXCHANGE_API_KEY', 'secret': 'EXCHANGE_API_SECRET'}
|
||||
|
||||
cls.exchange = Exchange(name=exchange_name, api_keys=cls.api_keys, exchange_id=exchange_name)
|
||||
|
||||
def test_connect_exchange(self):
|
||||
print("\nRunning test_connect_exchange...")
|
||||
exchange_class = getattr(ccxt, self.exchange.name)
|
||||
self.assertIsInstance(self.exchange.client, exchange_class)
|
||||
print("Exchange client instance test passed.")
|
||||
|
||||
self.assertIn('BTC/USDT', self.exchange.symbols)
|
||||
self.assertIn('ETH/USDT', self.exchange.symbols)
|
||||
print("Symbols test passed.")
|
||||
|
||||
self.assertIn('BTC/USDT', self.exchange.exchange_info)
|
||||
self.assertIn('ETH/USDT', self.exchange.exchange_info)
|
||||
print("Exchange info test passed.")
|
||||
|
||||
def test_get_symbols(self):
|
||||
print("\nRunning test_get_symbols...")
|
||||
symbols = self.exchange.get_symbols()
|
||||
self.assertIn('BTC/USDT', symbols)
|
||||
self.assertIn('ETH/USDT', symbols)
|
||||
print("Get symbols test passed.")
|
||||
|
||||
def test_get_price(self):
|
||||
print("\nRunning test_get_price...")
|
||||
price = self.exchange.get_price('BTC/USDT')
|
||||
self.assertGreater(price, 0)
|
||||
print(f"Get price test passed. Price: {price}")
|
||||
|
||||
def test_get_balances(self):
|
||||
print("\nRunning test_get_balances...")
|
||||
if self.api_keys is not None:
|
||||
balances = self.exchange.get_balances()
|
||||
self.assertTrue(any(balance['asset'] == 'BTC' for balance in balances))
|
||||
self.assertTrue(any(balance['asset'] == 'USDT' for balance in balances))
|
||||
print("Get balances test passed.")
|
||||
else:
|
||||
print("test_get_balances(): Can not test without providing API keys")
|
||||
|
||||
def test_get_historical_klines(self):
|
||||
print("\nRunning test_get_historical_klines...")
|
||||
start_dt = datetime(2021, 1, 1)
|
||||
end_dt = datetime(2021, 1, 2)
|
||||
klines = self.exchange.get_historical_klines('BTC/USDT', '1d', start_dt, end_dt)
|
||||
self.assertIsInstance(klines, pd.DataFrame)
|
||||
self.assertFalse(klines.empty)
|
||||
print("Get historical klines test passed.")
|
||||
print(klines.head())
|
||||
|
||||
def test_get_min_qty(self):
|
||||
print("\nRunning test_get_min_qty...")
|
||||
min_qty = self.exchange.get_min_qty('BTC/USDT')
|
||||
self.assertGreater(min_qty, 0)
|
||||
print(f"Get min qty test passed. Min qty: {min_qty}")
|
||||
|
||||
def test_get_min_notional_qty(self):
|
||||
print("\nRunning test_get_min_notional_qty...")
|
||||
min_notional_qty = self.exchange.get_min_notional_qty('BTC/USDT')
|
||||
self.assertGreater(min_notional_qty, 0)
|
||||
print(f"Get min notional qty test passed. Min notional qty: {min_notional_qty}")
|
||||
|
||||
def test_get_order(self):
|
||||
print("\nRunning test_get_order...")
|
||||
if self.api_keys is not None:
|
||||
# You need to create an order manually on exchange to test this
|
||||
order_id = 'your_order_id'
|
||||
order = self.exchange.get_order('BTC/USDT', order_id)
|
||||
self.assertIsInstance(order, dict)
|
||||
self.assertEqual(order['id'], order_id)
|
||||
print("Get order test passed.")
|
||||
else:
|
||||
print("test_get_order(): Can not test without providing API keys")
|
||||
|
||||
def test_place_order(self):
|
||||
print("\nRunning test_place_order...")
|
||||
if self.api_keys is not None:
|
||||
# Be cautious with creating real orders. This is for testing purposes.
|
||||
symbol = 'BTC/USDT'
|
||||
order_type = 'limit'
|
||||
side = 'buy'
|
||||
amount = 0.001
|
||||
price = 30000 # Adjust accordingly
|
||||
|
||||
result, order = self.exchange.place_order(
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
type=order_type,
|
||||
timeInForce='GTC',
|
||||
quantity=amount,
|
||||
price=price
|
||||
)
|
||||
|
||||
self.assertEqual(result, 'Success')
|
||||
self.assertIsInstance(order, dict)
|
||||
self.assertIn('id', order)
|
||||
print("Place order test passed.")
|
||||
else:
|
||||
print("test_place_order(): Can not test without providing API keys")
|
||||
|
||||
def test_get_open_orders(self):
|
||||
print("\nRunning test_get_open_orders...")
|
||||
if self.api_keys is not None:
|
||||
open_orders = self.exchange.get_open_orders()
|
||||
self.assertIsInstance(open_orders, list)
|
||||
print("Get open orders test passed.")
|
||||
else:
|
||||
print("test_get_open_orders(): Can not test without providing API keys")
|
||||
|
||||
def test_get_active_trades(self):
|
||||
print("\nRunning test_get_active_trades...")
|
||||
if self.api_keys is not None:
|
||||
active_trades = self.exchange.get_active_trades()
|
||||
self.assertIsInstance(active_trades, list)
|
||||
print("Get active trades test passed.")
|
||||
else:
|
||||
print("test_get_active_trades(): Can not test without providing API keys")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue