Fixed query_uptodate function and added comprehensive test cases

This commit is contained in:
Rob 2024-08-02 01:36:45 -03:00
parent 0bb9780ff6
commit 0b1ad39476
23 changed files with 920 additions and 1066 deletions

View File

@ -56,16 +56,19 @@ legend top
=<b>/index</b> =<b>/index</b>
———————————————————————————————— ————————————————————————————————
Handles the main landing page, logs in the user, Handles the main landing page, logs in the user,
ensures a valid connection with an exchange,
loads dynamic data, and renders the landing page. loads dynamic data, and renders the landing page.
end legend end legend
start start
:Log in user; :Log in user;
:Ensure a valid connection with an exchange;
:Load dynamic data; :Load dynamic data;
:Render Landing Page; :Render Landing Page;
stop stop
@enduml @enduml
``` ```
### /ws ### /ws

View File

@ -1,10 +1,14 @@
numpy==1.24.3 numpy==1.24.3
flask==2.3.2 flask==2.3.2
flask_cors==3.0.10
flask_sock==0.7.0
config~=0.5.1 config~=0.5.1
PyYAML~=6.0 PyYAML~=6.0
binance~=0.3
requests==2.30.0 requests==2.30.0
pandas==2.0.1 pandas==2.0.1
passlib~=1.7.4 passlib~=1.7.4
SQLAlchemy==2.0.13 SQLAlchemy==2.0.13
ccxt~=3.0.105 ccxt==4.3.65
email-validator~=2.2.0
TA-Lib~=0.4.32
bcrypt~=4.2.0

View File

@ -1,271 +0,0 @@
import json
import math
import time
from datetime import datetime, date, time, timedelta, timezone
import pandas
import pandas as pd
from alpaca.trading import GetAssetsRequest, AssetClass, MarketOrderRequest, OrderSide, TimeInForce
from Exchange import Exchange
import config
from alpaca.trading.client import TradingClient
from alpaca.data.historical import CryptoHistoricalDataClient
from alpaca.data.requests import CryptoBarsRequest, CryptoLatestQuoteRequest
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
import alpaca_trade_api as tradeapi
class AlpacaPaperExchange(Exchange):
def __init__(self, name, api_keys):
super().__init__(name, api_keys=api_keys)
def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt=None):
"""
Return a dataframe containing rows of candle attributes.
[(symbol, start_datetime, open, high, low,
close, Volume, trade_count, vwap)]
:param symbol: str: - Trading symbol.
:param interval: str: - Timeframe for the candle sticks.
:param start_dt: - Start datetime of data stream.
:param end_dt: - End datetime of data stream.
:return: pd.dataframe - The requested data.
"""
if end_dt is None:
end_dt = datetime.utcnow()
# Extract the numerical part of the timeframe param.
digits = int("".join([i if i.isdigit() else "" for i in interval]))
# Extract the alpha part of the timeframe param.
letter = "".join([i if i.isalpha() else "" for i in interval])
# Convert the abbreviation to the name of a TimeFrameUnit object attribute.
time_frame = "".join(['Minute' if i == 'm' else
'Hour' if i == 'h' else
'Day' if i == 'd' else
'Week' if i == 'w' else
'Month' if i == 'M' else 'Year' for i in letter])
# Fill in the request parameters.
request_params = CryptoBarsRequest(
symbol_or_symbols=symbol,
timeframe=TimeFrame(amount=digits, unit=TimeFrameUnit[time_frame]),
start=start_dt,
end=end_dt
)
if start_dt > end_dt:
raise ValueError(start_dt, end_dt)
# Request the bars from the exchange_interface.
bars = self.client['historical_data'].get_crypto_bars(request_params)
# Convert the response object into a list of dictionaries.
candles_dict = [c.__dict__ for c in bars[symbol]]
# Convert the data into a dataframe.
candles = pd.DataFrame(candles_dict)
# Reformat this dataframe to make it consistent with other exchanges.
del candles['symbol']
candles.rename({'timestamp': 'open_time'}, axis='columns', inplace=True)
# Convert the date column to a start_datetime to make it consistent with the other exchanges.
candles['open_time'] = candles.open_time.values.astype('int64') // 10 ** 6
# Return the data.
return candles
def _fetch_price(self, symbol) -> float:
"""
Get the latest price for a single symbol.
:param symbol: str - The symbol of the symbol.
:return: float - The minimum quantity sold per trade.
"""
request_params = CryptoLatestQuoteRequest(symbol_or_symbols=symbol)
lq = self.client['historical_data'].get_crypto_latest_quote(request_params)
lqd = lq[symbol]
return lqd.bid_price
def _fetch_min_qty(self, symbol) -> float:
"""
Get the minimum base quantity the exchange_interface allows per order.
:param symbol: str - The symbol of the trading pair.
:return: float - The minimum quantity sold per trade.
"""
assets = self.client['trading'].get_asset(symbol)
return assets.min_order_size
def _fetch_min_notional_qty(self, symbol) -> float:
"""
Get the minimum amount of the quote currency allowed per trade.
:param symbol: The symbol of the trading pair.
:return: The minimum quantity sold per trade.
"""
assets = self.client['trading'].get_asset(symbol)
return assets.price_increment
def _fetch_order(self, symbol, order_id) -> object:
"""Return the order status for a specific order."""
return self.client.query_order(symbol=symbol, orderId=order_id)
def _connect_exchange(self) -> object:
"""Connects to the exchange_interface and sets a reference the client."""
client = {
'historical_data': CryptoHistoricalDataClient(),
'trading': TradingClient(self.api_key, self.api_key_secret, paper=True),
'trade_api': tradeapi.REST(self.api_key, self.api_key_secret, "https://paper-api.alpaca.markets/")
}
return client
def _set_exchange_info(self) -> dict:
"""Fetches Info on all symbols from the exchange_interface."""
return self.client['trading'].get_account()
def _set_symbols(self) -> list:
"""Get information of coins (available for deposit and withdraw) for user"""
params = GetAssetsRequest(asset_class=AssetClass.CRYPTO)
# read all the symbol data into a dataframe.
assets = pandas.DataFrame.from_records(self.client['trading'].get_all_assets(params))
# This parses things weird, so I did some format manipulation.
# Extract the proper headers
headers = assets.head(1).applymap(lambda row: row[0]).values[0]
# Rename the headers that were assigned.
assets = assets.set_axis(headers, axis=1)
# remove the headers from the data.
assets = assets.applymap(lambda row: row[1])
# remove rows that aren't tradable.
assets = assets.query("tradable != False")
# Return a list of just the symbols.
return assets.symbol.values
def _set_balances(self) -> list:
"""
Retrieves the account balances, including cash balance and current P&L (Profit & Loss).
:return: list - A list of dictionaries containing the asset balances and P&L information.
Each dictionary has the following keys:
- 'asset': The asset symbol (e.g., 'USD', 'BTC', 'ETH').
- 'balance': The balance of the asset.
- 'pnl': The current P&L (Profit & Loss) of the account.
"""
account = self.client['trade_api'].get_account()
cash_balance = account.cash
portfolio_value = float(account.equity)
buying_power = float(account.buying_power)
# Calculate P&L as the difference between portfolio value and buying power,
# or set to 0 if no positions or open orders
positions = self.client['trade_api'].list_positions()
open_orders = self.client['trade_api'].list_orders(status='open')
if len(positions) == 0 and len(open_orders) == 0:
current_pl = 0.0
else:
current_pl = portfolio_value - buying_power
non_zero_assets = [
{
'asset': 'USD', # Assuming the cash balance is in USD
'balance': cash_balance,
'pnl': current_pl
}
]
return non_zero_assets
def get_asset_balances(self):
""" Get the individual asset balances"""
# Get the account information
account = self.client['trading'].get_account()
# Retrieve the balances of each non-zero asset
balances = []
for asset in account.assets:
if float(asset.qty) > 0:
balance = {
'asset': asset.symbol,
'balance': asset.qty
}
balances.append(balance)
return balances
def get_active_trades(self):
formatted_trades = []
# Get open trades
positions = self.client['trade_api'].list_positions()
for position in positions:
if float(position.qty) != 0.0:
active_trade = {
'symbol': position.symbol,
'side': 'buy' if float(position.qty) > 0 else 'sell',
'quantity': abs(float(position.qty)),
'price': float(position.avg_entry_price)
}
formatted_trades.append(active_trade)
return formatted_trades
def get_open_orders(self):
formatted_orders = []
# Get open orders
open_orders = self.client['trade_api'].list_orders(status='open')
for order in open_orders:
open_order = {
'symbol': order.symbol,
'side': order.side,
'quantity': order.qty,
'price': order.filled_avg_price
}
formatted_orders.append(open_order)
return formatted_orders
def _set_precision_rule(self, symbol) -> None:
"""Appends a record of places after the decimal required by the exchange_interface indexed by symbol."""
# I don't think this exchange_interface has precision rule.
# return the number of decimal places of min_trade_increment
assets = self.client['trading'].get_asset(symbol)
min_trade_rule = assets.min_trade_increment
precision = len(str(min_trade_rule).split('.')[1])
self.symbols_n_precision[symbol] = precision
return
def _place_order(self, symbol, side, type, timeInForce, quantity, price):
def format_arg(value, arg_name):
"""
Function to convert an arguments value to a desired precision and type float.
:param value: The Value of the argument.
:param arg_name: The name of the argument.
:return: float : The formatted output.
"""
# The required level of precision for this trading pair.
precision = self.get_symbol_precision_rule(symbol)
# If quantity was too low, set to the smallest allowable amount.
minimum = 0
if arg_name == 'quantity':
# The minimum quantity aloud to be traded.
minimum = self.get_min_qty(symbol)
elif arg_name == 'price':
# The minimum price aloud to be traded.
minimum = self.get_min_notional_qty(symbol)
if value < minimum:
value = minimum
return float(f"{value:.{precision}f}")
# Set price and quantity to desired precision and type.
quantity = format_arg(quantity, 'quantity')
price = format_arg(price, 'price')
if side == 'buy':
side = OrderSide.BUY
else:
side = OrderSide.SELL
market_order_data = MarketOrderRequest(
symbol=symbol,
qty=quantity,
side=side,
time_in_force=TimeInForce[timeInForce]
)
market_order = self.client['trading'].submit_order(market_order_data)
result = 'Success'
return result, market_order

View File

@ -1,218 +0,0 @@
import pandas as pd
from Exchange import Exchange
import ccxt
class BinanceFuturesExchange(Exchange):
def __init__(self, name, api_keys):
super().__init__(name, api_keys=api_keys)
import ccxt
def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt=None):
"""
Return a list [(Open time, Open, High, Low, Close,
Volume, Close time, Quote symbol volume,
Number of trades, Taker buy base symbol volume,
Taker buy quote symbol volume)]
:param symbol: str: Trading symbol.
:param interval: str: Timeframe for the candlesticks.
:param start_dt: start datetime of data stream.
:param end_dt: end datetime of data stream.
:return: pandas DataFrame containing the candlestick data.
"""
# Convert the start date to a Unix timestamp in seconds.
start_str = int(start_dt.timestamp())
end_str = None
if end_dt is not None:
end_str = int(end_dt.timestamp())
candles = self.client.fetch_ohlcv(symbol=symbol, timeframe=interval, since=start_str, limit=None, end=end_str)
# Create and return a pandas DataFrame.
df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume']
candles = pd.DataFrame(candles, columns=df_columns)
# Return the data
return candles
def _fetch_price(self, symbol) -> float:
"""
Get the latest price for a single symbol.
:param symbol: The symbol of the trading pair.
:return: The latest price for the symbol.
"""
ticker = self.client.fetch_ticker(symbol)
return float(ticker['last'])
def _fetch_min_qty(self, symbol) -> float:
"""
Get the minimum base quantity the exchange allows per order.
:param symbol: The symbol of the trading pair.
:return: The minimum quantity sold per trade.
"""
exchange_info = self.client.fetch_tickers()
symbol_info = exchange_info.get(symbol)
if symbol_info is None:
raise ValueError(f'No symbol: {symbol}')
# Extract the minimum quantity value from the symbol info
min_qty = symbol_info.get('info').get('filters')[2].get('minQty')
return float(min_qty)
def _fetch_min_notional_qty(self, symbol) -> float:
"""
Get the minimum amount of the quote currency allowed per trade.
:param symbol: The symbol of the trading pair.
:return: The minimum quantity sold per trade.
"""
exchange_info = self.client.fetch_tickers()
symbol_info = exchange_info.get(symbol)
if symbol_info is None:
raise ValueError(f'No symbol: {symbol}')
# Extract the minimum notional value from the symbol info
min_notional = symbol_info.get('info').get('filters')[0].get('minNotional')
return float(min_notional)
def _fetch_order(self, symbol, order_id) -> object:
"""Return the order status for a specific order."""
order = self.client.fetch_order(id=order_id, symbol=symbol)
return order
def _connect_exchange(self) -> object:
"""Connects to the exchange_interface and sets a reference to the client."""
exchange = ccxt.binance({
'apiKey': self.api_key,
'secret': self.api_key_secret,
'enableRateLimit': True,
'options': {
'defaultType': 'future',
}
})
return exchange
def _set_exchange_info(self) -> dict:
"""Fetches Info on all symbols from the exchange."""
exchange_info = self.client.load_markets()
return exchange_info
def _set_symbols(self) -> list:
"""Get information of coins (available for deposit and withdraw) for user"""
markets = self.client.load_markets()
# Extract symbol from each market
symbols = list(markets.keys())
return symbols
def _set_balances(self) -> list:
acc_info = self.client.fetch_balance()
non_zero_assets = []
for asset, balance in acc_info['total'].items():
if float(balance) > 0:
non_zero_asset = {
'asset': asset,
'balance': format(balance, '.8f'),
'pnl': acc_info['info']['totalUnrealizedProfit']
}
non_zero_assets.append(non_zero_asset)
return non_zero_assets
def _set_precision_rule(self, symbol) -> None:
# Isolate a list of symbol rules from the exchange_info
all_rules = self.exchange_info['symbols']
# Get the rules for a specific symbol.
symbol_rules = next(filter(lambda rules: (rules['pair'] == symbol), all_rules), None)
# Add the precision rule to a local index.
self.symbols_n_precision[symbol] = symbol_rules['baseAssetPrecision']
return
def _place_order(self, symbol, side, type, timeInForce, quantity, price):
def format_arg(value, arg_name):
"""
Function to convert an arguments value to a desired precision and type float.
:param value: The Value of the argument.
:param arg_name: The name of the argument.
:return: float : The formatted output.
"""
# The required level of precision for this trading pair.
precision = self.client.precision_for_symbol(symbol)
# If quantity was too low, set to the smallest allowable amount.
minimum = 0
if arg_name == 'quantity':
# The minimum quantity allowed to be traded.
minimum = self.client.amount_to_precision(symbol,
self.client.markets[symbol]['limits']['amount']['min'])
elif arg_name == 'price':
# The minimum price allowed to be traded.
minimum = self.client.price_to_precision(symbol, self.client.markets[symbol]['limits']['price']['min'])
if value < minimum:
value = minimum
return float(f"{value:.{precision}f}")
# Set price and quantity to desired precision and type.
quantity = format_arg(quantity, 'quantity')
price = format_arg(price, 'price')
order_params = {
'symbol': symbol,
'side': side,
'type': type,
'timeInForce': timeInForce,
'quantity': quantity,
'price': price
}
data = self.client.create_order(**order_params)
result = 'Success'
return result, data
def get_active_trades(self):
formatted_trades = []
# Get open positions
positions = self.client.fetch_positions()
for position in positions:
if float(position['contracts']) != 0.0:
active_trade = {
'symbol': position['symbol'],
'side': 'buy' if float(position['contracts']) > 0 else 'sell',
'quantity': abs(float(position['contracts'])),
'price': float(position['entryPrice'])
}
formatted_trades.append(active_trade)
return formatted_trades
def get_open_orders(self):
formatted_orders = []
# Get open orders
# open_orders = self.client.fetch_open_orders()
# Todo: fetching open orders without specifying a symbol is rate-limited to one call per 1252 seconds.
open_orders = {}
for order in open_orders:
open_order = {
'symbol': order['symbol'],
'side': order['side'],
'quantity': order['amount'],
'price': order['price']
}
formatted_orders.append(open_order)
return formatted_orders
class BinanceCoinExchange(BinanceFuturesExchange):
def __init__(self, name, api_keys):
super().__init__(name, api_keys=api_keys)
def _connect_exchange(self) -> object:
"""Connects to the exchange_interface and sets a reference to the client."""
return ccxt.binance({
'apiKey': self.api_key,
'secret': self.api_key_secret,
'enableRateLimit': True,
# Additional exchange-specific options can be added here
'options': {
'defaultType': 'coinfuture', # Connect to the Binance Coin Futures API
}
})

View File

@ -1,199 +0,0 @@
import pandas as pd
from binance.enums import HistoricalKlinesType
from shared_utilities import unix_time_millis
import config
from Exchange import Exchange
from binance.client import Client as BinanceClient
from binance.exceptions import BinanceAPIException
class BinanceSpotExchange(Exchange):
def __init__(self, name, api_keys):
super().__init__(name, api_keys)
def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt=None):
"""
Return a dataframe containing rows of candle attributes.
[(Open time, Open, High, Low, Close,
Volume, Close time, Quote symbol volume,
Number of trades, Taker buy base symbol volume,
Taker buy quote symbol volume, Ignore)]
:param symbol: str: Trading symbol.
:param interval: str: Timeframe for the candle sticks.
:param start_dt: start time-date of data stream.
:param end_dt: end datetime of data stream.
:return: pd.dataframe
"""
# Set this to Spot.
klines_type = HistoricalKlinesType.SPOT
# Convert the start date to a Unix start_datetime milliseconds.
start_str = int(unix_time_millis(start_dt))
end_str = end_dt
if end_dt is not None:
end_str = int(unix_time_millis(end_dt))
candles = self.client.get_historical_klines(symbol=symbol, interval=interval, start_str=start_str,
end_str=end_str, klines_type=klines_type)
# Create a pandas DataFrame.
df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume', 'close_time',
'quote_volume', 'num_trades', 'taker_buy_base_volume', 'taker_buy_quote_volume', 'ignore']
candles = pd.DataFrame(candles, columns=df_columns)
# Delete this useless column
del candles['ignore']
# Return the data
return candles
def _fetch_price(self, symbol) -> float:
"""
Get the latest price for a single symbol.
:param symbol: str - The symbol of the symbol.
:return: float - The minimum quantity sold per trade.
"""
record = self.client.get_symbol_ticker(**{'symbol': symbol})
return float(record['price'])
def _fetch_min_qty(self, symbol) -> float:
"""
Get the minimum base quantity the exchange_interface allows per order.
:param symbol: str - The symbol of the trading pair.
:return: float - The minimum quantity sold per trade.
"""
info = self.client.get_symbol_info(symbol=symbol)
for f_index, item in enumerate(info['filters']):
if item['filterType'] == 'LOT_SIZE':
break
return float(info['filters'][f_index]['minQty']) # 'minQty'
def _fetch_min_notional_qty(self, symbol) -> float:
"""
Get the minimum amount of the quote currency allowed per trade.
:param symbol: The symbol of the trading pair.
:return: The minimum quantity sold per trade.
"""
info = self.client.get_symbol_info(symbol=symbol)
for f_index, item in enumerate(info['filters']):
if item['filterType'] == 'MIN_NOTIONAL':
break
return float(info['filters'][f_index]['minNotional']) # 'minNotional'
def _fetch_order(self, symbol, order_id) -> object:
"""Return the order status for a specific order."""
return self.client.get_order(symbol=symbol, orderId=order_id)
def _connect_exchange(self) -> object:
"""Connects to the exchange_interface and sets a reference the client."""
return BinanceClient(self.api_key, self.api_key_secret)
def _set_exchange_info(self) -> dict:
"""Fetches Info on all symbols from the exchange_interface."""
return self.client.get_exchange_info()
def _set_symbols(self) -> list:
"""Get information of coins (available for deposit and withdraw) for user"""
# load into a dataframe
data = pd.DataFrame.from_records(self.client.get_all_tickers())
# unload into a list
return data.symbol.to_list()
def _set_balances(self) -> list:
account_info = self.client.get_account()
balances = []
for asset in account_info['balances']:
asset_symbol = asset['asset']
asset_balance = float(asset['free'])
asset_pnl = self.calculate_asset_pnl(asset_symbol)
if asset_balance > 0 or asset_pnl != 0:
asset_data = {
'asset': asset_symbol,
'balance': asset_balance,
'pnl': asset_pnl
}
balances.append(asset_data)
return balances
def calculate_asset_pnl(self, asset_symbol: str) -> float:
# Implement your logic to calculate the P&L for the given asset
# This may involve retrieving historical data, calculating positions, etc.
# Return the calculated P&L value as a float
return 0.0 # Placeholder value, replace with actual calculation
def _set_precision_rule(self, symbol) -> None:
# Isolate a list of symbol rules from the exchange_info
all_rules = self.exchange_info['symbols']
# Get the rules for a specific symbol.
symbol_rules = next(filter(lambda rules: (rules['symbol'] == symbol), all_rules), None)
# Add the precision rule to a local index.
self.symbols_n_precision[symbol] = symbol_rules['baseAssetPrecision']
return
def _place_order(self, symbol, side, type, timeInForce, quantity, price):
def format_arg(value, arg_name):
"""
Function to convert an arguments value to a desired precision and type float.
:param value: The Value of the argument.
:param arg_name: The name of the argument.
:return: float : The formatted output.
"""
# The required level of precision for this trading pair.
precision = self.get_symbol_precision_rule(symbol)
# If quantity was too low, set to the smallest allowable amount.
minimum = 0
if arg_name == 'quantity':
# The minimum quantity aloud to be traded.
minimum = self.get_min_qty(symbol)
elif arg_name == 'price':
# The minimum price aloud to be traded.
minimum = self.get_min_notional_qty(symbol)
if value < minimum:
value = minimum
return float(f"{value:.{precision}f}")
# Set price and quantity to desired precision and type.
quantity = format_arg(quantity, 'quantity')
price = format_arg(price, 'price')
data = self.client.create_test_order(symbol=symbol, side=side, type=type,
timeInForce=timeInForce, quantity=quantity,
price=price)
result = 'Success'
return result, data
def get_active_trades(self):
formatted_trades = []
# Get open trades
positions = self.client.get_account().get('positions')
if positions is None:
return formatted_trades
for position in positions:
if float(position['qty']) != 0.0:
active_trade = {
'symbol': position['symbol'],
'side': 'buy' if float(position['qty']) > 0 else 'sell',
'quantity': abs(float(position['qty'])),
'price': float(position['entryPrice'])
}
formatted_trades.append(active_trade)
return formatted_trades
def get_open_orders(self):
formatted_orders = []
# Get open orders
open_orders = self.client.get_open_orders()
for order in open_orders:
open_order = {
'symbol': order['symbol'],
'side': order['side'],
'quantity': order['origQty'],
'price': order['price']
}
formatted_orders.append(open_order)
return formatted_orders

View File

@ -172,7 +172,7 @@ class BrighterTrades:
return self.indicators.get_indicator_data(user_name=user_name, source=source, start_ts=start_ts, return self.indicators.get_indicator_data(user_name=user_name, source=source, start_ts=start_ts,
num_results=num_results) num_results=num_results)
def connect_user_to_exchange(self, user_name: str, default_exchange: str, default_keys: dict) -> bool: def connect_user_to_exchange(self, user_name: str, default_exchange: str, default_keys: dict = None) -> bool:
""" """
Connects an exchange if it is not already connected. Connects an exchange if it is not already connected.
@ -189,6 +189,7 @@ class BrighterTrades:
exchange_name=exchange, exchange_name=exchange,
api_keys=keys) api_keys=keys)
if not success: if not success:
# If no active exchange was successfully connected, connect to the default exchange
success = self.connect_or_config_exchange(user_name=user_name, success = self.connect_or_config_exchange(user_name=user_name,
exchange_name=default_exchange, exchange_name=default_exchange,
api_keys=default_keys) api_keys=default_keys)
@ -237,7 +238,7 @@ class BrighterTrades:
r_data = {} r_data = {}
r_data['title'] = self.config.app_data.get('application_title', '') r_data['title'] = self.config.app_data.get('application_title', '')
r_data['chart_interval'] = chart_view.get('timeframe', '') r_data['chart_interval'] = chart_view.get('timeframe', '')
r_data['selected_exchange'] = chart_view.get('exchange_name', '') r_data['selected_exchange'] = chart_view.get('exchange', '')
r_data['intervals'] = exchange.intervals if exchange else [] r_data['intervals'] = exchange.intervals if exchange else []
r_data['symbols'] = exchange.get_symbols() if exchange else {} r_data['symbols'] = exchange.get_symbols() if exchange else {}
r_data['available_exchanges'] = self.exchanges.get_available_exchanges() or [] r_data['available_exchanges'] = self.exchanges.get_available_exchanges() or []
@ -359,7 +360,7 @@ class BrighterTrades:
""" """
return self.strategies.get_strategies('json') return self.strategies.get_strategies('json')
def connect_or_config_exchange(self, user_name: str, exchange_name: str, api_keys: dict) -> bool: def connect_or_config_exchange(self, user_name: str, exchange_name: str, api_keys: dict = None) -> bool:
""" """
Connects to an exchange if not already connected, or configures the exchange connection for a single user. Connects to an exchange if not already connected, or configures the exchange connection for a single user.
@ -377,7 +378,9 @@ class BrighterTrades:
api_keys=api_keys) api_keys=api_keys)
if success: if success:
self.config.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set') self.config.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set')
self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name) if api_keys:
self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name,
user_name=user_name)
return True return True
else: else:
return False # Failed to connect return False # Failed to connect

View File

@ -38,6 +38,8 @@ class DataCache:
records = pd.concat([more_records, self.get_cache(key)], axis=0, ignore_index=True) records = pd.concat([more_records, self.get_cache(key)], axis=0, ignore_index=True)
# Drop any duplicates from overlap. # Drop any duplicates from overlap.
records = records.drop_duplicates(subset="open_time", keep='first') records = records.drop_duplicates(subset="open_time", keep='first')
# Sort the records by open_time.
records = records.sort_values(by='open_time').reset_index(drop=True)
# Replace the incomplete dataframe with the modified one. # Replace the incomplete dataframe with the modified one.
self.set_cache(data=records, key=key) self.set_cache(data=records, key=key)
return return

View File

@ -1,40 +1,175 @@
from abc import ABC, abstractmethod import ccxt
from typing import Tuple import pandas as pd
from datetime import datetime, timedelta
from typing import Tuple, Dict, List, Union
import time
class Exchange(ABC): class Exchange:
def __init__(self, name, api_keys): # Class attribute for caching market data
_market_cache = {}
def __init__(self, name: str, api_keys: Dict[str, str], exchange_id: str):
self.name = name self.name = name
# 1 The api key for the exchange. # The API key for the exchange.
self.api_key = api_keys['key'] self.api_key = api_keys['key'] if api_keys else None
# 2 The api secret key for the exchange. # The API secret key for the exchange.
self.api_key_secret = api_keys['secret'] self.api_key_secret = api_keys['secret'] if api_keys else None
# 3 The connection to the exchange_interface. # The exchange id for the exchange.
self.client = self._connect_exchange() self.exchange_id = exchange_id
# 4 Info on all symbols and exchange_interface rules. # The connection to the exchange_interface.
self.client: ccxt.Exchange = self._connect_exchange()
# Info on all symbols and exchange_interface rules.
self.exchange_info = self._set_exchange_info() self.exchange_info = self._set_exchange_info()
# 5 List of time intervals available for trading. # List of time intervals available for trading.
self.intervals = self._set_avail_intervals() self.intervals = self._set_avail_intervals()
# 6 All symbols available for trading. # All symbols available for trading.
self.symbols = self._set_symbols() self.symbols = self._set_symbols()
# 7 Any non-zero balance-info for all assets. # Any non-zero balance-info for all assets.
self.balances = self._set_balances() self.balances = self._set_balances()
# 8 Dictionary of places after the decimal requires by the exchange_interface indexed by symbol. # Dictionary of places after the decimal requires by the exchange_interface indexed by symbol.
self.symbols_n_precision = {} self.symbols_n_precision = {}
def _connect_exchange(self) -> ccxt.Exchange:
"""
Connects to the exchange_interface and sets a reference the client.
Handles cases where api_keys might be None.
"""
exchange_class = getattr(ccxt, self.exchange_id)
if not exchange_class:
raise ValueError(f"Exchange {self.exchange_id} is not supported by CCXT.")
if self.api_key and self.api_key_secret:
return exchange_class({
'apiKey': self.api_key,
'secret': self.api_key_secret,
'enableRateLimit': True,
'verbose': False # Disable verbose debugging output
})
else:
return exchange_class({
'enableRateLimit': True,
'verbose': False # Disable verbose debugging output
})
@staticmethod
def datetime_to_unix_millis(dt: datetime) -> int:
"""
Convert a datetime object to Unix timestamp in milliseconds.
:param dt: datetime - The datetime object to convert.
:return: int - The Unix timestamp in milliseconds.
"""
return int(dt.timestamp() * 1000)
def _fetch_historical_klines(self, symbol: str, interval: str, start_dt: datetime,
end_dt: datetime = None) -> pd.DataFrame:
if end_dt is None:
end_dt = datetime.utcnow()
max_interval = timedelta(days=200) # Binance's maximum interval
data_frames = []
current_start = start_dt
while current_start < end_dt:
current_end = min(current_start + max_interval, end_dt)
start_str = self.datetime_to_unix_millis(current_start)
end_str = self.datetime_to_unix_millis(current_end)
try:
candles = self.client.fetch_ohlcv(symbol=symbol, timeframe=interval, since=start_str,
params={'endTime': end_str})
if not candles:
break
df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume']
candles_df = pd.DataFrame(candles, columns=df_columns)
candles_df['open_time'] = candles_df['open_time'] // 1000
data_frames.append(candles_df)
# Move the start to the end of the current chunk to get the next chunk
current_start = current_end
# Sleep for 1 second to avoid hitting rate limits
time.sleep(1)
except ccxt.BaseError as e:
print(f"Error fetching OHLCV data: {str(e)}")
break
if data_frames:
result_df = pd.concat(data_frames)
return result_df
else:
return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume'])
def _fetch_price(self, symbol: str) -> float:
ticker = self.client.fetch_ticker(symbol)
return float(ticker['last'])
def _fetch_min_qty(self, symbol: str) -> float:
market_data = self.exchange_info[symbol]
return float(market_data['limits']['amount']['min'])
def _fetch_min_notional_qty(self, symbol: str) -> float:
market_data = self.exchange_info[symbol]
return float(market_data['limits']['cost']['min'])
def _fetch_order(self, symbol: str, order_id: str) -> object:
return self.client.fetch_order(order_id, symbol)
def _set_symbols(self) -> List[str]:
markets = self.client.fetch_markets()
symbols = [market['symbol'] for market in markets if market['active']]
return symbols
def _set_balances(self) -> List[Dict[str, Union[str, float]]]:
if self.api_key and self.api_key_secret:
try:
account_info = self.client.fetch_balance()
balances = []
for asset, balance in account_info['total'].items():
asset_balance = float(balance)
if asset_balance > 0:
balances.append({'asset': asset, 'balance': asset_balance, 'pnl': 0})
return balances
except NotImplementedError:
# Handle the case where fetch_balance is not supported
return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}]
else:
return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}]
def _set_exchange_info(self) -> dict:
"""
Fetches market data for all symbols from the exchange.
This includes details on trading pairs, limits, fees, and other relevant information.
Caches the market info to speed up subsequent calls.
"""
if self.exchange_id in Exchange._market_cache:
return Exchange._market_cache[self.exchange_id]
try:
markets_info = self.client.load_markets()
Exchange._market_cache[self.exchange_id] = markets_info
return markets_info
except ccxt.BaseError as e:
print(f"Error fetching market info: {str(e)}")
return {}
def get_client(self) -> object: def get_client(self) -> object:
""" Return a reference to the exchange_interface client.""" """ Return a reference to the exchange_interface client."""
return self.client return self.client
def get_avail_intervals(self) -> tuple: def get_avail_intervals(self) -> Tuple[str, ...]:
"""Returns a list of time intervals available for trading.""" """Returns a list of time intervals available for trading."""
return self.intervals return self.intervals
@ -42,15 +177,15 @@ class Exchange(ABC):
"""Returns Info on all symbols.""" """Returns Info on all symbols."""
return self.exchange_info return self.exchange_info
def get_symbols(self) -> list: def get_symbols(self) -> List[str]:
"""Returns all symbols available for trading.""" """Returns all symbols available for trading."""
return self.symbols return self.symbols
def get_balances(self) -> list: def get_balances(self) -> List[Dict[str, Union[str, float]]]:
"""Returns any non-zero balance-info for all assets.""" """Returns any non-zero balance-info for all assets."""
return self.balances return self.balances
def get_symbol_precision_rule(self, symbol) -> int: def get_symbol_precision_rule(self, symbol: str) -> int:
"""Returns number of places after the decimal required by the exchange_interface for a specific symbol.""" """Returns number of places after the decimal required by the exchange_interface for a specific symbol."""
r_value = self.symbols_n_precision.get(symbol) r_value = self.symbols_n_precision.get(symbol)
if r_value is None: if r_value is None:
@ -58,149 +193,108 @@ class Exchange(ABC):
r_value = self.symbols_n_precision.get(symbol) r_value = self.symbols_n_precision.get(symbol)
return r_value return r_value
def get_historical_klines(self, symbol, interval, start_dt, end_dt) -> object: def get_historical_klines(self, symbol: str, interval: str, start_dt: datetime,
""" end_dt: datetime = None) -> pd.DataFrame:
Return a dataframe containing rows of candle attributes. Attributes very between different exchanges. return self._fetch_historical_klines(symbol=symbol, interval=interval, start_dt=start_dt, end_dt=end_dt)
For example: [(Open time, Open, High, Low, Close,
Volume, Close time, Quote symbol volume,
Number of trades, Taker buy base symbol volume,
Taker buy quote symbol volume, Ignore)]
"""
return self._fetch_historical_klines(symbol=symbol, interval=interval,
start_dt=start_dt, end_dt=end_dt)
def get_price(self, symbol) -> float: def get_price(self, symbol: str) -> float:
"""
Get the latest price for a single symbol.
:param symbol: str - The symbol of the symbol.
:return: float - The minimum quantity sold per trade.
"""
return self._fetch_price(symbol) return self._fetch_price(symbol)
def get_min_qty(self, symbol) -> float: def get_min_qty(self, symbol: str) -> float:
"""
Get the minimum base quantity the exchange_interface allows per order.
:param symbol: str - The symbol of the trading pair.
:return: float - The minimum quantity sold per trade.
"""
return self._fetch_min_qty(symbol) return self._fetch_min_qty(symbol)
def get_min_notional_qty(self, symbol) -> float: def get_min_notional_qty(self, symbol: str) -> float:
"""
Get the minimum amount of the quote currency allowed per trade.
:param symbol: The symbol of the trading pair.
:return: The minimum quantity sold per trade.
"""
return self._fetch_min_notional_qty(symbol) return self._fetch_min_notional_qty(symbol)
@abstractmethod def get_order(self, symbol: str, order_id: str) -> object:
def get_active_trades(self):
pass
@abstractmethod
def get_open_orders(self):
pass
def get_order(self, symbol, order_id) -> object:
"""
Get an order by id.
:param symbol: The trading pair
:param order_id: The order id
:return: object
"""
return self._fetch_order(symbol, order_id) return self._fetch_order(symbol, order_id)
def place_order(self, symbol, side, type, timeInForce, quantity, price): def place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> \
result, msg = self._place_order(symbol=symbol, side=side, type=type, Tuple[str, object]:
timeInForce=timeInForce, quantity=quantity, price=price) result, msg = self._place_order(symbol=symbol, side=side, type=type, timeInForce=timeInForce, quantity=quantity,
price=price)
return result, msg return result, msg
def _set_avail_intervals(self) -> tuple: def _set_avail_intervals(self) -> Tuple[str, ...]:
"""Sets a list of time intervals available for trading on the exchange_interface."""
return '1m', '3m', '5m', '15m', '30m', '1h', '2h', '4h', '6h', '8h', '12h', '1d', '3d', '1w', '1M'
# noinspection PyMethodMayBeStatic
@abstractmethod
def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt) -> list:
""" """
Return a list [(Open time, Open, High, Low, Close, Sets a list of time intervals available for trading on the exchange_interface.
Volume, Close time, Quote symbol volume,
Number of trades, Taker buy base symbol volume,
Taker buy quote symbol volume, Ignore)]
""" """
pass return tuple(self.client.timeframes.keys())
@abstractmethod def _set_precision_rule(self, symbol: str) -> None:
def _fetch_price(self, symbol) -> float: market_data = self.exchange_info[symbol]
precision = market_data['precision']['amount']
self.symbols_n_precision[symbol] = precision
return
def _place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> \
Tuple[str, object]:
def format_arg(value: float) -> float:
precision = self.symbols_n_precision.get(symbol, 8)
return float(f"{value:.{precision}f}")
quantity = format_arg(quantity)
if price is not None:
price = format_arg(price)
order_params = {
'symbol': symbol,
'type': type,
'side': side,
'amount': quantity,
'params': {
'timeInForce': timeInForce
}
}
if price is not None:
order_params['price'] = price
order = self.client.create_order(**order_params)
return 'Success', order
def get_active_trades(self) -> List[Dict[str, Union[str, float]]]:
""" """
Get the latest price for a single symbol. Get the active trades (open positions).
:param symbol: str - The symbol of the symbol.
:return: float - The minimum quantity sold per trade.
""" """
pass if self.api_key and self.api_key_secret:
try:
positions = self.client.fetch_positions()
formatted_trades = []
for position in positions:
active_trade = {
'symbol': position['symbol'],
'side': 'buy' if float(position['quantity']) > 0 else 'sell',
'quantity': abs(float(position['quantity'])),
'price': float(position['entry_price'])
}
formatted_trades.append(active_trade)
return formatted_trades
except NotImplementedError:
# Handle the case where fetch_positions is not supported
return []
else:
return []
@abstractmethod def get_open_orders(self) -> List[Dict[str, Union[str, float]]]:
def _fetch_min_qty(self, symbol) -> float:
""" """
Get the minimum base quantity the exchange_interface allows per order. Get the open orders.
:param symbol: str - The symbol of the trading pair.
:return: float - The minimum quantity sold per trade.
""" """
pass if self.api_key and self.api_key_secret:
try:
@abstractmethod open_orders = self.client.fetch_open_orders()
def _fetch_min_notional_qty(self, symbol) -> float: formatted_orders = []
""" for order in open_orders:
Get the minimum amount of the quote currency allowed per trade. open_order = {
'symbol': order['symbol'],
:param symbol: The symbol of the trading pair. 'side': order['side'],
:return: The minimum quantity sold per trade. 'quantity': order['amount'],
""" 'price': order['price']
pass }
formatted_orders.append(open_order)
@abstractmethod return formatted_orders
def _fetch_order(self, symbol, order_id) -> object: except NotImplementedError:
""" # Handle the case where fetch_balance is not supported
Get an order by id. return []
:param symbol: The trading pair else:
:param order_id: The order id return []
:return: object
"""
pass
@abstractmethod
def _connect_exchange(self) -> object:
"""Connects to the exchange_interface and sets a reference the client."""
pass
@abstractmethod
def _set_exchange_info(self) -> dict:
"""Fetches Info on all symbols from the exchange_interface."""
pass
@abstractmethod
def _set_symbols(self) -> list:
"""Fetches all symbols available for trading on the exchange_interface."""
pass
@abstractmethod
def _set_balances(self) -> list:
"""Fetches any non-zero balance-info for all assets from exchange_interface."""
pass
@abstractmethod
def _set_precision_rule(self, symbol) -> dict:
"""Appends a record of places after the decimal required by the exchange_interface indexed by symbol."""
pass
@abstractmethod
def _place_order(self, symbol, side, type, timeInForce, quantity, price):
result = 'Success'
msg = ''
return result, msg

View File

@ -541,8 +541,10 @@ class Users:
# Get the user records from the database. # Get the user records from the database.
user = self.get_user_from_db(user_name) user = self.get_user_from_db(user_name)
# Get the exchanges list based on the field. # Get the exchanges list based on the field.
return json.loads(user.loc[0, category]) exchanges = user.loc[0, category]
except (KeyError, IndexError) as e: # Return the list if it exists, otherwise return an empty list.
return json.loads(exchanges) if exchanges else []
except (KeyError, IndexError, json.JSONDecodeError) as e:
# Log the error to the console # Log the error to the console
print(f"Error retrieving exchanges for user '{user_name}' and field '{category}': {str(e)}") print(f"Error retrieving exchanges for user '{user_name}' and field '{category}': {str(e)}")
return None return None
@ -558,8 +560,12 @@ class Users:
""" """
# Get the user records from the database. # Get the user records from the database.
user = self.get_user_from_db(user_name) user = self.get_user_from_db(user_name)
# Get the old active_exchanges list. # Get the old active_exchanges list, or initialize as an empty list if it is None.
active_exchanges = json.loads(user.loc[0, 'active_exchanges']) active_exchanges = user.loc[0, 'active_exchanges']
if active_exchanges is None:
active_exchanges = []
else:
active_exchanges = json.loads(active_exchanges)
# Define the actions for each command # Define the actions for each command
actions = { actions = {

View File

@ -30,7 +30,7 @@ cors = CORS(app, supports_credentials=True,
r"/api/history": {"origins": ["http://127.0.0.1:5000", "http://localhost:5000"]}, r"/api/history": {"origins": ["http://127.0.0.1:5000", "http://localhost:5000"]},
r"/api/indicator_init": {"origins": ["http://127.0.0.1:5000", "http://localhost:5000"]} r"/api/indicator_init": {"origins": ["http://127.0.0.1:5000", "http://localhost:5000"]}
}, },
headers=['Content-Type']) allow_headers=['Content-Type']) # Change from headers to allow_headers
@app.after_request @app.after_request
@ -46,6 +46,8 @@ def index():
Fetches data from brighter_trades and inject it into an HTML template. Fetches data from brighter_trades and inject it into an HTML template.
Renders the html template and serves the web application. Renders the html template and serves the web application.
""" """
# Clear the session to simulate a new visitor
session.clear()
try: try:
# Log the user in. # Log the user in.
user_name = brighter_trades.config.users.load_or_create_user(username=session.get('user')) user_name = brighter_trades.config.users.load_or_create_user(username=session.get('user'))
@ -61,13 +63,15 @@ def index():
print('[SERVING INDEX] (USERNAME):', user_name) print('[SERVING INDEX] (USERNAME):', user_name)
# Ensure that a valid connection with an exchange exist. default_exchange = 'binance'
keys = {'key': config.ALPACA_API_KEY, 'secret': config.ALPACA_API_SECRET} default_keys = None # No keys needed for public market data
# Ensure that a valid connection with an exchange exists
result = brighter_trades.connect_user_to_exchange(user_name=user_name, result = brighter_trades.connect_user_to_exchange(user_name=user_name,
default_exchange='alpaca', default_exchange=default_exchange,
default_keys=keys) default_keys=default_keys)
if not result: if not result:
raise ValueError("Couldn't connect to the alpaca exchange.") raise ValueError("Couldn't connect to the default exchange.")
# A dict of data required to build the html of the user app. # A dict of data required to build the html of the user app.
# Dynamic content like options and titles and balances to display. # Dynamic content like options and titles and balances to display.
@ -310,4 +314,4 @@ def indicator_init():
if __name__ == '__main__': if __name__ == '__main__':
app.run(debug=True, use_reloader=False) app.run(debug=False, use_reloader=False)

View File

@ -1,5 +1,5 @@
BINANCE_API_KEY = 'rkp1Xflb5nnwt6jys0PG27KXcqwn0q9lKCLryKcSp4mKW2UOlkPRuAHPg45rQVgj' BINANCE_API_KEY = 'rkp1Xflb5nnwt6jys0PG27KXcqwn0q9lKCLryKcSp4mKW2UOlkPRuAHPg45rQVgj'
BINANCE_API_SECRET = 'DiFhhYhF64nkPe5f3V7TRJX2bSVA7ZQZlozSdX7O7uYmBMdK985eA6Kp2B2zKvbK' BINANCE_API_SECRET = 'DiFhhYhF64nkPe5f3V7TRJX2bSVA7ZQZlozSdX7O7uYmBMdK985eA6Kp2B2zKvbK'
ALPACA_API_KEY = 'PKN0WFYT9VZYUVRBG1HM' ALPACA_API_KEY = 'PKE4RD999SJ8L53OUI8O'
ALPACA_API_SECRET = '0C1I6UcBSR2B0SZrBC3DoKGtcglAny8znorvganx' ALPACA_API_SECRET = 'buwlMoSSfZWGih8Er30quQt4d7brsBWdJXD1KB7C'
DB_FILE = "C:/Users/Rob/PycharmProjects/BrighterTrading/data/BrighterTrading.db" DB_FILE = "C:/Users/Rob/PycharmProjects/BrighterTrading/data/BrighterTrading.db"

View File

@ -242,7 +242,7 @@ class Database:
:param exchange_name: str - The name of the exchange_name. :param exchange_name: str - The name of the exchange_name.
:return: int - The primary id of the exchange_name. :return: int - The primary id of the exchange_name.
""" """
return self.get_from_static_table(item='id', table='exchange', indexes=HDict({'name': exchange_name})) return self.get_from_static_table(item='id', table='exchange', create_id=True,indexes=HDict({'name': exchange_name}))
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int: def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
""" """
@ -276,24 +276,23 @@ class Database:
market_id = self._fetch_market_id(symbol, exchange_name) market_id = self._fetch_market_id(symbol, exchange_name)
# Insert the market id into the dataframe. # Insert the market id into the dataframe.
candlesticks.insert(0, 'market_id', market_id) candlesticks.insert(0, 'market_id', market_id)
# Create a table schema. # Create a table schema. todo delete these line if not needed anymore
# Get a list of all the columns in the dataframe. # # Get a list of all the columns in the dataframe.
columns = list(candlesticks.columns.values) # columns = list(candlesticks.columns.values)
# Isolate any extra columns specific to individual exchanges. # # Isolate any extra columns specific to individual exchanges.
# The carriage return and tabs are unnecessary, they just tidy output for debugging. # # The carriage return and tabs are unnecessary, they just tidy output for debugging.
columns = ',\n\t\t\t\t\t'.join(columns[7:], ) # columns = ',\n\t\t\t\t\t'.join(columns[7:], )
# Define the columns common with all exchanges and append any extras columns. # # Define the columns common with all exchanges and append any extras columns.
sql_create = f""" sql_create = f"""
CREATE TABLE IF NOT EXISTS '{table_name}' ( CREATE TABLE IF NOT EXISTS '{table_name}' (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
market_id INTEGER, market_id INTEGER,
open_time UNIQUE ON CONFLICT IGNORE, open_time INTEGER UNIQUE ON CONFLICT IGNORE,
open NOT NULL, open REAL NOT NULL,
high NOT NULL, high REAL NOT NULL,
low NOT NULL, low REAL NOT NULL,
close NOT NULL, close REAL NOT NULL,
volume NOT NULL, volume REAL NOT NULL,
{columns},
FOREIGN KEY (market_id) REFERENCES market (id) FOREIGN KEY (market_id) REFERENCES market (id)
)""" )"""
# Connect to the database. # Connect to the database.
@ -338,17 +337,17 @@ class Database:
print(f'Got {len(records.index)} records from db') print(f'Got {len(records.index)} records from db')
else: else:
# If the table doesn't exist, get them from the exchange_name. # If the table doesn't exist, get them from the exchange_name.
print('\nTable didnt exist fetching from exchange_name') print(f'\nTable didnt exist fetching from {ex_details[2]}')
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
print(f'Requesting from {st} to {et}, Should be {temp} records') print(f'Requesting from {st} to {et}, Should be {temp} records')
records = self._populate_table(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details) records = self._populate_table(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
print(f'Got {len(records.index)} records from exchange_name') print(f'Got {len(records.index)} records from {ex_details[2]}')
# Check if the records in the db go far enough back to satisfy the query. # Check if the records in the db go far enough back to satisfy the query.
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length=rl) first_timestamp = query_satisfied(start_datetime=st, records=records, r_length=rl)
if first_timestamp: if first_timestamp:
# The records didn't go far enough back if a timestamp was returned. # The records didn't go far enough back if a timestamp was returned.
print('\nRecords did not go far enough back. Requesting from exchange_name') print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
print(f'first ts on record is: {first_timestamp}') print(f'first ts on record is: {first_timestamp}')
end_time = dt.datetime.utcfromtimestamp(first_timestamp) end_time = dt.datetime.utcfromtimestamp(first_timestamp)
print(f'Requesting from {st} to {end_time}') print(f'Requesting from {st} to {end_time}')
@ -359,7 +358,7 @@ class Database:
last_timestamp = query_uptodate(records=records, r_length=rl) last_timestamp = query_uptodate(records=records, r_length=rl)
if last_timestamp: if last_timestamp:
# The query was not up-to-date if a timestamp was returned. # The query was not up-to-date if a timestamp was returned.
print('\nRecords were not updated. Requesting from exchange_name.') print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
print(f'the last record on file is: {last_timestamp}') print(f'the last record on file is: {last_timestamp}')
start_time = dt.datetime.utcfromtimestamp(last_timestamp) start_time = dt.datetime.utcfromtimestamp(last_timestamp)
print(f'Requesting from {start_time} to {et}') print(f'Requesting from {start_time} to {et}')

View File

@ -1,12 +1,14 @@
import logging
import json import json
from typing import List, Any from typing import List, Any, Dict
import pandas as pd import pandas as pd
import requests import requests
import ccxt
from BinanceFutures import BinanceFuturesExchange, BinanceCoinExchange from Exchange import Exchange
from BinanceSpot import BinanceSpotExchange
from AlpacaPaperExchange import AlpacaPaperExchange # Setup logging
# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
# This just makes this method cleaner. # This just makes this method cleaner.
@ -17,48 +19,54 @@ def add_row(df, dic):
class ExchangeInterface: class ExchangeInterface:
""" """
Connects and maintains and routs data requests from exchanges. Connects and maintains and routes data requests from exchanges.
""" """
def __init__(self): def __init__(self):
# Create a dataframe to hold all the references and info for each user's configured exchanges. # Create a dataframe to hold all the references and info for each user's configured exchanges.
self.exchange_data = pd.DataFrame(columns=['user', 'name', 'reference', 'balance']) self.exchange_data = pd.DataFrame(columns=['user', 'name', 'reference', 'balances'])
# List of available exchanges
self.available_exchanges = ['alpaca', 'binance_coin', 'binance_futures', 'binance_spot'] # Populate the list of available exchanges from CCXT
self.available_exchanges = self.get_ccxt_exchanges()
def get_ccxt_exchanges(self) -> List[str]:
"""Retrieve the list of available exchanges from CCXT."""
return ccxt.exchanges
def connect_exchange(self, exchange_name: str, user_name: str, api_keys: dict = None) -> bool: def connect_exchange(self, exchange_name: str, user_name: str, api_keys: dict = None) -> bool:
""" """
Initialize and store a reference to the available exchanges. Initialize and store a reference to the available exchanges.
:param user_name: The name of the user connecting the exchange. :param user_name: The name of the user connecting the exchange.
:param api_keys: dict - {api: key, api-secret: key} :param api_keys: dict - {api: key, api-secret: key}
:param exchange_name: str - The name of the exchange. :param exchange_name: str - The name of the exchange.
:return: True if success | None on fail. :return: True if success | None on fail.
""" """
if exchange_name == 'alpaca': # logging.debug(
success = self.add_exchange(user_name, AlpacaPaperExchange, (exchange_name, api_keys)) # f"Attempting to connect to exchange '{exchange_name}' for user '{user_name}' with API keys: {api_keys}")
elif exchange_name == 'binance_coin':
success = self.add_exchange(user_name, BinanceCoinExchange, (exchange_name, api_keys))
elif exchange_name == 'binance_futures':
success = self.add_exchange(user_name, BinanceFuturesExchange, (exchange_name, api_keys))
elif exchange_name == 'binance_spot':
success = self.add_exchange(user_name, BinanceSpotExchange, (exchange_name, api_keys))
else:
success = False
return success
def add_exchange(self, user_name, _class, arg):
try: try:
ref = _class(*arg) # Initialize the exchange object
row = {'user': user_name, 'name': ref.name, 'reference': ref, 'balances': ref.balances} exchange = Exchange(name=exchange_name, api_keys=api_keys, exchange_id=exchange_name.lower())
# Update exchange data with the new connection
self.add_exchange(user_name, exchange)
# logging.debug(f"Successfully connected to exchange '{exchange_name}' for user '{user_name}'")
return True
except Exception as e:
logging.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}")
return False # Failed to connect
def add_exchange(self, user_name: str, exchange: Exchange):
try:
row = {'user': user_name, 'name': exchange.name, 'reference': exchange, 'balances': exchange.balances}
self.exchange_data = add_row(self.exchange_data, row) self.exchange_data = add_row(self.exchange_data, row)
except Exception as e: except Exception as e:
if e.status_code == 400 and e.error_code == -1021: if hasattr(e, 'status_code') and e.status_code == 400 and e.error_code == -1021:
print("Timestamp ahead of server's time error: Sync your system clock to fix this.") logging.error("Timestamp ahead of server's time error: Sync your system clock to fix this.")
print("Couldn't create an instance of the exchange!:\n", e) logging.error("Couldn't create an instance of the exchange!:\n", e)
raise raise
return True
def get_exchange(self, ename: str, uname: str) -> Any: def get_exchange(self, ename: str, uname: str) -> Any:
"""Return a reference to the exchange_name.""" """Return a reference to the exchange_name."""
@ -76,15 +84,15 @@ class ExchangeInterface:
connected_exchanges = self.exchange_data.loc[self.exchange_data['user'] == user_name, 'name'].tolist() connected_exchanges = self.exchange_data.loc[self.exchange_data['user'] == user_name, 'name'].tolist()
return connected_exchanges return connected_exchanges
def get_available_exchanges(self): def get_available_exchanges(self) -> List[str]:
""" Return a list of the exchanges available to connect to""" """ Return a list of the exchanges available to connect to"""
return self.available_exchanges return self.available_exchanges
def get_exchange_balances(self, name): def get_exchange_balances(self, name: str) -> pd.Series:
""" Return the balances of a single exchange_name""" """ Return the balances of a single exchange_name"""
return self.exchange_data.query("name == @name")['balances'] return self.exchange_data.query("name == @name")['balances']
def get_all_balances(self, user_name: str) -> dict: def get_all_balances(self, user_name: str) -> Dict[str, List[Dict[str, Any]]]:
""" """
Return the balances of all connected exchanges indexed by name. Return the balances of all connected exchanges indexed by name.
@ -105,7 +113,7 @@ class ExchangeInterface:
return balances_dict return balances_dict
def get_all_activated(self, user_name, fetch_type='trades'): def get_all_activated(self, user_name: str, fetch_type: str = 'trades') -> Dict[str, List[Dict[str, Any]]]:
"""Get active trades or open orders as a dictionary indexed by name""" """Get active trades or open orders as a dictionary indexed by name"""
filtered_data = self.exchange_data.loc[self.exchange_data['user'] == user_name, ['name', 'reference']] filtered_data = self.exchange_data.loc[self.exchange_data['user'] == user_name, ['name', 'reference']]
if filtered_data.empty: if filtered_data.empty:
@ -122,16 +130,16 @@ class ExchangeInterface:
elif fetch_type == 'orders': elif fetch_type == 'orders':
data = reference.get_open_orders() data = reference.get_open_orders()
else: else:
print(f"Invalid fetch type: {fetch_type}") logging.error(f"Invalid fetch type: {fetch_type}")
return {} return {}
data_dict[name] = data data_dict[name] = data
except Exception as e: except Exception as e:
print(f"Error retrieving data for {name}: {str(e)}") logging.error(f"Error retrieving data for {name}: {str(e)}")
return data_dict return data_dict
def get_order(self, symbol, order_id, target, user_name): def get_order(self, symbol: str, order_id: str, target: str, user_name: str) -> Any:
""" """
Return order - from a target exchange_interface. Return order - from a target exchange_interface.
:param user_name: The name of the user making the request. :param user_name: The name of the user making the request.
@ -145,7 +153,7 @@ class ExchangeInterface:
# Return the order. # Return the order.
return exchange.get_order(symbol=symbol, order_id=order_id) return exchange.get_order(symbol=symbol, order_id=order_id)
def get_trade_status(self, trade, user_name): def get_trade_status(self, trade, user_name: str) -> str:
""" """
Return the status of a trade Return the status of a trade
Todo: trade order.status might be outdated this request the status from the exchanges order record. Todo: trade order.status might be outdated this request the status from the exchanges order record.
@ -158,7 +166,7 @@ class ExchangeInterface:
# Return status. # Return status.
return order['status'] return order['status']
def get_trade_executed_qty(self, trade, user_name): def get_trade_executed_qty(self, trade, user_name: str) -> float:
""" """
Return the executed quantity of a trade. Return the executed quantity of a trade.
@ -173,7 +181,7 @@ class ExchangeInterface:
# Return quantity. # Return quantity.
return order['executedQty'] return order['executedQty']
def get_trade_executed_price(self, trade, user_name): def get_trade_executed_price(self, trade, user_name: str) -> float:
""" """
Return the average price of executed quantity of a trade Return the average price of executed quantity of a trade
@ -188,7 +196,7 @@ class ExchangeInterface:
return order['price'] return order['price']
@staticmethod @staticmethod
def get_price(symbol, price_source=None): def get_price(symbol: str, price_source: str = None) -> float:
""" """
:param price_source: alternative sources for price. :param price_source: alternative sources for price.
:param symbol: The symbol of the trading pair. :param symbol: The symbol of the trading pair.

View File

@ -0,0 +1,20 @@
import ccxt
def main():
# Create an instance of the Binance exchange
binance = ccxt.binance({
'enableRateLimit': True,
'verbose': False, # Ensure verbose mode is disabled
})
try:
# Load markets to test the connection
markets = binance.load_markets()
print("Markets loaded successfully")
except ccxt.BaseError as e:
print(f"Error loading markets: {str(e)}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,60 @@
import ccxt
def test_public_endpoints(exchange_id, symbol):
try:
exchange_class = getattr(ccxt, exchange_id)
exchange = exchange_class()
# Check and test fetch_ticker
if exchange.has.get('fetchTicker', False):
exchange.fetch_ticker(symbol)
# Check and test fetch_tickers
if exchange.has.get('fetchTickers', False):
exchange.fetch_tickers([symbol])
# Check and test fetch_order_book
if exchange.has.get('fetchOrderBook', False):
exchange.fetch_order_book(symbol)
# Check and test fetch_trades
if exchange.has.get('fetchTrades', False):
exchange.fetch_trades(symbol)
# Check and test fetch_ohlcv
if exchange.has.get('fetchOHLCV', False):
timeframe = '1m'
since = exchange.parse8601('2022-01-01T00:00:00Z')
exchange.fetch_ohlcv(symbol, timeframe, since)
return True
except Exception as e:
print(f"Exchange {exchange_id} failed: {e}")
return False
# List of all supported exchanges
exchange_ids = ccxt.exchanges
# Symbol to test (using a common symbol that is likely available on many exchanges)
test_symbol = 'BTC/USDT'
# Dictionary to store the exchanges that support public market data endpoints
working_exchanges = []
# Iterate through each exchange and test public endpoints
for exchange_id in exchange_ids:
print(f"Testing exchange: {exchange_id}")
if test_public_endpoints(exchange_id, test_symbol):
working_exchanges.append(exchange_id)
# Print the working exchanges
print("Exchanges with working public market data endpoints:")
for exchange in working_exchanges:
print(exchange)
# Optionally save the results to a file
with open('working_public_exchanges.txt', 'w') as f:
for exchange in working_exchanges:
f.write(f"{exchange}\n")

View File

@ -0,0 +1,63 @@
ascendex
bequant
bigone
binance
binanceus
binanceusdm
bitbay
bitbns
bitcoincom
bitfinex
bitfinex2
bitget
bitmart
bitmex
bitopro
bitrue
bitstamp
bitteam
blockchaincom
btcmarkets
bybit
coinbase
coinbaseadvanced
coinbaseexchange
coinex
coinlist
coinmate
coinsph
cryptocom
currencycom
delta
digifinex
exmo
fmfwio
gemini
hitbtc
hitbtc3
htx
huobi
indodax
kraken
kucoin
latoken
lbank
lykke
mexc
ndax
novadax
oceanex
okx
phemex
poloniex
probit
timex
tokocrypto
tradeogre
upbit
wazirx
whitebit
woo
xt
yobit
zonda

View File

@ -1,35 +1,41 @@
from functools import lru_cache from functools import lru_cache
import datetime as dt import datetime as dt
import pandas as pd from typing import Union
import pandas as pd
epoch = dt.datetime.utcfromtimestamp(0) epoch = dt.datetime.utcfromtimestamp(0)
def query_uptodate(records: pd.DataFrame, r_length: float): def query_uptodate(records: pd.DataFrame, r_length_min: float) -> Union[float, None]:
""" """
Check if records that span a period of time are up-to-date. Check if records that span a period of time are up-to-date.
:param records: - The dataframe holding results from a query. :param records: The dataframe holding results from a query.
:param r_length: - The timespan in minutes of each record in the data. :param r_length_min: The timespan in minutes of each record in the data.
:return: timestamp - The closest timestamp to start_datetime on record. :return: timestamp - None if records are up-to-date otherwise the newest timestamp on record.
""" """
print('\nChecking if the records are up-to-date...') print('\nChecking if the records are up-to-date...')
# Get the newest timestamp from the records passed in stored in ms. # Get the newest timestamp from the records passed in stored in ms
last_timestamp = float(records.open_time.max()) last_timestamp = float(records.open_time.max())
print(f'The last ts on record is {last_timestamp}') print(f'The last timestamp on record is {last_timestamp}')
# Get a timestamp of the UTC time in millisecond to match the records in the DB.
# Get a timestamp of the UTC time in milliseconds to match the records in the DB
now_timestamp = unix_time_millis(dt.datetime.utcnow()) now_timestamp = unix_time_millis(dt.datetime.utcnow())
print(f'The timestamp now is {now_timestamp}') print(f'The timestamp now is {now_timestamp}')
# Get the seconds since the records have been updated.
# Get the seconds since the records have been updated
seconds_since_update = ms_to_seconds(now_timestamp - last_timestamp) seconds_since_update = ms_to_seconds(now_timestamp - last_timestamp)
# Convert to minutes # Convert to minutes
minutes_since_update = seconds_since_update / 60 minutes_since_update = seconds_since_update / 60
print(f'The minutes since last update is {minutes_since_update}') print(f'The minutes since last update is {minutes_since_update}')
print(f'And the length of each record is {r_length}') print(f'And the length of each record is {r_length_min}')
# Return the timestamp if the time since last update is more than the timespan each record covers.
if minutes_since_update > (r_length - 0.3): # Return the timestamp if the time since last update is more than the timespan each record covers
# Return the last timestamp in seconds. tolerance_minutes = 10 / 60 # 10 seconds tolerance in minutes
if minutes_since_update > (r_length_min - tolerance_minutes):
# Return the last timestamp in seconds
return ms_to_seconds(last_timestamp) return ms_to_seconds(last_timestamp)
return None return None
@ -38,82 +44,98 @@ def ms_to_seconds(timestamp):
return timestamp / 1000 return timestamp / 1000
def unix_time_seconds(d_time):
return (d_time - epoch).total_seconds()
def unix_time_millis(d_time): def unix_time_millis(d_time):
return (d_time - epoch).total_seconds() * 1000.0 return (d_time - epoch).total_seconds() * 1000.0
def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length: float): def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length_min: float) -> Union[float, None]:
""" """
Check if a records that spans a set time goes far enough back to satisfy a query. Check if records span back far enough to satisfy a query.
:param start_datetime: - The datetime to be satisfied. This function determines whether the records provided cover the required start_datetime. It calculates
:param records: - The dataframe holding results from a query. the total duration covered by the records and checks if this duration, starting from the earliest record,
:param r_length: - The timespan in minutes of each record in the data. reaches back to include the start_datetime.
:return: timestamp - The closest timestamp to start_datetime on record.
:param start_datetime: The datetime the query starts at.
:param records: The dataframe holding results from a query.
:param r_length_min: The timespan in minutes of each record in the data.
:return: None if the query is satisfied (records span back far enough),
otherwise returns the earliest timestamp in records in seconds.
""" """
# Convert the datetime to a timestamp in milliseconds to match format in the DB. print('\nChecking if the query is satisfied...')
# Convert start_datetime to Unix timestamp in milliseconds
start_timestamp = unix_time_millis(start_datetime) start_timestamp = unix_time_millis(start_datetime)
print('Checking if we went far enough back.') print(f'Start timestamp: {start_timestamp}')
print('Requested: start_timestamp:', start_timestamp)
# Get the oldest timestamp from the records passed in. Convert from str to float. # Get the oldest timestamp from the records passed in
first_timestamp = float(records.open_time.min()) first_timestamp = float(records.open_time.min())
print('Received: first_timestamp:', first_timestamp) print(f'First timestamp in records: {first_timestamp}')
if pd.isna(first_timestamp):
# If there were no records returned. Signal a need for update by returning the current timestamp. # Calculate the total duration of the records in milliseconds
return dt.datetime.utcnow().timestamp() total_duration = len(records) * (r_length_min * 60 * 1000)
# Get the minutes between the first timestamp on record and the one requested. print(f'Total duration of records: {total_duration}')
minutes_between = ms_to_seconds(first_timestamp - start_timestamp) / 60
print('minutes_between:', minutes_between) # Check if the first timestamp plus the total duration is greater than or equal to the start timestamp
# Return the timestamp if the difference is greater than the timespan of a single record. if start_timestamp <= first_timestamp + total_duration:
if minutes_between > r_length: return None
# Return timestamp in seconds.
return ms_to_seconds(first_timestamp) return first_timestamp / 1000 # Return in seconds
return None
@lru_cache(maxsize=500) @lru_cache(maxsize=500)
def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime.timestamp: def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime.timestamp:
""" """
Returns an approximate start_datetime of a candle n minutes ago. Returns the approximate datetime for the start of a candle that was 'n' candles ago.
:param int - n: The number of records ago to calculate. :param n: int - The number of candles ago to calculate.
:param float - candle_length: The time in minutes that each candle represents. :param candle_length: float - The length of each candle in minutes.
:return datetime.start_datetime - The approximate start_datetime slightly less than the record is expected to have. :return: datetime - The approximate datetime for the start of the 'n'-th candle ago.
""" """
# Increment 'n' by 1 to ensure we account for the time that has passed since the last candle closed.
# Time will have passed since the last candle has closed. The start_datetime for nth candle back will be
# less than now-(n * timeframe) but greater than now-(n+1 * timeframe). Add 1 so we don't miss a candle.
n += 1 n += 1
# Calculate the time.
# Calculate the total minutes ago the 'n'-th candle started.
minutes_ago = n * candle_length minutes_ago = n * candle_length
# Get the date and time of n candle_length ago.
date_of = dt.datetime.utcnow() - dt.timedelta(minutes=minutes_ago) # Get the current UTC datetime.
# Return the result as a start_datetime. now = dt.datetime.utcnow()
# Calculate the datetime for 'n' candles ago.
date_of = now - dt.timedelta(minutes=minutes_ago)
# Return the calculated datetime.
return date_of return date_of
@lru_cache(maxsize=20) @lru_cache(maxsize=20)
def timeframe_to_minutes(timeframe): def timeframe_to_minutes(timeframe):
""" """
Converts a string representing a timeframe into candle_length. Converts a string representing a timeframe into an integer representing the approximate minutes.
:param timeframe: str - Timeframe format is [multiplier:focus]. eg '15m', '4h', '1d' :param timeframe: str - Timeframe format is [multiplier:focus]. eg '15m', '4h', '1d'
:return: int - Minutes the timeframe represents ex. '2h'-> 120(candle_length). :return: int - Minutes the timeframe represents ex. '2h'-> 120(minutes).
""" """
# Extract the numerical part of the timeframe param. # Extract the numerical part of the timeframe param.
digits = int("".join([i if i.isdigit() else "" for i in timeframe])) digits = int("".join([i if i.isdigit() else "" for i in timeframe]))
# Extract the alpha part of the timeframe param. # Extract the alpha part of the timeframe param.
letter = "".join([i if i.isalpha() else "" for i in timeframe]) letter = "".join([i if i.isalpha() else "" for i in timeframe])
if letter == 'm': if letter == 'm':
pass pass
elif letter == 'h': elif letter == 'h':
digits *= 60 digits *= 60
elif letter == 'd': elif letter == 'd':
digits *= (60 * 24) digits *= 60 * 24
elif letter == 'w': elif letter == 'w':
digits *= (60 * 24 * 7) digits *= 60 * 24 * 7
elif letter == 'M': elif letter == 'M':
digits *= (60 * 24 * 7 * 31) digits *= 60 * 24 * 31 # Maximum number of days in a month
elif letter == 'Y': elif letter == 'Y':
digits *= (60 * 24 * 7 * 31 * 365) digits *= 60 * 24 * 365 # Exact number of days in a year
return digits return digits

