Fixed query_uptodate function and added comprehensive test cases
This commit is contained in:
parent
0bb9780ff6
commit
0b1ad39476
|
|
@ -56,16 +56,19 @@ legend top
|
|||
=<b>/index</b>
|
||||
————————————————————————————————
|
||||
Handles the main landing page, logs in the user,
|
||||
ensures a valid connection with an exchange,
|
||||
loads dynamic data, and renders the landing page.
|
||||
end legend
|
||||
start
|
||||
:Log in user;
|
||||
:Ensure a valid connection with an exchange;
|
||||
:Load dynamic data;
|
||||
:Render Landing Page;
|
||||
stop
|
||||
@enduml
|
||||
|
||||
|
||||
|
||||
```
|
||||
|
||||
### /ws
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
numpy==1.24.3
|
||||
flask==2.3.2
|
||||
flask_cors==3.0.10
|
||||
flask_sock==0.7.0
|
||||
config~=0.5.1
|
||||
PyYAML~=6.0
|
||||
binance~=0.3
|
||||
requests==2.30.0
|
||||
pandas==2.0.1
|
||||
passlib~=1.7.4
|
||||
SQLAlchemy==2.0.13
|
||||
ccxt~=3.0.105
|
||||
ccxt==4.3.65
|
||||
email-validator~=2.2.0
|
||||
TA-Lib~=0.4.32
|
||||
bcrypt~=4.2.0
|
||||
|
|
|
|||
|
|
@ -1,271 +0,0 @@
|
|||
import json
|
||||
import math
|
||||
import time
|
||||
from datetime import datetime, date, time, timedelta, timezone
|
||||
|
||||
import pandas
|
||||
import pandas as pd
|
||||
from alpaca.trading import GetAssetsRequest, AssetClass, MarketOrderRequest, OrderSide, TimeInForce
|
||||
|
||||
from Exchange import Exchange
|
||||
import config
|
||||
from alpaca.trading.client import TradingClient
|
||||
from alpaca.data.historical import CryptoHistoricalDataClient
|
||||
from alpaca.data.requests import CryptoBarsRequest, CryptoLatestQuoteRequest
|
||||
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
|
||||
import alpaca_trade_api as tradeapi
|
||||
|
||||
|
||||
class AlpacaPaperExchange(Exchange):
|
||||
def __init__(self, name, api_keys):
|
||||
super().__init__(name, api_keys=api_keys)
|
||||
|
||||
def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt=None):
|
||||
"""
|
||||
Return a dataframe containing rows of candle attributes.
|
||||
[(symbol, start_datetime, open, high, low,
|
||||
close, Volume, trade_count, vwap)]
|
||||
|
||||
:param symbol: str: - Trading symbol.
|
||||
:param interval: str: - Timeframe for the candle sticks.
|
||||
:param start_dt: - Start datetime of data stream.
|
||||
:param end_dt: - End datetime of data stream.
|
||||
:return: pd.dataframe - The requested data.
|
||||
"""
|
||||
|
||||
if end_dt is None:
|
||||
end_dt = datetime.utcnow()
|
||||
|
||||
# Extract the numerical part of the timeframe param.
|
||||
digits = int("".join([i if i.isdigit() else "" for i in interval]))
|
||||
# Extract the alpha part of the timeframe param.
|
||||
letter = "".join([i if i.isalpha() else "" for i in interval])
|
||||
# Convert the abbreviation to the name of a TimeFrameUnit object attribute.
|
||||
time_frame = "".join(['Minute' if i == 'm' else
|
||||
'Hour' if i == 'h' else
|
||||
'Day' if i == 'd' else
|
||||
'Week' if i == 'w' else
|
||||
'Month' if i == 'M' else 'Year' for i in letter])
|
||||
# Fill in the request parameters.
|
||||
request_params = CryptoBarsRequest(
|
||||
symbol_or_symbols=symbol,
|
||||
timeframe=TimeFrame(amount=digits, unit=TimeFrameUnit[time_frame]),
|
||||
start=start_dt,
|
||||
end=end_dt
|
||||
)
|
||||
if start_dt > end_dt:
|
||||
raise ValueError(start_dt, end_dt)
|
||||
# Request the bars from the exchange_interface.
|
||||
bars = self.client['historical_data'].get_crypto_bars(request_params)
|
||||
# Convert the response object into a list of dictionaries.
|
||||
candles_dict = [c.__dict__ for c in bars[symbol]]
|
||||
# Convert the data into a dataframe.
|
||||
candles = pd.DataFrame(candles_dict)
|
||||
# Reformat this dataframe to make it consistent with other exchanges.
|
||||
del candles['symbol']
|
||||
candles.rename({'timestamp': 'open_time'}, axis='columns', inplace=True)
|
||||
# Convert the date column to a start_datetime to make it consistent with the other exchanges.
|
||||
candles['open_time'] = candles.open_time.values.astype('int64') // 10 ** 6
|
||||
# Return the data.
|
||||
return candles
|
||||
|
||||
def _fetch_price(self, symbol) -> float:
|
||||
"""
|
||||
Get the latest price for a single symbol.
|
||||
|
||||
:param symbol: str - The symbol of the symbol.
|
||||
:return: float - The minimum quantity sold per trade.
|
||||
"""
|
||||
request_params = CryptoLatestQuoteRequest(symbol_or_symbols=symbol)
|
||||
lq = self.client['historical_data'].get_crypto_latest_quote(request_params)
|
||||
lqd = lq[symbol]
|
||||
return lqd.bid_price
|
||||
|
||||
def _fetch_min_qty(self, symbol) -> float:
|
||||
"""
|
||||
Get the minimum base quantity the exchange_interface allows per order.
|
||||
|
||||
:param symbol: str - The symbol of the trading pair.
|
||||
:return: float - The minimum quantity sold per trade.
|
||||
"""
|
||||
assets = self.client['trading'].get_asset(symbol)
|
||||
return assets.min_order_size
|
||||
|
||||
def _fetch_min_notional_qty(self, symbol) -> float:
|
||||
"""
|
||||
Get the minimum amount of the quote currency allowed per trade.
|
||||
|
||||
:param symbol: The symbol of the trading pair.
|
||||
:return: The minimum quantity sold per trade.
|
||||
"""
|
||||
assets = self.client['trading'].get_asset(symbol)
|
||||
return assets.price_increment
|
||||
|
||||
def _fetch_order(self, symbol, order_id) -> object:
|
||||
"""Return the order status for a specific order."""
|
||||
return self.client.query_order(symbol=symbol, orderId=order_id)
|
||||
|
||||
def _connect_exchange(self) -> object:
|
||||
"""Connects to the exchange_interface and sets a reference the client."""
|
||||
client = {
|
||||
'historical_data': CryptoHistoricalDataClient(),
|
||||
'trading': TradingClient(self.api_key, self.api_key_secret, paper=True),
|
||||
'trade_api': tradeapi.REST(self.api_key, self.api_key_secret, "https://paper-api.alpaca.markets/")
|
||||
}
|
||||
|
||||
return client
|
||||
|
||||
def _set_exchange_info(self) -> dict:
|
||||
"""Fetches Info on all symbols from the exchange_interface."""
|
||||
return self.client['trading'].get_account()
|
||||
|
||||
def _set_symbols(self) -> list:
|
||||
"""Get information of coins (available for deposit and withdraw) for user"""
|
||||
params = GetAssetsRequest(asset_class=AssetClass.CRYPTO)
|
||||
# read all the symbol data into a dataframe.
|
||||
assets = pandas.DataFrame.from_records(self.client['trading'].get_all_assets(params))
|
||||
# This parses things weird, so I did some format manipulation.
|
||||
# Extract the proper headers
|
||||
headers = assets.head(1).applymap(lambda row: row[0]).values[0]
|
||||
# Rename the headers that were assigned.
|
||||
assets = assets.set_axis(headers, axis=1)
|
||||
# remove the headers from the data.
|
||||
assets = assets.applymap(lambda row: row[1])
|
||||
# remove rows that aren't tradable.
|
||||
assets = assets.query("tradable != False")
|
||||
# Return a list of just the symbols.
|
||||
return assets.symbol.values
|
||||
|
||||
def _set_balances(self) -> list:
|
||||
"""
|
||||
Retrieves the account balances, including cash balance and current P&L (Profit & Loss).
|
||||
|
||||
:return: list - A list of dictionaries containing the asset balances and P&L information.
|
||||
Each dictionary has the following keys:
|
||||
- 'asset': The asset symbol (e.g., 'USD', 'BTC', 'ETH').
|
||||
- 'balance': The balance of the asset.
|
||||
- 'pnl': The current P&L (Profit & Loss) of the account.
|
||||
"""
|
||||
account = self.client['trade_api'].get_account()
|
||||
cash_balance = account.cash
|
||||
portfolio_value = float(account.equity)
|
||||
buying_power = float(account.buying_power)
|
||||
|
||||
# Calculate P&L as the difference between portfolio value and buying power,
|
||||
# or set to 0 if no positions or open orders
|
||||
positions = self.client['trade_api'].list_positions()
|
||||
open_orders = self.client['trade_api'].list_orders(status='open')
|
||||
|
||||
if len(positions) == 0 and len(open_orders) == 0:
|
||||
current_pl = 0.0
|
||||
else:
|
||||
current_pl = portfolio_value - buying_power
|
||||
|
||||
non_zero_assets = [
|
||||
{
|
||||
'asset': 'USD', # Assuming the cash balance is in USD
|
||||
'balance': cash_balance,
|
||||
'pnl': current_pl
|
||||
}
|
||||
]
|
||||
return non_zero_assets
|
||||
|
||||
def get_asset_balances(self):
|
||||
""" Get the individual asset balances"""
|
||||
# Get the account information
|
||||
account = self.client['trading'].get_account()
|
||||
|
||||
# Retrieve the balances of each non-zero asset
|
||||
balances = []
|
||||
for asset in account.assets:
|
||||
if float(asset.qty) > 0:
|
||||
balance = {
|
||||
'asset': asset.symbol,
|
||||
'balance': asset.qty
|
||||
}
|
||||
balances.append(balance)
|
||||
|
||||
return balances
|
||||
|
||||
def get_active_trades(self):
|
||||
formatted_trades = []
|
||||
# Get open trades
|
||||
positions = self.client['trade_api'].list_positions()
|
||||
for position in positions:
|
||||
if float(position.qty) != 0.0:
|
||||
active_trade = {
|
||||
'symbol': position.symbol,
|
||||
'side': 'buy' if float(position.qty) > 0 else 'sell',
|
||||
'quantity': abs(float(position.qty)),
|
||||
'price': float(position.avg_entry_price)
|
||||
}
|
||||
formatted_trades.append(active_trade)
|
||||
return formatted_trades
|
||||
|
||||
def get_open_orders(self):
|
||||
formatted_orders = []
|
||||
# Get open orders
|
||||
open_orders = self.client['trade_api'].list_orders(status='open')
|
||||
for order in open_orders:
|
||||
open_order = {
|
||||
'symbol': order.symbol,
|
||||
'side': order.side,
|
||||
'quantity': order.qty,
|
||||
'price': order.filled_avg_price
|
||||
}
|
||||
formatted_orders.append(open_order)
|
||||
return formatted_orders
|
||||
|
||||
def _set_precision_rule(self, symbol) -> None:
|
||||
"""Appends a record of places after the decimal required by the exchange_interface indexed by symbol."""
|
||||
|
||||
# I don't think this exchange_interface has precision rule.
|
||||
# return the number of decimal places of min_trade_increment
|
||||
assets = self.client['trading'].get_asset(symbol)
|
||||
min_trade_rule = assets.min_trade_increment
|
||||
precision = len(str(min_trade_rule).split('.')[1])
|
||||
self.symbols_n_precision[symbol] = precision
|
||||
return
|
||||
|
||||
def _place_order(self, symbol, side, type, timeInForce, quantity, price):
|
||||
|
||||
def format_arg(value, arg_name):
|
||||
"""
|
||||
Function to convert an arguments value to a desired precision and type float.
|
||||
|
||||
:param value: The Value of the argument.
|
||||
:param arg_name: The name of the argument.
|
||||
:return: float : The formatted output.
|
||||
"""
|
||||
# The required level of precision for this trading pair.
|
||||
precision = self.get_symbol_precision_rule(symbol)
|
||||
# If quantity was too low, set to the smallest allowable amount.
|
||||
minimum = 0
|
||||
if arg_name == 'quantity':
|
||||
# The minimum quantity aloud to be traded.
|
||||
minimum = self.get_min_qty(symbol)
|
||||
elif arg_name == 'price':
|
||||
# The minimum price aloud to be traded.
|
||||
minimum = self.get_min_notional_qty(symbol)
|
||||
if value < minimum:
|
||||
value = minimum
|
||||
return float(f"{value:.{precision}f}")
|
||||
|
||||
# Set price and quantity to desired precision and type.
|
||||
quantity = format_arg(quantity, 'quantity')
|
||||
price = format_arg(price, 'price')
|
||||
|
||||
if side == 'buy':
|
||||
side = OrderSide.BUY
|
||||
else:
|
||||
side = OrderSide.SELL
|
||||
market_order_data = MarketOrderRequest(
|
||||
symbol=symbol,
|
||||
qty=quantity,
|
||||
side=side,
|
||||
time_in_force=TimeInForce[timeInForce]
|
||||
|
||||
)
|
||||
market_order = self.client['trading'].submit_order(market_order_data)
|
||||
result = 'Success'
|
||||
return result, market_order
|
||||
|
|
@ -1,218 +0,0 @@
|
|||
import pandas as pd
|
||||
from Exchange import Exchange
|
||||
import ccxt
|
||||
|
||||
|
||||
class BinanceFuturesExchange(Exchange):
|
||||
def __init__(self, name, api_keys):
|
||||
super().__init__(name, api_keys=api_keys)
|
||||
|
||||
import ccxt
|
||||
|
||||
def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt=None):
|
||||
"""
|
||||
Return a list [(Open time, Open, High, Low, Close,
|
||||
Volume, Close time, Quote symbol volume,
|
||||
Number of trades, Taker buy base symbol volume,
|
||||
Taker buy quote symbol volume)]
|
||||
|
||||
:param symbol: str: Trading symbol.
|
||||
:param interval: str: Timeframe for the candlesticks.
|
||||
:param start_dt: start datetime of data stream.
|
||||
:param end_dt: end datetime of data stream.
|
||||
:return: pandas DataFrame containing the candlestick data.
|
||||
"""
|
||||
|
||||
# Convert the start date to a Unix timestamp in seconds.
|
||||
start_str = int(start_dt.timestamp())
|
||||
end_str = None
|
||||
if end_dt is not None:
|
||||
end_str = int(end_dt.timestamp())
|
||||
candles = self.client.fetch_ohlcv(symbol=symbol, timeframe=interval, since=start_str, limit=None, end=end_str)
|
||||
# Create and return a pandas DataFrame.
|
||||
df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume']
|
||||
candles = pd.DataFrame(candles, columns=df_columns)
|
||||
# Return the data
|
||||
return candles
|
||||
|
||||
def _fetch_price(self, symbol) -> float:
|
||||
"""
|
||||
Get the latest price for a single symbol.
|
||||
|
||||
:param symbol: The symbol of the trading pair.
|
||||
:return: The latest price for the symbol.
|
||||
"""
|
||||
ticker = self.client.fetch_ticker(symbol)
|
||||
return float(ticker['last'])
|
||||
|
||||
def _fetch_min_qty(self, symbol) -> float:
|
||||
"""
|
||||
Get the minimum base quantity the exchange allows per order.
|
||||
|
||||
:param symbol: The symbol of the trading pair.
|
||||
:return: The minimum quantity sold per trade.
|
||||
"""
|
||||
exchange_info = self.client.fetch_tickers()
|
||||
symbol_info = exchange_info.get(symbol)
|
||||
if symbol_info is None:
|
||||
raise ValueError(f'No symbol: {symbol}')
|
||||
# Extract the minimum quantity value from the symbol info
|
||||
min_qty = symbol_info.get('info').get('filters')[2].get('minQty')
|
||||
return float(min_qty)
|
||||
|
||||
def _fetch_min_notional_qty(self, symbol) -> float:
|
||||
"""
|
||||
Get the minimum amount of the quote currency allowed per trade.
|
||||
|
||||
:param symbol: The symbol of the trading pair.
|
||||
:return: The minimum quantity sold per trade.
|
||||
"""
|
||||
exchange_info = self.client.fetch_tickers()
|
||||
symbol_info = exchange_info.get(symbol)
|
||||
if symbol_info is None:
|
||||
raise ValueError(f'No symbol: {symbol}')
|
||||
# Extract the minimum notional value from the symbol info
|
||||
min_notional = symbol_info.get('info').get('filters')[0].get('minNotional')
|
||||
return float(min_notional)
|
||||
|
||||
def _fetch_order(self, symbol, order_id) -> object:
|
||||
"""Return the order status for a specific order."""
|
||||
order = self.client.fetch_order(id=order_id, symbol=symbol)
|
||||
return order
|
||||
|
||||
def _connect_exchange(self) -> object:
|
||||
"""Connects to the exchange_interface and sets a reference to the client."""
|
||||
exchange = ccxt.binance({
|
||||
'apiKey': self.api_key,
|
||||
'secret': self.api_key_secret,
|
||||
'enableRateLimit': True,
|
||||
'options': {
|
||||
'defaultType': 'future',
|
||||
}
|
||||
})
|
||||
return exchange
|
||||
|
||||
def _set_exchange_info(self) -> dict:
|
||||
"""Fetches Info on all symbols from the exchange."""
|
||||
exchange_info = self.client.load_markets()
|
||||
return exchange_info
|
||||
|
||||
def _set_symbols(self) -> list:
|
||||
"""Get information of coins (available for deposit and withdraw) for user"""
|
||||
markets = self.client.load_markets()
|
||||
|
||||
# Extract symbol from each market
|
||||
symbols = list(markets.keys())
|
||||
|
||||
return symbols
|
||||
|
||||
def _set_balances(self) -> list:
|
||||
acc_info = self.client.fetch_balance()
|
||||
non_zero_assets = []
|
||||
for asset, balance in acc_info['total'].items():
|
||||
if float(balance) > 0:
|
||||
non_zero_asset = {
|
||||
'asset': asset,
|
||||
'balance': format(balance, '.8f'),
|
||||
'pnl': acc_info['info']['totalUnrealizedProfit']
|
||||
}
|
||||
non_zero_assets.append(non_zero_asset)
|
||||
return non_zero_assets
|
||||
|
||||
def _set_precision_rule(self, symbol) -> None:
|
||||
# Isolate a list of symbol rules from the exchange_info
|
||||
all_rules = self.exchange_info['symbols']
|
||||
# Get the rules for a specific symbol.
|
||||
symbol_rules = next(filter(lambda rules: (rules['pair'] == symbol), all_rules), None)
|
||||
# Add the precision rule to a local index.
|
||||
self.symbols_n_precision[symbol] = symbol_rules['baseAssetPrecision']
|
||||
return
|
||||
|
||||
def _place_order(self, symbol, side, type, timeInForce, quantity, price):
|
||||
def format_arg(value, arg_name):
|
||||
"""
|
||||
Function to convert an arguments value to a desired precision and type float.
|
||||
|
||||
:param value: The Value of the argument.
|
||||
:param arg_name: The name of the argument.
|
||||
:return: float : The formatted output.
|
||||
"""
|
||||
# The required level of precision for this trading pair.
|
||||
precision = self.client.precision_for_symbol(symbol)
|
||||
# If quantity was too low, set to the smallest allowable amount.
|
||||
minimum = 0
|
||||
if arg_name == 'quantity':
|
||||
# The minimum quantity allowed to be traded.
|
||||
minimum = self.client.amount_to_precision(symbol,
|
||||
self.client.markets[symbol]['limits']['amount']['min'])
|
||||
elif arg_name == 'price':
|
||||
# The minimum price allowed to be traded.
|
||||
minimum = self.client.price_to_precision(symbol, self.client.markets[symbol]['limits']['price']['min'])
|
||||
if value < minimum:
|
||||
value = minimum
|
||||
return float(f"{value:.{precision}f}")
|
||||
|
||||
# Set price and quantity to desired precision and type.
|
||||
quantity = format_arg(quantity, 'quantity')
|
||||
price = format_arg(price, 'price')
|
||||
|
||||
order_params = {
|
||||
'symbol': symbol,
|
||||
'side': side,
|
||||
'type': type,
|
||||
'timeInForce': timeInForce,
|
||||
'quantity': quantity,
|
||||
'price': price
|
||||
}
|
||||
data = self.client.create_order(**order_params)
|
||||
result = 'Success'
|
||||
return result, data
|
||||
|
||||
def get_active_trades(self):
|
||||
formatted_trades = []
|
||||
# Get open positions
|
||||
positions = self.client.fetch_positions()
|
||||
for position in positions:
|
||||
if float(position['contracts']) != 0.0:
|
||||
active_trade = {
|
||||
'symbol': position['symbol'],
|
||||
'side': 'buy' if float(position['contracts']) > 0 else 'sell',
|
||||
'quantity': abs(float(position['contracts'])),
|
||||
'price': float(position['entryPrice'])
|
||||
}
|
||||
formatted_trades.append(active_trade)
|
||||
return formatted_trades
|
||||
|
||||
def get_open_orders(self):
|
||||
formatted_orders = []
|
||||
# Get open orders
|
||||
# open_orders = self.client.fetch_open_orders()
|
||||
# Todo: fetching open orders without specifying a symbol is rate-limited to one call per 1252 seconds.
|
||||
open_orders = {}
|
||||
for order in open_orders:
|
||||
open_order = {
|
||||
'symbol': order['symbol'],
|
||||
'side': order['side'],
|
||||
'quantity': order['amount'],
|
||||
'price': order['price']
|
||||
}
|
||||
formatted_orders.append(open_order)
|
||||
return formatted_orders
|
||||
|
||||
|
||||
class BinanceCoinExchange(BinanceFuturesExchange):
|
||||
def __init__(self, name, api_keys):
|
||||
super().__init__(name, api_keys=api_keys)
|
||||
|
||||
def _connect_exchange(self) -> object:
|
||||
"""Connects to the exchange_interface and sets a reference to the client."""
|
||||
return ccxt.binance({
|
||||
'apiKey': self.api_key,
|
||||
'secret': self.api_key_secret,
|
||||
'enableRateLimit': True,
|
||||
# Additional exchange-specific options can be added here
|
||||
'options': {
|
||||
'defaultType': 'coinfuture', # Connect to the Binance Coin Futures API
|
||||
}
|
||||
})
|
||||
|
||||
|
|
@ -1,199 +0,0 @@
|
|||
import pandas as pd
|
||||
from binance.enums import HistoricalKlinesType
|
||||
from shared_utilities import unix_time_millis
|
||||
|
||||
import config
|
||||
|
||||
from Exchange import Exchange
|
||||
from binance.client import Client as BinanceClient
|
||||
from binance.exceptions import BinanceAPIException
|
||||
|
||||
|
||||
class BinanceSpotExchange(Exchange):
|
||||
def __init__(self, name, api_keys):
|
||||
super().__init__(name, api_keys)
|
||||
|
||||
def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt=None):
|
||||
"""
|
||||
Return a dataframe containing rows of candle attributes.
|
||||
[(Open time, Open, High, Low, Close,
|
||||
Volume, Close time, Quote symbol volume,
|
||||
Number of trades, Taker buy base symbol volume,
|
||||
Taker buy quote symbol volume, Ignore)]
|
||||
|
||||
:param symbol: str: Trading symbol.
|
||||
:param interval: str: Timeframe for the candle sticks.
|
||||
:param start_dt: start time-date of data stream.
|
||||
:param end_dt: end datetime of data stream.
|
||||
:return: pd.dataframe
|
||||
"""
|
||||
# Set this to Spot.
|
||||
klines_type = HistoricalKlinesType.SPOT
|
||||
# Convert the start date to a Unix start_datetime milliseconds.
|
||||
start_str = int(unix_time_millis(start_dt))
|
||||
end_str = end_dt
|
||||
if end_dt is not None:
|
||||
end_str = int(unix_time_millis(end_dt))
|
||||
|
||||
candles = self.client.get_historical_klines(symbol=symbol, interval=interval, start_str=start_str,
|
||||
end_str=end_str, klines_type=klines_type)
|
||||
# Create a pandas DataFrame.
|
||||
df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume', 'close_time',
|
||||
'quote_volume', 'num_trades', 'taker_buy_base_volume', 'taker_buy_quote_volume', 'ignore']
|
||||
candles = pd.DataFrame(candles, columns=df_columns)
|
||||
# Delete this useless column
|
||||
del candles['ignore']
|
||||
# Return the data
|
||||
return candles
|
||||
|
||||
def _fetch_price(self, symbol) -> float:
|
||||
"""
|
||||
Get the latest price for a single symbol.
|
||||
|
||||
:param symbol: str - The symbol of the symbol.
|
||||
:return: float - The minimum quantity sold per trade.
|
||||
"""
|
||||
record = self.client.get_symbol_ticker(**{'symbol': symbol})
|
||||
return float(record['price'])
|
||||
|
||||
def _fetch_min_qty(self, symbol) -> float:
|
||||
"""
|
||||
Get the minimum base quantity the exchange_interface allows per order.
|
||||
|
||||
:param symbol: str - The symbol of the trading pair.
|
||||
:return: float - The minimum quantity sold per trade.
|
||||
"""
|
||||
info = self.client.get_symbol_info(symbol=symbol)
|
||||
for f_index, item in enumerate(info['filters']):
|
||||
if item['filterType'] == 'LOT_SIZE':
|
||||
break
|
||||
return float(info['filters'][f_index]['minQty']) # 'minQty'
|
||||
|
||||
def _fetch_min_notional_qty(self, symbol) -> float:
|
||||
"""
|
||||
Get the minimum amount of the quote currency allowed per trade.
|
||||
|
||||
:param symbol: The symbol of the trading pair.
|
||||
:return: The minimum quantity sold per trade.
|
||||
"""
|
||||
info = self.client.get_symbol_info(symbol=symbol)
|
||||
for f_index, item in enumerate(info['filters']):
|
||||
if item['filterType'] == 'MIN_NOTIONAL':
|
||||
break
|
||||
return float(info['filters'][f_index]['minNotional']) # 'minNotional'
|
||||
|
||||
def _fetch_order(self, symbol, order_id) -> object:
|
||||
"""Return the order status for a specific order."""
|
||||
return self.client.get_order(symbol=symbol, orderId=order_id)
|
||||
|
||||
def _connect_exchange(self) -> object:
|
||||
"""Connects to the exchange_interface and sets a reference the client."""
|
||||
return BinanceClient(self.api_key, self.api_key_secret)
|
||||
|
||||
def _set_exchange_info(self) -> dict:
|
||||
"""Fetches Info on all symbols from the exchange_interface."""
|
||||
return self.client.get_exchange_info()
|
||||
|
||||
def _set_symbols(self) -> list:
|
||||
"""Get information of coins (available for deposit and withdraw) for user"""
|
||||
# load into a dataframe
|
||||
data = pd.DataFrame.from_records(self.client.get_all_tickers())
|
||||
# unload into a list
|
||||
return data.symbol.to_list()
|
||||
|
||||
def _set_balances(self) -> list:
|
||||
account_info = self.client.get_account()
|
||||
balances = []
|
||||
for asset in account_info['balances']:
|
||||
asset_symbol = asset['asset']
|
||||
asset_balance = float(asset['free'])
|
||||
asset_pnl = self.calculate_asset_pnl(asset_symbol)
|
||||
if asset_balance > 0 or asset_pnl != 0:
|
||||
asset_data = {
|
||||
'asset': asset_symbol,
|
||||
'balance': asset_balance,
|
||||
'pnl': asset_pnl
|
||||
}
|
||||
balances.append(asset_data)
|
||||
return balances
|
||||
|
||||
def calculate_asset_pnl(self, asset_symbol: str) -> float:
|
||||
# Implement your logic to calculate the P&L for the given asset
|
||||
# This may involve retrieving historical data, calculating positions, etc.
|
||||
# Return the calculated P&L value as a float
|
||||
return 0.0 # Placeholder value, replace with actual calculation
|
||||
|
||||
def _set_precision_rule(self, symbol) -> None:
|
||||
# Isolate a list of symbol rules from the exchange_info
|
||||
all_rules = self.exchange_info['symbols']
|
||||
# Get the rules for a specific symbol.
|
||||
symbol_rules = next(filter(lambda rules: (rules['symbol'] == symbol), all_rules), None)
|
||||
# Add the precision rule to a local index.
|
||||
self.symbols_n_precision[symbol] = symbol_rules['baseAssetPrecision']
|
||||
return
|
||||
|
||||
def _place_order(self, symbol, side, type, timeInForce, quantity, price):
|
||||
|
||||
def format_arg(value, arg_name):
|
||||
"""
|
||||
Function to convert an arguments value to a desired precision and type float.
|
||||
|
||||
:param value: The Value of the argument.
|
||||
:param arg_name: The name of the argument.
|
||||
:return: float : The formatted output.
|
||||
"""
|
||||
# The required level of precision for this trading pair.
|
||||
precision = self.get_symbol_precision_rule(symbol)
|
||||
# If quantity was too low, set to the smallest allowable amount.
|
||||
minimum = 0
|
||||
if arg_name == 'quantity':
|
||||
# The minimum quantity aloud to be traded.
|
||||
minimum = self.get_min_qty(symbol)
|
||||
elif arg_name == 'price':
|
||||
# The minimum price aloud to be traded.
|
||||
minimum = self.get_min_notional_qty(symbol)
|
||||
if value < minimum:
|
||||
value = minimum
|
||||
return float(f"{value:.{precision}f}")
|
||||
|
||||
# Set price and quantity to desired precision and type.
|
||||
quantity = format_arg(quantity, 'quantity')
|
||||
price = format_arg(price, 'price')
|
||||
|
||||
data = self.client.create_test_order(symbol=symbol, side=side, type=type,
|
||||
timeInForce=timeInForce, quantity=quantity,
|
||||
price=price)
|
||||
result = 'Success'
|
||||
return result, data
|
||||
|
||||
def get_active_trades(self):
|
||||
formatted_trades = []
|
||||
# Get open trades
|
||||
positions = self.client.get_account().get('positions')
|
||||
if positions is None:
|
||||
return formatted_trades
|
||||
for position in positions:
|
||||
if float(position['qty']) != 0.0:
|
||||
active_trade = {
|
||||
'symbol': position['symbol'],
|
||||
'side': 'buy' if float(position['qty']) > 0 else 'sell',
|
||||
'quantity': abs(float(position['qty'])),
|
||||
'price': float(position['entryPrice'])
|
||||
}
|
||||
formatted_trades.append(active_trade)
|
||||
return formatted_trades
|
||||
|
||||
def get_open_orders(self):
|
||||
formatted_orders = []
|
||||
# Get open orders
|
||||
open_orders = self.client.get_open_orders()
|
||||
for order in open_orders:
|
||||
open_order = {
|
||||
'symbol': order['symbol'],
|
||||
'side': order['side'],
|
||||
'quantity': order['origQty'],
|
||||
'price': order['price']
|
||||
}
|
||||
formatted_orders.append(open_order)
|
||||
return formatted_orders
|
||||
|
||||
|
|
@ -172,7 +172,7 @@ class BrighterTrades:
|
|||
return self.indicators.get_indicator_data(user_name=user_name, source=source, start_ts=start_ts,
|
||||
num_results=num_results)
|
||||
|
||||
def connect_user_to_exchange(self, user_name: str, default_exchange: str, default_keys: dict) -> bool:
|
||||
def connect_user_to_exchange(self, user_name: str, default_exchange: str, default_keys: dict = None) -> bool:
|
||||
"""
|
||||
Connects an exchange if it is not already connected.
|
||||
|
||||
|
|
@ -189,6 +189,7 @@ class BrighterTrades:
|
|||
exchange_name=exchange,
|
||||
api_keys=keys)
|
||||
if not success:
|
||||
# If no active exchange was successfully connected, connect to the default exchange
|
||||
success = self.connect_or_config_exchange(user_name=user_name,
|
||||
exchange_name=default_exchange,
|
||||
api_keys=default_keys)
|
||||
|
|
@ -237,7 +238,7 @@ class BrighterTrades:
|
|||
r_data = {}
|
||||
r_data['title'] = self.config.app_data.get('application_title', '')
|
||||
r_data['chart_interval'] = chart_view.get('timeframe', '')
|
||||
r_data['selected_exchange'] = chart_view.get('exchange_name', '')
|
||||
r_data['selected_exchange'] = chart_view.get('exchange', '')
|
||||
r_data['intervals'] = exchange.intervals if exchange else []
|
||||
r_data['symbols'] = exchange.get_symbols() if exchange else {}
|
||||
r_data['available_exchanges'] = self.exchanges.get_available_exchanges() or []
|
||||
|
|
@ -359,7 +360,7 @@ class BrighterTrades:
|
|||
"""
|
||||
return self.strategies.get_strategies('json')
|
||||
|
||||
def connect_or_config_exchange(self, user_name: str, exchange_name: str, api_keys: dict) -> bool:
|
||||
def connect_or_config_exchange(self, user_name: str, exchange_name: str, api_keys: dict = None) -> bool:
|
||||
"""
|
||||
Connects to an exchange if not already connected, or configures the exchange connection for a single user.
|
||||
|
||||
|
|
@ -377,7 +378,9 @@ class BrighterTrades:
|
|||
api_keys=api_keys)
|
||||
if success:
|
||||
self.config.users.active_exchange(exchange=exchange_name, user_name=user_name, cmd='set')
|
||||
self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name, user_name=user_name)
|
||||
if api_keys:
|
||||
self.config.users.update_api_keys(api_keys=api_keys, exchange=exchange_name,
|
||||
user_name=user_name)
|
||||
return True
|
||||
else:
|
||||
return False # Failed to connect
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ class DataCache:
|
|||
records = pd.concat([more_records, self.get_cache(key)], axis=0, ignore_index=True)
|
||||
# Drop any duplicates from overlap.
|
||||
records = records.drop_duplicates(subset="open_time", keep='first')
|
||||
# Sort the records by open_time.
|
||||
records = records.sort_values(by='open_time').reset_index(drop=True)
|
||||
# Replace the incomplete dataframe with the modified one.
|
||||
self.set_cache(data=records, key=key)
|
||||
return
|
||||
|
|
|
|||
386
src/Exchange.py
386
src/Exchange.py
|
|
@ -1,40 +1,175 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
import ccxt
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Tuple, Dict, List, Union
|
||||
import time
|
||||
|
||||
|
||||
class Exchange(ABC):
|
||||
def __init__(self, name, api_keys):
|
||||
class Exchange:
|
||||
# Class attribute for caching market data
|
||||
_market_cache = {}
|
||||
|
||||
def __init__(self, name: str, api_keys: Dict[str, str], exchange_id: str):
|
||||
self.name = name
|
||||
|
||||
# 1 The api key for the exchange.
|
||||
self.api_key = api_keys['key']
|
||||
# The API key for the exchange.
|
||||
self.api_key = api_keys['key'] if api_keys else None
|
||||
|
||||
# 2 The api secret key for the exchange.
|
||||
self.api_key_secret = api_keys['secret']
|
||||
# The API secret key for the exchange.
|
||||
self.api_key_secret = api_keys['secret'] if api_keys else None
|
||||
|
||||
# 3 The connection to the exchange_interface.
|
||||
self.client = self._connect_exchange()
|
||||
# The exchange id for the exchange.
|
||||
self.exchange_id = exchange_id
|
||||
|
||||
# 4 Info on all symbols and exchange_interface rules.
|
||||
# The connection to the exchange_interface.
|
||||
self.client: ccxt.Exchange = self._connect_exchange()
|
||||
|
||||
# Info on all symbols and exchange_interface rules.
|
||||
self.exchange_info = self._set_exchange_info()
|
||||
|
||||
# 5 List of time intervals available for trading.
|
||||
# List of time intervals available for trading.
|
||||
self.intervals = self._set_avail_intervals()
|
||||
|
||||
# 6 All symbols available for trading.
|
||||
# All symbols available for trading.
|
||||
self.symbols = self._set_symbols()
|
||||
|
||||
# 7 Any non-zero balance-info for all assets.
|
||||
# Any non-zero balance-info for all assets.
|
||||
self.balances = self._set_balances()
|
||||
|
||||
# 8 Dictionary of places after the decimal requires by the exchange_interface indexed by symbol.
|
||||
# Dictionary of places after the decimal requires by the exchange_interface indexed by symbol.
|
||||
self.symbols_n_precision = {}
|
||||
|
||||
def _connect_exchange(self) -> ccxt.Exchange:
|
||||
"""
|
||||
Connects to the exchange_interface and sets a reference the client.
|
||||
Handles cases where api_keys might be None.
|
||||
"""
|
||||
exchange_class = getattr(ccxt, self.exchange_id)
|
||||
if not exchange_class:
|
||||
raise ValueError(f"Exchange {self.exchange_id} is not supported by CCXT.")
|
||||
|
||||
if self.api_key and self.api_key_secret:
|
||||
return exchange_class({
|
||||
'apiKey': self.api_key,
|
||||
'secret': self.api_key_secret,
|
||||
'enableRateLimit': True,
|
||||
'verbose': False # Disable verbose debugging output
|
||||
})
|
||||
else:
|
||||
return exchange_class({
|
||||
'enableRateLimit': True,
|
||||
'verbose': False # Disable verbose debugging output
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def datetime_to_unix_millis(dt: datetime) -> int:
|
||||
"""
|
||||
Convert a datetime object to Unix timestamp in milliseconds.
|
||||
|
||||
:param dt: datetime - The datetime object to convert.
|
||||
:return: int - The Unix timestamp in milliseconds.
|
||||
"""
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
def _fetch_historical_klines(self, symbol: str, interval: str, start_dt: datetime,
|
||||
end_dt: datetime = None) -> pd.DataFrame:
|
||||
if end_dt is None:
|
||||
end_dt = datetime.utcnow()
|
||||
|
||||
max_interval = timedelta(days=200) # Binance's maximum interval
|
||||
data_frames = []
|
||||
current_start = start_dt
|
||||
|
||||
while current_start < end_dt:
|
||||
current_end = min(current_start + max_interval, end_dt)
|
||||
start_str = self.datetime_to_unix_millis(current_start)
|
||||
end_str = self.datetime_to_unix_millis(current_end)
|
||||
|
||||
try:
|
||||
candles = self.client.fetch_ohlcv(symbol=symbol, timeframe=interval, since=start_str,
|
||||
params={'endTime': end_str})
|
||||
if not candles:
|
||||
break
|
||||
|
||||
df_columns = ['open_time', 'open', 'high', 'low', 'close', 'volume']
|
||||
candles_df = pd.DataFrame(candles, columns=df_columns)
|
||||
candles_df['open_time'] = candles_df['open_time'] // 1000
|
||||
data_frames.append(candles_df)
|
||||
|
||||
# Move the start to the end of the current chunk to get the next chunk
|
||||
current_start = current_end
|
||||
|
||||
# Sleep for 1 second to avoid hitting rate limits
|
||||
time.sleep(1)
|
||||
|
||||
except ccxt.BaseError as e:
|
||||
print(f"Error fetching OHLCV data: {str(e)}")
|
||||
break
|
||||
|
||||
if data_frames:
|
||||
result_df = pd.concat(data_frames)
|
||||
return result_df
|
||||
else:
|
||||
return pd.DataFrame(columns=['open_time', 'open', 'high', 'low', 'close', 'volume'])
|
||||
|
||||
def _fetch_price(self, symbol: str) -> float:
|
||||
ticker = self.client.fetch_ticker(symbol)
|
||||
return float(ticker['last'])
|
||||
|
||||
def _fetch_min_qty(self, symbol: str) -> float:
|
||||
market_data = self.exchange_info[symbol]
|
||||
return float(market_data['limits']['amount']['min'])
|
||||
|
||||
def _fetch_min_notional_qty(self, symbol: str) -> float:
|
||||
market_data = self.exchange_info[symbol]
|
||||
return float(market_data['limits']['cost']['min'])
|
||||
|
||||
def _fetch_order(self, symbol: str, order_id: str) -> object:
|
||||
return self.client.fetch_order(order_id, symbol)
|
||||
|
||||
def _set_symbols(self) -> List[str]:
|
||||
markets = self.client.fetch_markets()
|
||||
symbols = [market['symbol'] for market in markets if market['active']]
|
||||
return symbols
|
||||
|
||||
def _set_balances(self) -> List[Dict[str, Union[str, float]]]:
|
||||
if self.api_key and self.api_key_secret:
|
||||
try:
|
||||
account_info = self.client.fetch_balance()
|
||||
balances = []
|
||||
for asset, balance in account_info['total'].items():
|
||||
asset_balance = float(balance)
|
||||
if asset_balance > 0:
|
||||
balances.append({'asset': asset, 'balance': asset_balance, 'pnl': 0})
|
||||
return balances
|
||||
except NotImplementedError:
|
||||
# Handle the case where fetch_balance is not supported
|
||||
return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}]
|
||||
else:
|
||||
return [{'asset': 'N/A', 'balance': 0, 'pnl': 0}]
|
||||
|
||||
def _set_exchange_info(self) -> dict:
|
||||
"""
|
||||
Fetches market data for all symbols from the exchange.
|
||||
This includes details on trading pairs, limits, fees, and other relevant information.
|
||||
Caches the market info to speed up subsequent calls.
|
||||
"""
|
||||
if self.exchange_id in Exchange._market_cache:
|
||||
return Exchange._market_cache[self.exchange_id]
|
||||
|
||||
try:
|
||||
markets_info = self.client.load_markets()
|
||||
Exchange._market_cache[self.exchange_id] = markets_info
|
||||
return markets_info
|
||||
except ccxt.BaseError as e:
|
||||
print(f"Error fetching market info: {str(e)}")
|
||||
return {}
|
||||
|
||||
def get_client(self) -> object:
|
||||
""" Return a reference to the exchange_interface client."""
|
||||
return self.client
|
||||
|
||||
def get_avail_intervals(self) -> tuple:
|
||||
def get_avail_intervals(self) -> Tuple[str, ...]:
|
||||
"""Returns a list of time intervals available for trading."""
|
||||
return self.intervals
|
||||
|
||||
|
|
@ -42,15 +177,15 @@ class Exchange(ABC):
|
|||
"""Returns Info on all symbols."""
|
||||
return self.exchange_info
|
||||
|
||||
def get_symbols(self) -> list:
|
||||
def get_symbols(self) -> List[str]:
|
||||
"""Returns all symbols available for trading."""
|
||||
return self.symbols
|
||||
|
||||
def get_balances(self) -> list:
|
||||
def get_balances(self) -> List[Dict[str, Union[str, float]]]:
|
||||
"""Returns any non-zero balance-info for all assets."""
|
||||
return self.balances
|
||||
|
||||
def get_symbol_precision_rule(self, symbol) -> int:
|
||||
def get_symbol_precision_rule(self, symbol: str) -> int:
|
||||
"""Returns number of places after the decimal required by the exchange_interface for a specific symbol."""
|
||||
r_value = self.symbols_n_precision.get(symbol)
|
||||
if r_value is None:
|
||||
|
|
@ -58,149 +193,108 @@ class Exchange(ABC):
|
|||
r_value = self.symbols_n_precision.get(symbol)
|
||||
return r_value
|
||||
|
||||
def get_historical_klines(self, symbol, interval, start_dt, end_dt) -> object:
|
||||
"""
|
||||
Return a dataframe containing rows of candle attributes. Attributes very between different exchanges.
|
||||
For example: [(Open time, Open, High, Low, Close,
|
||||
Volume, Close time, Quote symbol volume,
|
||||
Number of trades, Taker buy base symbol volume,
|
||||
Taker buy quote symbol volume, Ignore)]
|
||||
"""
|
||||
return self._fetch_historical_klines(symbol=symbol, interval=interval,
|
||||
start_dt=start_dt, end_dt=end_dt)
|
||||
def get_historical_klines(self, symbol: str, interval: str, start_dt: datetime,
|
||||
end_dt: datetime = None) -> pd.DataFrame:
|
||||
return self._fetch_historical_klines(symbol=symbol, interval=interval, start_dt=start_dt, end_dt=end_dt)
|
||||
|
||||
def get_price(self, symbol) -> float:
|
||||
"""
|
||||
Get the latest price for a single symbol.
|
||||
|
||||
:param symbol: str - The symbol of the symbol.
|
||||
:return: float - The minimum quantity sold per trade.
|
||||
"""
|
||||
def get_price(self, symbol: str) -> float:
|
||||
return self._fetch_price(symbol)
|
||||
|
||||
def get_min_qty(self, symbol) -> float:
|
||||
"""
|
||||
Get the minimum base quantity the exchange_interface allows per order.
|
||||
|
||||
:param symbol: str - The symbol of the trading pair.
|
||||
:return: float - The minimum quantity sold per trade.
|
||||
"""
|
||||
def get_min_qty(self, symbol: str) -> float:
|
||||
return self._fetch_min_qty(symbol)
|
||||
|
||||
def get_min_notional_qty(self, symbol) -> float:
|
||||
"""
|
||||
Get the minimum amount of the quote currency allowed per trade.
|
||||
|
||||
:param symbol: The symbol of the trading pair.
|
||||
:return: The minimum quantity sold per trade.
|
||||
"""
|
||||
def get_min_notional_qty(self, symbol: str) -> float:
|
||||
return self._fetch_min_notional_qty(symbol)
|
||||
|
||||
@abstractmethod
|
||||
def get_active_trades(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_open_orders(self):
|
||||
pass
|
||||
|
||||
def get_order(self, symbol, order_id) -> object:
|
||||
"""
|
||||
Get an order by id.
|
||||
|
||||
:param symbol: The trading pair
|
||||
:param order_id: The order id
|
||||
:return: object
|
||||
"""
|
||||
def get_order(self, symbol: str, order_id: str) -> object:
|
||||
return self._fetch_order(symbol, order_id)
|
||||
|
||||
def place_order(self, symbol, side, type, timeInForce, quantity, price):
|
||||
result, msg = self._place_order(symbol=symbol, side=side, type=type,
|
||||
timeInForce=timeInForce, quantity=quantity, price=price)
|
||||
def place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> \
|
||||
Tuple[str, object]:
|
||||
result, msg = self._place_order(symbol=symbol, side=side, type=type, timeInForce=timeInForce, quantity=quantity,
|
||||
price=price)
|
||||
return result, msg
|
||||
|
||||
def _set_avail_intervals(self) -> tuple:
|
||||
"""Sets a list of time intervals available for trading on the exchange_interface."""
|
||||
return '1m', '3m', '5m', '15m', '30m', '1h', '2h', '4h', '6h', '8h', '12h', '1d', '3d', '1w', '1M'
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
@abstractmethod
|
||||
def _fetch_historical_klines(self, symbol, interval, start_dt, end_dt) -> list:
|
||||
def _set_avail_intervals(self) -> Tuple[str, ...]:
|
||||
"""
|
||||
Return a list [(Open time, Open, High, Low, Close,
|
||||
Volume, Close time, Quote symbol volume,
|
||||
Number of trades, Taker buy base symbol volume,
|
||||
Taker buy quote symbol volume, Ignore)]
|
||||
Sets a list of time intervals available for trading on the exchange_interface.
|
||||
"""
|
||||
pass
|
||||
return tuple(self.client.timeframes.keys())
|
||||
|
||||
@abstractmethod
|
||||
def _fetch_price(self, symbol) -> float:
|
||||
def _set_precision_rule(self, symbol: str) -> None:
|
||||
market_data = self.exchange_info[symbol]
|
||||
precision = market_data['precision']['amount']
|
||||
self.symbols_n_precision[symbol] = precision
|
||||
return
|
||||
|
||||
def _place_order(self, symbol: str, side: str, type: str, timeInForce: str, quantity: float, price: float = None) -> \
|
||||
Tuple[str, object]:
|
||||
def format_arg(value: float) -> float:
|
||||
precision = self.symbols_n_precision.get(symbol, 8)
|
||||
return float(f"{value:.{precision}f}")
|
||||
|
||||
quantity = format_arg(quantity)
|
||||
if price is not None:
|
||||
price = format_arg(price)
|
||||
|
||||
order_params = {
|
||||
'symbol': symbol,
|
||||
'type': type,
|
||||
'side': side,
|
||||
'amount': quantity,
|
||||
'params': {
|
||||
'timeInForce': timeInForce
|
||||
}
|
||||
}
|
||||
|
||||
if price is not None:
|
||||
order_params['price'] = price
|
||||
|
||||
order = self.client.create_order(**order_params)
|
||||
return 'Success', order
|
||||
|
||||
def get_active_trades(self) -> List[Dict[str, Union[str, float]]]:
|
||||
"""
|
||||
Get the latest price for a single symbol.
|
||||
|
||||
:param symbol: str - The symbol of the symbol.
|
||||
:return: float - The minimum quantity sold per trade.
|
||||
Get the active trades (open positions).
|
||||
"""
|
||||
pass
|
||||
if self.api_key and self.api_key_secret:
|
||||
try:
|
||||
positions = self.client.fetch_positions()
|
||||
formatted_trades = []
|
||||
for position in positions:
|
||||
active_trade = {
|
||||
'symbol': position['symbol'],
|
||||
'side': 'buy' if float(position['quantity']) > 0 else 'sell',
|
||||
'quantity': abs(float(position['quantity'])),
|
||||
'price': float(position['entry_price'])
|
||||
}
|
||||
formatted_trades.append(active_trade)
|
||||
return formatted_trades
|
||||
except NotImplementedError:
|
||||
# Handle the case where fetch_positions is not supported
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def _fetch_min_qty(self, symbol) -> float:
|
||||
def get_open_orders(self) -> List[Dict[str, Union[str, float]]]:
|
||||
"""
|
||||
Get the minimum base quantity the exchange_interface allows per order.
|
||||
|
||||
:param symbol: str - The symbol of the trading pair.
|
||||
:return: float - The minimum quantity sold per trade.
|
||||
Get the open orders.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _fetch_min_notional_qty(self, symbol) -> float:
|
||||
"""
|
||||
Get the minimum amount of the quote currency allowed per trade.
|
||||
|
||||
:param symbol: The symbol of the trading pair.
|
||||
:return: The minimum quantity sold per trade.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _fetch_order(self, symbol, order_id) -> object:
|
||||
"""
|
||||
Get an order by id.
|
||||
:param symbol: The trading pair
|
||||
:param order_id: The order id
|
||||
:return: object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _connect_exchange(self) -> object:
|
||||
"""Connects to the exchange_interface and sets a reference the client."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _set_exchange_info(self) -> dict:
|
||||
"""Fetches Info on all symbols from the exchange_interface."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _set_symbols(self) -> list:
|
||||
"""Fetches all symbols available for trading on the exchange_interface."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _set_balances(self) -> list:
|
||||
"""Fetches any non-zero balance-info for all assets from exchange_interface."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _set_precision_rule(self, symbol) -> dict:
|
||||
"""Appends a record of places after the decimal required by the exchange_interface indexed by symbol."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _place_order(self, symbol, side, type, timeInForce, quantity, price):
|
||||
result = 'Success'
|
||||
msg = ''
|
||||
return result, msg
|
||||
if self.api_key and self.api_key_secret:
|
||||
try:
|
||||
open_orders = self.client.fetch_open_orders()
|
||||
formatted_orders = []
|
||||
for order in open_orders:
|
||||
open_order = {
|
||||
'symbol': order['symbol'],
|
||||
'side': order['side'],
|
||||
'quantity': order['amount'],
|
||||
'price': order['price']
|
||||
}
|
||||
formatted_orders.append(open_order)
|
||||
return formatted_orders
|
||||
except NotImplementedError:
|
||||
# Handle the case where fetch_balance is not supported
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
|
|
|||
14
src/Users.py
14
src/Users.py
|
|
@ -541,8 +541,10 @@ class Users:
|
|||
# Get the user records from the database.
|
||||
user = self.get_user_from_db(user_name)
|
||||
# Get the exchanges list based on the field.
|
||||
return json.loads(user.loc[0, category])
|
||||
except (KeyError, IndexError) as e:
|
||||
exchanges = user.loc[0, category]
|
||||
# Return the list if it exists, otherwise return an empty list.
|
||||
return json.loads(exchanges) if exchanges else []
|
||||
except (KeyError, IndexError, json.JSONDecodeError) as e:
|
||||
# Log the error to the console
|
||||
print(f"Error retrieving exchanges for user '{user_name}' and field '{category}': {str(e)}")
|
||||
return None
|
||||
|
|
@ -558,8 +560,12 @@ class Users:
|
|||
"""
|
||||
# Get the user records from the database.
|
||||
user = self.get_user_from_db(user_name)
|
||||
# Get the old active_exchanges list.
|
||||
active_exchanges = json.loads(user.loc[0, 'active_exchanges'])
|
||||
# Get the old active_exchanges list, or initialize as an empty list if it is None.
|
||||
active_exchanges = user.loc[0, 'active_exchanges']
|
||||
if active_exchanges is None:
|
||||
active_exchanges = []
|
||||
else:
|
||||
active_exchanges = json.loads(active_exchanges)
|
||||
|
||||
# Define the actions for each command
|
||||
actions = {
|
||||
|
|
|
|||
18
src/app.py
18
src/app.py
|
|
@ -30,7 +30,7 @@ cors = CORS(app, supports_credentials=True,
|
|||
r"/api/history": {"origins": ["http://127.0.0.1:5000", "http://localhost:5000"]},
|
||||
r"/api/indicator_init": {"origins": ["http://127.0.0.1:5000", "http://localhost:5000"]}
|
||||
},
|
||||
headers=['Content-Type'])
|
||||
allow_headers=['Content-Type']) # Change from headers to allow_headers
|
||||
|
||||
|
||||
@app.after_request
|
||||
|
|
@ -46,6 +46,8 @@ def index():
|
|||
Fetches data from brighter_trades and inject it into an HTML template.
|
||||
Renders the html template and serves the web application.
|
||||
"""
|
||||
# Clear the session to simulate a new visitor
|
||||
session.clear()
|
||||
try:
|
||||
# Log the user in.
|
||||
user_name = brighter_trades.config.users.load_or_create_user(username=session.get('user'))
|
||||
|
|
@ -61,13 +63,15 @@ def index():
|
|||
|
||||
print('[SERVING INDEX] (USERNAME):', user_name)
|
||||
|
||||
# Ensure that a valid connection with an exchange exist.
|
||||
keys = {'key': config.ALPACA_API_KEY, 'secret': config.ALPACA_API_SECRET}
|
||||
default_exchange = 'binance'
|
||||
default_keys = None # No keys needed for public market data
|
||||
|
||||
# Ensure that a valid connection with an exchange exists
|
||||
result = brighter_trades.connect_user_to_exchange(user_name=user_name,
|
||||
default_exchange='alpaca',
|
||||
default_keys=keys)
|
||||
default_exchange=default_exchange,
|
||||
default_keys=default_keys)
|
||||
if not result:
|
||||
raise ValueError("Couldn't connect to the alpaca exchange.")
|
||||
raise ValueError("Couldn't connect to the default exchange.")
|
||||
|
||||
# A dict of data required to build the html of the user app.
|
||||
# Dynamic content like options and titles and balances to display.
|
||||
|
|
@ -310,4 +314,4 @@ def indicator_init():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True, use_reloader=False)
|
||||
app.run(debug=False, use_reloader=False)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
BINANCE_API_KEY = 'rkp1Xflb5nnwt6jys0PG27KXcqwn0q9lKCLryKcSp4mKW2UOlkPRuAHPg45rQVgj'
|
||||
BINANCE_API_SECRET = 'DiFhhYhF64nkPe5f3V7TRJX2bSVA7ZQZlozSdX7O7uYmBMdK985eA6Kp2B2zKvbK'
|
||||
ALPACA_API_KEY = 'PKN0WFYT9VZYUVRBG1HM'
|
||||
ALPACA_API_SECRET = '0C1I6UcBSR2B0SZrBC3DoKGtcglAny8znorvganx'
|
||||
ALPACA_API_KEY = 'PKE4RD999SJ8L53OUI8O'
|
||||
ALPACA_API_SECRET = 'buwlMoSSfZWGih8Er30quQt4d7brsBWdJXD1KB7C'
|
||||
DB_FILE = "C:/Users/Rob/PycharmProjects/BrighterTrading/data/BrighterTrading.db"
|
||||
|
|
|
|||
|
|
@ -242,7 +242,7 @@ class Database:
|
|||
:param exchange_name: str - The name of the exchange_name.
|
||||
:return: int - The primary id of the exchange_name.
|
||||
"""
|
||||
return self.get_from_static_table(item='id', table='exchange', indexes=HDict({'name': exchange_name}))
|
||||
return self.get_from_static_table(item='id', table='exchange', create_id=True,indexes=HDict({'name': exchange_name}))
|
||||
|
||||
def _fetch_market_id(self, symbol: str, exchange_name: str) -> int:
|
||||
"""
|
||||
|
|
@ -276,24 +276,23 @@ class Database:
|
|||
market_id = self._fetch_market_id(symbol, exchange_name)
|
||||
# Insert the market id into the dataframe.
|
||||
candlesticks.insert(0, 'market_id', market_id)
|
||||
# Create a table schema.
|
||||
# Get a list of all the columns in the dataframe.
|
||||
columns = list(candlesticks.columns.values)
|
||||
# Isolate any extra columns specific to individual exchanges.
|
||||
# The carriage return and tabs are unnecessary, they just tidy output for debugging.
|
||||
columns = ',\n\t\t\t\t\t'.join(columns[7:], )
|
||||
# Define the columns common with all exchanges and append any extras columns.
|
||||
# Create a table schema. todo delete these line if not needed anymore
|
||||
# # Get a list of all the columns in the dataframe.
|
||||
# columns = list(candlesticks.columns.values)
|
||||
# # Isolate any extra columns specific to individual exchanges.
|
||||
# # The carriage return and tabs are unnecessary, they just tidy output for debugging.
|
||||
# columns = ',\n\t\t\t\t\t'.join(columns[7:], )
|
||||
# # Define the columns common with all exchanges and append any extras columns.
|
||||
sql_create = f"""
|
||||
CREATE TABLE IF NOT EXISTS '{table_name}' (
|
||||
id INTEGER PRIMARY KEY,
|
||||
market_id INTEGER,
|
||||
open_time UNIQUE ON CONFLICT IGNORE,
|
||||
open NOT NULL,
|
||||
high NOT NULL,
|
||||
low NOT NULL,
|
||||
close NOT NULL,
|
||||
volume NOT NULL,
|
||||
{columns},
|
||||
open_time INTEGER UNIQUE ON CONFLICT IGNORE,
|
||||
open REAL NOT NULL,
|
||||
high REAL NOT NULL,
|
||||
low REAL NOT NULL,
|
||||
close REAL NOT NULL,
|
||||
volume REAL NOT NULL,
|
||||
FOREIGN KEY (market_id) REFERENCES market (id)
|
||||
)"""
|
||||
# Connect to the database.
|
||||
|
|
@ -338,17 +337,17 @@ class Database:
|
|||
print(f'Got {len(records.index)} records from db')
|
||||
else:
|
||||
# If the table doesn't exist, get them from the exchange_name.
|
||||
print('\nTable didnt exist fetching from exchange_name')
|
||||
print(f'\nTable didnt exist fetching from {ex_details[2]}')
|
||||
temp = (((unix_time_millis(et) - unix_time_millis(st)) / 1000) / 60) / rl
|
||||
print(f'Requesting from {st} to {et}, Should be {temp} records')
|
||||
records = self._populate_table(table_name=table_name, start_time=st, end_time=et, ex_details=ex_details)
|
||||
print(f'Got {len(records.index)} records from exchange_name')
|
||||
print(f'Got {len(records.index)} records from {ex_details[2]}')
|
||||
|
||||
# Check if the records in the db go far enough back to satisfy the query.
|
||||
first_timestamp = query_satisfied(start_datetime=st, records=records, r_length=rl)
|
||||
if first_timestamp:
|
||||
# The records didn't go far enough back if a timestamp was returned.
|
||||
print('\nRecords did not go far enough back. Requesting from exchange_name')
|
||||
print(f'\nRecords did not go far enough back. Requesting from {ex_details[2]}')
|
||||
print(f'first ts on record is: {first_timestamp}')
|
||||
end_time = dt.datetime.utcfromtimestamp(first_timestamp)
|
||||
print(f'Requesting from {st} to {end_time}')
|
||||
|
|
@ -359,7 +358,7 @@ class Database:
|
|||
last_timestamp = query_uptodate(records=records, r_length=rl)
|
||||
if last_timestamp:
|
||||
# The query was not up-to-date if a timestamp was returned.
|
||||
print('\nRecords were not updated. Requesting from exchange_name.')
|
||||
print(f'\nRecords were not updated. Requesting from {ex_details[2]}.')
|
||||
print(f'the last record on file is: {last_timestamp}')
|
||||
start_time = dt.datetime.utcfromtimestamp(last_timestamp)
|
||||
print(f'Requesting from {start_time} to {et}')
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
import logging
|
||||
import json
|
||||
from typing import List, Any
|
||||
|
||||
from typing import List, Any, Dict
|
||||
import pandas as pd
|
||||
import requests
|
||||
import ccxt
|
||||
|
||||
from BinanceFutures import BinanceFuturesExchange, BinanceCoinExchange
|
||||
from BinanceSpot import BinanceSpotExchange
|
||||
from AlpacaPaperExchange import AlpacaPaperExchange
|
||||
from Exchange import Exchange
|
||||
|
||||
# Setup logging
|
||||
# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
# This just makes this method cleaner.
|
||||
|
|
@ -17,48 +19,54 @@ def add_row(df, dic):
|
|||
|
||||
class ExchangeInterface:
|
||||
"""
|
||||
Connects and maintains and routs data requests from exchanges.
|
||||
Connects and maintains and routes data requests from exchanges.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
# Create a dataframe to hold all the references and info for each user's configured exchanges.
|
||||
self.exchange_data = pd.DataFrame(columns=['user', 'name', 'reference', 'balance'])
|
||||
# List of available exchanges
|
||||
self.available_exchanges = ['alpaca', 'binance_coin', 'binance_futures', 'binance_spot']
|
||||
self.exchange_data = pd.DataFrame(columns=['user', 'name', 'reference', 'balances'])
|
||||
|
||||
# Populate the list of available exchanges from CCXT
|
||||
self.available_exchanges = self.get_ccxt_exchanges()
|
||||
|
||||
def get_ccxt_exchanges(self) -> List[str]:
|
||||
"""Retrieve the list of available exchanges from CCXT."""
|
||||
return ccxt.exchanges
|
||||
|
||||
def connect_exchange(self, exchange_name: str, user_name: str, api_keys: dict = None) -> bool:
|
||||
"""
|
||||
Initialize and store a reference to the available exchanges.
|
||||
Initialize and store a reference to the available exchanges.
|
||||
|
||||
:param user_name: The name of the user connecting the exchange.
|
||||
:param api_keys: dict - {api: key, api-secret: key}
|
||||
:param exchange_name: str - The name of the exchange.
|
||||
:return: True if success | None on fail.
|
||||
:param user_name: The name of the user connecting the exchange.
|
||||
:param api_keys: dict - {api: key, api-secret: key}
|
||||
:param exchange_name: str - The name of the exchange.
|
||||
:return: True if success | None on fail.
|
||||
"""
|
||||
if exchange_name == 'alpaca':
|
||||
success = self.add_exchange(user_name, AlpacaPaperExchange, (exchange_name, api_keys))
|
||||
elif exchange_name == 'binance_coin':
|
||||
success = self.add_exchange(user_name, BinanceCoinExchange, (exchange_name, api_keys))
|
||||
elif exchange_name == 'binance_futures':
|
||||
success = self.add_exchange(user_name, BinanceFuturesExchange, (exchange_name, api_keys))
|
||||
elif exchange_name == 'binance_spot':
|
||||
success = self.add_exchange(user_name, BinanceSpotExchange, (exchange_name, api_keys))
|
||||
else:
|
||||
success = False
|
||||
return success
|
||||
# logging.debug(
|
||||
# f"Attempting to connect to exchange '{exchange_name}' for user '{user_name}' with API keys: {api_keys}")
|
||||
|
||||
def add_exchange(self, user_name, _class, arg):
|
||||
try:
|
||||
ref = _class(*arg)
|
||||
row = {'user': user_name, 'name': ref.name, 'reference': ref, 'balances': ref.balances}
|
||||
# Initialize the exchange object
|
||||
exchange = Exchange(name=exchange_name, api_keys=api_keys, exchange_id=exchange_name.lower())
|
||||
|
||||
# Update exchange data with the new connection
|
||||
self.add_exchange(user_name, exchange)
|
||||
# logging.debug(f"Successfully connected to exchange '{exchange_name}' for user '{user_name}'")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to connect user '{user_name}' to exchange '{exchange_name}': {str(e)}")
|
||||
return False # Failed to connect
|
||||
|
||||
def add_exchange(self, user_name: str, exchange: Exchange):
|
||||
try:
|
||||
row = {'user': user_name, 'name': exchange.name, 'reference': exchange, 'balances': exchange.balances}
|
||||
self.exchange_data = add_row(self.exchange_data, row)
|
||||
except Exception as e:
|
||||
if e.status_code == 400 and e.error_code == -1021:
|
||||
print("Timestamp ahead of server's time error: Sync your system clock to fix this.")
|
||||
print("Couldn't create an instance of the exchange!:\n", e)
|
||||
if hasattr(e, 'status_code') and e.status_code == 400 and e.error_code == -1021:
|
||||
logging.error("Timestamp ahead of server's time error: Sync your system clock to fix this.")
|
||||
logging.error("Couldn't create an instance of the exchange!:\n", e)
|
||||
raise
|
||||
return True
|
||||
|
||||
def get_exchange(self, ename: str, uname: str) -> Any:
|
||||
"""Return a reference to the exchange_name."""
|
||||
|
|
@ -76,15 +84,15 @@ class ExchangeInterface:
|
|||
connected_exchanges = self.exchange_data.loc[self.exchange_data['user'] == user_name, 'name'].tolist()
|
||||
return connected_exchanges
|
||||
|
||||
def get_available_exchanges(self):
|
||||
def get_available_exchanges(self) -> List[str]:
|
||||
""" Return a list of the exchanges available to connect to"""
|
||||
return self.available_exchanges
|
||||
|
||||
def get_exchange_balances(self, name):
|
||||
def get_exchange_balances(self, name: str) -> pd.Series:
|
||||
""" Return the balances of a single exchange_name"""
|
||||
return self.exchange_data.query("name == @name")['balances']
|
||||
|
||||
def get_all_balances(self, user_name: str) -> dict:
|
||||
def get_all_balances(self, user_name: str) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Return the balances of all connected exchanges indexed by name.
|
||||
|
||||
|
|
@ -105,7 +113,7 @@ class ExchangeInterface:
|
|||
|
||||
return balances_dict
|
||||
|
||||
def get_all_activated(self, user_name, fetch_type='trades'):
|
||||
def get_all_activated(self, user_name: str, fetch_type: str = 'trades') -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Get active trades or open orders as a dictionary indexed by name"""
|
||||
filtered_data = self.exchange_data.loc[self.exchange_data['user'] == user_name, ['name', 'reference']]
|
||||
if filtered_data.empty:
|
||||
|
|
@ -122,16 +130,16 @@ class ExchangeInterface:
|
|||
elif fetch_type == 'orders':
|
||||
data = reference.get_open_orders()
|
||||
else:
|
||||
print(f"Invalid fetch type: {fetch_type}")
|
||||
logging.error(f"Invalid fetch type: {fetch_type}")
|
||||
return {}
|
||||
|
||||
data_dict[name] = data
|
||||
except Exception as e:
|
||||
print(f"Error retrieving data for {name}: {str(e)}")
|
||||
logging.error(f"Error retrieving data for {name}: {str(e)}")
|
||||
|
||||
return data_dict
|
||||
|
||||
def get_order(self, symbol, order_id, target, user_name):
|
||||
def get_order(self, symbol: str, order_id: str, target: str, user_name: str) -> Any:
|
||||
"""
|
||||
Return order - from a target exchange_interface.
|
||||
:param user_name: The name of the user making the request.
|
||||
|
|
@ -145,7 +153,7 @@ class ExchangeInterface:
|
|||
# Return the order.
|
||||
return exchange.get_order(symbol=symbol, order_id=order_id)
|
||||
|
||||
def get_trade_status(self, trade, user_name):
|
||||
def get_trade_status(self, trade, user_name: str) -> str:
|
||||
"""
|
||||
Return the status of a trade
|
||||
Todo: trade order.status might be outdated this request the status from the exchanges order record.
|
||||
|
|
@ -158,7 +166,7 @@ class ExchangeInterface:
|
|||
# Return status.
|
||||
return order['status']
|
||||
|
||||
def get_trade_executed_qty(self, trade, user_name):
|
||||
def get_trade_executed_qty(self, trade, user_name: str) -> float:
|
||||
"""
|
||||
Return the executed quantity of a trade.
|
||||
|
||||
|
|
@ -173,7 +181,7 @@ class ExchangeInterface:
|
|||
# Return quantity.
|
||||
return order['executedQty']
|
||||
|
||||
def get_trade_executed_price(self, trade, user_name):
|
||||
def get_trade_executed_price(self, trade, user_name: str) -> float:
|
||||
"""
|
||||
Return the average price of executed quantity of a trade
|
||||
|
||||
|
|
@ -188,7 +196,7 @@ class ExchangeInterface:
|
|||
return order['price']
|
||||
|
||||
@staticmethod
|
||||
def get_price(symbol, price_source=None):
|
||||
def get_price(symbol: str, price_source: str = None) -> float:
|
||||
"""
|
||||
:param price_source: alternative sources for price.
|
||||
:param symbol: The symbol of the trading pair.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
import ccxt
|
||||
|
||||
|
||||
def main():
|
||||
# Create an instance of the Binance exchange
|
||||
binance = ccxt.binance({
|
||||
'enableRateLimit': True,
|
||||
'verbose': False, # Ensure verbose mode is disabled
|
||||
})
|
||||
|
||||
try:
|
||||
# Load markets to test the connection
|
||||
markets = binance.load_markets()
|
||||
print("Markets loaded successfully")
|
||||
except ccxt.BaseError as e:
|
||||
print(f"Error loading markets: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
import ccxt
|
||||
|
||||
|
||||
def test_public_endpoints(exchange_id, symbol):
|
||||
try:
|
||||
exchange_class = getattr(ccxt, exchange_id)
|
||||
exchange = exchange_class()
|
||||
|
||||
# Check and test fetch_ticker
|
||||
if exchange.has.get('fetchTicker', False):
|
||||
exchange.fetch_ticker(symbol)
|
||||
|
||||
# Check and test fetch_tickers
|
||||
if exchange.has.get('fetchTickers', False):
|
||||
exchange.fetch_tickers([symbol])
|
||||
|
||||
# Check and test fetch_order_book
|
||||
if exchange.has.get('fetchOrderBook', False):
|
||||
exchange.fetch_order_book(symbol)
|
||||
|
||||
# Check and test fetch_trades
|
||||
if exchange.has.get('fetchTrades', False):
|
||||
exchange.fetch_trades(symbol)
|
||||
|
||||
# Check and test fetch_ohlcv
|
||||
if exchange.has.get('fetchOHLCV', False):
|
||||
timeframe = '1m'
|
||||
since = exchange.parse8601('2022-01-01T00:00:00Z')
|
||||
exchange.fetch_ohlcv(symbol, timeframe, since)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Exchange {exchange_id} failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# List of all supported exchanges
|
||||
exchange_ids = ccxt.exchanges
|
||||
|
||||
# Symbol to test (using a common symbol that is likely available on many exchanges)
|
||||
test_symbol = 'BTC/USDT'
|
||||
|
||||
# Dictionary to store the exchanges that support public market data endpoints
|
||||
working_exchanges = []
|
||||
|
||||
# Iterate through each exchange and test public endpoints
|
||||
for exchange_id in exchange_ids:
|
||||
print(f"Testing exchange: {exchange_id}")
|
||||
if test_public_endpoints(exchange_id, test_symbol):
|
||||
working_exchanges.append(exchange_id)
|
||||
|
||||
# Print the working exchanges
|
||||
print("Exchanges with working public market data endpoints:")
|
||||
for exchange in working_exchanges:
|
||||
print(exchange)
|
||||
|
||||
# Optionally save the results to a file
|
||||
with open('working_public_exchanges.txt', 'w') as f:
|
||||
for exchange in working_exchanges:
|
||||
f.write(f"{exchange}\n")
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
ascendex
|
||||
bequant
|
||||
bigone
|
||||
binance
|
||||
binanceus
|
||||
binanceusdm
|
||||
bitbay
|
||||
bitbns
|
||||
bitcoincom
|
||||
bitfinex
|
||||
bitfinex2
|
||||
bitget
|
||||
bitmart
|
||||
bitmex
|
||||
bitopro
|
||||
bitrue
|
||||
bitstamp
|
||||
bitteam
|
||||
blockchaincom
|
||||
btcmarkets
|
||||
bybit
|
||||
coinbase
|
||||
coinbaseadvanced
|
||||
coinbaseexchange
|
||||
coinex
|
||||
coinlist
|
||||
coinmate
|
||||
coinsph
|
||||
cryptocom
|
||||
currencycom
|
||||
delta
|
||||
digifinex
|
||||
exmo
|
||||
fmfwio
|
||||
gemini
|
||||
hitbtc
|
||||
hitbtc3
|
||||
htx
|
||||
huobi
|
||||
indodax
|
||||
kraken
|
||||
kucoin
|
||||
latoken
|
||||
lbank
|
||||
lykke
|
||||
mexc
|
||||
ndax
|
||||
novadax
|
||||
oceanex
|
||||
okx
|
||||
phemex
|
||||
poloniex
|
||||
probit
|
||||
timex
|
||||
tokocrypto
|
||||
tradeogre
|
||||
upbit
|
||||
wazirx
|
||||
whitebit
|
||||
woo
|
||||
xt
|
||||
yobit
|
||||
zonda
|
||||
|
|
@ -1,35 +1,41 @@
|
|||
from functools import lru_cache
|
||||
import datetime as dt
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
epoch = dt.datetime.utcfromtimestamp(0)
|
||||
|
||||
|
||||
def query_uptodate(records: pd.DataFrame, r_length: float):
|
||||
def query_uptodate(records: pd.DataFrame, r_length_min: float) -> Union[float, None]:
|
||||
"""
|
||||
Check if records that span a period of time are up-to-date.
|
||||
Check if records that span a period of time are up-to-date.
|
||||
|
||||
:param records: - The dataframe holding results from a query.
|
||||
:param r_length: - The timespan in minutes of each record in the data.
|
||||
:return: timestamp - The closest timestamp to start_datetime on record.
|
||||
:param records: The dataframe holding results from a query.
|
||||
:param r_length_min: The timespan in minutes of each record in the data.
|
||||
:return: timestamp - None if records are up-to-date otherwise the newest timestamp on record.
|
||||
"""
|
||||
print('\nChecking if the records are up-to-date...')
|
||||
# Get the newest timestamp from the records passed in stored in ms.
|
||||
# Get the newest timestamp from the records passed in stored in ms
|
||||
last_timestamp = float(records.open_time.max())
|
||||
print(f'The last ts on record is {last_timestamp}')
|
||||
# Get a timestamp of the UTC time in millisecond to match the records in the DB.
|
||||
print(f'The last timestamp on record is {last_timestamp}')
|
||||
|
||||
# Get a timestamp of the UTC time in milliseconds to match the records in the DB
|
||||
now_timestamp = unix_time_millis(dt.datetime.utcnow())
|
||||
print(f'The timestamp now is {now_timestamp}')
|
||||
# Get the seconds since the records have been updated.
|
||||
|
||||
# Get the seconds since the records have been updated
|
||||
seconds_since_update = ms_to_seconds(now_timestamp - last_timestamp)
|
||||
|
||||
# Convert to minutes
|
||||
minutes_since_update = seconds_since_update / 60
|
||||
print(f'The minutes since last update is {minutes_since_update}')
|
||||
print(f'And the length of each record is {r_length}')
|
||||
# Return the timestamp if the time since last update is more than the timespan each record covers.
|
||||
if minutes_since_update > (r_length - 0.3):
|
||||
# Return the last timestamp in seconds.
|
||||
print(f'And the length of each record is {r_length_min}')
|
||||
|
||||
# Return the timestamp if the time since last update is more than the timespan each record covers
|
||||
tolerance_minutes = 10 / 60 # 10 seconds tolerance in minutes
|
||||
if minutes_since_update > (r_length_min - tolerance_minutes):
|
||||
# Return the last timestamp in seconds
|
||||
return ms_to_seconds(last_timestamp)
|
||||
return None
|
||||
|
||||
|
|
@ -38,82 +44,98 @@ def ms_to_seconds(timestamp):
|
|||
return timestamp / 1000
|
||||
|
||||
|
||||
def unix_time_seconds(d_time):
|
||||
return (d_time - epoch).total_seconds()
|
||||
|
||||
|
||||
def unix_time_millis(d_time):
|
||||
return (d_time - epoch).total_seconds() * 1000.0
|
||||
|
||||
|
||||
def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length: float):
|
||||
def query_satisfied(start_datetime: dt.datetime, records: pd.DataFrame, r_length_min: float) -> Union[float, None]:
|
||||
"""
|
||||
Check if a records that spans a set time goes far enough back to satisfy a query.
|
||||
Check if records span back far enough to satisfy a query.
|
||||
|
||||
:param start_datetime: - The datetime to be satisfied.
|
||||
:param records: - The dataframe holding results from a query.
|
||||
:param r_length: - The timespan in minutes of each record in the data.
|
||||
:return: timestamp - The closest timestamp to start_datetime on record.
|
||||
This function determines whether the records provided cover the required start_datetime. It calculates
|
||||
the total duration covered by the records and checks if this duration, starting from the earliest record,
|
||||
reaches back to include the start_datetime.
|
||||
|
||||
:param start_datetime: The datetime the query starts at.
|
||||
:param records: The dataframe holding results from a query.
|
||||
:param r_length_min: The timespan in minutes of each record in the data.
|
||||
:return: None if the query is satisfied (records span back far enough),
|
||||
otherwise returns the earliest timestamp in records in seconds.
|
||||
"""
|
||||
# Convert the datetime to a timestamp in milliseconds to match format in the DB.
|
||||
print('\nChecking if the query is satisfied...')
|
||||
|
||||
# Convert start_datetime to Unix timestamp in milliseconds
|
||||
start_timestamp = unix_time_millis(start_datetime)
|
||||
print('Checking if we went far enough back.')
|
||||
print('Requested: start_timestamp:', start_timestamp)
|
||||
# Get the oldest timestamp from the records passed in. Convert from str to float.
|
||||
print(f'Start timestamp: {start_timestamp}')
|
||||
|
||||
# Get the oldest timestamp from the records passed in
|
||||
first_timestamp = float(records.open_time.min())
|
||||
print('Received: first_timestamp:', first_timestamp)
|
||||
if pd.isna(first_timestamp):
|
||||
# If there were no records returned. Signal a need for update by returning the current timestamp.
|
||||
return dt.datetime.utcnow().timestamp()
|
||||
# Get the minutes between the first timestamp on record and the one requested.
|
||||
minutes_between = ms_to_seconds(first_timestamp - start_timestamp) / 60
|
||||
print('minutes_between:', minutes_between)
|
||||
# Return the timestamp if the difference is greater than the timespan of a single record.
|
||||
if minutes_between > r_length:
|
||||
# Return timestamp in seconds.
|
||||
return ms_to_seconds(first_timestamp)
|
||||
return None
|
||||
print(f'First timestamp in records: {first_timestamp}')
|
||||
|
||||
# Calculate the total duration of the records in milliseconds
|
||||
total_duration = len(records) * (r_length_min * 60 * 1000)
|
||||
print(f'Total duration of records: {total_duration}')
|
||||
|
||||
# Check if the first timestamp plus the total duration is greater than or equal to the start timestamp
|
||||
if start_timestamp <= first_timestamp + total_duration:
|
||||
return None
|
||||
|
||||
return first_timestamp / 1000 # Return in seconds
|
||||
|
||||
|
||||
@lru_cache(maxsize=500)
|
||||
def ts_of_n_minutes_ago(n: int, candle_length: float) -> dt.datetime.timestamp:
|
||||
"""
|
||||
Returns an approximate start_datetime of a candle n minutes ago.
|
||||
Returns the approximate datetime for the start of a candle that was 'n' candles ago.
|
||||
|
||||
:param int - n: The number of records ago to calculate.
|
||||
:param float - candle_length: The time in minutes that each candle represents.
|
||||
:return datetime.start_datetime - The approximate start_datetime slightly less than the record is expected to have.
|
||||
:param n: int - The number of candles ago to calculate.
|
||||
:param candle_length: float - The length of each candle in minutes.
|
||||
:return: datetime - The approximate datetime for the start of the 'n'-th candle ago.
|
||||
"""
|
||||
|
||||
# Time will have passed since the last candle has closed. The start_datetime for nth candle back will be
|
||||
# less than now-(n * timeframe) but greater than now-(n+1 * timeframe). Add 1 so we don't miss a candle.
|
||||
# Increment 'n' by 1 to ensure we account for the time that has passed since the last candle closed.
|
||||
n += 1
|
||||
# Calculate the time.
|
||||
|
||||
# Calculate the total minutes ago the 'n'-th candle started.
|
||||
minutes_ago = n * candle_length
|
||||
# Get the date and time of n candle_length ago.
|
||||
date_of = dt.datetime.utcnow() - dt.timedelta(minutes=minutes_ago)
|
||||
# Return the result as a start_datetime.
|
||||
|
||||
# Get the current UTC datetime.
|
||||
now = dt.datetime.utcnow()
|
||||
|
||||
# Calculate the datetime for 'n' candles ago.
|
||||
date_of = now - dt.timedelta(minutes=minutes_ago)
|
||||
|
||||
# Return the calculated datetime.
|
||||
return date_of
|
||||
|
||||
|
||||
@lru_cache(maxsize=20)
|
||||
def timeframe_to_minutes(timeframe):
|
||||
"""
|
||||
Converts a string representing a timeframe into candle_length.
|
||||
Converts a string representing a timeframe into an integer representing the approximate minutes.
|
||||
|
||||
:param timeframe: str - Timeframe format is [multiplier:focus]. eg '15m', '4h', '1d'
|
||||
:return: int - Minutes the timeframe represents ex. '2h'-> 120(candle_length).
|
||||
:return: int - Minutes the timeframe represents ex. '2h'-> 120(minutes).
|
||||
"""
|
||||
# Extract the numerical part of the timeframe param.
|
||||
digits = int("".join([i if i.isdigit() else "" for i in timeframe]))
|
||||
# Extract the alpha part of the timeframe param.
|
||||
letter = "".join([i if i.isalpha() else "" for i in timeframe])
|
||||
|
||||
if letter == 'm':
|
||||
pass
|
||||
elif letter == 'h':
|
||||
digits *= 60
|
||||
elif letter == 'd':
|
||||
digits *= (60 * 24)
|
||||
digits *= 60 * 24
|
||||
elif letter == 'w':
|
||||
digits *= (60 * 24 * 7)
|
||||
digits *= 60 * 24 * 7
|
||||
elif letter == 'M':
|
||||
digits *= (60 * 24 * 7 * 31)
|
||||
digits *= 60 * 24 * 31 # Maximum number of days in a month
|
||||
elif letter == 'Y':
|
||||
digits *= (60 * 24 * 7 * 31 * 365)
|
||||
digits *= 60 * 24 * 365 # Exact number of days in a year
|
||||
|
||||
return digits
|
||||
|
|
|
|||
|
|
@ -1,23 +0,0 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
from AlpacaPaperExchange import AlpacaPaperExchange
|
||||
|
||||
|
||||
def test_alpaca_paper_exchange():
|
||||
ape = AlpacaPaperExchange('ape')
|
||||
# print(f'\nTesting name assignment: {ape.name}')
|
||||
# print(f'\nTesting get_client(): {ape.get_client()}')
|
||||
# print(f'\nTesting get_exchange_info(): {ape.get_exchange_info()}')
|
||||
# print(f'\nTesting get_symbols(): {ape.get_symbols()}')
|
||||
# print(f'\nTesting get_balances(): {ape.get_balances()}')
|
||||
# print(f'\nTesting get_avail_intervals(): {ape.get_avail_intervals()}')
|
||||
last_hour_date_time = datetime.now() - timedelta(hours=24)
|
||||
k = ape.get_historical_klines(symbol="BTC/USDT", interval="5m",start_str=last_hour_date_time, klines_type=0)
|
||||
{print(k) for k in k}
|
||||
# print(f'\nTesting get_symbol_precision_rule(): {ape.get_symbol_precision_rule("BTC/USDT")}')
|
||||
# print(f'\nTesting get_min_qty(): {ape.get_min_qty("ETH/USDT")}')
|
||||
# print(f'\nTesting get_min_notional_qty(): {ape.get_min_notional_qty("BTC/USDT")}')
|
||||
# print(f'\nTesting get_price(): {ape.get_price("ETH/USDT")}')
|
||||
# ape.get_order()
|
||||
|
||||
assert True
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
from binance.enums import HistoricalKlinesType
|
||||
|
||||
from BinanceFutures import BinanceFuturesExchange
|
||||
|
||||
|
||||
def test_binancefuturesexchange():
|
||||
bfe = BinanceFuturesExchange('bfe')
|
||||
print(f'\nTesting name assignment: {bfe.name}')
|
||||
print(f'\nTesting get_client(): {bfe.get_client()}')
|
||||
# print(f'\nTesting get_exchange_info(): {bfe.get_exchange_info()}')
|
||||
print(f'\nTesting get_symbols(): {bfe.get_symbols()}')
|
||||
print(f'\nTesting get_balances(): {bfe.get_balances()}')
|
||||
print(f'\nTesting get_avail_intervals(): {bfe.get_avail_intervals()}')
|
||||
print(f'\nTesting get_historical_klines(): {bfe.get_historical_klines(symbol="BTCUSDT",interval="15m",start_str="1 hour ago UTC", klines_type=HistoricalKlinesType.SPOT)}')
|
||||
print(f'\nTesting get_symbol_precision_rule(): {bfe.get_symbol_precision_rule("ETHUSDT")}')
|
||||
print(f'\nTesting get_min_qty(): {bfe.get_min_qty("ETHUSDT")}')
|
||||
print(f'\nTesting get_min_notional_qty(): {bfe.get_min_notional_qty("BTCUSDT")}')
|
||||
print(f'\nTesting get_price(): {bfe.get_price("BTCUSDT")}')
|
||||
# bse.get_order()
|
||||
assert True
|
||||
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
from binance.enums import HistoricalKlinesType
|
||||
|
||||
from BinanceSpot import BinanceSpotExchange
|
||||
|
||||
|
||||
def test_binance_spot_exchange():
|
||||
bse = BinanceSpotExchange('bse')
|
||||
print(f'\nTesting name assignment: {bse.name}')
|
||||
print(f'Testing get_client(): {bse.get_client()}')
|
||||
# print(f'Testing get_symbol_info(): {bse.get_exchange_info()}')
|
||||
print(f'Testing get_symbols(): {bse.get_symbols()}')
|
||||
print(f'Testing get_balances(): {bse.get_balances()}')
|
||||
print(f'Testing get_avail_intervals(): {bse.get_avail_intervals()}')
|
||||
print(f'Testing get_historical_klines(): {bse.get_historical_klines(symbol="ETHUSDT",interval="15m",start_str="1 hour ago UTC",klines_type=HistoricalKlinesType.SPOT)}')
|
||||
print(f'Testing get_symbol_precision_rule(): {bse.get_symbol_precision_rule("ETHUSDT")}')
|
||||
print(f'Testing get_min_qty(): {bse.get_min_qty("ETHUSDT")}')
|
||||
print(f'Testing get_min_notional_qty(): {bse.get_min_notional_qty("ETHUSDT")}')
|
||||
print(f'Testing get_price(): {bse.get_price("BTCUSDT")}')
|
||||
# #bse.get_order()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,9 +1,124 @@
|
|||
from DataCache import DataCache
|
||||
from exchangeinterface import ExchangeInterface
|
||||
import unittest
|
||||
import pandas as pd
|
||||
import datetime as dt
|
||||
|
||||
|
||||
def test_cache_exists():
|
||||
exchanges = ExchangeInterface()
|
||||
# This object maintains all the cached data. Pass it connection to the exchanges.
|
||||
data = DataCache(exchanges)
|
||||
assert data.cache_exists(key='BTC/USD_2h_alpaca') is False
|
||||
class TestDataCache(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Setup the database connection here
|
||||
self.exchanges = ExchangeInterface()
|
||||
self.exchanges.connect_exchange(exchange_name='binance', user_name='test_guy', api_keys=None)
|
||||
# This object maintains all the cached data. Pass it connection to the exchanges.
|
||||
self.data = DataCache(self.exchanges)
|
||||
|
||||
asset, timeframe, exchange = 'BTC/USD', '2h', 'binance'
|
||||
self.key1 = f'{asset}_{timeframe}_{exchange}'
|
||||
|
||||
asset, timeframe, exchange = 'ETH/USD', '2h', 'binance'
|
||||
self.key2 = f'{asset}_{timeframe}_{exchange}'
|
||||
|
||||
def test_set_cache(self):
|
||||
# Tests
|
||||
print('Testing set_cache flag not set:')
|
||||
self.data.set_cache(data='data', key=self.key1)
|
||||
attr = self.data.__getattribute__('cached_data')
|
||||
self.assertEqual(attr[self.key1], 'data')
|
||||
|
||||
self.data.set_cache(data='more_data', key=self.key1)
|
||||
attr = self.data.__getattribute__('cached_data')
|
||||
self.assertEqual(attr[self.key1], 'more_data')
|
||||
|
||||
print('Testing set_cache no-overwrite flag set:')
|
||||
self.data.set_cache(data='even_more_data', key=self.key1, do_not_overwrite=True)
|
||||
attr = self.data.__getattribute__('cached_data')
|
||||
self.assertEqual(attr[self.key1], 'more_data')
|
||||
|
||||
def test_cache_exists(self):
|
||||
# Tests
|
||||
print('Testing cache_exists() method:')
|
||||
self.assertFalse(self.data.cache_exists(key=self.key2))
|
||||
self.data.set_cache(data='data', key=self.key1)
|
||||
self.assertTrue(self.data.cache_exists(key=self.key1))
|
||||
|
||||
def test_update_candle_cache(self):
|
||||
print('Testing update_candle_cache() method:')
|
||||
# Initial data
|
||||
df_initial = pd.DataFrame({
|
||||
'open_time': [1, 2, 3],
|
||||
'open': [100, 101, 102],
|
||||
'high': [110, 111, 112],
|
||||
'low': [90, 91, 92],
|
||||
'close': [105, 106, 107],
|
||||
'volume': [1000, 1001, 1002]
|
||||
})
|
||||
|
||||
# Data to be added
|
||||
df_new = pd.DataFrame({
|
||||
'open_time': [3, 4, 5],
|
||||
'open': [102, 103, 104],
|
||||
'high': [112, 113, 114],
|
||||
'low': [92, 93, 94],
|
||||
'close': [107, 108, 109],
|
||||
'volume': [1002, 1003, 1004]
|
||||
})
|
||||
|
||||
self.data.set_cache(data=df_initial, key=self.key1)
|
||||
self.data.update_candle_cache(more_records=df_new, key=self.key1)
|
||||
|
||||
result = self.data.get_cache(key=self.key1)
|
||||
expected = pd.DataFrame({
|
||||
'open_time': [1, 2, 3, 4, 5],
|
||||
'open': [100, 101, 102, 103, 104],
|
||||
'high': [110, 111, 112, 113, 114],
|
||||
'low': [90, 91, 92, 93, 94],
|
||||
'close': [105, 106, 107, 108, 109],
|
||||
'volume': [1000, 1001, 1002, 1003, 1004]
|
||||
})
|
||||
|
||||
pd.testing.assert_frame_equal(result, expected)
|
||||
|
||||
def test_update_cached_dict(self):
|
||||
print('Testing update_cached_dict() method:')
|
||||
self.data.set_cache(data={}, key=self.key1)
|
||||
self.data.update_cached_dict(cache_key=self.key1, dict_key='sub_key', data='value')
|
||||
|
||||
cache = self.data.get_cache(key=self.key1)
|
||||
self.assertEqual(cache['sub_key'], 'value')
|
||||
|
||||
def test_get_cache(self):
|
||||
print('Testing get_cache() method:')
|
||||
self.data.set_cache(data='data', key=self.key1)
|
||||
result = self.data.get_cache(key=self.key1)
|
||||
self.assertEqual(result, 'data')
|
||||
|
||||
def test_get_records_since(self):
|
||||
print('Testing get_records_since() method:')
|
||||
df_initial = pd.DataFrame({
|
||||
'open_time': [1, 2, 3],
|
||||
'open': [100, 101, 102],
|
||||
'high': [110, 111, 112],
|
||||
'low': [90, 91, 92],
|
||||
'close': [105, 106, 107],
|
||||
'volume': [1000, 1001, 1002]
|
||||
})
|
||||
|
||||
self.data.set_cache(data=df_initial, key=self.key1)
|
||||
start_datetime = dt.datetime.utcfromtimestamp(2)
|
||||
result = self.data.get_records_since(key=self.key1, start_datetime=start_datetime, record_length=60, ex_details=[]).sort_values(by='open_time').reset_index(drop=True)
|
||||
|
||||
expected = pd.DataFrame({
|
||||
'open_time': [2, 3],
|
||||
'open': [101, 102],
|
||||
'high': [111, 112],
|
||||
'low': [91, 92],
|
||||
'close': [106, 107],
|
||||
'volume': [1001, 1002]
|
||||
})
|
||||
|
||||
pd.testing.assert_frame_equal(result, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,28 +1,73 @@
|
|||
def test_index():
|
||||
from BrighterTrades import BrighterTrades
|
||||
obj = BrighterTrades()
|
||||
assert True
|
||||
import unittest
|
||||
from flask import Flask
|
||||
from src.app import app
|
||||
import json
|
||||
|
||||
|
||||
def test_ws():
|
||||
assert False
|
||||
class FlaskAppTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""
|
||||
Set up the test client and any other test configuration.
|
||||
"""
|
||||
self.app = app.test_client()
|
||||
self.app.testing = True
|
||||
|
||||
def test_index(self):
|
||||
"""
|
||||
Test the index route.
|
||||
"""
|
||||
response = self.app.get('/')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn(b'Welcome', response.data) # Adjust this based on your actual landing page content
|
||||
|
||||
def test_login(self):
|
||||
"""
|
||||
Test the login route with valid and invalid credentials.
|
||||
"""
|
||||
# Valid credentials
|
||||
valid_data = {'user_name': 'test_user', 'password': 'test_password'}
|
||||
response = self.app.post('/login', data=valid_data)
|
||||
self.assertEqual(response.status_code, 302) # Redirects on success
|
||||
|
||||
# Invalid credentials
|
||||
invalid_data = {'user_name': 'wrong_user', 'password': 'wrong_password'}
|
||||
response = self.app.post('/login', data=invalid_data)
|
||||
self.assertEqual(response.status_code, 302) # Redirects on failure
|
||||
self.assertIn(b'Invalid user_name or password', response.data)
|
||||
|
||||
def test_signup(self):
|
||||
"""
|
||||
Test the signup route.
|
||||
"""
|
||||
data = {'email': 'test@example.com', 'user_name': 'new_user', 'password': 'new_password'}
|
||||
response = self.app.post('/signup_submit', data=data)
|
||||
self.assertEqual(response.status_code, 302) # Redirects on success
|
||||
|
||||
def test_signout(self):
|
||||
"""
|
||||
Test the signout route.
|
||||
"""
|
||||
response = self.app.get('/signout')
|
||||
self.assertEqual(response.status_code, 302) # Redirects on signout
|
||||
|
||||
def test_history(self):
|
||||
"""
|
||||
Test the history route.
|
||||
"""
|
||||
data = {"user_name": "test_user"}
|
||||
response = self.app.post('/api/history', data=json.dumps(data), content_type='application/json')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn(b'price_history', response.data)
|
||||
|
||||
def test_indicator_init(self):
|
||||
"""
|
||||
Test the indicator initialization route.
|
||||
"""
|
||||
data = {"user_name": "test_user"}
|
||||
response = self.app.post('/api/indicator_init', data=json.dumps(data), content_type='application/json')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn(b'indicator_data', response.data)
|
||||
|
||||
|
||||
def test_settings():
|
||||
assert False
|
||||
|
||||
|
||||
def test_history():
|
||||
assert False
|
||||
|
||||
|
||||
def test_signup():
|
||||
assert False
|
||||
|
||||
|
||||
def test_signup_submit():
|
||||
assert False
|
||||
|
||||
|
||||
def test_indicator_init():
|
||||
assert False
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,162 @@
|
|||
import time
|
||||
import unittest
|
||||
import datetime as dt
|
||||
import pandas as pd
|
||||
from shared_utilities import (
|
||||
query_uptodate, ms_to_seconds, unix_time_seconds, unix_time_millis,
|
||||
query_satisfied, ts_of_n_minutes_ago, timeframe_to_minutes
|
||||
)
|
||||
|
||||
|
||||
class TestSharedUtilities(unittest.TestCase):
|
||||
|
||||
def test_query_uptodate(self):
|
||||
print('Testing query_uptodate()')
|
||||
|
||||
# (Test case 1) The records should not be up-to-date (very old timestamps)
|
||||
records = pd.DataFrame({
|
||||
'open_time': [1, 2, 3, 4, 5]
|
||||
})
|
||||
result = query_uptodate(records, 1)
|
||||
if result is None:
|
||||
print('Records are up-to-date.')
|
||||
else:
|
||||
print('Records are not up-to-date.')
|
||||
print(f'Result for the first test case: {result}')
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
# (Test case 2) The records should be up-to-date (recent timestamps)
|
||||
now = unix_time_millis(dt.datetime.utcnow())
|
||||
recent_records = pd.DataFrame({
|
||||
'open_time': [now - 70000, now - 60000, now - 40000]
|
||||
})
|
||||
result = query_uptodate(recent_records, 1)
|
||||
if result is None:
|
||||
print('Records are up-to-date.')
|
||||
else:
|
||||
print('Records are not up-to-date.')
|
||||
print(f'Result for the second test case: {result}')
|
||||
self.assertIsNone(result)
|
||||
|
||||
# (Test case 3) The records just under the tolerance for a record length of 1 hour.
|
||||
# The records should not be up-to-date (recent timestamps)
|
||||
one_hour = 60 * 60 * 1000 # one hour in milliseconds
|
||||
tolerance_milliseconds = 10 * 1000 # tolerance in milliseconds
|
||||
recent_time = unix_time_millis(dt.datetime.utcnow())
|
||||
borderline_records = pd.DataFrame({
|
||||
'open_time': [recent_time - one_hour + (tolerance_milliseconds - 3)] # just within the tolerance
|
||||
})
|
||||
result = query_uptodate(borderline_records, 60)
|
||||
if result is None:
|
||||
print('Records are up-to-date.')
|
||||
else:
|
||||
print('Records are not up-to-date.')
|
||||
print(f'Result for the third test case: {result}')
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
# (Test case 4) The records just over the tolerance for a record length of 1 hour.
|
||||
# The records should be up-to-date (recent timestamps)
|
||||
one_hour = 60 * 60 * 1000 # one hour in milliseconds
|
||||
tolerance_milliseconds = 10 * 1000 # tolerance in milliseconds
|
||||
recent_time = unix_time_millis(dt.datetime.utcnow())
|
||||
borderline_records = pd.DataFrame({
|
||||
'open_time': [recent_time - one_hour + (tolerance_milliseconds + 3)] # just within the tolerance
|
||||
})
|
||||
result = query_uptodate(borderline_records, 60)
|
||||
if result is None:
|
||||
print('Records are up-to-date.')
|
||||
else:
|
||||
print('Records are not up-to-date.')
|
||||
print(f'Result for the third test case: {result}')
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_ms_to_seconds(self):
|
||||
print('Testing ms_to_seconds()')
|
||||
self.assertEqual(ms_to_seconds(1000), 1)
|
||||
self.assertEqual(ms_to_seconds(0), 0)
|
||||
|
||||
def test_unix_time_seconds(self):
|
||||
print('Testing unix_time_seconds()')
|
||||
time = dt.datetime(2020, 1, 1)
|
||||
self.assertEqual(unix_time_seconds(time), 1577836800)
|
||||
|
||||
def test_unix_time_millis(self):
|
||||
print('Testing unix_time_millis()')
|
||||
time = dt.datetime(2020, 1, 1)
|
||||
self.assertEqual(unix_time_millis(time), 1577836800000.0)
|
||||
|
||||
def test_query_satisfied(self):
|
||||
print('Testing query_satisfied()')
|
||||
|
||||
# Test case where the records should satisfy the query (records cover the start time)
|
||||
start_datetime = dt.datetime(2020, 1, 1)
|
||||
records = pd.DataFrame({
|
||||
'open_time': [1577836800000.0 - 600000, 1577836800000.0 - 300000, 1577836800000.0]
|
||||
# Covering the start time
|
||||
})
|
||||
result = query_satisfied(start_datetime, records, 1)
|
||||
if result is None:
|
||||
print('Query is satisfied: records span back far enough.')
|
||||
else:
|
||||
print('Query is not satisfied: records do not span back far enough.')
|
||||
print(f'Result for the first test case: {result}')
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
# Test case where the records should not satisfy the query (recent records but not enough)
|
||||
recent_time = unix_time_millis(dt.datetime.utcnow())
|
||||
records = pd.DataFrame({
|
||||
'open_time': [recent_time - 300 * 60 * 1000, recent_time - 240 * 60 * 1000, recent_time - 180 * 60 * 1000]
|
||||
})
|
||||
result = query_satisfied(start_datetime, records, 1)
|
||||
if result is None:
|
||||
print('Query is satisfied: records span back far enough.')
|
||||
else:
|
||||
print('Query is not satisfied: records do not span back far enough.')
|
||||
print(f'Result for the second test case: {result}')
|
||||
self.assertIsNone(result)
|
||||
|
||||
# Additional test case for partial coverage
|
||||
start_datetime = dt.datetime.utcnow() - dt.timedelta(minutes=300)
|
||||
records = pd.DataFrame({
|
||||
'open_time': [unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=240)),
|
||||
unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=180)),
|
||||
unix_time_millis(dt.datetime.utcnow() - dt.timedelta(minutes=120))]
|
||||
})
|
||||
result = query_satisfied(start_datetime, records, 60)
|
||||
if result is None:
|
||||
print('Query is satisfied: records span back far enough.')
|
||||
else:
|
||||
print('Query is not satisfied: records do not span back far enough.')
|
||||
print(f'Result for the third test case: {result}')
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_ts_of_n_minutes_ago(self):
|
||||
print('Testing ts_of_n_minutes_ago()')
|
||||
now = dt.datetime.utcnow()
|
||||
|
||||
test_cases = [
|
||||
(60, 1), # 60 candles of 1 minute each
|
||||
(10, 5), # 10 candles of 5 minutes each
|
||||
(1, 1440), # 1 candle of 1 day (1440 minutes)
|
||||
(7, 10080), # 7 candles of 1 week (10080 minutes)
|
||||
(30, 60), # 30 candles of 1 hour (60 minutes)
|
||||
]
|
||||
|
||||
for n, candle_length in test_cases:
|
||||
with self.subTest(n=n, candle_length=candle_length):
|
||||
result = ts_of_n_minutes_ago(n, candle_length)
|
||||
expected_time = now - dt.timedelta(minutes=(n + 1) * candle_length)
|
||||
self.assertAlmostEqual(unix_time_seconds(result), unix_time_seconds(expected_time), delta=60)
|
||||
|
||||
def test_timeframe_to_minutes(self):
|
||||
print('Testing timeframe_to_minutes()')
|
||||
self.assertEqual(timeframe_to_minutes('15m'), 15)
|
||||
self.assertEqual(timeframe_to_minutes('1h'), 60)
|
||||
self.assertEqual(timeframe_to_minutes('1d'), 1440)
|
||||
self.assertEqual(timeframe_to_minutes('1w'), 10080)
|
||||
self.assertEqual(timeframe_to_minutes('1M'), 44640) # 31 days in a month
|
||||
self.assertEqual(timeframe_to_minutes('1Y'), 525600) # 525600 minutes in a year
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue