From c398a423a353174244c4d803c4ad987f221a9a81 Mon Sep 17 00:00:00 2001 From: Rob Date: Sun, 4 Aug 2024 02:06:23 -0300 Subject: [PATCH] implemented tests for Exchangeinterface.py --- src/BrighterTrades.py | 6 +- src/DataCache.py | 3 +- src/Exchange.py | 49 ++++---- src/ExchangeInterface.py | 214 ++++++++++++++++++++++++++++++++ src/app.py | 5 + src/exchangeinterface.py | 210 ------------------------------- src/templates/price_chart.html | 2 +- src/trade.py | 2 +- tests/test_DataCache.py | 2 +- tests/test_Exchange.py | 7 -- tests/test_Users.py | 2 +- tests/test_candles.py | 2 +- tests/test_exchangeinterface.py | 119 ++++++++++++++++++ tests/test_indicators.py | 2 +- tests/test_trade.py | 4 +- 15 files changed, 375 insertions(+), 254 deletions(-) create mode 100644 src/ExchangeInterface.py delete mode 100644 src/exchangeinterface.py create mode 100644 tests/test_exchangeinterface.py diff --git a/src/BrighterTrades.py b/src/BrighterTrades.py index 61d31d4..28ea178 100644 --- a/src/BrighterTrades.py +++ b/src/BrighterTrades.py @@ -6,7 +6,7 @@ from Strategies import Strategies from backtesting import Backtester from candles import Candles from Configuration import Configuration -from exchangeinterface import ExchangeInterface +from ExchangeInterface import ExchangeInterface from indicators import Indicators from Signals import Signals from trade import Trades @@ -429,14 +429,14 @@ class BrighterTrades: return None # Forward the request to trades. - status, result = self.trades.new_trade(target=vld('target'), symbol=vld('symbol'), price=vld('price'), + status, result = self.trades.new_trade(target=vld('exchange_name'), symbol=vld('symbol'), price=vld('price'), side=vld('side'), order_type=vld('orderType'), qty=vld('quantity')) if status == 'Error': print(f'Error placing the trade: {result}') return None - print(f'Trade order received: target={vld("target")}, ' + print(f'Trade order received: exchange_name={vld("exchange_name")}, ' f'symbol={vld("symbol")}, ' f'side={vld("side")}, ' f'type={vld("orderType")}, ' diff --git a/src/DataCache.py b/src/DataCache.py index 0d2f4de..cbbb5c8 100644 --- a/src/DataCache.py +++ b/src/DataCache.py @@ -5,8 +5,7 @@ from Database import Database from shared_utilities import query_satisfied, query_uptodate, unix_time_millis, timeframe_to_minutes import logging -# Set up logging -logging.basicConfig(level=logging.DEBUG) + logger = logging.getLogger(__name__) diff --git a/src/Exchange.py b/src/Exchange.py index 44f0ff1..ea8693a 100644 --- a/src/Exchange.py +++ b/src/Exchange.py @@ -1,12 +1,10 @@ import ccxt import pandas as pd from datetime import datetime, timedelta -from typing import Tuple, Dict, List, Union +from typing import Tuple, Dict, List, Union, Any import time import logging -# Configure logging -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -29,8 +27,11 @@ class Exchange: self.name = name self.api_key = api_keys['key'] if api_keys else None self.api_key_secret = api_keys['secret'] if api_keys else None + self.configured = False self.exchange_id = exchange_id self.client: ccxt.Exchange = self._connect_exchange() + if self.client: + self._check_authentication() self.exchange_info = self._set_exchange_info() self.intervals = self._set_avail_intervals() self.symbols = self._set_symbols() @@ -63,6 +64,17 @@ class Exchange: 'verbose': False }) + def _check_authentication(self): + try: + # Perform an authenticated request to check if the API keys are valid + self.client.fetch_balance() + self.configured = True + logger.info("Authentication successful. Trading bot configured.") + except ccxt.AuthenticationError: + logger.error("Authentication failed. Please check your API keys.") + except Exception as e: + logger.error(f"An error occurred: {e}") + @staticmethod def datetime_to_unix_millis(dt: datetime) -> int: """ @@ -179,23 +191,6 @@ class Exchange: logger.error(f"Error fetching minimum notional quantity for {symbol}: {str(e)}") return 0.0 - def _fetch_order(self, symbol: str, order_id: str) -> object: - """ - Fetches an order by its ID for a given symbol. - - Parameters: - symbol (str): The trading symbol (e.g., 'BTC/USDT'). - order_id (str): The ID of the order. - - Returns: - object: The order details. - """ - try: - return self.client.fetch_order(order_id, symbol) - except ccxt.BaseError as e: - logger.error(f"Error fetching order {order_id} for {symbol}: {str(e)}") - return None - def _set_symbols(self) -> List[str]: """ Sets the list of available symbols on the exchange. @@ -365,7 +360,7 @@ class Exchange: """ return self._fetch_min_notional_qty(symbol) - def get_order(self, symbol: str, order_id: str) -> object: + def get_order(self, symbol: str, order_id: str) -> Dict[str, Any] | None: """ Returns an order by its ID for a given symbol. @@ -374,9 +369,13 @@ class Exchange: order_id (str): The ID of the order. Returns: - object: The order details. + Dict[str, Any]: The order details or None on error. """ - return self._fetch_order(symbol, order_id) + try: + return self.client.fetch_order(order_id, symbol) + except ccxt.BaseError as e: + logger.error(f"Error fetching order {order_id} for {symbol}: {str(e)}") + return None def place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> Tuple[str, object]: @@ -418,7 +417,8 @@ class Exchange: precision = market_data['precision']['amount'] self.symbols_n_precision[symbol] = precision - def _place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> Tuple[str, object]: + def _place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> \ + Tuple[str, object]: """ Places an order on the exchange. @@ -433,6 +433,7 @@ class Exchange: Returns: Tuple[str, object]: A tuple containing the result ('Success' or 'Failure') and the order details or None. """ + def format_arg(value: float) -> float: precision = self.symbols_n_precision.get(symbol, 8) return float(f"{value:.{precision}f}") diff --git a/src/ExchangeInterface.py b/src/ExchangeInterface.py new file mode 100644 index 0000000..345b034 --- /dev/null +++ b/src/ExchangeInterface.py @@ -0,0 +1,214 @@ +import logging +import json +from typing import List, Any, Dict +import pandas as pd +import requests +import ccxt +from Exchange import Exchange + +logger = logging.getLogger(__name__) + + +# Utility function to add a row to a DataFrame +def add_row(df: pd.DataFrame, dic: Dict[str, Any]) -> pd.DataFrame: + return pd.concat([df, pd.DataFrame([dic])], ignore_index=True) + + +class ExchangeInterface: + """ + Connects, maintains, and routes data requests to/from multiple exchanges. + """ + + def __init__(self): + self.exchange_data = pd.DataFrame(columns=['user', 'name', 'reference', 'balances']) + self.available_exchanges = self.get_ccxt_exchanges() + + # Create a default user and exchange for unsigned requests + default_ex_name = 'binance' + self.connect_exchange(exchange_name=default_ex_name, user_name='default') + self.default_exchange = self.get_exchange(ename=default_ex_name, uname='default') + + def get_ccxt_exchanges(self) -> List[str]: + """Retrieve the list of available exchanges from CCXT.""" + return ccxt.exchanges + + def connect_exchange(self, exchange_name: str, user_name: str, api_keys: Dict[str, str] = None) -> bool: + """ + Initialize and store a reference to the specified exchange. + + :param exchange_name: The name of the exchange. + :param user_name: The name of the user connecting the exchange. + :param api_keys: Optional API keys for the exchange. + :return: True if successful, False otherwise. + """ + try: + exchange = Exchange(name=exchange_name, api_keys=api_keys, exchange_id=exchange_name.lower()) + self.add_exchange(user_name, exchange) + return True + except Exception as e: + logging.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}") + return False + + def add_exchange(self, user_name: str, exchange: Exchange): + """ + Add an exchange to the user's list of exchanges. + + :param user_name: The name of the user. + :param exchange: The Exchange object to add. + """ + try: + row = {'user': user_name, 'name': exchange.name, 'reference': exchange, 'balances': exchange.balances} + self.exchange_data = add_row(self.exchange_data, row) + except Exception as e: + logging.error(f"Couldn't create an instance of the exchange! {str(e)}") + raise + + def get_exchange(self, ename: str, uname: str) -> Exchange: + """ + Get a reference to the specified exchange for a user. + + :param ename: The name of the exchange. + :param uname: The name of the user. + :return: The Exchange object. + """ + if not ename or not uname: + raise ValueError('Missing argument!') + + exchange_data = self.exchange_data.query("name == @ename and user == @uname") + if exchange_data.empty: + raise ValueError('No matching exchange found.') + + return exchange_data.at[exchange_data.index[0], 'reference'] + + def get_connected_exchanges(self, user_name: str) -> List[str]: + """ + Get a list of connected exchanges for a user. + + :param user_name: The name of the user. + :return: A list of connected exchange names. + """ + return self.exchange_data.loc[self.exchange_data['user'] == user_name, 'name'].tolist() + + def get_available_exchanges(self) -> List[str]: + """Get a list of available exchanges.""" + return self.available_exchanges + + def get_exchange_balances(self, user_name: str, name: str) -> pd.Series: + """ + Get the balances of a specified exchange for a specific user. + + :param user_name: The name of the user. + :param name: The name of the exchange. + :return: A Series containing the balances. + """ + filtered_data = self.exchange_data.query("user == @user_name and name == @name") + if not filtered_data.empty: + return filtered_data.iloc[0]['balances'] + else: + return pd.Series(dtype='object') # Return an empty Series if no match is found + + def get_all_balances(self, user_name: str) -> Dict[str, List[Dict[str, Any]]]: + """ + Get the balances of all connected exchanges for a user. + + :param user_name: The name of the user. + :return: A dictionary containing the balances of all connected exchanges. + """ + filtered_data = self.exchange_data.loc[self.exchange_data['user'] == user_name, ['name', 'balances']] + if filtered_data.empty: + return {} + + balances_dict = {row['name']: row['balances'] for _, row in filtered_data.iterrows()} + return balances_dict + + def get_all_activated(self, user_name: str, fetch_type: str = 'trades') -> Dict[str, List[Dict[str, Any]]]: + """ + Get active trades or open orders for all connected exchanges. + + :param user_name: The name of the user. + :param fetch_type: The type of data to fetch ('trades' or 'orders'). + :return: A dictionary indexed by exchange name with lists of active trades or open orders. + """ + filtered_data = self.exchange_data.loc[self.exchange_data['user'] == user_name, ['name', 'reference']] + if filtered_data.empty: + return {} + + data_dict = {} + for name, reference in filtered_data.itertuples(index=False): + if pd.isna(reference): + continue + + try: + if fetch_type == 'trades': + data = reference.get_active_trades() + elif fetch_type == 'orders': + data = reference.get_open_orders() + else: + logging.error(f"Invalid fetch type: {fetch_type}") + return {} + + data_dict[name] = data + except Exception as e: + logging.error(f"Error retrieving data for {name}: {str(e)}") + + return data_dict + + def get_order(self, symbol: str, order_id: str, exchange_name: str, user_name: str) -> Any: + """ + Get an order from a specified exchange. + + :param symbol: The trading symbol. + :param order_id: The order ID. + :param exchange_name: The name of the exchange. + :param user_name: The name of the user. + :return: The order details. + """ + exchange = self.get_exchange(ename=exchange_name, uname=user_name) + return exchange.get_order(symbol=symbol, order_id=order_id) + + def get_trade_info(self, trade, user_name: str, info_type: str) -> Dict[str, Any] | None: + """ + Get information about a trade (status, executed quantity, executed price). + + :param trade: The trade object. + :param user_name: The name of the user. + :param info_type: The type of information ('status', 'executed_qty', 'executed_price'). + :return: The requested information or None if the order is not found. + """ + exchange = self.get_exchange(ename=trade.target, uname=user_name) + if exchange.configured is False: + logger.error("Must configure API keys to request trade info.") + return None + + order = exchange.get_order(symbol=trade.symbol, order_id=trade.order.orderId) + + if order is None: + logger.error(f"Order {trade.order.orderId} for {trade.symbol} not found.") + return None + + if isinstance(order, dict): + if info_type == 'status': + return order.get('status') + elif info_type == 'executed_qty': + return order.get('filled') + elif info_type == 'executed_price': + return order.get('average') + else: + logger.error(f"Invalid info type: {info_type}") + return None + else: + logger.error("Order object is not a dictionary") + return None + + def get_price(self, symbol: str, price_source: str = None) -> float: + """ + Get the current price of a trading pair. + + :param symbol: The trading symbol. + :param price_source: Optional alternative source for price. + :return: The current price. + """ + if price_source is None: + return self.default_exchange.get_price(symbol=symbol) + else: + raise ValueError(f'No implementation for price source: {price_source}') diff --git a/src/app.py b/src/app.py index 3c80252..d49fa1a 100644 --- a/src/app.py +++ b/src/app.py @@ -1,4 +1,6 @@ import json +import logging + from flask import Flask, render_template, request, redirect, jsonify, session, flash from flask_cors import CORS from flask_sock import Sock @@ -8,6 +10,9 @@ from email_validator import validate_email, EmailNotValidError import config from BrighterTrades import BrighterTrades +# Set up logging +logging.basicConfig(level=logging.DEBUG) + # Create a BrighterTrades object. This the main application that maintains access to the server, local storage, # and manages objects that process trade data. brighter_trades = BrighterTrades() diff --git a/src/exchangeinterface.py b/src/exchangeinterface.py deleted file mode 100644 index 5101223..0000000 --- a/src/exchangeinterface.py +++ /dev/null @@ -1,210 +0,0 @@ -import logging -import json -from typing import List, Any, Dict -import pandas as pd -import requests -import ccxt - -from Exchange import Exchange - -# Setup logging -# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') - - -# This just makes this method cleaner. -def add_row(df, dic): - df = pd.concat([df, pd.DataFrame.from_records([dic])], ignore_index=True) - return df - - -class ExchangeInterface: - """ - Connects and maintains and routes data requests from exchanges. - """ - - def __init__(self): - # Create a dataframe to hold all the references and info for each user's configured exchanges. - self.exchange_data = pd.DataFrame(columns=['user', 'name', 'reference', 'balances']) - - # Populate the list of available exchanges from CCXT - self.available_exchanges = self.get_ccxt_exchanges() - - def get_ccxt_exchanges(self) -> List[str]: - """Retrieve the list of available exchanges from CCXT.""" - return ccxt.exchanges - - def connect_exchange(self, exchange_name: str, user_name: str, api_keys: dict = None) -> bool: - """ - Initialize and store a reference to the available exchanges. - - :param user_name: The name of the user connecting the exchange. - :param api_keys: dict - {api: key, api-secret: key} - :param exchange_name: str - The name of the exchange. - :return: True if success | None on fail. - """ - # logging.debug( - # f"Attempting to connect to exchange '{exchange_name}' for user '{user_name}' with API keys: {api_keys}") - - try: - # Initialize the exchange object - exchange = Exchange(name=exchange_name, api_keys=api_keys, exchange_id=exchange_name.lower()) - - # Update exchange data with the new connection - self.add_exchange(user_name, exchange) - # logging.debug(f"Successfully connected to exchange '{exchange_name}' for user '{user_name}'") - - return True - except Exception as e: - logging.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}") - return False # Failed to connect - - def add_exchange(self, user_name: str, exchange: Exchange): - try: - row = {'user': user_name, 'name': exchange.name, 'reference': exchange, 'balances': exchange.balances} - self.exchange_data = add_row(self.exchange_data, row) - except Exception as e: - if hasattr(e, 'status_code') and e.status_code == 400 and e.error_code == -1021: - logging.error("Timestamp ahead of server's time error: Sync your system clock to fix this.") - logging.error("Couldn't create an instance of the exchange!:\n", e) - raise - - def get_exchange(self, ename: str, uname: str) -> Any: - """Return a reference to the exchange_name.""" - if not ename or not uname: - raise ValueError('Missing argument!') - - exchange_data = self.exchange_data.query("name == @ename and user == @uname") - if exchange_data.empty: - raise ValueError('No matching exchange found.') - - return exchange_data.at[exchange_data.index[0], 'reference'] - - def get_connected_exchanges(self, user_name: str) -> List[str]: - """Return a list of the connected exchanges.""" - connected_exchanges = self.exchange_data.loc[self.exchange_data['user'] == user_name, 'name'].tolist() - return connected_exchanges - - def get_available_exchanges(self) -> List[str]: - """ Return a list of the exchanges available to connect to""" - return self.available_exchanges - - def get_exchange_balances(self, name: str) -> pd.Series: - """ Return the balances of a single exchange_name""" - return self.exchange_data.query("name == @name")['balances'] - - def get_all_balances(self, user_name: str) -> Dict[str, List[Dict[str, Any]]]: - """ - Return the balances of all connected exchanges indexed by name. - - :param user_name: str - The name of the user. - :return: dict - A dictionary containing the balances of all connected exchanges. - The dictionary is indexed by exchange name, and the values are lists of dictionaries - containing the asset balances and P&L information for each exchange. - """ - filtered_data = self.exchange_data.loc[self.exchange_data['user'] == user_name, ['name', 'balances']] - if filtered_data.empty: - return {} - - balances_dict = {} - for _, row in filtered_data.iterrows(): - exchange_name = row['name'] - balances = row['balances'] - balances_dict[exchange_name] = balances - - return balances_dict - - def get_all_activated(self, user_name: str, fetch_type: str = 'trades') -> Dict[str, List[Dict[str, Any]]]: - """Get active trades or open orders as a dictionary indexed by name""" - filtered_data = self.exchange_data.loc[self.exchange_data['user'] == user_name, ['name', 'reference']] - if filtered_data.empty: - return {} - - data_dict = {} - for name, reference in filtered_data.itertuples(index=False): - if pd.isna(reference): - continue - - try: - if fetch_type == 'trades': - data = reference.get_active_trades() - elif fetch_type == 'orders': - data = reference.get_open_orders() - else: - logging.error(f"Invalid fetch type: {fetch_type}") - return {} - - data_dict[name] = data - except Exception as e: - logging.error(f"Error retrieving data for {name}: {str(e)}") - - return data_dict - - def get_order(self, symbol: str, order_id: str, target: str, user_name: str) -> Any: - """ - Return order - from a target exchange_interface. - :param user_name: The name of the user making the request. - :param symbol: trading symbol - :param order_id: The order ID - :param target: The exchange_interface to fetch this info. - :return: { Success: order| Fail: None } - """ - # Target exchange_interface. - exchange = self.get_exchange(ename=target, uname=user_name) - # Return the order. - return exchange.get_order(symbol=symbol, order_id=order_id) - - def get_trade_status(self, trade, user_name: str) -> str: - """ - Return the status of a trade - Todo: trade order.status might be outdated this request the status from the exchanges order record. - Todo You could just update the trade and get the status from there. - """ - # Target exchange_interface. - exchange = self.get_exchange(ename=trade.target, uname=user_name) - # Get the order from the target. - order = exchange.get_order(symbol=trade.symbol, order_id=trade.order.orderId) - # Return status. - return order['status'] - - def get_trade_executed_qty(self, trade, user_name: str) -> float: - """ - Return the executed quantity of a trade. - - :param user_name: The name of the user executing the command. - :param trade: todo: - - """ - # Target exchange_interface. - exchange = self.get_exchange(ename=trade.target, uname=user_name) - # Get the order from the target. - order = exchange.get_order(symbol=trade.symbol, order_id=trade.order.orderId) - # Return quantity. - return order['executedQty'] - - def get_trade_executed_price(self, trade, user_name: str) -> float: - """ - Return the average price of executed quantity of a trade - - :param user_name: The name of the user executing this trade. - :param trade: - """ - # Target exchange_interface. - exchange = self.get_exchange(ename=trade.target, uname=user_name) - # Get the order from the target. - order = exchange.get_order(symbol=trade.symbol, order_id=trade.order.orderId) - # Return quantity. - return order['price'] - - @staticmethod - def get_price(symbol: str, price_source: str = None) -> float: - """ - :param price_source: alternative sources for price. - :param symbol: The symbol of the trading pair. - :return: The current ticker price. - """ - if price_source is None: - request = requests.get(f'https://api.binance.com/api/v3/ticker/price?symbol={symbol}') - json_obj = json.loads(request.text) - return float(json_obj['price']) - else: - raise ValueError(f'No implementation for price source: {price_source}') diff --git a/src/templates/price_chart.html b/src/templates/price_chart.html index fcc8512..f823ffc 100644 --- a/src/templates/price_chart.html +++ b/src/templates/price_chart.html @@ -3,7 +3,7 @@
- +
diff --git a/src/trade.py b/src/trade.py index 15d1716..263b678 100644 --- a/src/trade.py +++ b/src/trade.py @@ -513,7 +513,7 @@ class Trades: # Required fields. if not target or not symbol or not side or not order_type: - return 'Error', 'Missing argument: target, symbol, side and order_type required.' + return 'Error', 'Missing argument: exchange_name, symbol, side and order_type required.' # If quantity is not provided set it to a small amount. # It will be rounded up to the minimum required amount by the exchange_interface. diff --git a/tests/test_DataCache.py b/tests/test_DataCache.py index 9ef4376..901c5aa 100644 --- a/tests/test_DataCache.py +++ b/tests/test_DataCache.py @@ -1,5 +1,5 @@ from DataCache import DataCache -from exchangeinterface import ExchangeInterface +from ExchangeInterface import ExchangeInterface import unittest import pandas as pd import datetime as dt diff --git a/tests/test_Exchange.py b/tests/test_Exchange.py index b61ee70..2f13c8e 100644 --- a/tests/test_Exchange.py +++ b/tests/test_Exchange.py @@ -181,13 +181,6 @@ class TestExchange(unittest.TestCase): self.assertEqual(price, 0.0) self.mock_client.fetch_ticker.assert_called_with('BTC/USDT') - @patch('ccxt.binance') - def test_fetch_order_invalid_response(self, mock_exchange): - self.mock_client.fetch_order.side_effect = ccxt.ExchangeError('Invalid response') - order = self.exchange.get_order('BTC/USDT', 'invalid_order_id') - self.assertIsNone(order) - self.mock_client.fetch_order.assert_called_with('invalid_order_id', 'BTC/USDT') - if __name__ == '__main__': unittest.main() diff --git a/tests/test_Users.py b/tests/test_Users.py index df2061d..3f33a2c 100644 --- a/tests/test_Users.py +++ b/tests/test_Users.py @@ -2,7 +2,7 @@ import json from Configuration import Configuration from DataCache import DataCache -from exchangeinterface import ExchangeInterface +from ExchangeInterface import ExchangeInterface # Object that interacts and maintains exchange_interface and account data exchanges = ExchangeInterface() diff --git a/tests/test_candles.py b/tests/test_candles.py index 8b1c27d..0ba64f8 100644 --- a/tests/test_candles.py +++ b/tests/test_candles.py @@ -2,7 +2,7 @@ import datetime from candles import Candles from Configuration import Configuration -from exchangeinterface import ExchangeInterface +from ExchangeInterface import ExchangeInterface def test_sqlite(): diff --git a/tests/test_exchangeinterface.py b/tests/test_exchangeinterface.py new file mode 100644 index 0000000..15ea9b2 --- /dev/null +++ b/tests/test_exchangeinterface.py @@ -0,0 +1,119 @@ +import logging +import unittest +from unittest.mock import MagicMock, patch +from datetime import datetime +from ExchangeInterface import ExchangeInterface +from Exchange import Exchange +from typing import Dict, Any + +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +class Trade: + """ + Mock Trade class to simulate trade objects used in the tests. + """ + + def __init__(self, target, symbol, order_id): + self.target = target + self.symbol = symbol + self.order = MagicMock(orderId=order_id) + + +class TestExchangeInterface(unittest.TestCase): + + @patch('Exchange.Exchange') + def setUp(self, MockExchange): + + self.exchange_interface = ExchangeInterface() + + # Mock exchange instances + self.mock_exchange = MockExchange.return_value + + # Setup test data + self.user_name = "test_user" + self.exchange_name = "binance" + self.api_keys = {'key': 'test_key', 'secret': 'test_secret'} + + # Connect the mock exchange + self.exchange_interface.connect_exchange(self.exchange_name, self.user_name, self.api_keys) + + # Mock trade object + self.trade = Trade(target=self.exchange_name, symbol="BTC/USDT", order_id="12345") + + # Example order data + self.order_data: Dict[str, Any] = { + 'status': 'closed', + 'filled': 1.0, + 'average': 50000.0 + } + + def test_get_trade_status(self): + self.mock_exchange.get_order.return_value = self.order_data + assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict + + with self.assertLogs(level='ERROR') as log: + status = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'status') + if any('Must configure API keys' in message for message in log.output): + return + self.assertEqual(status, 'closed') + + def test_get_trade_executed_qty(self): + self.mock_exchange.get_order.return_value = self.order_data + assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict + + with self.assertLogs(level='ERROR') as log: + executed_qty = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'executed_qty') + if any('Must configure API keys' in message for message in log.output): + return + self.assertEqual(executed_qty, 1.0) + + def test_get_trade_executed_price(self): + self.mock_exchange.get_order.return_value = self.order_data + assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict + + with self.assertLogs(level='ERROR') as log: + executed_price = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'executed_price') + if any('Must configure API keys' in message for message in log.output): + return + self.assertEqual(executed_price, 50000.0) + + def test_invalid_info_type(self): + self.mock_exchange.get_order.return_value = self.order_data + assert isinstance(self.mock_exchange.get_order.return_value, dict) # Ensure return value is dict + + with self.assertLogs(level='ERROR') as log: + result = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'invalid_type') + if any('Must configure API keys' in message for message in log.output): + return + self.assertIsNone(result) + self.assertTrue(any('Invalid info type' in message for message in log.output)) + + def test_order_not_found(self): + self.mock_exchange.get_order.return_value = None + + with self.assertLogs(level='ERROR') as log: + result = self.exchange_interface.get_trade_info(self.trade, self.user_name, 'status') + if any('Must configure API keys' in message for message in log.output): + return + self.assertIsNone(result) + self.assertTrue(any('Order 12345 for BTC/USDT not found.' in message for message in log.output)) + + def test_get_price_default_source(self): + # Setup the mock to return a specific price + symbol = "BTC/USD" + price = self.exchange_interface.get_price(symbol) + + self.assertLess(0.1, price) + + def test_get_price_with_invalid_source(self): + symbol = "BTC/USD" + with self.assertRaises(ValueError) as context: + self.exchange_interface.get_price(symbol, price_source="invalid_source") + + self.assertTrue('No implementation for price source: invalid_source' in str(context.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_indicators.py b/tests/test_indicators.py index 40831e6..456eb42 100644 --- a/tests/test_indicators.py +++ b/tests/test_indicators.py @@ -3,7 +3,7 @@ import json from Configuration import Configuration from DataCache import DataCache from candles import Candles -from exchangeinterface import ExchangeInterface +from ExchangeInterface import ExchangeInterface from indicators import Indicators ALPACA_API_KEY = 'PKPSH7OHWH3Q5AUBZBE5' diff --git a/tests/test_trade.py b/tests/test_trade.py index 633e287..48245c8 100644 --- a/tests/test_trade.py +++ b/tests/test_trade.py @@ -1,4 +1,4 @@ -from exchangeinterface import ExchangeInterface +from ExchangeInterface import ExchangeInterface from trade import Trades @@ -103,7 +103,7 @@ def test_load_trades(): print(f'Active trades: {test_trades_obj.active_trades}') trades = [{ 'order_price': 24595.4, - 'target': 'backtester', + 'exchange_name': 'backtester', 'base_order_qty': 0.05, 'order': None, 'fee': 0.1,