View File

@ -1,23 +0,0 @@
from datetime import datetime, timedelta
from AlpacaPaperExchange import AlpacaPaperExchange
def test_alpaca_paper_exchange():
ape = AlpacaPaperExchange('ape')
# print(f'\nTesting name assignment: {ape.name}')
# print(f'\nTesting get_client(): {ape.get_client()}')
# print(f'\nTesting get_exchange_info(): {ape.get_exchange_info()}')
# print(f'\nTesting get_symbols(): {ape.get_symbols()}')
# print(f'\nTesting get_balances(): {ape.get_balances()}')
# print(f'\nTesting get_avail_intervals(): {ape.get_avail_intervals()}')
last_hour_date_time = datetime.now() - timedelta(hours=24)
k = ape.get_historical_klines(symbol="BTC/USDT", interval="5m",start_str=last_hour_date_time, klines_type=0)
{print(k) for k in k}
# print(f'\nTesting get_symbol_precision_rule(): {ape.get_symbol_precision_rule("BTC/USDT")}')
# print(f'\nTesting get_min_qty(): {ape.get_min_qty("ETH/USDT")}')
# print(f'\nTesting get_min_notional_qty(): {ape.get_min_notional_qty("BTC/USDT")}')
# print(f'\nTesting get_price(): {ape.get_price("ETH/USDT")}')
# ape.get_order()
assert True

View File

@ -1,21 +0,0 @@
from binance.enums import HistoricalKlinesType
from BinanceFutures import BinanceFuturesExchange
def test_binancefuturesexchange():
bfe = BinanceFuturesExchange('bfe')
print(f'\nTesting name assignment: {bfe.name}')
print(f'\nTesting get_client(): {bfe.get_client()}')
# print(f'\nTesting get_exchange_info(): {bfe.get_exchange_info()}')
print(f'\nTesting get_symbols(): {bfe.get_symbols()}')
print(f'\nTesting get_balances(): {bfe.get_balances()}')
print(f'\nTesting get_avail_intervals(): {bfe.get_avail_intervals()}')
print(f'\nTesting get_historical_klines(): {bfe.get_historical_klines(symbol="BTCUSDT",interval="15m",start_str="1 hour ago UTC", klines_type=HistoricalKlinesType.SPOT)}')
print(f'\nTesting get_symbol_precision_rule(): {bfe.get_symbol_precision_rule("ETHUSDT")}')
print(f'\nTesting get_min_qty(): {bfe.get_min_qty("ETHUSDT")}')
print(f'\nTesting get_min_notional_qty(): {bfe.get_min_notional_qty("BTCUSDT")}')
print(f'\nTesting get_price(): {bfe.get_price("BTCUSDT")}')
# bse.get_order()
assert True

View File

@ -1,24 +0,0 @@
from binance.enums import HistoricalKlinesType
from BinanceSpot import BinanceSpotExchange
def test_binance_spot_exchange():
bse = BinanceSpotExchange('bse')
print(f'\nTesting name assignment: {bse.name}')
print(f'Testing get_client(): {bse.get_client()}')
# print(f'Testing get_symbol_info(): {bse.get_exchange_info()}')
print(f'Testing get_symbols(): {bse.get_symbols()}')
print(f'Testing get_balances(): {bse.get_balances()}')
print(f'Testing get_avail_intervals(): {bse.get_avail_intervals()}')
print(f'Testing get_historical_klines(): {bse.get_historical_klines(symbol="ETHUSDT",interval="15m",start_str="1 hour ago UTC",klines_type=HistoricalKlinesType.SPOT)}')
print(f'Testing get_symbol_precision_rule(): {bse.get_symbol_precision_rule("ETHUSDT")}')
print(f'Testing get_min_qty(): {bse.get_min_qty("ETHUSDT")}')
print(f'Testing get_min_notional_qty(): {bse.get_min_notional_qty("ETHUSDT")}')
print(f'Testing get_price(): {bse.get_price("BTCUSDT")}')
# #bse.get_order()

View File

