diff --git a/markdown/App/App.md b/markdown/App/App.md index 980e493..3e7564a 100644 --- a/markdown/App/App.md +++ b/markdown/App/App.md @@ -56,16 +56,19 @@ legend top =/index ———————————————————————————————— Handles the main landing page, logs in the user, + ensures a valid connection with an exchange, loads dynamic data, and renders the landing page. end legend start :Log in user; + :Ensure a valid connection with an exchange; :Load dynamic data; :Render Landing Page; stop @enduml + ``` ### /ws diff --git a/requirements.txt b/requirements.txt index e76a63e..f217fdf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,14 @@ numpy==1.24.3 flask==2.3.2 +flask_cors==3.0.10 +flask_sock==0.7.0 config~=0.5.1 PyYAML~=6.0 -binance~=0.3 requests==2.30.0 pandas==2.0.1 passlib~=1.7.4 SQLAlchemy==2.0.13 -ccxt~=3.0.105 \ No newline at end of file +ccxt==4.3.65 +email-validator~=2.2.0 +TA-Lib~=0.4.32 +bcrypt~=4.2.0 diff --git a/src/AlpacaPaperExchange.py b/src/AlpacaPaperExchange.py deleted file mode 100644 index cb1c170..0000000 --- a/src/AlpacaPaperExchange.py +++ /dev/null @@ -1,271 +0,0 @@ -import json -import math -import time -from datetime import datetime, date, time, timedelta, timezone - -import pandas -import pandas as pd -from alpaca.trading import GetAssetsRequest, AssetClass, MarketOrderRequest, OrderSide, TimeInForce - -from Exchange import Exchange -import config -from alpaca.trading.client import TradingClient -from alpaca.data.historical import CryptoHistoricalDataClient -from alpaca.data.requests import CryptoBarsRequest, CryptoLatestQuoteRequest -from alpaca.data.timeframe import TimeFrame, TimeFrameUnit -import alpaca_trade_api as tradeapi - - -class AlpacaPaperExchange(Exchange): - def __init__(self, name, api_keys): - super().__init__(name, api_keys=api_keys) - - def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt=None): - """ - Return a dataframe containing rows of candle attributes. - [(symbol, start_datetime, open, high, low, - close, Volume, trade_count, vwap)] - - :param symbol: str: - Trading symbol. - :param interval: str: - Timeframe for the candle sticks. - :param start_dt: - Start datetime of data stream. - :param end_dt: - End datetime of data stream. - :return: pd.dataframe - The requested data. - """ - - if end_dt is None: - end_dt = datetime.utcnow() - - # Extract the numerical part of the timeframe param. - digits = int("".join([i if i.isdigit() else "" for i in interval])) - # Extract the alpha part of the timeframe param. - letter = "".join([i if i.isalpha() else "" for i in interval]) - # Convert the abbreviation to the name of a TimeFrameUnit object attribute. - time_frame = "".join(['Minute' if i == 'm' else - 'Hour' if i == 'h' else - 'Day' if i == 'd' else - 'Week' if i == 'w' else - 'Month' if i == 'M' else 'Year' for i in letter]) - # Fill in the request parameters. - request_params = CryptoBarsRequest( - symbol_or_symbols=symbol, - timeframe=TimeFrame(amount=digits, unit=TimeFrameUnit[time_frame]), - start=start_dt, - end=end_dt - ) - if start_dt > end_dt: - raise ValueError(start_dt, end_dt) - # Request the bars from the exchange_interface. - bars = self.client['historical_data'].get_crypto_bars(request_params) - # Convert the response object into a list of dictionaries. - candles_dict = [c.__dict__ for c in bars[symbol]] - # Convert the data into a dataframe. - candles = pd.DataFrame(candles_dict) - # Reformat this dataframe to make it consistent with other exchanges. - del candles['symbol'] - candles.rename({'timestamp': 'open_time'}, axis='columns', inplace=True) - # Convert the date column to a start_datetime to make it consistent with the other exchanges. - candles['open_time'] = candles.open_time.values.astype('int64') // 10 ** 6 - # Return the data. - return candles - - def _fetch_price(self, symbol) -> float: - """ - Get the latest price for a single symbol. - - :param symbol: str - The symbol of the symbol. - :return: float - The minimum quantity sold per trade. - """ - request_params = CryptoLatestQuoteRequest(symbol_or_symbols=symbol) - lq = self.client['historical_data'].get_crypto_latest_quote(request_params) - lqd = lq[symbol] - return lqd.bid_price - - def _fetch_min_qty(self, symbol) -> float: - """ - Get the minimum base quantity the exchange_interface allows per order. - - :param symbol: str - The symbol of the trading pair. - :return: float - The minimum quantity sold per trade. - """ - assets = self.client['trading'].get_asset(symbol) - return assets.min_order_size - - def _fetch_min_notional_qty(self, symbol) -> float: - """ - Get the minimum amount of the quote currency allowed per trade. - - :param symbol: The symbol of the trading pair. - :return: The minimum quantity sold per trade. - """ - assets = self.client['trading'].get_asset(symbol) - return assets.price_increment - - def _fetch_order(self, symbol, order_id) -> object: - """Return the order status for a specific order.""" - return self.client.query_order(symbol=symbol, orderId=order_id) - - def _connect_exchange(self) -> object: - """Connects to the exchange_interface and sets a reference the client.""" - client = { - 'historical_data': CryptoHistoricalDataClient(), - 'trading': TradingClient(self.api_key, self.api_key_secret, paper=True), - 'trade_api': tradeapi.REST(self.api_key, self.api_key_secret, "https://paper-api.alpaca.markets/") - } - - return client - - def _set_exchange_info(self) -> dict: - """Fetches Info on all symbols from the exchange_interface.""" - return self.client['trading'].get_account() - - def _set_symbols(self) -> list: - """Get information of coins (available for deposit and withdraw) for user""" - params = GetAssetsRequest(asset_class=AssetClass.CRYPTO) - # read all the symbol data into a dataframe. - assets = pandas.DataFrame.from_records(self.client['trading'].get_all_assets(params)) - # This parses things weird, so I did some format manipulation. - # Extract the proper headers - headers = assets.head(1).applymap(lambda row: row[0]).values[0] - # Rename the headers that were assigned. - assets = assets.set_axis(headers, axis=1) - # remove the headers from the data. - assets = assets.applymap(lambda row: row[1]) - # remove rows that aren't tradable. - assets = assets.query("tradable != False") - # Return a list of just the symbols. - return assets.symbol.values - - def _set_balances(self) -> list: - """ - Retrieves the account balances, including cash balance and current P&L (Profit & Loss). - - :return: list - A list of dictionaries containing the asset balances and P&L information. - Each dictionary has the following keys: - - 'asset': The asset symbol (e.g., 'USD', 'BTC', 'ETH'). - - 'balance': The balance of the asset. - - 'pnl': The current P&L (Profit & Loss) of the account. - """ - account = self.client['trade_api'].get_account() - cash_balance = account.cash - portfolio_value = float(account.equity) - buying_power = float(account.buying_power) - - # Calculate P&L as the difference between portfolio value and buying power, - # or set to 0 if no positions or open orders - positions = self.client['trade_api'].list_positions() - open_orders = self.client['trade_api'].list_orders(status='open') - - if len(positions) == 0 and len(open_orders) == 0: - current_pl = 0.0 - else: - current_pl = portfolio_value - buying_power - - non_zero_assets = [ - { - 'asset': 'USD', # Assuming the cash balance is in USD - 'balance': cash_balance, - 'pnl': current_pl - } - ] - return non_zero_assets - - def get_asset_balances(self): - """ Get the individual asset balances""" - # Get the account information - account = self.client['trading'].get_account() - - # Retrieve the balances of each non-zero asset - balances = [] - for asset in account.assets: - if float(asset.qty) > 0: - balance = { - 'asset': asset.symbol, - 'balance': asset.qty - } - balances.append(balance) - - return balances - - def get_active_trades(self): - formatted_trades = [] - # Get open trades - positions = self.client['trade_api'].list_positions() - for position in positions: - if float(position.qty) != 0.0: - active_trade = { - 'symbol': position.symbol, - 'side': 'buy' if float(position.qty) > 0 else 'sell', - 'quantity': abs(float(position.qty)), - 'price': float(position.avg_entry_price) - } - formatted_trades.append(active_trade) - return formatted_trades - - def get_open_orders(self): - formatted_orders = [] - # Get open orders - open_orders = self.client['trade_api'].list_orders(status='open') - for order in open_orders: - open_order = { - 'symbol': order.symbol, - 'side': order.side, - 'quantity': order.qty, - 'price': order.filled_avg_price - } - formatted_orders.append(open_order) - return formatted_orders - - def _set_precision_rule(self, symbol) -> None: - """Appends a record of places after the decimal required by the exchange_interface indexed by symbol.""" - - # I don't think this exchange_interface has precision rule. - # return the number of decimal places of min_trade_increment - assets = self.client['trading'].get_asset(symbol) - min_trade_rule = assets.min_trade_increment - precision = len(str(min_trade_rule).split('.')[1]) - self.symbols_n_precision[symbol] = precision - return - - def _place_order(self, symbol, side, type, timeInForce, quantity, price): - - def format_arg(value, arg_name): - """ - Function to convert an arguments value to a desired precision and type float. - - :param value: The Value of the argument. - :param arg_name: The name of the argument. - :return: float : The formatted output. - """ - # The required level of precision for this trading pair. - precision = self.get_symbol_precision_rule(symbol) - # If quantity was too low, set to the smallest allowable amount. - minimum = 0 - if arg_name == 'quantity': - # The minimum quantity aloud to be traded. - minimum = self.get_min_qty(symbol) - elif arg_name == 'price': - # The minimum price aloud to be traded. - minimum = self.get_min_notional_qty(symbol) - if value < minimum: - value = minimum - return float(f"{value:.{precision}f}") - - # Set price and quantity to desired precision and type. - quantity = format_arg(quantity, 'quantity') - price = format_arg(price, 'price') - - if side == 'buy': - side = OrderSide.BUY - else: - side = OrderSide.SELL - market_order_data = MarketOrderRequest( - symbol=symbol, - qty=quantity, - side=side, - time_in_force=TimeInForce[timeInForce] - - ) - market_order = self.client['trading'].submit_order(market_order_data) - result = 'Success' - return result, market_order diff --git a/src/BinanceFutures.py b/src/BinanceFutures.py deleted file mode 100644 index 93357fb..0000000 --- a/src/BinanceFutures.py +++ /dev/null @@ -1,218 +0,0 @@ -import pandas as pd -from Exchange import Exchange -import ccxt - - -class BinanceFuturesExchange(Exchange): - def __init__(self, name, api_keys): - super().__init__(name, api_keys=api_keys) - - import ccxt - - def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt=None): - """ - Return a list [(Open time, Open, High, Low, Close, - Volume, Close time, Quote symbol volume, - Number of trades, Taker buy base symbol volume, - Taker buy quote symbol volume)] - - :param symbol: str: Trading symbol. - :param interval: str: Timeframe for the candlesticks. - :param start_dt: start datetime of data stream. - :param end_dt: end datetime of data stream. - :return: pandas DataFrame containing the candlestick data. - """ - - # Convert the start date to a Unix timestamp in seconds. - start_str = int(start_dt.timestamp()) - end_str = None - if end_dt is not None: - end_str = int(end_dt.timestamp()) - candles = self.client.fetch_ohlcv(symbol=symbol, timeframe=interval, since=start_str, limit=None, end=end_str) - # Create and return a pandas DataFrame. - df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume'] - candles = pd.DataFrame(candles, columns=df_columns) - # Return the data - return candles - - def _fetch_price(self, symbol) -> float: - """ - Get the latest price for a single symbol. - - :param symbol: The symbol of the trading pair. - :return: The latest price for the symbol. - """ - ticker = self.client.fetch_ticker(symbol) - return float(ticker['last']) - - def _fetch_min_qty(self, symbol) -> float: - """ - Get the minimum base quantity the exchange allows per order. - - :param symbol: The symbol of the trading pair. - :return: The minimum quantity sold per trade. - """ - exchange_info = self.client.fetch_tickers() - symbol_info = exchange_info.get(symbol) - if symbol_info is None: - raise ValueError(f'No symbol: {symbol}') - # Extract the minimum quantity value from the symbol info - min_qty = symbol_info.get('info').get('filters')[2].get('minQty') - return float(min_qty) - - def _fetch_min_notional_qty(self, symbol) -> float: - """ - Get the minimum amount of the quote currency allowed per trade. - - :param symbol: The symbol of the trading pair. - :return: The minimum quantity sold per trade. - """ - exchange_info = self.client.fetch_tickers() - symbol_info = exchange_info.get(symbol) - if symbol_info is None: - raise ValueError(f'No symbol: {symbol}') - # Extract the minimum notional value from the symbol info - min_notional = symbol_info.get('info').get('filters')[0].get('minNotional') - return float(min_notional) - - def _fetch_order(self, symbol, order_id) -> object: - """Return the order status for a specific order.""" - order = self.client.fetch_order(id=order_id, symbol=symbol) - return order - - def _connect_exchange(self) -> object: - """Connects to the exchange_interface and sets a reference to the client.""" - exchange = ccxt.binance({ - 'apiKey': self.api_key, - 'secret': self.api_key_secret, - 'enableRateLimit': True, - 'options': { - 'defaultType': 'future', - } - }) - return exchange - - def _set_exchange_info(self) -> dict: - """Fetches Info on all symbols from the exchange.""" - exchange_info = self.client.load_markets() - return exchange_info - - def _set_symbols(self) -> list: - """Get information of coins (available for deposit and withdraw) for user""" - markets = self.client.load_markets() - - # Extract symbol from each market - symbols = list(markets.keys()) - - return symbols - - def _set_balances(self) -> list: - acc_info = self.client.fetch_balance() - non_zero_assets = [] - for asset, balance in acc_info['total'].items(): - if float(balance) > 0: - non_zero_asset = { - 'asset': asset, - 'balance': format(balance, '.8f'), - 'pnl': acc_info['info']['totalUnrealizedProfit'] - } - non_zero_assets.append(non_zero_asset) - return non_zero_assets - - def _set_precision_rule(self, symbol) -> None: - # Isolate a list of symbol rules from the exchange_info - all_rules = self.exchange_info['symbols'] - # Get the rules for a specific symbol. - symbol_rules = next(filter(lambda rules: (rules['pair'] == symbol), all_rules), None) - # Add the precision rule to a local index. - self.symbols_n_precision[symbol] = symbol_rules['baseAssetPrecision'] - return - - def _place_order(self, symbol, side, type, timeInForce, quantity, price): - def format_arg(value, arg_name): - """ - Function to convert an arguments value to a desired precision and type float. - - :param value: The Value of the argument. - :param arg_name: The name of the argument. - :return: float : The formatted output. - """ - # The required level of precision for this trading pair. - precision = self.client.precision_for_symbol(symbol) - # If quantity was too low, set to the smallest allowable amount. - minimum = 0 - if arg_name == 'quantity': - # The minimum quantity allowed to be traded. - minimum = self.client.amount_to_precision(symbol, - self.client.markets[symbol]['limits']['amount']['min']) - elif arg_name == 'price': - # The minimum price allowed to be traded. - minimum = self.client.price_to_precision(symbol, self.client.markets[symbol]['limits']['price']['min']) - if value < minimum: - value = minimum - return float(f"{value:.{precision}f}") - - # Set price and quantity to desired precision and type. - quantity = format_arg(quantity, 'quantity') - price = format_arg(price, 'price') - - order_params = { - 'symbol': symbol, - 'side': side, - 'type': type, - 'timeInForce': timeInForce, - 'quantity': quantity, - 'price': price - } - data = self.client.create_order(**order_params) - result = 'Success' - return result, data - - def get_active_trades(self): - formatted_trades = [] - # Get open positions - positions = self.client.fetch_positions() - for position in positions: - if float(position['contracts']) != 0.0: - active_trade = { - 'symbol': position['symbol'], - 'side': 'buy' if float(position['contracts']) > 0 else 'sell', - 'quantity': abs(float(position['contracts'])), - 'price': float(position['entryPrice']) - } - formatted_trades.append(active_trade) - return formatted_trades - - def get_open_orders(self): - formatted_orders = [] - # Get open orders - # open_orders = self.client.fetch_open_orders() - # Todo: fetching open orders without specifying a symbol is rate-limited to one call per 1252 seconds. - open_orders = {} - for order in open_orders: - open_order = { - 'symbol': order['symbol'], - 'side': order['side'], - 'quantity': order['amount'], - 'price': order['price'] - } - formatted_orders.append(open_order) - return formatted_orders - - -class BinanceCoinExchange(BinanceFuturesExchange): - def __init__(self, name, api_keys): - super().__init__(name, api_keys=api_keys) - - def _connect_exchange(self) -> object: - """Connects to the exchange_interface and sets a reference to the client.""" - return ccxt.binance({ - 'apiKey': self.api_key, - 'secret': self.api_key_secret, - 'enableRateLimit': True, - # Additional exchange-specific options can be added here - 'options': { - 'defaultType': 'coinfuture', # Connect to the Binance Coin Futures API - } - }) - diff --git a/src/BinanceSpot.py b/src/BinanceSpot.py deleted file mode 100644 index 6cd0a38..0000000 --- a/src/BinanceSpot.py +++ /dev/null @@ -1,199 +0,0 @@ -import pandas as pd -from binance.enums import HistoricalKlinesType -from shared_utilities import unix_time_millis - -import config - -from Exchange import Exchange -from binance.client import Client as BinanceClient -from binance.exceptions import BinanceAPIException - - -class BinanceSpotExchange(Exchange): - def __init__(self, name, api_keys): - super().__init__(name, api_keys) - - def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt=None): - """ - Return a dataframe containing rows of candle attributes. - [(Open time, Open, High, Low, Close, - Volume, Close time, Quote symbol volume, - Number of trades, Taker buy base symbol volume, - Taker buy quote symbol volume, Ignore)] - - :param symbol: str: Trading symbol. - :param interval: str: Timeframe for the candle sticks. - :param start_dt: start time-date of data stream. - :param end_dt: end datetime of data stream. - :return: pd.dataframe - """ - # Set this to Spot. - klines_type = HistoricalKlinesType.SPOT - # Convert the start date to a Unix start_datetime milliseconds. - start_str = int(unix_time_millis(start_dt)) - end_str = end_dt - if end_dt is not None: - end_str = int(unix_time_millis(end_dt)) - - candles = self.client.get_historical_klines(symbol=symbol, interval=interval, start_str=start_str, - end_str=end_str, klines_type=klines_type) - # Create a pandas DataFrame. - df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume', 'close_time', - 'quote_volume', 'num_trades', 'taker_buy_base_volume', 'taker_buy_quote_volume', 'ignore'] - candles = pd.DataFrame(candles, columns=df_columns) - # Delete this useless column - del candles['ignore'] - # Return the data - return candles - - def _fetch_price(self, symbol) -> float: - """ - Get the latest price for a single symbol. - - :param symbol: str - The symbol of the symbol. - :return: float - The minimum quantity sold per trade. - """ - record = self.client.get_symbol_ticker(**{'symbol': symbol}) - return float(record['price']) - - def _fetch_min_qty(self, symbol) -> float: - """ - Get the minimum base quantity the exchange_interface allows per order. - - :param symbol: str - The symbol of the trading pair. - :return: float - The minimum quantity sold per trade. - """ - info = self.client.get_symbol_info(symbol=symbol) - for f_index, item in enumerate(info['filters']): - if item['filterType'] == 'LOT_SIZE': - break - return float(info['filters'][f_index]['minQty']) # 'minQty' - - def _fetch_min_notional_qty(self, symbol) -> float: - """ - Get the minimum amount of the quote currency allowed per trade. - - :param symbol: The symbol of the trading pair. - :return: The minimum quantity sold per trade. - """ - info = self.client.get_symbol_info(symbol=symbol) - for f_index, item in enumerate(info['filters']): - if item['filterType'] == 'MIN_NOTIONAL': - break - return float(info['filters'][f_index]['minNotional']) # 'minNotional' - - def _fetch_order(self, symbol, order_id) -> object: - """Return the order status for a specific order.""" - return self.client.get_order(symbol=symbol, orderId=order_id) - - def _connect_exchange(self) -> object: - """Connects to the exchange_interface and sets a reference the client.""" - return BinanceClient(self.api_key, self.api_key_secret) - - def _set_exchange_info(self) -> dict: - """Fetches Info on all symbols from the exchange_interface.""" - return self.client.get_exchange_info() - - def _set_symbols(self) -> list: - """Get information of coins (available for deposit and withdraw) for user""" - # load into a dataframe - data = pd.DataFrame.from_records(self.client.get_all_tickers()) - # unload into a list - return data.symbol.to_list() - - def _set_balances(self) -> list: - account_info = self.client.get_account() - balances = [] - for asset in account_info['balances']: - asset_symbol = asset['asset'] - asset_balance = float(asset['free']) - asset_pnl = self.calculate_asset_pnl(asset_symbol) - if asset_balance > 0 or asset_pnl != 0: - asset_data = { - 'asset': asset_symbol, - 'balance': asset_balance, - 'pnl': asset_pnl - } - balances.append(asset_data) - return balances - - def calculate_asset_pnl(self, asset_symbol: str) -> float: - # Implement your logic to calculate the P&L for the given asset - # This may involve retrieving historical data, calculating positions, etc. - # Return the calculated P&L value as a float - return 0.0 # Placeholder value, replace with actual calculation - - def _set_precision_rule(self, symbol) -> None: - # Isolate a list of symbol rules from the exchange_info - all_rules = self.exchange_info['symbols'] - # Get the rules for a specific symbol. - symbol_rules = next(filter(lambda rules: (rules['symbol'] == symbol), all_rules), None) - # Add the precision rule to a local index. - self.symbols_n_precision[symbol] = symbol_rules['baseAssetPrecision'] - return - - def _place_order(self, symbol, side, type, timeInForce, quantity, price): - - def format_arg(value, arg_name): - """ - Function to convert an arguments value to a desired precision and type float. - - :param value: The Value of the argument. - :param arg_name: The name of the argument. - :return: float : The formatted output. - """ - # The required level of precision for this trading pair. - precision = self.get_symbol_precision_rule(symbol) - # If quantity was too low, set to the smallest allowable amount. - minimum = 0 - if arg_name == 'quantity': - # The minimum quantity aloud to be traded. - minimum = self.get_min_qty(symbol) - elif arg_name == 'price': - # The minimum price aloud to be traded. - minimum = self.get_min_notional_qty(symbol) - if value < minimum: - value = minimum - return float(f"{value:.{precision}f}") - - # Set price and quantity to desired precision and type. - quantity = format_arg(quantity, 'quantity') - price = format_arg(price, 'price') - - data = self.client.create_test_order(symbol=symbol, side=side, type=type, - timeInForce=timeInForce, quantity=quantity, - price=price) - result = 'Success' - return result, data - - def get_active_trades(self): - formatted_trades = [] - # Get open trades - positions = self.client.get_account().get('positions') - if positions is None: - return formatted_trades - for position in positions: - if float(position['qty']) != 0.0: - active_trade = { - 'symbol': position['symbol'], - 'side': 'buy' if float(position['qty']) > 0 else 'sell', - 'quantity': abs(float(position['qty'])), - 'price': float(position['entryPrice']) - } - formatted_trades.append(active_trade) - return formatted_trades - - def get_open_orders(self): - formatted_orders = [] - # Get open orders - open_orders = self.client.get_open_orders() - for order in open_orders: - open_order = { - 'symbol': order['symbol'], - 'side': order['side'], - 'quantity': order['origQty'], - 'price': order['price'] - } - formatted_orders.append(open_order) - return formatted_orders - diff --git a/src/BrighterTrades.py b/src/BrighterTrades.py index 58f9c8b..c1d524b 100644 --- a/src/BrighterTrades.py +++ b/src/BrighterTrades.py @@ -172,7 +172,7 @@ class BrighterTrades: return self.indicators.get_indicator_data(user_name=user_name, source=source, start_ts=start_ts, num_results=num_results) - def connect_user_to_exchange(self, user_name: str, default_exchange: str, default_keys: dict) -> bool: + def connect_user_to_exchange(self, user_name: str, default_exchange: str, default_keys: dict = None) -> bool: """ Connects an exchange if it is not already connected. @@ -189,6 +189,7 @@ class BrighterTrades: exchange_name=exchange, api_keys=keys) if not success: + # If no active exchange was successfully connected, connect to the default exchange success = self.connect_or_config_exchange(user_name=user_name, exchange_name=default_exchange, api_keys=default_keys) @@ -237,7 +238,7 @@ class BrighterTrades: r_data = {} r_data['title'] = self.config.app_data.get('application_title', '') r_data['chart_interval'] = chart_view.get('timeframe', '') - r_data['selected_exchange'] = chart_view.get('exchange_name', '') + r_data['selected_exchange'] = chart_view.get('exchange', '') r_data['intervals'] = exchange.intervals if exchange else [] r_data['symbols'] = exchange.get_symbols() if exchange else {} r_data['available_exchanges'] = self.exchanges.get_available_exchanges() or [] @@ -359,7 +360,7 @@ class BrighterTrades: """ return self.strategies.get_strategies('json') - def connect_or_config_exchange(self, user_name: str, exchange_name: str, api_keys: dict) -> bool: + def connect_or_config_exchange(self, user_name: str, exchange_name: str, api_keys: dict = None) -> bool: """ Connects to an exchange if not already connected, or configures the exchange connection for a single user. @@ -377,7 +378,9 @@ class BrighterTrades: api_keys=api_keys) if success: self.config.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set') - self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) + if api_keys: + self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, + user_name=user_name) return True else: return False # Failed to connect diff --git a/src/DataCache.py b/src/DataCache.py index e15f344..2773606 100644 --- a/src/DataCache.py +++ b/src/DataCache.py @@ -38,6 +38,8 @@ class DataCache: records = pd.concat([more_records, self.get_cache(key)], axis=0, ignore_index=True) # Drop any duplicates from overlap. records = records.drop_duplicates(subset="open_time", keep='first') + # Sort the records by open_time. + records = records.sort_values(by='open_time').reset_index(drop=True) # Replace the incomplete dataframe with the modified one. self.set_cache(data=records, key=key) return diff --git a/src/Exchange.py b/src/Exchange.py index 6e3d105..c7bad2d 100644 --- a/src/Exchange.py +++ b/src/Exchange.py @@ -1,40 +1,175 @@ -from abc import ABC, abstractmethod -from typing import Tuple +import ccxt +import pandas as pd +from datetime import datetime, timedelta +from typing import Tuple, Dict, List, Union +import time -class Exchange(ABC): - def __init__(self, name, api_keys): +class Exchange: + # Class attribute for caching market data + _market_cache = {} + + def __init__(self, name: str, api_keys: Dict[str, str], exchange_id: str): self.name = name - # 1 The api key for the exchange. - self.api_key = api_keys['key'] + # The API key for the exchange. + self.api_key = api_keys['key'] if api_keys else None - # 2 The api secret key for the exchange. - self.api_key_secret = api_keys['secret'] + # The API secret key for the exchange. + self.api_key_secret = api_keys['secret'] if api_keys else None - # 3 The connection to the exchange_interface. - self.client = self._connect_exchange() + # The exchange id for the exchange. + self.exchange_id = exchange_id - # 4 Info on all symbols and exchange_interface rules. + # The connection to the exchange_interface. + self.client: ccxt.Exchange = self._connect_exchange() + + # Info on all symbols and exchange_interface rules. self.exchange_info = self._set_exchange_info() - # 5 List of time intervals available for trading. + # List of time intervals available for trading. self.intervals = self._set_avail_intervals() - # 6 All symbols available for trading. + # All symbols available for trading. self.symbols = self._set_symbols() - # 7 Any non-zero balance-info for all assets. + # Any non-zero balance-info for all assets. self.balances = self._set_balances() - # 8 Dictionary of places after the decimal requires by the exchange_interface indexed by symbol. + # Dictionary of places after the decimal requires by the exchange_interface indexed by symbol. self.symbols_n_precision = {} + def _connect_exchange(self) -> ccxt.Exchange: + """ + Connects to the exchange_interface and sets a reference the client. + Handles cases where api_keys might be None. + """ + exchange_class = getattr(ccxt, self.exchange_id) + if not exchange_class: + raise ValueError(f"Exchange {self.exchange_id} is not supported by CCXT.") + + if self.api_key and self.api_key_secret: + return exchange_class({ + 'apiKey': self.api_key, + 'secret': self.api_key_secret, + 'enableRateLimit': True, + 'verbose': False # Disable verbose debugging output + }) + else: + return exchange_class({ + 'enableRateLimit': True, + 'verbose': False # Disable verbose debugging output + }) + + @staticmethod + def datetime_to_unix_millis(dt: datetime) -> int: + """ + Convert a datetime object to Unix timestamp in milliseconds. + + :param dt: datetime - The datetime object to convert. + :return: int - The Unix timestamp in milliseconds. + """ + return int(dt.timestamp() * 1000) + + def _fetch_historical_klines(self, symbol: str, interval: str, start_dt: datetime, + end_dt: datetime = None) -> pd.DataFrame: + if end_dt is None: + end_dt = datetime.utcnow() + + max_interval = timedelta(days=200) # Binance's maximum interval + data_frames = [] + current_start = start_dt + + while current_start < end_dt: + current_end = min(current_start + max_interval, end_dt) + start_str = self.datetime_to_unix_millis(current_start) + end_str = self.datetime_to_unix_millis(current_end) + + try: + candles = self.client.fetch_ohlcv(symbol=symbol, timeframe=interval, since=start_str, + params={'endTime': end_str}) + if not candles: + break + + df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume'] + candles_df = pd.DataFrame(candles, columns=df_columns) + candles_df['open_time'] = candles_df['open_time'] // 1000 + data_frames.append(candles_df) + + # Move the start to the end of the current chunk to get the next chunk + current_start = current_end + + # Sleep for 1 second to avoid hitting rate limits + time.sleep(1) + + except ccxt.BaseError as e: + print(f"Error fetching OHLCV data: {str(e)}") + break + + if data_frames: + result_df = pd.concat(data_frames) + return result_df + else: + return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume']) + + def _fetch_price(self, symbol: str) -> float: + ticker = self.client.fetch_ticker(symbol) + return float(ticker['last']) + + def _fetch_min_qty(self, symbol: str) -> float: + market_data = self.exchange_info[symbol] + return float(market_data['limits']['amount']['min']) + + def _fetch_min_notional_qty(self, symbol: str) -> float: + market_data = self.exchange_info[symbol] + return float(market_data['limits']['cost']['min']) + + def _fetch_order(self, symbol: str, order_id: str) -> object: + return self.client.fetch_order(order_id, symbol) + + def _set_symbols(self) -> List[str]: + markets = self.client.fetch_markets() + symbols = [market['symbol'] for market in markets if market['active']] + return symbols + + def _set_balances(self) -> List[Dict[str, Union[str, float]]]: + if self.api_key and self.api_key_secret: + try: + account_info = self.client.fetch_balance() + balances = [] + for asset, balance in account_info['total'].items(): + asset_balance = float(balance) + if asset_balance > 0: + balances.append({'asset': asset, 'balance': asset_balance, 'pnl': 0}) + return balances + except NotImplementedError: + # Handle the case where fetch_balance is not supported + return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}] + else: + return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}] + + def _set_exchange_info(self) -> dict: + """ + Fetches market data for all symbols from the exchange. + This includes details on trading pairs, limits, fees, and other relevant information. + Caches the market info to speed up subsequent calls. + """ + if self.exchange_id in Exchange._market_cache: + return Exchange._market_cache[self.exchange_id] + + try: + markets_info = self.client.load_markets() + Exchange._market_cache[self.exchange_id] = markets_info + return markets_info + except ccxt.BaseError as e: + print(f"Error fetching market info: {str(e)}") + return {} + def get_client(self) -> object: """ Return a reference to the exchange_interface client.""" return self.client - def get_avail_intervals(self) -> tuple: + def get_avail_intervals(self) -> Tuple[str, ...]: """Returns a list of time intervals available for trading.""" return self.intervals @@ -42,15 +177,15 @@ class Exchange(ABC): """Returns Info on all symbols.""" return self.exchange_info - def get_symbols(self) -> list: + def get_symbols(self) -> List[str]: """Returns all symbols available for trading.""" return self.symbols - def get_balances(self) -> list: + def get_balances(self) -> List[Dict[str, Union[str, float]]]: """Returns any non-zero balance-info for all assets.""" return self.balances - def get_symbol_precision_rule(self, symbol) -> int: + def get_symbol_precision_rule(self, symbol: str) -> int: """Returns number of places after the decimal required by the exchange_interface for a specific symbol.""" r_value = self.symbols_n_precision.get(symbol) if r_value is None: @@ -58,149 +193,108 @@ class Exchange(ABC): r_value = self.symbols_n_precision.get(symbol) return r_value - def get_historical_klines(self, symbol, interval, start_dt, end_dt) -> object: - """ - Return a dataframe containing rows of candle attributes. Attributes very between different exchanges. - For example: [(Open time, Open, High, Low, Close, - Volume, Close time, Quote symbol volume, - Number of trades, Taker buy base symbol volume, - Taker buy quote symbol volume, Ignore)] - """ - return self._fetch_historical_klines(symbol=symbol, interval=interval, - start_dt=start_dt, end_dt=end_dt) + def get_historical_klines(self, symbol: str, interval: str, start_dt: datetime, + end_dt: datetime = None) -> pd.DataFrame: + return self._fetch_historical_klines(symbol=symbol, interval=interval, start_dt=start_dt, end_dt=end_dt) - def get_price(self, symbol) -> float: - """ - Get the latest price for a single symbol. - - :param symbol: str - The symbol of the symbol. - :return: float - The minimum quantity sold per trade. - """ + def get_price(self, symbol: str) -> float: return self._fetch_price(symbol) - def get_min_qty(self, symbol) -> float: - """ - Get the minimum base quantity the exchange_interface allows per order. - - :param symbol: str - The symbol of the trading pair. - :return: float - The minimum quantity sold per trade. - """ + def get_min_qty(self, symbol: str) -> float: return self._fetch_min_qty(symbol) - def get_min_notional_qty(self, symbol) -> float: - """ - Get the minimum amount of the quote currency allowed per trade. - - :param symbol: The symbol of the trading pair. - :return: The minimum quantity sold per trade. - """ + def get_min_notional_qty(self, symbol: str) -> float: return self._fetch_min_notional_qty(symbol) - @abstractmethod - def get_active_trades(self): - pass - - @abstractmethod - def get_open_orders(self): - pass - - def get_order(self, symbol, order_id) -> object: - """ - Get an order by id. - - :param symbol: The trading pair - :param order_id: The order id - :return: object - """ + def get_order(self, symbol: str, order_id: str) -> object: return self._fetch_order(symbol, order_id) - def place_order(self, symbol, side, type, timeInForce, quantity, price): - result, msg = self._place_order(symbol=symbol, side=side, type=type, - timeInForce=timeInForce, quantity=quantity, price=price) + def place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> \ + Tuple[str, object]: + result, msg = self._place_order(symbol=symbol, side=side, type=type, timeInForce=timeInForce, quantity=quantity, + price=price) return result, msg - def _set_avail_intervals(self) -> tuple: - """Sets a list of time intervals available for trading on the exchange_interface.""" - return '1m', '3m', '5m', '15m', '30m', '1h', '2h', '4h', '6h', '8h', '12h', '1d', '3d', '1w', '1M' - - # noinspection PyMethodMayBeStatic - @abstractmethod - def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt) -> list: + def _set_avail_intervals(self) -> Tuple[str, ...]: """ - Return a list [(Open time, Open, High, Low, Close, - Volume, Close time, Quote symbol volume, - Number of trades, Taker buy base symbol volume, - Taker buy quote symbol volume, Ignore)] + Sets a list of time intervals available for trading on the exchange_interface. """ - pass + return tuple(self.client.timeframes.keys()) - @abstractmethod - def _fetch_price(self, symbol) -> float: + def _set_precision_rule(self, symbol: str) -> None: + market_data = self.exchange_info[symbol] + precision = market_data['precision']['amount'] + self.symbols_n_precision[symbol] = precision + return + + def _place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> \ + Tuple[str, object]: + def format_arg(value: float) -> float: + precision = self.symbols_n_precision.get(symbol, 8) + return float(f"{value:.{precision}f}") + + quantity = format_arg(quantity) + if price is not None: + price = format_arg(price) + + order_params = { + 'symbol': symbol, + 'type': type, + 'side': side, + 'amount': quantity, + 'params': { + 'timeInForce': timeInForce + } + } + + if price is not None: + order_params['price'] = price + + order = self.client.create_order(**order_params) + return 'Success', order + + def get_active_trades(self) -> List[Dict[str, Union[str, float]]]: """ - Get the latest price for a single symbol. - - :param symbol: str - The symbol of the symbol. - :return: float - The minimum quantity sold per trade. + Get the active trades (open positions). """ - pass + if self.api_key and self.api_key_secret: + try: + positions = self.client.fetch_positions() + formatted_trades = [] + for position in positions: + active_trade = { + 'symbol': position['symbol'], + 'side': 'buy' if float(position['quantity']) > 0 else 'sell', + 'quantity': abs(float(position['quantity'])), + 'price': float(position['entry_price']) + } + formatted_trades.append(active_trade) + return formatted_trades + except NotImplementedError: + # Handle the case where fetch_positions is not supported + return [] + else: + return [] - @abstractmethod - def _fetch_min_qty(self, symbol) -> float: + def get_open_orders(self) -> List[Dict[str, Union[str, float]]]: """ - Get the minimum base quantity the exchange_interface allows per order. - - :param symbol: str - The symbol of the trading pair. - :return: float - The minimum quantity sold per trade. + Get the open orders. """ - pass - - @abstractmethod - def _fetch_min_notional_qty(self, symbol) -> float: - """ - Get the minimum amount of the quote currency allowed per trade. - - :param symbol: The symbol of the trading pair. - :return: The minimum quantity sold per trade. - """ - pass - - @abstractmethod - def _fetch_order(self, symbol, order_id) -> object: - """ - Get an order by id. - :param symbol: The trading pair - :param order_id: The order id - :return: object - """ - pass - - @abstractmethod - def _connect_exchange(self) -> object: - """Connects to the exchange_interface and sets a reference the client.""" - pass - - @abstractmethod - def _set_exchange_info(self) -> dict: - """Fetches Info on all symbols from the exchange_interface.""" - pass - - @abstractmethod - def _set_symbols(self) -> list: - """Fetches all symbols available for trading on the exchange_interface.""" - pass - - @abstractmethod - def _set_balances(self) -> list: - """Fetches any non-zero balance-info for all assets from exchange_interface.""" - pass - - @abstractmethod - def _set_precision_rule(self, symbol) -> dict: - """Appends a record of places after the decimal required by the exchange_interface indexed by symbol.""" - pass - - @abstractmethod - def _place_order(self, symbol, side, type, timeInForce, quantity, price): - result = 'Success' - msg = '' - return result, msg + if self.api_key and self.api_key_secret: + try: + open_orders = self.client.fetch_open_orders() + formatted_orders = [] + for order in open_orders: + open_order = { + 'symbol': order['symbol'], + 'side': order['side'], + 'quantity': order['amount'], + 'price': order['price'] + } + formatted_orders.append(open_order) + return formatted_orders + except NotImplementedError: + # Handle the case where fetch_balance is not supported + return [] + else: + return [] diff --git a/src/Users.py b/src/Users.py index 9b8d272..51c54af 100644 --- a/src/Users.py +++ b/src/Users.py @@ -541,8 +541,10 @@ class Users: # Get the user records from the database. user = self.get_user_from_db(user_name) # Get the exchanges list based on the field. - return json.loads(user.loc[0, category]) - except (KeyError, IndexError) as e: + exchanges = user.loc[0, category] + # Return the list if it exists, otherwise return an empty list. + return json.loads(exchanges) if exchanges else [] + except (KeyError, IndexError, json.JSONDecodeError) as e: # Log the error to the console print(f"Error retrieving exchanges for user '{user_name}' and field '{category}': {str(e)}") return None @@ -558,8 +560,12 @@ class Users: """ # Get the user records from the database. user = self.get_user_from_db(user_name) - # Get the old active_exchanges list. - active_exchanges = json.loads(user.loc[0, 'active_exchanges']) + # Get the old active_exchanges list, or initialize as an empty list if it is None. + active_exchanges = user.loc[0, 'active_exchanges'] + if active_exchanges is None: + active_exchanges = [] + else: + active_exchanges = json.loads(active_exchanges) # Define the actions for each command actions = { diff --git a/src/app.py b/src/app.py index 758de9c..3c80252 100644 --- a/src/app.py +++ b/src/app.py @@ -30,7 +30,7 @@ cors = CORS(app, supports_credentials=True, r"/api/history": {"origins": ["http://127.0.0.1:5000", "http://localhost:5000"]}, r"/api/indicator_init": {"origins": ["http://127.0.0.1:5000", "http://localhost:5000"]} }, - headers=['Content-Type']) + allow_headers=['Content-Type']) # Change from headers to allow_headers @app.after_request @@ -46,6 +46,8 @@ def index(): Fetches data from brighter_trades and inject it into an HTML template. Renders the html template and serves the web application. """ + # Clear the session to simulate a new visitor + session.clear() try: # Log the user in. user_name = brighter_trades.config.users.load_or_create_user(username=session.get('user')) @@ -61,13 +63,15 @@ def index(): print('[SERVING INDEX] (USERNAME):', user_name) - # Ensure that a valid connection with an exchange exist. - keys = {'key': config.ALPACA_API_KEY, 'secret': config.ALPACA_API_SECRET} + default_exchange = 'binance' + default_keys = None # No keys needed for public market data + + # Ensure that a valid connection with an exchange exists result = brighter_trades.connect_user_to_exchange(user_name=user_name, - default_exchange='alpaca', - default_keys=keys) + default_exchange=default_exchange, + default_keys=default_keys) if not result: - raise ValueError("Couldn't connect to the alpaca exchange.") + raise ValueError("Couldn't connect to the default exchange.") # A dict of data required to build the html of the user app. # Dynamic content like options and titles and balances to display. @@ -310,4 +314,4 @@ def indicator_init(): if __name__ == '__main__': - app.run(debug=True, use_reloader=False) + app.run(debug=False, use_reloader=False) diff --git a/src/config.py b/src/config.py index 6ca7103..96a399c 100644 --- a/src/config.py +++ b/src/config.py @@ -1,5 +1,5 @@ BINANCE_API_KEY = 'rkp1Xflb5nnwt6jys0PG27KXcqwn0q9lKCLryKcSp4mKW2UOlkPRuAHPg45rQVgj' BINANCE_API_SECRET = 'DiFhhYhF64nkPe5f3V7TRJX2bSVA7ZQZlozSdX7O7uYmBMdK985eA6Kp2B2zKvbK' -ALPACA_API_KEY = 'PKN0WFYT9VZYUVRBG1HM' -ALPACA_API_SECRET = '0C1I6UcBSR2B0SZrBC3DoKGtcglAny8znorvganx' +ALPACA_API_KEY = 'PKE4RD999SJ8L53OUI8O' +ALPACA_API_SECRET = 'buwlMoSSfZWGih8Er30quQt4d7brsBWdJXD1KB7C' DB_FILE = "C:/Users/Rob/PycharmProjects/BrighterTrading/data/BrighterTrading.db" diff --git a/src/database.py b/src/database.py index 0f3892b..43d38ac 100644 --- a/src/database.py +++ b/src/database.py @@ -242,7 +242,7 @@ class Database: :param exchange_name: str - The name of the exchange_name. :return: int - The primary id of the exchange_name. """ - return self.get_from_static_table(item='id', table='exchange', indexes=HDict({'name': exchange_name})) + return self.get_from_static_table(item='id', table='exchange', create_id=True,indexes=HDict({'name': exchange_name})) def _fetch_market_id(self, symbol: str, exchange_name: str) -> int: """ @@ -276,24 +276,23 @@ class Database: market_id = self._fetch_market_id(symbol, exchange_name) # Insert the market id into the dataframe. candlesticks.insert(0, 'market_id', market_id) - # Create a table schema. - # Get a list of all the columns in the dataframe. - columns = list(candlesticks.columns.values) - # Isolate any extra columns specific to individual exchanges. - # The carriage return and tabs are unnecessary, they just tidy output for debugging. - columns = ',\n\t\t\t\t\t'.join(columns[7:], ) - # Define the columns common with all exchanges and append any extras columns. + # Create a table schema. todo delete these line if not needed anymore + # # Get a list of all the columns in the dataframe. + # columns = list(candlesticks.columns.values) + # # Isolate any extra columns specific to individual exchanges. + # # The carriage return and tabs are unnecessary, they just tidy output for debugging. + # columns = ',\n\t\t\t\t\t'.join(columns[7:], ) + # # Define the columns common with all exchanges and append any extras columns. sql_create = f""" CREATE TABLE IF NOT EXISTS '{table_name}' ( id INTEGER PRIMARY KEY, market_id INTEGER, - open_time UNIQUE ON CONFLICT IGNORE, - open NOT NULL, - high NOT NULL, - low NOT NULL, - close NOT NULL, - volume NOT NULL, - {columns}, + open_time INTEGER UNIQUE ON CONFLICT IGNORE, + open REAL NOT NULL, + high REAL NOT NULL, + low REAL NOT NULL, + close REAL NOT NULL, + volume REAL NOT NULL, FOREIGN KEY (market_id) REFERENCES market (id) )""" # Connect to the database. @@ -338,17 +337,17 @@ class Database: print(f'Got {len(records.index)} records from db') else: # If the table doesn't exist, get them from the exchange_name. - print('\nTable didnt exist fetching from exchange_name') + print(f'\nTable didnt exist fetching from {ex_details[2]}') temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl print(f'Requesting from {st} to {et}, Should be {temp} records') records = self._populate_table(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details) - print(f'Got {len(records.index)} records from exchange_name') + print(f'Got {len(records.index)} records from {ex_details[2]}') # Check if the records in the db go far enough back to satisfy the query. first_timestamp = query_satisfied(start_datetime=st, records=records, r_length=rl) if first_timestamp: # The records didn't go far enough back if a timestamp was returned. - print('\nRecords did not go far enough back. Requesting from exchange_name') + print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}') print(f'first ts on record is: {first_timestamp}') end_time = dt.datetime.utcfromtimestamp(first_timestamp) print(f'Requesting from {st} to {end_time}') @@ -359,7 +358,7 @@ class Database: last_timestamp = query_uptodate(records=records, r_length=rl) if last_timestamp: # The query was not up-to-date if a timestamp was returned. - print('\nRecords were not updated. Requesting from exchange_name.') + print(f'\nRecords were not updated. Requesting from {ex_details[2]}.') print(f'the last record on file is: {last_timestamp}') start_time = dt.datetime.utcfromtimestamp(last_timestamp) print(f'Requesting from {start_time} to {et}') diff --git a/src/exchangeinterface.py b/src/exchangeinterface.py index 0a2e060..5101223 100644 --- a/src/exchangeinterface.py +++ b/src/exchangeinterface.py @@ -1,12 +1,14 @@ +import logging import json -from typing import List, Any - +from typing import List, Any, Dict import pandas as pd import requests +import ccxt -from BinanceFutures import BinanceFuturesExchange, BinanceCoinExchange -from BinanceSpot import BinanceSpotExchange -from AlpacaPaperExchange import AlpacaPaperExchange +from Exchange import Exchange + +# Setup logging +# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') # This just makes this method cleaner. @@ -17,48 +19,54 @@ def add_row(df, dic): class ExchangeInterface: """ - Connects and maintains and routs data requests from exchanges. + Connects and maintains and routes data requests from exchanges. """ def __init__(self): - # Create a dataframe to hold all the references and info for each user's configured exchanges. - self.exchange_data = pd.DataFrame(columns=['user', 'name', 'reference', 'balance']) - # List of available exchanges - self.available_exchanges = ['alpaca', 'binance_coin', 'binance_futures', 'binance_spot'] + self.exchange_data = pd.DataFrame(columns=['user', 'name', 'reference', 'balances']) + + # Populate the list of available exchanges from CCXT + self.available_exchanges = self.get_ccxt_exchanges() + + def get_ccxt_exchanges(self) -> List[str]: + """Retrieve the list of available exchanges from CCXT.""" + return ccxt.exchanges def connect_exchange(self, exchange_name: str, user_name: str, api_keys: dict = None) -> bool: """ - Initialize and store a reference to the available exchanges. + Initialize and store a reference to the available exchanges. - :param user_name: The name of the user connecting the exchange. - :param api_keys: dict - {api: key, api-secret: key} - :param exchange_name: str - The name of the exchange. - :return: True if success | None on fail. + :param user_name: The name of the user connecting the exchange. + :param api_keys: dict - {api: key, api-secret: key} + :param exchange_name: str - The name of the exchange. + :return: True if success | None on fail. """ - if exchange_name == 'alpaca': - success = self.add_exchange(user_name, AlpacaPaperExchange, (exchange_name, api_keys)) - elif exchange_name == 'binance_coin': - success = self.add_exchange(user_name, BinanceCoinExchange, (exchange_name, api_keys)) - elif exchange_name == 'binance_futures': - success = self.add_exchange(user_name, BinanceFuturesExchange, (exchange_name, api_keys)) - elif exchange_name == 'binance_spot': - success = self.add_exchange(user_name, BinanceSpotExchange, (exchange_name, api_keys)) - else: - success = False - return success + # logging.debug( + # f"Attempting to connect to exchange '{exchange_name}' for user '{user_name}' with API keys: {api_keys}") - def add_exchange(self, user_name, _class, arg): try: - ref = _class(*arg) - row = {'user': user_name, 'name': ref.name, 'reference': ref, 'balances': ref.balances} + # Initialize the exchange object + exchange = Exchange(name=exchange_name, api_keys=api_keys, exchange_id=exchange_name.lower()) + + # Update exchange data with the new connection + self.add_exchange(user_name, exchange) + # logging.debug(f"Successfully connected to exchange '{exchange_name}' for user '{user_name}'") + + return True + except Exception as e: + logging.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}") + return False # Failed to connect + + def add_exchange(self, user_name: str, exchange: Exchange): + try: + row = {'user': user_name, 'name': exchange.name, 'reference': exchange, 'balances': exchange.balances} self.exchange_data = add_row(self.exchange_data, row) except Exception as e: - if e.status_code == 400 and e.error_code == -1021: - print("Timestamp ahead of server's time error: Sync your system clock to fix this.") - print("Couldn't create an instance of the exchange!:\n", e) + if hasattr(e, 'status_code') and e.status_code == 400 and e.error_code == -1021: + logging.error("Timestamp ahead of server's time error: Sync your system clock to fix this.") + logging.error("Couldn't create an instance of the exchange!:\n", e) raise - return True def get_exchange(self, ename: str, uname: str) -> Any: """Return a reference to the exchange_name.""" @@ -76,15 +84,15 @@ class ExchangeInterface: connected_exchanges = self.exchange_data.loc[self.exchange_data['user'] == user_name, 'name'].tolist() return connected_exchanges - def get_available_exchanges(self): + def get_available_exchanges(self) -> List[str]: """ Return a list of the exchanges available to connect to""" return self.available_exchanges - def get_exchange_balances(self, name): + def get_exchange_balances(self, name: str) -> pd.Series: """ Return the balances of a single exchange_name""" return self.exchange_data.query("name == @name")['balances'] - def get_all_balances(self, user_name: str) -> dict: + def get_all_balances(self, user_name: str) -> Dict[str, List[Dict[str, Any]]]: """ Return the balances of all connected exchanges indexed by name. @@ -105,7 +113,7 @@ class ExchangeInterface: return balances_dict - def get_all_activated(self, user_name, fetch_type='trades'): + def get_all_activated(self, user_name: str, fetch_type: str = 'trades') -> Dict[str, List[Dict[str, Any]]]: """Get active trades or open orders as a dictionary indexed by name""" filtered_data = self.exchange_data.loc[self.exchange_data['user'] == user_name, ['name', 'reference']] if filtered_data.empty: @@ -122,16 +130,16 @@ class ExchangeInterface: elif fetch_type == 'orders': data = reference.get_open_orders() else: - print(f"Invalid fetch type: {fetch_type}") + logging.error(f"Invalid fetch type: {fetch_type}") return {} data_dict[name] = data except Exception as e: - print(f"Error retrieving data for {name}: {str(e)}") + logging.error(f"Error retrieving data for {name}: {str(e)}") return data_dict - def get_order(self, symbol, order_id, target, user_name): + def get_order(self, symbol: str, order_id: str, target: str, user_name: str) -> Any: """ Return order - from a target exchange_interface. :param user_name: The name of the user making the request. @@ -145,7 +153,7 @@ class ExchangeInterface: # Return the order. return exchange.get_order(symbol=symbol, order_id=order_id) - def get_trade_status(self, trade, user_name): + def get_trade_status(self, trade, user_name: str) -> str: """ Return the status of a trade Todo: trade order.status might be outdated this request the status from the exchanges order record. @@ -158,7 +166,7 @@ class ExchangeInterface: # Return status. return order['status'] - def get_trade_executed_qty(self, trade, user_name): + def get_trade_executed_qty(self, trade, user_name: str) -> float: """ Return the executed quantity of a trade. @@ -173,7 +181,7 @@ class ExchangeInterface: # Return quantity. return order['executedQty'] - def get_trade_executed_price(self, trade, user_name): + def get_trade_executed_price(self, trade, user_name: str) -> float: """ Return the average price of executed quantity of a trade @@ -188,7 +196,7 @@ class ExchangeInterface: return order['price'] @staticmethod - def get_price(symbol, price_source=None): + def get_price(symbol: str, price_source: str = None) -> float: """ :param price_source: alternative sources for price. :param symbol: The symbol of the trading pair. diff --git a/src/maintenence/debuging_testing.py b/src/maintenence/debuging_testing.py new file mode 100644 index 0000000..d8f9e2d --- /dev/null +++ b/src/maintenence/debuging_testing.py @@ -0,0 +1,20 @@ +import ccxt + + +def main(): + # Create an instance of the Binance exchange + binance = ccxt.binance({ + 'enableRateLimit': True, + 'verbose': False, # Ensure verbose mode is disabled + }) + + try: + # Load markets to test the connection + markets = binance.load_markets() + print("Markets loaded successfully") + except ccxt.BaseError as e: + print(f"Error loading markets: {str(e)}") + + +if __name__ == "__main__": + main() diff --git a/src/maintenence/generate_list_exchanges_with_public_data.py b/src/maintenence/generate_list_exchanges_with_public_data.py new file mode 100644 index 0000000..d4ff19a --- /dev/null +++ b/src/maintenence/generate_list_exchanges_with_public_data.py @@ -0,0 +1,60 @@ +import ccxt + + +def test_public_endpoints(exchange_id, symbol): + try: + exchange_class = getattr(ccxt, exchange_id) + exchange = exchange_class() + + # Check and test fetch_ticker + if exchange.has.get('fetchTicker', False): + exchange.fetch_ticker(symbol) + + # Check and test fetch_tickers + if exchange.has.get('fetchTickers', False): + exchange.fetch_tickers([symbol]) + + # Check and test fetch_order_book + if exchange.has.get('fetchOrderBook', False): + exchange.fetch_order_book(symbol) + + # Check and test fetch_trades + if exchange.has.get('fetchTrades', False): + exchange.fetch_trades(symbol) + + # Check and test fetch_ohlcv + if exchange.has.get('fetchOHLCV', False): + timeframe = '1m' + since = exchange.parse8601('2022-01-01T00:00:00Z') + exchange.fetch_ohlcv(symbol, timeframe, since) + + return True + except Exception as e: + print(f"Exchange {exchange_id} failed: {e}") + return False + + +# List of all supported exchanges +exchange_ids = ccxt.exchanges + +# Symbol to test (using a common symbol that is likely available on many exchanges) +test_symbol = 'BTC/USDT' + +# Dictionary to store the exchanges that support public market data endpoints +working_exchanges = [] + +# Iterate through each exchange and test public endpoints +for exchange_id in exchange_ids: + print(f"Testing exchange: {exchange_id}") + if test_public_endpoints(exchange_id, test_symbol): + working_exchanges.append(exchange_id) + +# Print the working exchanges +print("Exchanges with working public market data endpoints:") +for exchange in working_exchanges: + print(exchange) + +# Optionally save the results to a file +with open('working_public_exchanges.txt', 'w') as f: + for exchange in working_exchanges: + f.write(f"{exchange}\n") diff --git a/src/maintenence/working_public_exchanges.txt b/src/maintenence/working_public_exchanges.txt new file mode 100644 index 0000000..7d493ba --- /dev/null +++ b/src/maintenence/working_public_exchanges.txt @@ -0,0 +1,63 @@ +ascendex +bequant +bigone +binance +binanceus +binanceusdm +bitbay +bitbns +bitcoincom +bitfinex +bitfinex2 +bitget +bitmart +bitmex +bitopro +bitrue +bitstamp +bitteam +blockchaincom +btcmarkets +bybit +coinbase +coinbaseadvanced +coinbaseexchange +coinex +coinlist +coinmate +coinsph +cryptocom +currencycom +delta +digifinex +exmo +fmfwio +gemini +hitbtc +hitbtc3 +htx +huobi +indodax +kraken +kucoin +latoken +lbank +lykke +mexc +ndax +novadax +oceanex +okx +phemex +poloniex +probit +timex +tokocrypto +tradeogre +upbit +wazirx +whitebit +woo +xt +yobit +zonda diff --git a/src/shared_utilities.py b/src/shared_utilities.py index 53bcfe5..fabe323 100644 --- a/src/shared_utilities.py +++ b/src/shared_utilities.py @@ -1,35 +1,41 @@ from functools import lru_cache import datetime as dt -import pandas as pd +from typing import Union +import pandas as pd epoch = dt.datetime.utcfromtimestamp(0) -def query_uptodate(records: pd.DataFrame, r_length: float): +def query_uptodate(records: pd.DataFrame, r_length_min: float) -> Union[float, None]: """ - Check if records that span a period of time are up-to-date. + Check if records that span a period of time are up-to-date. - :param records: - The dataframe holding results from a query. - :param r_length: - The timespan in minutes of each record in the data. - :return: timestamp - The closest timestamp to start_datetime on record. + :param records: The dataframe holding results from a query. + :param r_length_min: The timespan in minutes of each record in the data. + :return: timestamp - None if records are up-to-date otherwise the newest timestamp on record. """ print('\nChecking if the records are up-to-date...') - # Get the newest timestamp from the records passed in stored in ms. + # Get the newest timestamp from the records passed in stored in ms last_timestamp = float(records.open_time.max()) - print(f'The last ts on record is {last_timestamp}') - # Get a timestamp of the UTC time in millisecond to match the records in the DB. + print(f'The last timestamp on record is {last_timestamp}') + + # Get a timestamp of the UTC time in milliseconds to match the records in the DB now_timestamp = unix_time_millis(dt.datetime.utcnow()) print(f'The timestamp now is {now_timestamp}') - # Get the seconds since the records have been updated. + + # Get the seconds since the records have been updated seconds_since_update = ms_to_seconds(now_timestamp - last_timestamp) + # Convert to minutes minutes_since_update = seconds_since_update / 60 print(f'The minutes since last update is {minutes_since_update}') - print(f'And the length of each record is {r_length}') - # Return the timestamp if the time since last update is more than the timespan each record covers. - if minutes_since_update > (r_length - 0.3): - # Return the last timestamp in seconds. + print(f'And the length of each record is {r_length_min}') + + # Return the timestamp if the time since last update is more than the timespan each record covers + tolerance_minutes = 10 / 60 # 10 seconds tolerance in minutes + if minutes_since_update > (r_length_min - tolerance_minutes): + # Return the last timestamp in seconds return ms_to_seconds(last_timestamp) return None @@ -38,82 +44,98 @@ def ms_to_seconds(timestamp): return timestamp / 1000 +def unix_time_seconds(d_time): + return (d_time - epoch).total_seconds() + + def unix_time_millis(d_time): return (d_time - epoch).total_seconds() * 1000.0 -def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length: float): +def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length_min: float) -> Union[float, None]: """ - Check if a records that spans a set time goes far enough back to satisfy a query. + Check if records span back far enough to satisfy a query. - :param start_datetime: - The datetime to be satisfied. - :param records: - The dataframe holding results from a query. - :param r_length: - The timespan in minutes of each record in the data. - :return: timestamp - The closest timestamp to start_datetime on record. + This function determines whether the records provided cover the required start_datetime. It calculates + the total duration covered by the records and checks if this duration, starting from the earliest record, + reaches back to include the start_datetime. + + :param start_datetime: The datetime the query starts at. + :param records: The dataframe holding results from a query. + :param r_length_min: The timespan in minutes of each record in the data. + :return: None if the query is satisfied (records span back far enough), + otherwise returns the earliest timestamp in records in seconds. """ - # Convert the datetime to a timestamp in milliseconds to match format in the DB. + print('\nChecking if the query is satisfied...') + + # Convert start_datetime to Unix timestamp in milliseconds start_timestamp = unix_time_millis(start_datetime) - print('Checking if we went far enough back.') - print('Requested: start_timestamp:', start_timestamp) - # Get the oldest timestamp from the records passed in. Convert from str to float. + print(f'Start timestamp: {start_timestamp}') + + # Get the oldest timestamp from the records passed in first_timestamp = float(records.open_time.min()) - print('Received: first_timestamp:', first_timestamp) - if pd.isna(first_timestamp): - # If there were no records returned. Signal a need for update by returning the current timestamp. - return dt.datetime.utcnow().timestamp() - # Get the minutes between the first timestamp on record and the one requested. - minutes_between = ms_to_seconds(first_timestamp - start_timestamp) / 60 - print('minutes_between:', minutes_between) - # Return the timestamp if the difference is greater than the timespan of a single record. - if minutes_between > r_length: - # Return timestamp in seconds. - return ms_to_seconds(first_timestamp) - return None + print(f'First timestamp in records: {first_timestamp}') + + # Calculate the total duration of the records in milliseconds + total_duration = len(records) * (r_length_min * 60 * 1000) + print(f'Total duration of records: {total_duration}') + + # Check if the first timestamp plus the total duration is greater than or equal to the start timestamp + if start_timestamp <= first_timestamp + total_duration: + return None + + return first_timestamp / 1000 # Return in seconds @lru_cache(maxsize=500) def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime.timestamp: """ - Returns an approximate start_datetime of a candle n minutes ago. + Returns the approximate datetime for the start of a candle that was 'n' candles ago. - :param int - n: The number of records ago to calculate. - :param float - candle_length: The time in minutes that each candle represents. - :return datetime.start_datetime - The approximate start_datetime slightly less than the record is expected to have. + :param n: int - The number of candles ago to calculate. + :param candle_length: float - The length of each candle in minutes. + :return: datetime - The approximate datetime for the start of the 'n'-th candle ago. """ - - # Time will have passed since the last candle has closed. The start_datetime for nth candle back will be - # less than now-(n * timeframe) but greater than now-(n+1 * timeframe). Add 1 so we don't miss a candle. + # Increment 'n' by 1 to ensure we account for the time that has passed since the last candle closed. n += 1 - # Calculate the time. + + # Calculate the total minutes ago the 'n'-th candle started. minutes_ago = n * candle_length - # Get the date and time of n candle_length ago. - date_of = dt.datetime.utcnow() - dt.timedelta(minutes=minutes_ago) - # Return the result as a start_datetime. + + # Get the current UTC datetime. + now = dt.datetime.utcnow() + + # Calculate the datetime for 'n' candles ago. + date_of = now - dt.timedelta(minutes=minutes_ago) + + # Return the calculated datetime. return date_of @lru_cache(maxsize=20) def timeframe_to_minutes(timeframe): """ - Converts a string representing a timeframe into candle_length. + Converts a string representing a timeframe into an integer representing the approximate minutes. :param timeframe: str - Timeframe format is [multiplier:focus]. eg '15m', '4h', '1d' - :return: int - Minutes the timeframe represents ex. '2h'-> 120(candle_length). + :return: int - Minutes the timeframe represents ex. '2h'-> 120(minutes). """ # Extract the numerical part of the timeframe param. digits = int("".join([i if i.isdigit() else "" for i in timeframe])) # Extract the alpha part of the timeframe param. letter = "".join([i if i.isalpha() else "" for i in timeframe]) + if letter == 'm': pass elif letter == 'h': digits *= 60 elif letter == 'd': - digits *= (60 * 24) + digits *= 60 * 24 elif letter == 'w': - digits *= (60 * 24 * 7) + digits *= 60 * 24 * 7 elif letter == 'M': - digits *= (60 * 24 * 7 * 31) + digits *= 60 * 24 * 31 # Maximum number of days in a month elif letter == 'Y': - digits *= (60 * 24 * 7 * 31 * 365) + digits *= 60 * 24 * 365 # Exact number of days in a year + return digits diff --git a/tests/test_AlpacaPaperExchange.py b/tests/test_AlpacaPaperExchange.py deleted file mode 100644 index 807cc13..0000000 --- a/tests/test_AlpacaPaperExchange.py +++ /dev/null @@ -1,23 +0,0 @@ -from datetime import datetime, timedelta - -from AlpacaPaperExchange import AlpacaPaperExchange - - -def test_alpaca_paper_exchange(): - ape = AlpacaPaperExchange('ape') - # print(f'\nTesting name assignment: {ape.name}') - # print(f'\nTesting get_client(): {ape.get_client()}') - # print(f'\nTesting get_exchange_info(): {ape.get_exchange_info()}') - # print(f'\nTesting get_symbols(): {ape.get_symbols()}') - # print(f'\nTesting get_balances(): {ape.get_balances()}') - # print(f'\nTesting get_avail_intervals(): {ape.get_avail_intervals()}') - last_hour_date_time = datetime.now() - timedelta(hours=24) - k = ape.get_historical_klines(symbol="BTC/USDT", interval="5m",start_str=last_hour_date_time, klines_type=0) - {print(k) for k in k} - # print(f'\nTesting get_symbol_precision_rule(): {ape.get_symbol_precision_rule("BTC/USDT")}') - # print(f'\nTesting get_min_qty(): {ape.get_min_qty("ETH/USDT")}') - # print(f'\nTesting get_min_notional_qty(): {ape.get_min_notional_qty("BTC/USDT")}') - # print(f'\nTesting get_price(): {ape.get_price("ETH/USDT")}') - # ape.get_order() - - assert True diff --git a/tests/test_BinanceFutures.py b/tests/test_BinanceFutures.py deleted file mode 100644 index e74ae63..0000000 --- a/tests/test_BinanceFutures.py +++ /dev/null @@ -1,21 +0,0 @@ -from binance.enums import HistoricalKlinesType - -from BinanceFutures import BinanceFuturesExchange - - -def test_binancefuturesexchange(): - bfe = BinanceFuturesExchange('bfe') - print(f'\nTesting name assignment: {bfe.name}') - print(f'\nTesting get_client(): {bfe.get_client()}') - # print(f'\nTesting get_exchange_info(): {bfe.get_exchange_info()}') - print(f'\nTesting get_symbols(): {bfe.get_symbols()}') - print(f'\nTesting get_balances(): {bfe.get_balances()}') - print(f'\nTesting get_avail_intervals(): {bfe.get_avail_intervals()}') - print(f'\nTesting get_historical_klines(): {bfe.get_historical_klines(symbol="BTCUSDT",interval="15m",start_str="1 hour ago UTC", klines_type=HistoricalKlinesType.SPOT)}') - print(f'\nTesting get_symbol_precision_rule(): {bfe.get_symbol_precision_rule("ETHUSDT")}') - print(f'\nTesting get_min_qty(): {bfe.get_min_qty("ETHUSDT")}') - print(f'\nTesting get_min_notional_qty(): {bfe.get_min_notional_qty("BTCUSDT")}') - print(f'\nTesting get_price(): {bfe.get_price("BTCUSDT")}') - # bse.get_order() - assert True - diff --git a/tests/test_BinanceSpot.py b/tests/test_BinanceSpot.py deleted file mode 100644 index 32e2a60..0000000 --- a/tests/test_BinanceSpot.py +++ /dev/null @@ -1,24 +0,0 @@ -from binance.enums import HistoricalKlinesType - -from BinanceSpot import BinanceSpotExchange - - -def test_binance_spot_exchange(): - bse = BinanceSpotExchange('bse') - print(f'\nTesting name assignment: {bse.name}') - print(f'Testing get_client(): {bse.get_client()}') - # print(f'Testing get_symbol_info(): {bse.get_exchange_info()}') - print(f'Testing get_symbols(): {bse.get_symbols()}') - print(f'Testing get_balances(): {bse.get_balances()}') - print(f'Testing get_avail_intervals(): {bse.get_avail_intervals()}') - print(f'Testing get_historical_klines(): {bse.get_historical_klines(symbol="ETHUSDT",interval="15m",start_str="1 hour ago UTC",klines_type=HistoricalKlinesType.SPOT)}') - print(f'Testing get_symbol_precision_rule(): {bse.get_symbol_precision_rule("ETHUSDT")}') - print(f'Testing get_min_qty(): {bse.get_min_qty("ETHUSDT")}') - print(f'Testing get_min_notional_qty(): {bse.get_min_notional_qty("ETHUSDT")}') - print(f'Testing get_price(): {bse.get_price("BTCUSDT")}') - # #bse.get_order() - - - - - diff --git a/tests/test_DataCache.py b/tests/test_DataCache.py index 885dceb..ad9ce95 100644 --- a/tests/test_DataCache.py +++ b/tests/test_DataCache.py @@ -1,9 +1,124 @@ from DataCache import DataCache from exchangeinterface import ExchangeInterface +import unittest +import pandas as pd +import datetime as dt -def test_cache_exists(): - exchanges = ExchangeInterface() - # This object maintains all the cached data. Pass it connection to the exchanges. - data = DataCache(exchanges) - assert data.cache_exists(key='BTC/USD_2h_alpaca') is False +class TestDataCache(unittest.TestCase): + def setUp(self): + # Setup the database connection here + self.exchanges = ExchangeInterface() + self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None) + # This object maintains all the cached data. Pass it connection to the exchanges. + self.data = DataCache(self.exchanges) + + asset, timeframe, exchange = 'BTC/USD', '2h', 'binance' + self.key1 = f'{asset}_{timeframe}_{exchange}' + + asset, timeframe, exchange = 'ETH/USD', '2h', 'binance' + self.key2 = f'{asset}_{timeframe}_{exchange}' + + def test_set_cache(self): + # Tests + print('Testing set_cache flag not set:') + self.data.set_cache(data='data', key=self.key1) + attr = self.data.__getattribute__('cached_data') + self.assertEqual(attr[self.key1], 'data') + + self.data.set_cache(data='more_data', key=self.key1) + attr = self.data.__getattribute__('cached_data') + self.assertEqual(attr[self.key1], 'more_data') + + print('Testing set_cache no-overwrite flag set:') + self.data.set_cache(data='even_more_data', key=self.key1, do_not_overwrite=True) + attr = self.data.__getattribute__('cached_data') + self.assertEqual(attr[self.key1], 'more_data') + + def test_cache_exists(self): + # Tests + print('Testing cache_exists() method:') + self.assertFalse(self.data.cache_exists(key=self.key2)) + self.data.set_cache(data='data', key=self.key1) + self.assertTrue(self.data.cache_exists(key=self.key1)) + + def test_update_candle_cache(self): + print('Testing update_candle_cache() method:') + # Initial data + df_initial = pd.DataFrame({ + 'open_time': [1, 2, 3], + 'open': [100, 101, 102], + 'high': [110, 111, 112], + 'low': [90, 91, 92], + 'close': [105, 106, 107], + 'volume': [1000, 1001, 1002] + }) + + # Data to be added + df_new = pd.DataFrame({ + 'open_time': [3, 4, 5], + 'open': [102, 103, 104], + 'high': [112, 113, 114], + 'low': [92, 93, 94], + 'close': [107, 108, 109], + 'volume': [1002, 1003, 1004] + }) + + self.data.set_cache(data=df_initial, key=self.key1) + self.data.update_candle_cache(more_records=df_new, key=self.key1) + + result = self.data.get_cache(key=self.key1) + expected = pd.DataFrame({ + 'open_time': [1, 2, 3, 4, 5], + 'open': [100, 101, 102, 103, 104], + 'high': [110, 111, 112, 113, 114], + 'low': [90, 91, 92, 93, 94], + 'close': [105, 106, 107, 108, 109], + 'volume': [1000, 1001, 1002, 1003, 1004] + }) + + pd.testing.assert_frame_equal(result, expected) + + def test_update_cached_dict(self): + print('Testing update_cached_dict() method:') + self.data.set_cache(data={}, key=self.key1) + self.data.update_cached_dict(cache_key=self.key1, dict_key='sub_key', data='value') + + cache = self.data.get_cache(key=self.key1) + self.assertEqual(cache['sub_key'], 'value') + + def test_get_cache(self): + print('Testing get_cache() method:') + self.data.set_cache(data='data', key=self.key1) + result = self.data.get_cache(key=self.key1) + self.assertEqual(result, 'data') + + def test_get_records_since(self): + print('Testing get_records_since() method:') + df_initial = pd.DataFrame({ + 'open_time': [1, 2, 3], + 'open': [100, 101, 102], + 'high': [110, 111, 112], + 'low': [90, 91, 92], + 'close': [105, 106, 107], + 'volume': [1000, 1001, 1002] + }) + + self.data.set_cache(data=df_initial, key=self.key1) + start_datetime = dt.datetime.utcfromtimestamp(2) + result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60, ex_details=[]).sort_values(by='open_time').reset_index(drop=True) + + expected = pd.DataFrame({ + 'open_time': [2, 3], + 'open': [101, 102], + 'high': [111, 112], + 'low': [91, 92], + 'close': [106, 107], + 'volume': [1001, 1002] + }) + + pd.testing.assert_frame_equal(result, expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_app.py b/tests/test_app.py index 65c3d08..5c86b0c 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,28 +1,73 @@ -def test_index(): - from BrighterTrades import BrighterTrades - obj = BrighterTrades() - assert True +import unittest +from flask import Flask +from src.app import app +import json -def test_ws(): - assert False +class FlaskAppTests(unittest.TestCase): + def setUp(self): + """ + Set up the test client and any other test configuration. + """ + self.app = app.test_client() + self.app.testing = True + + def test_index(self): + """ + Test the index route. + """ + response = self.app.get('/') + self.assertEqual(response.status_code, 200) + self.assertIn(b'Welcome', response.data) # Adjust this based on your actual landing page content + + def test_login(self): + """ + Test the login route with valid and invalid credentials. + """ + # Valid credentials + valid_data = {'user_name': 'test_user', 'password': 'test_password'} + response = self.app.post('/login', data=valid_data) + self.assertEqual(response.status_code, 302) # Redirects on success + + # Invalid credentials + invalid_data = {'user_name': 'wrong_user', 'password': 'wrong_password'} + response = self.app.post('/login', data=invalid_data) + self.assertEqual(response.status_code, 302) # Redirects on failure + self.assertIn(b'Invalid user_name or password', response.data) + + def test_signup(self): + """ + Test the signup route. + """ + data = {'email': 'test@example.com', 'user_name': 'new_user', 'password': 'new_password'} + response = self.app.post('/signup_submit', data=data) + self.assertEqual(response.status_code, 302) # Redirects on success + + def test_signout(self): + """ + Test the signout route. + """ + response = self.app.get('/signout') + self.assertEqual(response.status_code, 302) # Redirects on signout + + def test_history(self): + """ + Test the history route. + """ + data = {"user_name": "test_user"} + response = self.app.post('/api/history', data=json.dumps(data), content_type='application/json') + self.assertEqual(response.status_code, 200) + self.assertIn(b'price_history', response.data) + + def test_indicator_init(self): + """ + Test the indicator initialization route. + """ + data = {"user_name": "test_user"} + response = self.app.post('/api/indicator_init', data=json.dumps(data), content_type='application/json') + self.assertEqual(response.status_code, 200) + self.assertIn(b'indicator_data', response.data) -def test_settings(): - assert False - - -def test_history(): - assert False - - -def test_signup(): - assert False - - -def test_signup_submit(): - assert False - - -def test_indicator_init(): - assert False +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_shared_utilities.py b/tests/test_shared_utilities.py new file mode 100644 index 0000000..07bbbcd --- /dev/null +++ b/tests/test_shared_utilities.py @@ -0,0 +1,162 @@ +import time +import unittest +import datetime as dt +import pandas as pd +from shared_utilities import ( + query_uptodate, ms_to_seconds, unix_time_seconds, unix_time_millis, + query_satisfied, ts_of_n_minutes_ago, timeframe_to_minutes +) + + +class TestSharedUtilities(unittest.TestCase): + + def test_query_uptodate(self): + print('Testing query_uptodate()') + + # (Test case 1) The records should not be up-to-date (very old timestamps) + records = pd.DataFrame({ + 'open_time': [1, 2, 3, 4, 5] + }) + result = query_uptodate(records, 1) + if result is None: + print('Records are up-to-date.') + else: + print('Records are not up-to-date.') + print(f'Result for the first test case: {result}') + self.assertIsNotNone(result) + + # (Test case 2) The records should be up-to-date (recent timestamps) + now = unix_time_millis(dt.datetime.utcnow()) + recent_records = pd.DataFrame({ + 'open_time': [now - 70000, now - 60000, now - 40000] + }) + result = query_uptodate(recent_records, 1) + if result is None: + print('Records are up-to-date.') + else: + print('Records are not up-to-date.') + print(f'Result for the second test case: {result}') + self.assertIsNone(result) + + # (Test case 3) The records just under the tolerance for a record length of 1 hour. + # The records should not be up-to-date (recent timestamps) + one_hour = 60 * 60 * 1000 # one hour in milliseconds + tolerance_milliseconds = 10 * 1000 # tolerance in milliseconds + recent_time = unix_time_millis(dt.datetime.utcnow()) + borderline_records = pd.DataFrame({ + 'open_time': [recent_time - one_hour + (tolerance_milliseconds - 3)] # just within the tolerance + }) + result = query_uptodate(borderline_records, 60) + if result is None: + print('Records are up-to-date.') + else: + print('Records are not up-to-date.') + print(f'Result for the third test case: {result}') + self.assertIsNotNone(result) + + # (Test case 4) The records just over the tolerance for a record length of 1 hour. + # The records should be up-to-date (recent timestamps) + one_hour = 60 * 60 * 1000 # one hour in milliseconds + tolerance_milliseconds = 10 * 1000 # tolerance in milliseconds + recent_time = unix_time_millis(dt.datetime.utcnow()) + borderline_records = pd.DataFrame({ + 'open_time': [recent_time - one_hour + (tolerance_milliseconds + 3)] # just within the tolerance + }) + result = query_uptodate(borderline_records, 60) + if result is None: + print('Records are up-to-date.') + else: + print('Records are not up-to-date.') + print(f'Result for the third test case: {result}') + self.assertIsNone(result) + + def test_ms_to_seconds(self): + print('Testing ms_to_seconds()') + self.assertEqual(ms_to_seconds(1000), 1) + self.assertEqual(ms_to_seconds(0), 0) + + def test_unix_time_seconds(self): + print('Testing unix_time_seconds()') + time = dt.datetime(2020, 1, 1) + self.assertEqual(unix_time_seconds(time), 1577836800) + + def test_unix_time_millis(self): + print('Testing unix_time_millis()') + time = dt.datetime(2020, 1, 1) + self.assertEqual(unix_time_millis(time), 1577836800000.0) + + def test_query_satisfied(self): + print('Testing query_satisfied()') + + # Test case where the records should satisfy the query (records cover the start time) + start_datetime = dt.datetime(2020, 1, 1) + records = pd.DataFrame({ + 'open_time': [1577836800000.0 - 600000, 1577836800000.0 - 300000, 1577836800000.0] + # Covering the start time + }) + result = query_satisfied(start_datetime, records, 1) + if result is None: + print('Query is satisfied: records span back far enough.') + else: + print('Query is not satisfied: records do not span back far enough.') + print(f'Result for the first test case: {result}') + self.assertIsNotNone(result) + + # Test case where the records should not satisfy the query (recent records but not enough) + recent_time = unix_time_millis(dt.datetime.utcnow()) + records = pd.DataFrame({ + 'open_time': [recent_time - 300 * 60 * 1000, recent_time - 240 * 60 * 1000, recent_time - 180 * 60 * 1000] + }) + result = query_satisfied(start_datetime, records, 1) + if result is None: + print('Query is satisfied: records span back far enough.') + else: + print('Query is not satisfied: records do not span back far enough.') + print(f'Result for the second test case: {result}') + self.assertIsNone(result) + + # Additional test case for partial coverage + start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=300) + records = pd.DataFrame({ + 'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=240)), + unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=180)), + unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=120))] + }) + result = query_satisfied(start_datetime, records, 60) + if result is None: + print('Query is satisfied: records span back far enough.') + else: + print('Query is not satisfied: records do not span back far enough.') + print(f'Result for the third test case: {result}') + self.assertIsNone(result) + + def test_ts_of_n_minutes_ago(self): + print('Testing ts_of_n_minutes_ago()') + now = dt.datetime.utcnow() + + test_cases = [ + (60, 1), # 60 candles of 1 minute each + (10, 5), # 10 candles of 5 minutes each + (1, 1440), # 1 candle of 1 day (1440 minutes) + (7, 10080), # 7 candles of 1 week (10080 minutes) + (30, 60), # 30 candles of 1 hour (60 minutes) + ] + + for n, candle_length in test_cases: + with self.subTest(n=n, candle_length=candle_length): + result = ts_of_n_minutes_ago(n, candle_length) + expected_time = now - dt.timedelta(minutes=(n + 1) * candle_length) + self.assertAlmostEqual(unix_time_seconds(result), unix_time_seconds(expected_time), delta=60) + + def test_timeframe_to_minutes(self): + print('Testing timeframe_to_minutes()') + self.assertEqual(timeframe_to_minutes('15m'), 15) + self.assertEqual(timeframe_to_minutes('1h'), 60) + self.assertEqual(timeframe_to_minutes('1d'), 1440) + self.assertEqual(timeframe_to_minutes('1w'), 10080) + self.assertEqual(timeframe_to_minutes('1M'), 44640) # 31 days in a month + self.assertEqual(timeframe_to_minutes('1Y'), 525600) # 525600 minutes in a year + + +if __name__ == '__main__': + unittest.main()