brighter-trading/src/brokers/paper_broker.py

729 lines
27 KiB
Python

"""
Paper Trading Broker Implementation for BrighterTrading.
Simulates order execution with live price data for paper trading.
Orders are filled based on current market prices with configurable
slippage and commission.
"""
import logging
from typing import Any, Dict, List, Optional, Callable
import uuid
import json
from datetime import datetime, timezone, timedelta
from .base_broker import (
BaseBroker, OrderResult, OrderSide, OrderType, OrderStatus, Position
)
logger = logging.getLogger(__name__)
class PaperOrder:
"""Represents a paper trading order."""
def __init__(
self,
order_id: str,
symbol: str,
side: OrderSide,
order_type: OrderType,
size: float,
price: Optional[float] = None,
stop_loss: Optional[float] = None,
take_profit: Optional[float] = None,
time_in_force: str = 'GTC'
):
self.order_id = order_id
self.symbol = symbol
self.side = side
self.order_type = order_type
self.size = size
self.price = price
self.stop_loss = stop_loss
self.take_profit = take_profit
self.time_in_force = time_in_force
self.status = OrderStatus.PENDING
self.filled_qty = 0.0
self.filled_price = 0.0
self.commission = 0.0
self.locked_funds = 0.0 # Amount locked for this order
self.created_at = datetime.now(timezone.utc)
self.filled_at: Optional[datetime] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'order_id': self.order_id,
'symbol': self.symbol,
'side': self.side.value,
'order_type': self.order_type.value,
'size': self.size,
'price': self.price,
'stop_loss': self.stop_loss,
'take_profit': self.take_profit,
'status': self.status.value,
'filled_qty': self.filled_qty,
'filled_price': self.filled_price,
'commission': self.commission,
'locked_funds': self.locked_funds,
'created_at': self.created_at.isoformat(),
'filled_at': self.filled_at.isoformat() if self.filled_at else None
}
class PaperBroker(BaseBroker):
"""
Paper trading broker that simulates order execution.
Features:
- Simulated fills based on market prices
- Configurable slippage and commission
- Position tracking with P&L calculation
- Persistent state across restarts (via data_cache)
"""
def __init__(
self,
price_provider: Optional[Callable[[str], float]] = None,
data_cache: Any = None,
initial_balance: float = 10000.0,
commission: float = 0.001,
slippage: float = 0.0005
):
"""
Initialize the PaperBroker.
:param price_provider: Callable that returns current price for a symbol.
:param data_cache: DataCache instance for persistence.
:param initial_balance: Starting balance in quote currency.
:param commission: Commission rate (0.001 = 0.1%).
:param slippage: Slippage rate for market orders.
"""
super().__init__(initial_balance, commission, slippage)
self._price_provider = price_provider
self._data_cache = data_cache
# Balance tracking
self._cash = initial_balance
self._locked_balance = 0.0
# Order and position storage
self._orders: Dict[str, PaperOrder] = {}
self._positions: Dict[str, Position] = {}
self._trade_history: List[Dict[str, Any]] = []
# SL/TP tracking per symbol: {symbol: {stop_loss, take_profit, side, entry_price}}
self._position_sltp: Dict[str, Dict[str, Any]] = {}
# Current prices cache
self._current_prices: Dict[str, float] = {}
def set_price_provider(self, provider: Callable[[str], float]):
"""Set the price provider callable."""
self._price_provider = provider
def update_price(self, symbol: str, price: float):
"""Update the current price for a symbol."""
self._current_prices[symbol] = price
def place_order(
self,
symbol: str,
side: OrderSide,
order_type: OrderType,
size: float,
price: Optional[float] = None,
stop_loss: Optional[float] = None,
take_profit: Optional[float] = None,
time_in_force: str = 'GTC'
) -> OrderResult:
"""Place a paper trading order."""
order_id = str(uuid.uuid4())[:8]
# Validate order
current_price = self.get_current_price(symbol)
if current_price <= 0:
return OrderResult(
success=False,
message=f"Cannot get current price for {symbol}"
)
# Calculate required margin/cost using the actual execution price
# For market orders: include slippage; for limit orders: use limit price
if order_type == OrderType.MARKET:
if side == OrderSide.BUY:
execution_price = current_price * (1 + self.slippage)
else:
execution_price = current_price * (1 - self.slippage)
else:
# For limit orders, use the limit price (or current price if not specified)
execution_price = price if price else current_price
order_value = size * execution_price
required_funds = order_value * (1 + self.commission)
if side == OrderSide.BUY and required_funds > self._cash:
return OrderResult(
success=False,
message=f"Insufficient funds: need {required_funds:.2f}, have {self._cash:.2f}"
)
# For sell orders, check if we have the position
if side == OrderSide.SELL:
position = self._positions.get(symbol)
if position is None or position.size < size:
available = position.size if position else 0
return OrderResult(
success=False,
message=f"Insufficient position: need {size}, have {available}"
)
# Create the order
order = PaperOrder(
order_id=order_id,
symbol=symbol,
side=side,
order_type=order_type,
size=size,
price=price,
stop_loss=stop_loss,
take_profit=take_profit,
time_in_force=time_in_force
)
# For market orders, fill immediately
if order_type == OrderType.MARKET:
fill_price = self._calculate_fill_price(symbol, side, current_price)
self._fill_order(order, fill_price)
logger.info(f"PaperBroker: Market order filled: {side.value} {size} {symbol} @ {fill_price:.4f}")
else:
# Store pending order
order.status = OrderStatus.OPEN
self._orders[order_id] = order
# Lock funds for buy limit orders
if side == OrderSide.BUY:
order.locked_funds = required_funds
self._locked_balance += required_funds
self._cash -= required_funds
logger.info(f"PaperBroker: Limit order placed: {side.value} {size} {symbol} @ {price}")
return OrderResult(
success=True,
order_id=order_id,
status=order.status,
filled_qty=order.filled_qty,
filled_price=order.filled_price,
commission=order.commission,
message=f"Order {order_id} {'filled' if order.status == OrderStatus.FILLED else 'placed'}"
)
def _calculate_fill_price(self, symbol: str, side: OrderSide, market_price: float) -> float:
"""Calculate fill price with slippage."""
if side == OrderSide.BUY:
return market_price * (1 + self.slippage)
else:
return market_price * (1 - self.slippage)
def _fill_order(self, order: PaperOrder, fill_price: float):
"""Execute an order fill."""
order.filled_qty = order.size
order.filled_price = fill_price
order.commission = order.size * fill_price * self.commission
order.status = OrderStatus.FILLED
order.filled_at = datetime.now(timezone.utc)
# Calculate profitability for sell orders (for fee tracking)
order.is_profitable = False
order.realized_pnl = 0.0
order.entry_price = 0.0
# Update balances and positions
order_value = order.size * fill_price
if order.side == OrderSide.BUY:
# Deduct cost from cash
total_cost = order_value + order.commission
self._cash -= total_cost
# Update position
if order.symbol in self._positions:
existing = self._positions[order.symbol]
new_size = existing.size + order.size
new_entry = (existing.entry_price * existing.size + fill_price * order.size) / new_size
existing.size = new_size
existing.entry_price = new_entry
else:
self._positions[order.symbol] = Position(
symbol=order.symbol,
size=order.size,
entry_price=fill_price,
current_price=fill_price,
unrealized_pnl=0.0
)
# Record SL/TP for this position (if set on order)
if order.stop_loss or order.take_profit:
self._position_sltp[order.symbol] = {
'stop_loss': order.stop_loss,
'take_profit': order.take_profit,
'side': 'long', # BUY opens a long position
'entry_price': fill_price
}
logger.info(f"SL/TP set for {order.symbol}: SL={order.stop_loss}, TP={order.take_profit}")
else:
# Add proceeds to cash
total_proceeds = order_value - order.commission
self._cash += total_proceeds
# Update position
if order.symbol in self._positions:
position = self._positions[order.symbol]
order.entry_price = position.entry_price # Store for fee calculation
realized_pnl = (fill_price - position.entry_price) * order.size - order.commission
position.realized_pnl += realized_pnl
position.size -= order.size
# Track profitability for fee calculation
order.realized_pnl = realized_pnl
order.is_profitable = realized_pnl > 0
# Remove position if fully closed
if position.size <= 0:
del self._positions[order.symbol]
# Clear SL/TP tracking for this symbol
if order.symbol in self._position_sltp:
del self._position_sltp[order.symbol]
# Record trade
self._trade_history.append({
'order_id': order.order_id,
'symbol': order.symbol,
'side': order.side.value,
'size': order.size,
'price': fill_price,
'commission': order.commission,
'timestamp': order.filled_at.isoformat()
})
# Store in orders dict for reference
self._orders[order.order_id] = order
def cancel_order(self, order_id: str) -> bool:
"""Cancel a pending order."""
if order_id not in self._orders:
return False
order = self._orders[order_id]
if order.status != OrderStatus.OPEN:
return False
# Release locked funds for buy orders
if order.side == OrderSide.BUY and order.locked_funds > 0:
self._locked_balance -= order.locked_funds
self._cash += order.locked_funds
order.locked_funds = 0
order.status = OrderStatus.CANCELLED
logger.info(f"PaperBroker: Order {order_id} cancelled")
return True
def get_order(self, order_id: str) -> Optional[Dict[str, Any]]:
"""Get order details."""
if order_id in self._orders:
return self._orders[order_id].to_dict()
return None
def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get all open orders."""
open_orders = [
order.to_dict() for order in self._orders.values()
if order.status == OrderStatus.OPEN
and (symbol is None or order.symbol == symbol)
]
return open_orders
def get_balance(self, asset: Optional[str] = None) -> float:
"""Get total equity (cash + locked + position values)."""
total = self._cash + self._locked_balance
for position in self._positions.values():
# Include full position value (size * current_price), not just unrealized PnL
total += position.size * position.current_price
return total
def get_available_balance(self, asset: Optional[str] = None) -> float:
"""Get available cash (not locked in orders)."""
return self._cash
def get_position(self, symbol: str) -> Optional[Position]:
"""Get position for a symbol."""
return self._positions.get(symbol)
def get_all_positions(self) -> List[Position]:
"""Get all open positions."""
return list(self._positions.values())
def get_current_price(self, symbol: str) -> float:
"""Get current price for a symbol."""
# First check cache
if symbol in self._current_prices:
return self._current_prices[symbol]
# Then try price provider
if self._price_provider:
try:
price = self._price_provider(symbol)
self._current_prices[symbol] = price
return price
except Exception as e:
logger.warning(f"PaperBroker: Failed to get price for {symbol}: {e}")
return 0.0
def update(self) -> List[Dict[str, Any]]:
"""
Process pending limit orders and update positions.
Called on each price update to check for fills.
"""
events = []
# Update position P&L
for symbol, position in self._positions.items():
current_price = self.get_current_price(symbol)
if current_price > 0:
position.current_price = current_price
position.unrealized_pnl = (current_price - position.entry_price) * position.size
# Evaluate SL/TP for all tracked positions
for symbol, sltp in list(self._position_sltp.items()):
if symbol not in self._positions:
del self._position_sltp[symbol]
continue
position = self._positions[symbol]
current_price = self.get_current_price(symbol)
if position.size <= 0 or current_price <= 0:
del self._position_sltp[symbol]
continue
triggered = None
trigger_price = current_price
# Long position: SL triggers when price drops, TP when price rises
if sltp['side'] == 'long':
if sltp.get('stop_loss') and current_price <= sltp['stop_loss']:
triggered = 'stop_loss'
elif sltp.get('take_profit') and current_price >= sltp['take_profit']:
triggered = 'take_profit'
# Short position: SL triggers when price rises, TP when price drops
else:
if sltp.get('stop_loss') and current_price >= sltp['stop_loss']:
triggered = 'stop_loss'
elif sltp.get('take_profit') and current_price <= sltp['take_profit']:
triggered = 'take_profit'
if triggered:
# Auto-close position
close_result = self.close_position(symbol)
if close_result.success:
events.append({
'type': 'sltp_triggered',
'trigger': triggered,
'symbol': symbol,
'trigger_price': trigger_price,
'size': close_result.filled_qty,
'pnl': position.unrealized_pnl
})
logger.info(f"SL/TP triggered: {triggered} for {symbol} at {trigger_price}")
# SL/TP tracking cleared in _fill_order when position closes
# Check pending limit orders
for order_id, order in list(self._orders.items()):
if order.status != OrderStatus.OPEN:
continue
current_price = self.get_current_price(order.symbol)
if current_price <= 0:
continue
should_fill = False
if order.order_type == OrderType.LIMIT:
if order.side == OrderSide.BUY and current_price <= order.price:
should_fill = True
elif order.side == OrderSide.SELL and current_price >= order.price:
should_fill = True
if should_fill:
# Release locked funds first (for buy orders)
if order.side == OrderSide.BUY and order.locked_funds > 0:
self._locked_balance -= order.locked_funds
self._cash += order.locked_funds
order.locked_funds = 0
fill_price = order.price # Limit orders fill at limit price
self._fill_order(order, fill_price)
events.append({
'type': 'fill',
'order_id': order_id,
'symbol': order.symbol,
'side': order.side.value,
'size': order.filled_qty,
'filled_qty': order.filled_qty,
'price': order.filled_price,
'filled_price': order.filled_price,
'commission': order.commission,
'is_profitable': getattr(order, 'is_profitable', False),
'realized_pnl': getattr(order, 'realized_pnl', 0.0),
'entry_price': getattr(order, 'entry_price', 0.0)
})
logger.info(f"PaperBroker: Limit order filled: {order.side.value} {order.size} {order.symbol} @ {fill_price:.4f}")
return events
def get_trade_history(self) -> List[Dict[str, Any]]:
"""Get all executed trades."""
return self._trade_history.copy()
def reset(self):
"""Reset the broker to initial state."""
self._cash = self.initial_balance
self._locked_balance = 0.0
self._orders.clear()
self._positions.clear()
self._trade_history.clear()
self._current_prices.clear()
logger.info(f"PaperBroker: Reset with balance {self.initial_balance}")
# ==================== State Persistence Methods ====================
def _ensure_persistence_cache(self) -> bool:
"""
Ensure the persistence table/cache exists.
"""
if not self._data_cache:
return False
try:
# Ensure backing DB table exists for datacache read/write methods.
if hasattr(self._data_cache, 'db') and hasattr(self._data_cache.db, 'execute_sql'):
self._data_cache.db.execute_sql(
'CREATE TABLE IF NOT EXISTS "paper_broker_states" ('
'id INTEGER PRIMARY KEY AUTOINCREMENT, '
'tbl_key TEXT UNIQUE, '
'strategy_instance_id TEXT UNIQUE, '
'broker_state TEXT, '
'updated_at TEXT)',
[]
)
# Migration path for any older local table that was created without tbl_key.
try:
existing_df = self._data_cache.db.get_all_rows('paper_broker_states')
if 'tbl_key' not in existing_df.columns:
self._data_cache.db.execute_sql(
'ALTER TABLE "paper_broker_states" ADD COLUMN tbl_key TEXT',
[]
)
except Exception:
# If schema inspection fails, continue with current schema.
pass
# Keep tbl_key aligned with strategy_instance_id for DataCache overwrite semantics.
self._data_cache.db.execute_sql(
'UPDATE "paper_broker_states" '
'SET tbl_key = strategy_instance_id '
'WHERE tbl_key IS NULL OR tbl_key = ""',
[]
)
self._data_cache.create_cache(
name='paper_broker_states',
cache_type='table',
size_limit=5000,
eviction_policy='deny',
default_expiration=timedelta(days=7),
columns=['id', 'tbl_key', 'strategy_instance_id', 'broker_state', 'updated_at']
)
return True
except Exception as e:
logger.error(f"PaperBroker: Error ensuring persistence cache: {e}", exc_info=True)
return False
def to_state_dict(self) -> Dict[str, Any]:
"""
Serialize broker state to a dictionary for persistence.
Returns dict containing all state needed to restore the broker.
"""
# Serialize orders
orders_data = {}
for order_id, order in self._orders.items():
orders_data[order_id] = order.to_dict()
# Serialize positions
positions_data = {}
for symbol, position in self._positions.items():
positions_data[symbol] = position.to_dict()
return {
'cash': self._cash,
'locked_balance': self._locked_balance,
'initial_balance': self.initial_balance,
'commission': self.commission,
'slippage': self.slippage,
'orders': orders_data,
'positions': positions_data,
'trade_history': self._trade_history,
'current_prices': self._current_prices,
'position_sltp': self._position_sltp,
}
def from_state_dict(self, state: Dict[str, Any]):
"""
Restore broker state from a dictionary.
:param state: State dict from to_state_dict().
"""
if not state:
return
# Restore balances
self._cash = state.get('cash', self.initial_balance)
self._locked_balance = state.get('locked_balance', 0.0)
# Restore orders
self._orders.clear()
orders_data = state.get('orders', {})
for order_id, order_dict in orders_data.items():
order = PaperOrder(
order_id=order_dict['order_id'],
symbol=order_dict['symbol'],
side=OrderSide(order_dict['side']),
order_type=OrderType(order_dict['order_type']),
size=order_dict['size'],
price=order_dict.get('price'),
stop_loss=order_dict.get('stop_loss'),
take_profit=order_dict.get('take_profit'),
)
order.status = OrderStatus(order_dict['status'])
order.filled_qty = order_dict.get('filled_qty', 0.0)
order.filled_price = order_dict.get('filled_price', 0.0)
order.commission = order_dict.get('commission', 0.0)
order.locked_funds = order_dict.get('locked_funds', 0.0)
if order_dict.get('created_at'):
order.created_at = datetime.fromisoformat(order_dict['created_at'])
if order_dict.get('filled_at'):
order.filled_at = datetime.fromisoformat(order_dict['filled_at'])
self._orders[order_id] = order
# Restore positions
self._positions.clear()
positions_data = state.get('positions', {})
for symbol, pos_dict in positions_data.items():
self._positions[symbol] = Position.from_dict(pos_dict)
# Restore trade history
self._trade_history = state.get('trade_history', [])
# Restore price cache
self._current_prices = state.get('current_prices', {})
# Restore SL/TP tracking
self._position_sltp = state.get('position_sltp', {})
logger.info(f"PaperBroker: State restored - cash: {self._cash:.2f}, "
f"positions: {len(self._positions)}, orders: {len(self._orders)}")
def save_state(self, strategy_instance_id: str) -> bool:
"""
Save broker state to the data cache.
:param strategy_instance_id: Unique identifier for the strategy instance.
:return: True if saved successfully.
"""
if not self._data_cache:
logger.warning("PaperBroker: No data cache available for persistence")
return False
try:
if not self._ensure_persistence_cache():
return False
state_dict = self.to_state_dict()
state_json = json.dumps(state_dict)
# Check if state already exists
existing = self._data_cache.get_rows_from_datacache(
cache_name='paper_broker_states',
filter_vals=[('strategy_instance_id', strategy_instance_id)]
)
columns = ('tbl_key', 'strategy_instance_id', 'broker_state', 'updated_at')
values = (strategy_instance_id, strategy_instance_id, state_json, datetime.now(timezone.utc).isoformat())
if existing.empty:
# Insert new state
self._data_cache.insert_row_into_datacache(
cache_name='paper_broker_states',
columns=columns,
values=values
)
else:
# Update existing state
self._data_cache.modify_datacache_item(
cache_name='paper_broker_states',
filter_vals=[('strategy_instance_id', strategy_instance_id)],
field_names=columns,
new_values=values,
overwrite='strategy_instance_id'
)
logger.debug(f"PaperBroker: State saved for {strategy_instance_id}")
return True
except Exception as e:
logger.error(f"PaperBroker: Error saving state: {e}", exc_info=True)
return False
def load_state(self, strategy_instance_id: str) -> bool:
"""
Load broker state from the data cache.
:param strategy_instance_id: Unique identifier for the strategy instance.
:return: True if state was loaded successfully.
"""
if not self._data_cache:
logger.warning("PaperBroker: No data cache available for persistence")
return False
try:
if not self._ensure_persistence_cache():
return False
existing = self._data_cache.get_rows_from_datacache(
cache_name='paper_broker_states',
filter_vals=[('strategy_instance_id', strategy_instance_id)]
)
if existing.empty:
logger.debug(f"PaperBroker: No saved state for {strategy_instance_id}")
return False
state_json = existing.iloc[0].get('broker_state', '{}')
state_dict = json.loads(state_json)
self.from_state_dict(state_dict)
logger.info(f"PaperBroker: State loaded for {strategy_instance_id}")
return True
except Exception as e:
logger.error(f"PaperBroker: Error loading state: {e}", exc_info=True)
return False