@ -1,9 +1,124 @@
from DataCache import DataCache from DataCache import DataCache
from exchangeinterface import ExchangeInterface from exchangeinterface import ExchangeInterface
import unittest
import pandas as pd
import datetime as dt
def test_cache_exists(): class TestDataCache(unittest.TestCase):
exchanges = ExchangeInterface() def setUp(self):
# This object maintains all the cached data. Pass it connection to the exchanges. # Setup the database connection here
data = DataCache(exchanges) self.exchanges = ExchangeInterface()
assert data.cache_exists(key='BTC/USD_2h_alpaca') is False self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
# This object maintains all the cached data. Pass it connection to the exchanges.
self.data = DataCache(self.exchanges)
asset, timeframe, exchange = 'BTC/USD', '2h', 'binance'
self.key1 = f'{asset}_{timeframe}_{exchange}'
asset, timeframe, exchange = 'ETH/USD', '2h', 'binance'
self.key2 = f'{asset}_{timeframe}_{exchange}'
def test_set_cache(self):
# Tests
print('Testing set_cache flag not set:')
self.data.set_cache(data='data', key=self.key1)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key1], 'data')
self.data.set_cache(data='more_data', key=self.key1)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key1], 'more_data')
print('Testing set_cache no-overwrite flag set:')
self.data.set_cache(data='even_more_data', key=self.key1, do_not_overwrite=True)
attr = self.data.__getattribute__('cached_data')
self.assertEqual(attr[self.key1], 'more_data')
def test_cache_exists(self):
# Tests
print('Testing cache_exists() method:')
self.assertFalse(self.data.cache_exists(key=self.key2))
self.data.set_cache(data='data', key=self.key1)
self.assertTrue(self.data.cache_exists(key=self.key1))
def test_update_candle_cache(self):
print('Testing update_candle_cache() method:')
# Initial data
df_initial = pd.DataFrame({
'open_time': [1, 2, 3],
'open': [100, 101, 102],
'high': [110, 111, 112],
'low': [90, 91, 92],
'close': [105, 106, 107],
'volume': [1000, 1001, 1002]
})
# Data to be added
df_new = pd.DataFrame({
'open_time': [3, 4, 5],
'open': [102, 103, 104],
'high': [112, 113, 114],
'low': [92, 93, 94],
'close': [107, 108, 109],
'volume': [1002, 1003, 1004]
})
self.data.set_cache(data=df_initial, key=self.key1)
self.data.update_candle_cache(more_records=df_new, key=self.key1)
result = self.data.get_cache(key=self.key1)
expected = pd.DataFrame({
'open_time': [1, 2, 3, 4, 5],
'open': [100, 101, 102, 103, 104],
'high': [110, 111, 112, 113, 114],
'low': [90, 91, 92, 93, 94],
'close': [105, 106, 107, 108, 109],
'volume': [1000, 1001, 1002, 1003, 1004]
})
pd.testing.assert_frame_equal(result, expected)
def test_update_cached_dict(self):
print('Testing update_cached_dict() method:')
self.data.set_cache(data={}, key=self.key1)
self.data.update_cached_dict(cache_key=self.key1, dict_key='sub_key', data='value')
cache = self.data.get_cache(key=self.key1)
self.assertEqual(cache['sub_key'], 'value')
def test_get_cache(self):
print('Testing get_cache() method:')
self.data.set_cache(data='data', key=self.key1)
result = self.data.get_cache(key=self.key1)
self.assertEqual(result, 'data')
def test_get_records_since(self):
print('Testing get_records_since() method:')
df_initial = pd.DataFrame({
'open_time': [1, 2, 3],
'open': [100, 101, 102],
'high': [110, 111, 112],
'low': [90, 91, 92],
'close': [105, 106, 107],
'volume': [1000, 1001, 1002]
})
self.data.set_cache(data=df_initial, key=self.key1)
start_datetime = dt.datetime.utcfromtimestamp(2)
result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60, ex_details=[]).sort_values(by='open_time').reset_index(drop=True)
expected = pd.DataFrame({
'open_time': [2, 3],
'open': [101, 102],
'high': [111, 112],
'low': [91, 92],
'close': [106, 107],
'volume': [1001, 1002]
})
pd.testing.assert_frame_equal(result, expected)
if __name__ == '__main__':
unittest.main()

View File

@ -1,28 +1,73 @@
def test_index(): import unittest
from BrighterTrades import BrighterTrades from flask import Flask
obj = BrighterTrades() from src.app import app
assert True import json
def test_ws(): class FlaskAppTests(unittest.TestCase):
assert False def setUp(self):
"""
Set up the test client and any other test configuration.
"""
self.app = app.test_client()
self.app.testing = True
def test_index(self):
"""
Test the index route.
"""
response = self.app.get('/')
self.assertEqual(response.status_code, 200)
self.assertIn(b'Welcome', response.data) # Adjust this based on your actual landing page content
def test_login(self):
"""
Test the login route with valid and invalid credentials.
"""
# Valid credentials
valid_data = {'user_name': 'test_user', 'password': 'test_password'}
response = self.app.post('/login', data=valid_data)
self.assertEqual(response.status_code, 302) # Redirects on success
# Invalid credentials
invalid_data = {'user_name': 'wrong_user', 'password': 'wrong_password'}
response = self.app.post('/login', data=invalid_data)
self.assertEqual(response.status_code, 302) # Redirects on failure
self.assertIn(b'Invalid user_name or password', response.data)
def test_signup(self):
"""
Test the signup route.
"""
data = {'email': 'test@example.com', 'user_name': 'new_user', 'password': 'new_password'}
response = self.app.post('/signup_submit', data=data)
self.assertEqual(response.status_code, 302) # Redirects on success
def test_signout(self):
"""
Test the signout route.
"""
response = self.app.get('/signout')
self.assertEqual(response.status_code, 302) # Redirects on signout
def test_history(self):
"""
Test the history route.
"""
data = {"user_name": "test_user"}
response = self.app.post('/api/history', data=json.dumps(data), content_type='application/json')
self.assertEqual(response.status_code, 200)
self.assertIn(b'price_history', response.data)
def test_indicator_init(self):
"""
Test the indicator initialization route.
"""
data = {"user_name": "test_user"}
response = self.app.post('/api/indicator_init', data=json.dumps(data), content_type='application/json')
self.assertEqual(response.status_code, 200)
self.assertIn(b'indicator_data', response.data)
def test_settings(): if __name__ == '__main__':
assert False unittest.main()
def test_history():
assert False
def test_signup():
assert False
def test_signup_submit():
assert False
def test_indicator_init():
assert False

View File

@ -0,0 +1,162 @@
import time
import unittest
import datetime as dt
import pandas as pd
from shared_utilities import (
query_uptodate, ms_to_seconds, unix_time_seconds, unix_time_millis,
query_satisfied, ts_of_n_minutes_ago, timeframe_to_minutes
)
class TestSharedUtilities(unittest.TestCase):
def test_query_uptodate(self):
print('Testing query_uptodate()')
# (Test case 1) The records should not be up-to-date (very old timestamps)
records = pd.DataFrame({
'open_time': [1, 2, 3, 4, 5]
})
result = query_uptodate(records, 1)
if result is None:
print('Records are up-to-date.')
else:
print('Records are not up-to-date.')
print(f'Result for the first test case: {result}')
self.assertIsNotNone(result)
# (Test case 2) The records should be up-to-date (recent timestamps)
now = unix_time_millis(dt.datetime.utcnow())
recent_records = pd.DataFrame({
'open_time': [now - 70000, now - 60000, now - 40000]
})
result = query_uptodate(recent_records, 1)
if result is None:
print('Records are up-to-date.')
else:
print('Records are not up-to-date.')
print(f'Result for the second test case: {result}')
self.assertIsNone(result)
# (Test case 3) The records just under the tolerance for a record length of 1 hour.
# The records should not be up-to-date (recent timestamps)
one_hour = 60 * 60 * 1000 # one hour in milliseconds
tolerance_milliseconds = 10 * 1000 # tolerance in milliseconds
recent_time = unix_time_millis(dt.datetime.utcnow())
borderline_records = pd.DataFrame({
'open_time': [recent_time - one_hour + (tolerance_milliseconds - 3)] # just within the tolerance
})
result = query_uptodate(borderline_records, 60)
if result is None:
print('Records are up-to-date.')
else:
print('Records are not up-to-date.')
print(f'Result for the third test case: {result}')
self.assertIsNotNone(result)
# (Test case 4) The records just over the tolerance for a record length of 1 hour.
# The records should be up-to-date (recent timestamps)
one_hour = 60 * 60 * 1000 # one hour in milliseconds
tolerance_milliseconds = 10 * 1000 # tolerance in milliseconds
recent_time = unix_time_millis(dt.datetime.utcnow())
borderline_records = pd.DataFrame({
'open_time': [recent_time - one_hour + (tolerance_milliseconds + 3)] # just within the tolerance
})
result = query_uptodate(borderline_records, 60)
if result is None:
print('Records are up-to-date.')
else:
print('Records are not up-to-date.')
print(f'Result for the third test case: {result}')
self.assertIsNone(result)
def test_ms_to_seconds(self):
print('Testing ms_to_seconds()')
self.assertEqual(ms_to_seconds(1000), 1)
self.assertEqual(ms_to_seconds(0), 0)
def test_unix_time_seconds(self):
print('Testing unix_time_seconds()')
time = dt.datetime(2020, 1, 1)
self.assertEqual(unix_time_seconds(time), 1577836800)
def test_unix_time_millis(self):
print('Testing unix_time_millis()')
time = dt.datetime(2020, 1, 1)
self.assertEqual(unix_time_millis(time), 1577836800000.0)
def test_query_satisfied(self):
print('Testing query_satisfied()')
# Test case where the records should satisfy the query (records cover the start time)
start_datetime = dt.datetime(2020, 1, 1)
records = pd.DataFrame({
'open_time': [1577836800000.0 - 600000, 1577836800000.0 - 300000, 1577836800000.0]
# Covering the start time
})
result = query_satisfied(start_datetime, records, 1)
if result is None:
print('Query is satisfied: records span back far enough.')
else:
print('Query is not satisfied: records do not span back far enough.')
print(f'Result for the first test case: {result}')
self.assertIsNotNone(result)
# Test case where the records should not satisfy the query (recent records but not enough)
recent_time = unix_time_millis(dt.datetime.utcnow())
records = pd.DataFrame({
'open_time': [recent_time - 300 * 60 * 1000, recent_time - 240 * 60 * 1000, recent_time - 180 * 60 * 1000]
})
result = query_satisfied(start_datetime, records, 1)
if result is None:
print('Query is satisfied: records span back far enough.')
else:
print('Query is not satisfied: records do not span back far enough.')
print(f'Result for the second test case: {result}')
self.assertIsNone(result)
# Additional test case for partial coverage
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=300)
records = pd.DataFrame({
'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=240)),
unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=180)),
unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=120))]
})
result = query_satisfied(start_datetime, records, 60)
if result is None:
print('Query is satisfied: records span back far enough.')
else:
print('Query is not satisfied: records do not span back far enough.')
print(f'Result for the third test case: {result}')
self.assertIsNone(result)
def test_ts_of_n_minutes_ago(self):
print('Testing ts_of_n_minutes_ago()')
now = dt.datetime.utcnow()
test_cases = [
(60, 1), # 60 candles of 1 minute each
(10, 5), # 10 candles of 5 minutes each
(1, 1440), # 1 candle of 1 day (1440 minutes)
(7, 10080), # 7 candles of 1 week (10080 minutes)
(30, 60), # 30 candles of 1 hour (60 minutes)
]
for n, candle_length in test_cases:
with self.subTest(n=n, candle_length=candle_length):
result = ts_of_n_minutes_ago(n, candle_length)
expected_time = now - dt.timedelta(minutes=(n + 1) * candle_length)
self.assertAlmostEqual(unix_time_seconds(result), unix_time_seconds(expected_time), delta=60)
def test_timeframe_to_minutes(self):
print('Testing timeframe_to_minutes()')
self.assertEqual(timeframe_to_minutes('15m'), 15)
self.assertEqual(timeframe_to_minutes('1h'), 60)
self.assertEqual(timeframe_to_minutes('1d'), 1440)
self.assertEqual(timeframe_to_minutes('1w'), 10080)
self.assertEqual(timeframe_to_minutes('1M'), 44640) # 31 days in a month
self.assertEqual(timeframe_to_minutes('1Y'), 525600) # 525600 minutes in a year
if __name__ == '__main__':
unittest.main()