Completed unittests for Exchange.
This commit is contained in:
parent
0b1ad39476
commit
e601f8c23e
|
|
@ -115,7 +115,7 @@ class DataCache:
|
||||||
# Check if the records in the cache go far enough back to satisfy the query.
|
# Check if the records in the cache go far enough back to satisfy the query.
|
||||||
first_timestamp = query_satisfied(start_datetime=start_datetime,
|
first_timestamp = query_satisfied(start_datetime=start_datetime,
|
||||||
records=records,
|
records=records,
|
||||||
r_length=record_length)
|
r_length_min=record_length)
|
||||||
if first_timestamp:
|
if first_timestamp:
|
||||||
# The records didn't go far enough back if a timestamp was returned.
|
# The records didn't go far enough back if a timestamp was returned.
|
||||||
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
||||||
|
|
@ -129,7 +129,7 @@ class DataCache:
|
||||||
self.update_candle_cache(additional_records, key)
|
self.update_candle_cache(additional_records, key)
|
||||||
|
|
||||||
# Check if the records received are up-to-date.
|
# Check if the records received are up-to-date.
|
||||||
last_timestamp = query_uptodate(records=records, r_length=record_length)
|
last_timestamp = query_uptodate(records=records, r_length_min=record_length)
|
||||||
|
|
||||||
if last_timestamp:
|
if last_timestamp:
|
||||||
# The query was not up-to-date if a timestamp was returned.
|
# The query was not up-to-date if a timestamp was returned.
|
||||||
|
|
|
||||||
160
src/Exchange.py
160
src/Exchange.py
|
|
@ -3,80 +3,58 @@ import pandas as pd
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Tuple, Dict, List, Union
|
from typing import Tuple, Dict, List, Union
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Exchange:
|
class Exchange:
|
||||||
# Class attribute for caching market data
|
|
||||||
_market_cache = {}
|
_market_cache = {}
|
||||||
|
|
||||||
def __init__(self, name: str, api_keys: Dict[str, str], exchange_id: str):
|
def __init__(self, name: str, api_keys: Dict[str, str], exchange_id: str):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
# The API key for the exchange.
|
|
||||||
self.api_key = api_keys['key'] if api_keys else None
|
self.api_key = api_keys['key'] if api_keys else None
|
||||||
|
|
||||||
# The API secret key for the exchange.
|
|
||||||
self.api_key_secret = api_keys['secret'] if api_keys else None
|
self.api_key_secret = api_keys['secret'] if api_keys else None
|
||||||
|
|
||||||
# The exchange id for the exchange.
|
|
||||||
self.exchange_id = exchange_id
|
self.exchange_id = exchange_id
|
||||||
|
|
||||||
# The connection to the exchange_interface.
|
|
||||||
self.client: ccxt.Exchange = self._connect_exchange()
|
self.client: ccxt.Exchange = self._connect_exchange()
|
||||||
|
|
||||||
# Info on all symbols and exchange_interface rules.
|
|
||||||
self.exchange_info = self._set_exchange_info()
|
self.exchange_info = self._set_exchange_info()
|
||||||
|
|
||||||
# List of time intervals available for trading.
|
|
||||||
self.intervals = self._set_avail_intervals()
|
self.intervals = self._set_avail_intervals()
|
||||||
|
|
||||||
# All symbols available for trading.
|
|
||||||
self.symbols = self._set_symbols()
|
self.symbols = self._set_symbols()
|
||||||
|
|
||||||
# Any non-zero balance-info for all assets.
|
|
||||||
self.balances = self._set_balances()
|
self.balances = self._set_balances()
|
||||||
|
|
||||||
# Dictionary of places after the decimal requires by the exchange_interface indexed by symbol.
|
|
||||||
self.symbols_n_precision = {}
|
self.symbols_n_precision = {}
|
||||||
|
|
||||||
def _connect_exchange(self) -> ccxt.Exchange:
|
def _connect_exchange(self) -> ccxt.Exchange:
|
||||||
"""
|
|
||||||
Connects to the exchange_interface and sets a reference the client.
|
|
||||||
Handles cases where api_keys might be None.
|
|
||||||
"""
|
|
||||||
exchange_class = getattr(ccxt, self.exchange_id)
|
exchange_class = getattr(ccxt, self.exchange_id)
|
||||||
if not exchange_class:
|
if not exchange_class:
|
||||||
|
logger.error(f"Exchange {self.exchange_id} is not supported by CCXT.")
|
||||||
raise ValueError(f"Exchange {self.exchange_id} is not supported by CCXT.")
|
raise ValueError(f"Exchange {self.exchange_id} is not supported by CCXT.")
|
||||||
|
|
||||||
|
logger.info(f"Connecting to exchange {self.exchange_id}.")
|
||||||
if self.api_key and self.api_key_secret:
|
if self.api_key and self.api_key_secret:
|
||||||
return exchange_class({
|
return exchange_class({
|
||||||
'apiKey': self.api_key,
|
'apiKey': self.api_key,
|
||||||
'secret': self.api_key_secret,
|
'secret': self.api_key_secret,
|
||||||
'enableRateLimit': True,
|
'enableRateLimit': True,
|
||||||
'verbose': False # Disable verbose debugging output
|
'verbose': False
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
return exchange_class({
|
return exchange_class({
|
||||||
'enableRateLimit': True,
|
'enableRateLimit': True,
|
||||||
'verbose': False # Disable verbose debugging output
|
'verbose': False
|
||||||
})
|
})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def datetime_to_unix_millis(dt: datetime) -> int:
|
def datetime_to_unix_millis(dt: datetime) -> int:
|
||||||
"""
|
|
||||||
Convert a datetime object to Unix timestamp in milliseconds.
|
|
||||||
|
|
||||||
:param dt: datetime - The datetime object to convert.
|
|
||||||
:return: int - The Unix timestamp in milliseconds.
|
|
||||||
"""
|
|
||||||
return int(dt.timestamp() * 1000)
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
def _fetch_historical_klines(self, symbol: str, interval: str, start_dt: datetime,
|
def _fetch_historical_klines(self, symbol: str, interval: str,
|
||||||
end_dt: datetime = None) -> pd.DataFrame:
|
start_dt: datetime, end_dt: datetime = None) -> pd.DataFrame:
|
||||||
if end_dt is None:
|
if end_dt is None:
|
||||||
end_dt = datetime.utcnow()
|
end_dt = datetime.utcnow()
|
||||||
|
|
||||||
max_interval = timedelta(days=200) # Binance's maximum interval
|
max_interval = timedelta(days=200)
|
||||||
data_frames = []
|
data_frames = []
|
||||||
current_start = start_dt
|
current_start = start_dt
|
||||||
|
|
||||||
|
|
@ -86,9 +64,11 @@ class Exchange:
|
||||||
end_str = self.datetime_to_unix_millis(current_end)
|
end_str = self.datetime_to_unix_millis(current_end)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
candles = self.client.fetch_ohlcv(symbol=symbol, timeframe=interval, since=start_str,
|
logger.info(f"Fetching OHLCV data for {symbol} from {current_start} to {current_end}.")
|
||||||
params={'endTime': end_str})
|
candles = self.client.fetch_ohlcv(symbol=symbol, timeframe=interval,
|
||||||
|
since=start_str, params={'endTime': end_str})
|
||||||
if not candles:
|
if not candles:
|
||||||
|
logger.warning(f"No OHLCV data returned for {symbol} from {current_start} to {current_end}.")
|
||||||
break
|
break
|
||||||
|
|
||||||
df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume']
|
df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume']
|
||||||
|
|
@ -96,41 +76,60 @@ class Exchange:
|
||||||
candles_df['open_time'] = candles_df['open_time'] // 1000
|
candles_df['open_time'] = candles_df['open_time'] // 1000
|
||||||
data_frames.append(candles_df)
|
data_frames.append(candles_df)
|
||||||
|
|
||||||
# Move the start to the end of the current chunk to get the next chunk
|
|
||||||
current_start = current_end
|
current_start = current_end
|
||||||
|
|
||||||
# Sleep for 1 second to avoid hitting rate limits
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
except ccxt.BaseError as e:
|
except ccxt.BaseError as e:
|
||||||
print(f"Error fetching OHLCV data: {str(e)}")
|
logger.error(f"Error fetching OHLCV data for {symbol}: {str(e)}")
|
||||||
break
|
break
|
||||||
|
|
||||||
if data_frames:
|
if data_frames:
|
||||||
result_df = pd.concat(data_frames)
|
result_df = pd.concat(data_frames)
|
||||||
|
logger.info(f"Successfully fetched OHLCV data for {symbol}.")
|
||||||
return result_df
|
return result_df
|
||||||
else:
|
else:
|
||||||
|
logger.warning(f"No OHLCV data fetched for {symbol}.")
|
||||||
return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume'])
|
return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume'])
|
||||||
|
|
||||||
def _fetch_price(self, symbol: str) -> float:
|
def _fetch_price(self, symbol: str) -> float:
|
||||||
ticker = self.client.fetch_ticker(symbol)
|
try:
|
||||||
return float(ticker['last'])
|
ticker = self.client.fetch_ticker(symbol)
|
||||||
|
return float(ticker['last'])
|
||||||
|
except ccxt.BaseError as e:
|
||||||
|
logger.error(f"Error fetching price for {symbol}: {str(e)}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def _fetch_min_qty(self, symbol: str) -> float:
|
def _fetch_min_qty(self, symbol: str) -> float:
|
||||||
market_data = self.exchange_info[symbol]
|
try:
|
||||||
return float(market_data['limits']['amount']['min'])
|
market_data = self.exchange_info[symbol]
|
||||||
|
return float(market_data['limits']['amount']['min'])
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(f"Error fetching minimum quantity for {symbol}: {str(e)}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def _fetch_min_notional_qty(self, symbol: str) -> float:
|
def _fetch_min_notional_qty(self, symbol: str) -> float:
|
||||||
market_data = self.exchange_info[symbol]
|
try:
|
||||||
return float(market_data['limits']['cost']['min'])
|
market_data = self.exchange_info[symbol]
|
||||||
|
return float(market_data['limits']['cost']['min'])
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(f"Error fetching minimum notional quantity for {symbol}: {str(e)}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def _fetch_order(self, symbol: str, order_id: str) -> object:
|
def _fetch_order(self, symbol: str, order_id: str) -> object:
|
||||||
return self.client.fetch_order(order_id, symbol)
|
try:
|
||||||
|
return self.client.fetch_order(order_id, symbol)
|
||||||
|
except ccxt.BaseError as e:
|
||||||
|
logger.error(f"Error fetching order {order_id} for {symbol}: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
def _set_symbols(self) -> List[str]:
|
def _set_symbols(self) -> List[str]:
|
||||||
markets = self.client.fetch_markets()
|
try:
|
||||||
symbols = [market['symbol'] for market in markets if market['active']]
|
markets = self.client.fetch_markets()
|
||||||
return symbols
|
symbols = [market['symbol'] for market in markets if market['active']]
|
||||||
|
return symbols
|
||||||
|
except ccxt.BaseError as e:
|
||||||
|
logger.error(f"Error fetching symbols: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
def _set_balances(self) -> List[Dict[str, Union[str, float]]]:
|
def _set_balances(self) -> List[Dict[str, Union[str, float]]]:
|
||||||
if self.api_key and self.api_key_secret:
|
if self.api_key and self.api_key_secret:
|
||||||
|
|
@ -142,18 +141,13 @@ class Exchange:
|
||||||
if asset_balance > 0:
|
if asset_balance > 0:
|
||||||
balances.append({'asset': asset, 'balance': asset_balance, 'pnl': 0})
|
balances.append({'asset': asset, 'balance': asset_balance, 'pnl': 0})
|
||||||
return balances
|
return balances
|
||||||
except NotImplementedError:
|
except ccxt.BaseError as e:
|
||||||
# Handle the case where fetch_balance is not supported
|
logger.error(f"Error fetching balances: {str(e)}")
|
||||||
return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}]
|
return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}]
|
||||||
else:
|
else:
|
||||||
return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}]
|
return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}]
|
||||||
|
|
||||||
def _set_exchange_info(self) -> dict:
|
def _set_exchange_info(self) -> dict:
|
||||||
"""
|
|
||||||
Fetches market data for all symbols from the exchange.
|
|
||||||
This includes details on trading pairs, limits, fees, and other relevant information.
|
|
||||||
Caches the market info to speed up subsequent calls.
|
|
||||||
"""
|
|
||||||
if self.exchange_id in Exchange._market_cache:
|
if self.exchange_id in Exchange._market_cache:
|
||||||
return Exchange._market_cache[self.exchange_id]
|
return Exchange._market_cache[self.exchange_id]
|
||||||
|
|
||||||
|
|
@ -162,40 +156,35 @@ class Exchange:
|
||||||
Exchange._market_cache[self.exchange_id] = markets_info
|
Exchange._market_cache[self.exchange_id] = markets_info
|
||||||
return markets_info
|
return markets_info
|
||||||
except ccxt.BaseError as e:
|
except ccxt.BaseError as e:
|
||||||
print(f"Error fetching market info: {str(e)}")
|
logger.error(f"Error fetching market info: {str(e)}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def get_client(self) -> object:
|
def get_client(self) -> object:
|
||||||
""" Return a reference to the exchange_interface client."""
|
|
||||||
return self.client
|
return self.client
|
||||||
|
|
||||||
def get_avail_intervals(self) -> Tuple[str, ...]:
|
def get_avail_intervals(self) -> Tuple[str, ...]:
|
||||||
"""Returns a list of time intervals available for trading."""
|
|
||||||
return self.intervals
|
return self.intervals
|
||||||
|
|
||||||
def get_exchange_info(self) -> dict:
|
def get_exchange_info(self) -> dict:
|
||||||
"""Returns Info on all symbols."""
|
|
||||||
return self.exchange_info
|
return self.exchange_info
|
||||||
|
|
||||||
def get_symbols(self) -> List[str]:
|
def get_symbols(self) -> List[str]:
|
||||||
"""Returns all symbols available for trading."""
|
|
||||||
return self.symbols
|
return self.symbols
|
||||||
|
|
||||||
def get_balances(self) -> List[Dict[str, Union[str, float]]]:
|
def get_balances(self) -> List[Dict[str, Union[str, float]]]:
|
||||||
"""Returns any non-zero balance-info for all assets."""
|
|
||||||
return self.balances
|
return self.balances
|
||||||
|
|
||||||
def get_symbol_precision_rule(self, symbol: str) -> int:
|
def get_symbol_precision_rule(self, symbol: str) -> int:
|
||||||
"""Returns number of places after the decimal required by the exchange_interface for a specific symbol."""
|
|
||||||
r_value = self.symbols_n_precision.get(symbol)
|
r_value = self.symbols_n_precision.get(symbol)
|
||||||
if r_value is None:
|
if r_value is None:
|
||||||
self._set_precision_rule(symbol)
|
self._set_precision_rule(symbol)
|
||||||
r_value = self.symbols_n_precision.get(symbol)
|
r_value = self.symbols_n_precision.get(symbol)
|
||||||
return r_value
|
return r_value
|
||||||
|
|
||||||
def get_historical_klines(self, symbol: str, interval: str, start_dt: datetime,
|
def get_historical_klines(self, symbol: str, interval: str,
|
||||||
end_dt: datetime = None) -> pd.DataFrame:
|
start_dt: datetime, end_dt: datetime = None) -> pd.DataFrame:
|
||||||
return self._fetch_historical_klines(symbol=symbol, interval=interval, start_dt=start_dt, end_dt=end_dt)
|
return self._fetch_historical_klines(symbol=symbol, interval=interval,
|
||||||
|
start_dt=start_dt, end_dt=end_dt)
|
||||||
|
|
||||||
def get_price(self, symbol: str) -> float:
|
def get_price(self, symbol: str) -> float:
|
||||||
return self._fetch_price(symbol)
|
return self._fetch_price(symbol)
|
||||||
|
|
@ -209,26 +198,21 @@ class Exchange:
|
||||||
def get_order(self, symbol: str, order_id: str) -> object:
|
def get_order(self, symbol: str, order_id: str) -> object:
|
||||||
return self._fetch_order(symbol, order_id)
|
return self._fetch_order(symbol, order_id)
|
||||||
|
|
||||||
def place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> \
|
def place_order(self, symbol: str, side: str, type: str, timeInForce: str,
|
||||||
Tuple[str, object]:
|
quantity: float, price: float = None) -> Tuple[str, object]:
|
||||||
result, msg = self._place_order(symbol=symbol, side=side, type=type, timeInForce=timeInForce, quantity=quantity,
|
result, msg = self._place_order(symbol=symbol, side=side, type=type,
|
||||||
price=price)
|
timeInForce=timeInForce, quantity=quantity, price=price)
|
||||||
return result, msg
|
return result, msg
|
||||||
|
|
||||||
def _set_avail_intervals(self) -> Tuple[str, ...]:
|
def _set_avail_intervals(self) -> Tuple[str, ...]:
|
||||||
"""
|
|
||||||
Sets a list of time intervals available for trading on the exchange_interface.
|
|
||||||
"""
|
|
||||||
return tuple(self.client.timeframes.keys())
|
return tuple(self.client.timeframes.keys())
|
||||||
|
|
||||||
def _set_precision_rule(self, symbol: str) -> None:
|
def _set_precision_rule(self, symbol: str) -> None:
|
||||||
market_data = self.exchange_info[symbol]
|
market_data = self.exchange_info[symbol]
|
||||||
precision = market_data['precision']['amount']
|
precision = market_data['precision']['amount']
|
||||||
self.symbols_n_precision[symbol] = precision
|
self.symbols_n_precision[symbol] = precision
|
||||||
return
|
|
||||||
|
|
||||||
def _place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> \
|
def _place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> Tuple[str, object]:
|
||||||
Tuple[str, object]:
|
|
||||||
def format_arg(value: float) -> float:
|
def format_arg(value: float) -> float:
|
||||||
precision = self.symbols_n_precision.get(symbol, 8)
|
precision = self.symbols_n_precision.get(symbol, 8)
|
||||||
return float(f"{value:.{precision}f}")
|
return float(f"{value:.{precision}f}")
|
||||||
|
|
@ -250,13 +234,14 @@ class Exchange:
|
||||||
if price is not None:
|
if price is not None:
|
||||||
order_params['price'] = price
|
order_params['price'] = price
|
||||||
|
|
||||||
order = self.client.create_order(**order_params)
|
try:
|
||||||
return 'Success', order
|
order = self.client.create_order(**order_params)
|
||||||
|
return 'Success', order
|
||||||
|
except ccxt.BaseError as e:
|
||||||
|
logger.error(f"Error placing order for {symbol}: {str(e)}")
|
||||||
|
return 'Failure', None
|
||||||
|
|
||||||
def get_active_trades(self) -> List[Dict[str, Union[str, float]]]:
|
def get_active_trades(self) -> List[Dict[str, Union[str, float]]]:
|
||||||
"""
|
|
||||||
Get the active trades (open positions).
|
|
||||||
"""
|
|
||||||
if self.api_key and self.api_key_secret:
|
if self.api_key and self.api_key_secret:
|
||||||
try:
|
try:
|
||||||
positions = self.client.fetch_positions()
|
positions = self.client.fetch_positions()
|
||||||
|
|
@ -270,16 +255,13 @@ class Exchange:
|
||||||
}
|
}
|
||||||
formatted_trades.append(active_trade)
|
formatted_trades.append(active_trade)
|
||||||
return formatted_trades
|
return formatted_trades
|
||||||
except NotImplementedError:
|
except ccxt.BaseError as e:
|
||||||
# Handle the case where fetch_positions is not supported
|
logger.error(f"Error fetching active trades: {str(e)}")
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_open_orders(self) -> List[Dict[str, Union[str, float]]]:
|
def get_open_orders(self) -> List[Dict[str, Union[str, float]]]:
|
||||||
"""
|
|
||||||
Get the open orders.
|
|
||||||
"""
|
|
||||||
if self.api_key and self.api_key_secret:
|
if self.api_key and self.api_key_secret:
|
||||||
try:
|
try:
|
||||||
open_orders = self.client.fetch_open_orders()
|
open_orders = self.client.fetch_open_orders()
|
||||||
|
|
@ -293,8 +275,8 @@ class Exchange:
|
||||||
}
|
}
|
||||||
formatted_orders.append(open_order)
|
formatted_orders.append(open_order)
|
||||||
return formatted_orders
|
return formatted_orders
|
||||||
except NotImplementedError:
|
except ccxt.BaseError as e:
|
||||||
# Handle the case where fetch_balance is not supported
|
logger.error(f"Error fetching open orders: {str(e)}")
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -344,7 +344,7 @@ class Database:
|
||||||
print(f'Got {len(records.index)} records from {ex_details[2]}')
|
print(f'Got {len(records.index)} records from {ex_details[2]}')
|
||||||
|
|
||||||
# Check if the records in the db go far enough back to satisfy the query.
|
# Check if the records in the db go far enough back to satisfy the query.
|
||||||
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length=rl)
|
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length_min=rl)
|
||||||
if first_timestamp:
|
if first_timestamp:
|
||||||
# The records didn't go far enough back if a timestamp was returned.
|
# The records didn't go far enough back if a timestamp was returned.
|
||||||
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
|
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
|
||||||
|
|
@ -355,7 +355,7 @@ class Database:
|
||||||
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
|
records = add_data(data=records, tn=table_name, start_t=st, end_t=end_time)
|
||||||
|
|
||||||
# Check if the records received are up-to-date.
|
# Check if the records received are up-to-date.
|
||||||
last_timestamp = query_uptodate(records=records, r_length=rl)
|
last_timestamp = query_uptodate(records=records, r_length_min=rl)
|
||||||
if last_timestamp:
|
if last_timestamp:
|
||||||
# The query was not up-to-date if a timestamp was returned.
|
# The query was not up-to-date if a timestamp was returned.
|
||||||
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
|
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,193 @@
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import ccxt
|
||||||
|
from datetime import datetime
|
||||||
|
import pandas as pd
|
||||||
|
from Exchange import Exchange
|
||||||
|
|
||||||
|
|
||||||
|
class TestExchange(unittest.TestCase):
|
||||||
|
|
||||||
|
@patch('ccxt.binance')
|
||||||
|
def setUp(self, mock_exchange):
|
||||||
|
self.mock_client = MagicMock()
|
||||||
|
mock_exchange.return_value = self.mock_client
|
||||||
|
|
||||||
|
self.api_keys = {'key': 'test_key', 'secret': 'test_secret'}
|
||||||
|
|
||||||
|
self.mock_client.fetch_markets.return_value = [
|
||||||
|
{'symbol': 'BTC/USDT', 'active': True},
|
||||||
|
{'symbol': 'ETH/USDT', 'active': True},
|
||||||
|
]
|
||||||
|
self.mock_client.load_markets.return_value = {
|
||||||
|
'BTC/USDT': {
|
||||||
|
'limits': {
|
||||||
|
'amount': {'min': 0.001},
|
||||||
|
'cost': {'min': 10.0} # Ensure the cost limit is included
|
||||||
|
},
|
||||||
|
'precision': {
|
||||||
|
'amount': 8,
|
||||||
|
'price': 8
|
||||||
|
},
|
||||||
|
'active': True,
|
||||||
|
'symbol': 'BTC/USDT'
|
||||||
|
},
|
||||||
|
'ETH/USDT': {
|
||||||
|
'limits': {
|
||||||
|
'amount': {'min': 0.01},
|
||||||
|
'cost': {'min': 20.0} # Ensure the cost limit is included
|
||||||
|
},
|
||||||
|
'precision': {
|
||||||
|
'amount': 8,
|
||||||
|
'price': 8
|
||||||
|
},
|
||||||
|
'active': True,
|
||||||
|
'symbol': 'ETH/USDT'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.mock_client.fetch_ticker.return_value = {'last': 30000.0}
|
||||||
|
self.mock_client.fetch_ohlcv.return_value = [
|
||||||
|
[1609459200000, 29000.0, 29500.0, 28800.0, 29400.0, 1000]
|
||||||
|
]
|
||||||
|
self.mock_client.fetch_balance.return_value = {
|
||||||
|
'total': {'BTC': 1.0, 'USDT': 1000.0}
|
||||||
|
}
|
||||||
|
self.mock_client.create_order.return_value = {
|
||||||
|
'id': 'test_order_id',
|
||||||
|
'symbol': 'BTC/USDT',
|
||||||
|
'type': 'limit',
|
||||||
|
'side': 'buy',
|
||||||
|
'price': 30000.0,
|
||||||
|
'amount': 1.0,
|
||||||
|
'status': 'open'
|
||||||
|
}
|
||||||
|
self.mock_client.fetch_open_orders.return_value = [
|
||||||
|
{'id': 'test_order_id', 'symbol': 'BTC/USDT', 'side': 'buy', 'amount': 1.0, 'price': 30000.0}
|
||||||
|
]
|
||||||
|
self.mock_client.fetch_positions.return_value = [
|
||||||
|
{'symbol': 'BTC/USDT', 'quantity': 1.0, 'entry_price': 29000.0}
|
||||||
|
]
|
||||||
|
self.mock_client.fetch_order.return_value = {
|
||||||
|
'id': 'test_order_id',
|
||||||
|
'symbol': 'BTC/USDT',
|
||||||
|
'side': 'buy',
|
||||||
|
'amount': 1.0,
|
||||||
|
'price': 30000.0,
|
||||||
|
'status': 'open'
|
||||||
|
}
|
||||||
|
self.mock_client.timeframes = {
|
||||||
|
'1m': '1m',
|
||||||
|
'5m': '5m',
|
||||||
|
'1h': '1h',
|
||||||
|
'1d': '1d',
|
||||||
|
}
|
||||||
|
|
||||||
|
self.exchange = Exchange(name='Binance', api_keys=self.api_keys, exchange_id='binance')
|
||||||
|
|
||||||
|
def test_connect_exchange(self):
|
||||||
|
self.assertIsInstance(self.exchange.client, MagicMock)
|
||||||
|
self.mock_client.load_markets.assert_called_once()
|
||||||
|
self.assertEqual(self.exchange.symbols, ['BTC/USDT', 'ETH/USDT'])
|
||||||
|
self.assertIn('BTC/USDT', self.exchange.exchange_info)
|
||||||
|
self.assertIn('ETH/USDT', self.exchange.exchange_info)
|
||||||
|
self.assertEqual(self.exchange.balances, [
|
||||||
|
{'asset': 'BTC', 'balance': 1.0, 'pnl': 0},
|
||||||
|
{'asset': 'USDT', 'balance': 1000.0, 'pnl': 0}
|
||||||
|
])
|
||||||
|
|
||||||
|
def test_get_symbols(self):
|
||||||
|
symbols = self.exchange.get_symbols()
|
||||||
|
self.assertEqual(symbols, ['BTC/USDT', 'ETH/USDT'])
|
||||||
|
|
||||||
|
def test_get_price(self):
|
||||||
|
price = self.exchange.get_price('BTC/USDT')
|
||||||
|
self.assertEqual(price, 30000.0)
|
||||||
|
self.mock_client.fetch_ticker.assert_called_with('BTC/USDT')
|
||||||
|
|
||||||
|
def test_get_balances(self):
|
||||||
|
balances = self.exchange.get_balances()
|
||||||
|
self.assertEqual(balances, [
|
||||||
|
{'asset': 'BTC', 'balance': 1.0, 'pnl': 0},
|
||||||
|
{'asset': 'USDT', 'balance': 1000.0, 'pnl': 0}
|
||||||
|
])
|
||||||
|
self.mock_client.fetch_balance.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_historical_klines(self):
|
||||||
|
start_dt = datetime(2021, 1, 1)
|
||||||
|
end_dt = datetime(2021, 1, 2)
|
||||||
|
klines = self.exchange.get_historical_klines('BTC/USDT', '1d', start_dt, end_dt)
|
||||||
|
expected_df = pd.DataFrame([
|
||||||
|
{'open_time': 1609459200, 'open': 29000.0, 'high': 29500.0, 'low': 28800.0, 'close': 29400.0,
|
||||||
|
'volume': 1000}
|
||||||
|
])
|
||||||
|
pd.testing.assert_frame_equal(klines, expected_df)
|
||||||
|
self.mock_client.fetch_ohlcv.assert_called()
|
||||||
|
|
||||||
|
def test_get_min_qty(self):
|
||||||
|
min_qty = self.exchange.get_min_qty('BTC/USDT')
|
||||||
|
self.assertEqual(min_qty, 0.001)
|
||||||
|
|
||||||
|
def test_get_min_notional_qty(self):
|
||||||
|
min_notional_qty = self.exchange.get_min_notional_qty('BTC/USDT')
|
||||||
|
self.assertEqual(min_notional_qty, 10.0)
|
||||||
|
|
||||||
|
def test_get_order(self):
|
||||||
|
order = self.exchange.get_order('BTC/USDT', 'test_order_id')
|
||||||
|
self.assertIsInstance(order, dict)
|
||||||
|
self.assertEqual(order['id'], 'test_order_id')
|
||||||
|
self.mock_client.fetch_order.assert_called_with('test_order_id', 'BTC/USDT')
|
||||||
|
|
||||||
|
def test_place_order(self):
|
||||||
|
result, order = self.exchange.place_order(
|
||||||
|
symbol='BTC/USDT',
|
||||||
|
side='buy',
|
||||||
|
type='limit',
|
||||||
|
timeInForce='GTC',
|
||||||
|
quantity=1.0,
|
||||||
|
price=30000.0
|
||||||
|
)
|
||||||
|
self.assertEqual(result, 'Success')
|
||||||
|
self.assertIsInstance(order, dict)
|
||||||
|
self.assertEqual(order.get('id'), 'test_order_id')
|
||||||
|
self.mock_client.create_order.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_open_orders(self):
|
||||||
|
open_orders = self.exchange.get_open_orders()
|
||||||
|
self.assertEqual(open_orders, [
|
||||||
|
{'symbol': 'BTC/USDT', 'side': 'buy', 'quantity': 1.0, 'price': 30000.0}
|
||||||
|
])
|
||||||
|
self.mock_client.fetch_open_orders.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_active_trades(self):
|
||||||
|
active_trades = self.exchange.get_active_trades()
|
||||||
|
self.assertEqual(active_trades, [
|
||||||
|
{'symbol': 'BTC/USDT', 'side': 'buy', 'quantity': 1.0, 'price': 29000.0}
|
||||||
|
])
|
||||||
|
self.mock_client.fetch_positions.assert_called_once()
|
||||||
|
|
||||||
|
@patch('ccxt.binance')
|
||||||
|
def test_fetch_ohlcv_network_failure(self, mock_exchange):
|
||||||
|
self.mock_client.fetch_ohlcv.side_effect = ccxt.NetworkError('Network error')
|
||||||
|
start_dt = datetime(2021, 1, 1)
|
||||||
|
end_dt = datetime(2021, 1, 2)
|
||||||
|
klines = self.exchange.get_historical_klines('BTC/USDT', '1d', start_dt, end_dt)
|
||||||
|
self.assertTrue(klines.empty)
|
||||||
|
self.mock_client.fetch_ohlcv.assert_called()
|
||||||
|
|
||||||
|
@patch('ccxt.binance')
|
||||||
|
def test_fetch_ticker_invalid_response(self, mock_exchange):
|
||||||
|
self.mock_client.fetch_ticker.side_effect = ccxt.ExchangeError('Invalid response')
|
||||||
|
price = self.exchange.get_price('BTC/USDT')
|
||||||
|
self.assertEqual(price, 0.0)
|
||||||
|
self.mock_client.fetch_ticker.assert_called_with('BTC/USDT')
|
||||||
|
|
||||||
|
@patch('ccxt.binance')
|
||||||
|
def test_fetch_order_invalid_response(self, mock_exchange):
|
||||||
|
self.mock_client.fetch_order.side_effect = ccxt.ExchangeError('Invalid response')
|
||||||
|
order = self.exchange.get_order('BTC/USDT', 'invalid_order_id')
|
||||||
|
self.assertIsNone(order)
|
||||||
|
self.mock_client.fetch_order.assert_called_with('invalid_order_id', 'BTC/USDT')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
|
@ -0,0 +1,139 @@
|
||||||
|
import unittest
|
||||||
|
import ccxt
|
||||||
|
from datetime import datetime
|
||||||
|
import pandas as pd
|
||||||
|
from Exchange import Exchange
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class TestExchange(unittest.TestCase):
|
||||||
|
api_keys: Optional[Dict[str, str]]
|
||||||
|
exchange: Exchange
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
exchange_name = 'binance'
|
||||||
|
cls.api_keys = None
|
||||||
|
"""Uncomment and Provide api keys to connect to exchange."""
|
||||||
|
# cls.api_keys = {'key': 'EXCHANGE_API_KEY', 'secret': 'EXCHANGE_API_SECRET'}
|
||||||
|
|
||||||
|
cls.exchange = Exchange(name=exchange_name, api_keys=cls.api_keys, exchange_id=exchange_name)
|
||||||
|
|
||||||
|
def test_connect_exchange(self):
|
||||||
|
print("\nRunning test_connect_exchange...")
|
||||||
|
exchange_class = getattr(ccxt, self.exchange.name)
|
||||||
|
self.assertIsInstance(self.exchange.client, exchange_class)
|
||||||
|
print("Exchange client instance test passed.")
|
||||||
|
|
||||||
|
self.assertIn('BTC/USDT', self.exchange.symbols)
|
||||||
|
self.assertIn('ETH/USDT', self.exchange.symbols)
|
||||||
|
print("Symbols test passed.")
|
||||||
|
|
||||||
|
self.assertIn('BTC/USDT', self.exchange.exchange_info)
|
||||||
|
self.assertIn('ETH/USDT', self.exchange.exchange_info)
|
||||||
|
print("Exchange info test passed.")
|
||||||
|
|
||||||
|
def test_get_symbols(self):
|
||||||
|
print("\nRunning test_get_symbols...")
|
||||||
|
symbols = self.exchange.get_symbols()
|
||||||
|
self.assertIn('BTC/USDT', symbols)
|
||||||
|
self.assertIn('ETH/USDT', symbols)
|
||||||
|
print("Get symbols test passed.")
|
||||||
|
|
||||||
|
def test_get_price(self):
|
||||||
|
print("\nRunning test_get_price...")
|
||||||
|
price = self.exchange.get_price('BTC/USDT')
|
||||||
|
self.assertGreater(price, 0)
|
||||||
|
print(f"Get price test passed. Price: {price}")
|
||||||
|
|
||||||
|
def test_get_balances(self):
|
||||||
|
print("\nRunning test_get_balances...")
|
||||||
|
if self.api_keys is not None:
|
||||||
|
balances = self.exchange.get_balances()
|
||||||
|
self.assertTrue(any(balance['asset'] == 'BTC' for balance in balances))
|
||||||
|
self.assertTrue(any(balance['asset'] == 'USDT' for balance in balances))
|
||||||
|
print("Get balances test passed.")
|
||||||
|
else:
|
||||||
|
print("test_get_balances(): Can not test without providing API keys")
|
||||||
|
|
||||||
|
def test_get_historical_klines(self):
|
||||||
|
print("\nRunning test_get_historical_klines...")
|
||||||
|
start_dt = datetime(2021, 1, 1)
|
||||||
|
end_dt = datetime(2021, 1, 2)
|
||||||
|
klines = self.exchange.get_historical_klines('BTC/USDT', '1d', start_dt, end_dt)
|
||||||
|
self.assertIsInstance(klines, pd.DataFrame)
|
||||||
|
self.assertFalse(klines.empty)
|
||||||
|
print("Get historical klines test passed.")
|
||||||
|
print(klines.head())
|
||||||
|
|
||||||
|
def test_get_min_qty(self):
|
||||||
|
print("\nRunning test_get_min_qty...")
|
||||||
|
min_qty = self.exchange.get_min_qty('BTC/USDT')
|
||||||
|
self.assertGreater(min_qty, 0)
|
||||||
|
print(f"Get min qty test passed. Min qty: {min_qty}")
|
||||||
|
|
||||||
|
def test_get_min_notional_qty(self):
|
||||||
|
print("\nRunning test_get_min_notional_qty...")
|
||||||
|
min_notional_qty = self.exchange.get_min_notional_qty('BTC/USDT')
|
||||||
|
self.assertGreater(min_notional_qty, 0)
|
||||||
|
print(f"Get min notional qty test passed. Min notional qty: {min_notional_qty}")
|
||||||
|
|
||||||
|
def test_get_order(self):
|
||||||
|
print("\nRunning test_get_order...")
|
||||||
|
if self.api_keys is not None:
|
||||||
|
# You need to create an order manually on exchange to test this
|
||||||
|
order_id = 'your_order_id'
|
||||||
|
order = self.exchange.get_order('BTC/USDT', order_id)
|
||||||
|
self.assertIsInstance(order, dict)
|
||||||
|
self.assertEqual(order['id'], order_id)
|
||||||
|
print("Get order test passed.")
|
||||||
|
else:
|
||||||
|
print("test_get_order(): Can not test without providing API keys")
|
||||||
|
|
||||||
|
def test_place_order(self):
|
||||||
|
print("\nRunning test_place_order...")
|
||||||
|
if self.api_keys is not None:
|
||||||
|
# Be cautious with creating real orders. This is for testing purposes.
|
||||||
|
symbol = 'BTC/USDT'
|
||||||
|
order_type = 'limit'
|
||||||
|
side = 'buy'
|
||||||
|
amount = 0.001
|
||||||
|
price = 30000 # Adjust accordingly
|
||||||
|
|
||||||
|
result, order = self.exchange.place_order(
|
||||||
|
symbol=symbol,
|
||||||
|
side=side,
|
||||||
|
type=order_type,
|
||||||
|
timeInForce='GTC',
|
||||||
|
quantity=amount,
|
||||||
|
price=price
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result, 'Success')
|
||||||
|
self.assertIsInstance(order, dict)
|
||||||
|
self.assertIn('id', order)
|
||||||
|
print("Place order test passed.")
|
||||||
|
else:
|
||||||
|
print("test_place_order(): Can not test without providing API keys")
|
||||||
|
|
||||||
|
def test_get_open_orders(self):
|
||||||
|
print("\nRunning test_get_open_orders...")
|
||||||
|
if self.api_keys is not None:
|
||||||
|
open_orders = self.exchange.get_open_orders()
|
||||||
|
self.assertIsInstance(open_orders, list)
|
||||||
|
print("Get open orders test passed.")
|
||||||
|
else:
|
||||||
|
print("test_get_open_orders(): Can not test without providing API keys")
|
||||||
|
|
||||||
|
def test_get_active_trades(self):
|
||||||
|
print("\nRunning test_get_active_trades...")
|
||||||
|
if self.api_keys is not None:
|
||||||
|
active_trades = self.exchange.get_active_trades()
|
||||||
|
self.assertIsInstance(active_trades, list)
|
||||||
|
print("Get active trades test passed.")
|
||||||
|
else:
|
||||||
|
print("test_get_active_trades(): Can not test without providing API keys")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue