729 lines
27 KiB
Python
729 lines
27 KiB
Python
"""
|
|
Paper Trading Broker Implementation for BrighterTrading.
|
|
|
|
Simulates order execution with live price data for paper trading.
|
|
Orders are filled based on current market prices with configurable
|
|
slippage and commission.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Optional, Callable
|
|
import uuid
|
|
import json
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
from .base_broker import (
|
|
BaseBroker, OrderResult, OrderSide, OrderType, OrderStatus, Position
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PaperOrder:
|
|
"""Represents a paper trading order."""
|
|
|
|
def __init__(
|
|
self,
|
|
order_id: str,
|
|
symbol: str,
|
|
side: OrderSide,
|
|
order_type: OrderType,
|
|
size: float,
|
|
price: Optional[float] = None,
|
|
stop_loss: Optional[float] = None,
|
|
take_profit: Optional[float] = None,
|
|
time_in_force: str = 'GTC'
|
|
):
|
|
self.order_id = order_id
|
|
self.symbol = symbol
|
|
self.side = side
|
|
self.order_type = order_type
|
|
self.size = size
|
|
self.price = price
|
|
self.stop_loss = stop_loss
|
|
self.take_profit = take_profit
|
|
self.time_in_force = time_in_force
|
|
self.status = OrderStatus.PENDING
|
|
self.filled_qty = 0.0
|
|
self.filled_price = 0.0
|
|
self.commission = 0.0
|
|
self.locked_funds = 0.0 # Amount locked for this order
|
|
self.created_at = datetime.now(timezone.utc)
|
|
self.filled_at: Optional[datetime] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
'order_id': self.order_id,
|
|
'symbol': self.symbol,
|
|
'side': self.side.value,
|
|
'order_type': self.order_type.value,
|
|
'size': self.size,
|
|
'price': self.price,
|
|
'stop_loss': self.stop_loss,
|
|
'take_profit': self.take_profit,
|
|
'status': self.status.value,
|
|
'filled_qty': self.filled_qty,
|
|
'filled_price': self.filled_price,
|
|
'commission': self.commission,
|
|
'locked_funds': self.locked_funds,
|
|
'created_at': self.created_at.isoformat(),
|
|
'filled_at': self.filled_at.isoformat() if self.filled_at else None
|
|
}
|
|
|
|
|
|
class PaperBroker(BaseBroker):
|
|
"""
|
|
Paper trading broker that simulates order execution.
|
|
|
|
Features:
|
|
- Simulated fills based on market prices
|
|
- Configurable slippage and commission
|
|
- Position tracking with P&L calculation
|
|
- Persistent state across restarts (via data_cache)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
price_provider: Optional[Callable[[str], float]] = None,
|
|
data_cache: Any = None,
|
|
initial_balance: float = 10000.0,
|
|
commission: float = 0.001,
|
|
slippage: float = 0.0005
|
|
):
|
|
"""
|
|
Initialize the PaperBroker.
|
|
|
|
:param price_provider: Callable that returns current price for a symbol.
|
|
:param data_cache: DataCache instance for persistence.
|
|
:param initial_balance: Starting balance in quote currency.
|
|
:param commission: Commission rate (0.001 = 0.1%).
|
|
:param slippage: Slippage rate for market orders.
|
|
"""
|
|
super().__init__(initial_balance, commission, slippage)
|
|
self._price_provider = price_provider
|
|
self._data_cache = data_cache
|
|
|
|
# Balance tracking
|
|
self._cash = initial_balance
|
|
self._locked_balance = 0.0
|
|
|
|
# Order and position storage
|
|
self._orders: Dict[str, PaperOrder] = {}
|
|
self._positions: Dict[str, Position] = {}
|
|
self._trade_history: List[Dict[str, Any]] = []
|
|
|
|
# SL/TP tracking per symbol: {symbol: {stop_loss, take_profit, side, entry_price}}
|
|
self._position_sltp: Dict[str, Dict[str, Any]] = {}
|
|
|
|
# Current prices cache
|
|
self._current_prices: Dict[str, float] = {}
|
|
|
|
def set_price_provider(self, provider: Callable[[str], float]):
|
|
"""Set the price provider callable."""
|
|
self._price_provider = provider
|
|
|
|
def update_price(self, symbol: str, price: float):
|
|
"""Update the current price for a symbol."""
|
|
self._current_prices[symbol] = price
|
|
|
|
def place_order(
|
|
self,
|
|
symbol: str,
|
|
side: OrderSide,
|
|
order_type: OrderType,
|
|
size: float,
|
|
price: Optional[float] = None,
|
|
stop_loss: Optional[float] = None,
|
|
take_profit: Optional[float] = None,
|
|
time_in_force: str = 'GTC'
|
|
) -> OrderResult:
|
|
"""Place a paper trading order."""
|
|
order_id = str(uuid.uuid4())[:8]
|
|
|
|
# Validate order
|
|
current_price = self.get_current_price(symbol)
|
|
if current_price <= 0:
|
|
return OrderResult(
|
|
success=False,
|
|
message=f"Cannot get current price for {symbol}"
|
|
)
|
|
|
|
# Calculate required margin/cost using the actual execution price
|
|
# For market orders: include slippage; for limit orders: use limit price
|
|
if order_type == OrderType.MARKET:
|
|
if side == OrderSide.BUY:
|
|
execution_price = current_price * (1 + self.slippage)
|
|
else:
|
|
execution_price = current_price * (1 - self.slippage)
|
|
else:
|
|
# For limit orders, use the limit price (or current price if not specified)
|
|
execution_price = price if price else current_price
|
|
|
|
order_value = size * execution_price
|
|
required_funds = order_value * (1 + self.commission)
|
|
|
|
if side == OrderSide.BUY and required_funds > self._cash:
|
|
return OrderResult(
|
|
success=False,
|
|
message=f"Insufficient funds: need {required_funds:.2f}, have {self._cash:.2f}"
|
|
)
|
|
|
|
# For sell orders, check if we have the position
|
|
if side == OrderSide.SELL:
|
|
position = self._positions.get(symbol)
|
|
if position is None or position.size < size:
|
|
available = position.size if position else 0
|
|
return OrderResult(
|
|
success=False,
|
|
message=f"Insufficient position: need {size}, have {available}"
|
|
)
|
|
|
|
# Create the order
|
|
order = PaperOrder(
|
|
order_id=order_id,
|
|
symbol=symbol,
|
|
side=side,
|
|
order_type=order_type,
|
|
size=size,
|
|
price=price,
|
|
stop_loss=stop_loss,
|
|
take_profit=take_profit,
|
|
time_in_force=time_in_force
|
|
)
|
|
|
|
# For market orders, fill immediately
|
|
if order_type == OrderType.MARKET:
|
|
fill_price = self._calculate_fill_price(symbol, side, current_price)
|
|
self._fill_order(order, fill_price)
|
|
logger.info(f"PaperBroker: Market order filled: {side.value} {size} {symbol} @ {fill_price:.4f}")
|
|
else:
|
|
# Store pending order
|
|
order.status = OrderStatus.OPEN
|
|
self._orders[order_id] = order
|
|
# Lock funds for buy limit orders
|
|
if side == OrderSide.BUY:
|
|
order.locked_funds = required_funds
|
|
self._locked_balance += required_funds
|
|
self._cash -= required_funds
|
|
logger.info(f"PaperBroker: Limit order placed: {side.value} {size} {symbol} @ {price}")
|
|
|
|
return OrderResult(
|
|
success=True,
|
|
order_id=order_id,
|
|
status=order.status,
|
|
filled_qty=order.filled_qty,
|
|
filled_price=order.filled_price,
|
|
commission=order.commission,
|
|
message=f"Order {order_id} {'filled' if order.status == OrderStatus.FILLED else 'placed'}"
|
|
)
|
|
|
|
def _calculate_fill_price(self, symbol: str, side: OrderSide, market_price: float) -> float:
|
|
"""Calculate fill price with slippage."""
|
|
if side == OrderSide.BUY:
|
|
return market_price * (1 + self.slippage)
|
|
else:
|
|
return market_price * (1 - self.slippage)
|
|
|
|
def _fill_order(self, order: PaperOrder, fill_price: float):
|
|
"""Execute an order fill."""
|
|
order.filled_qty = order.size
|
|
order.filled_price = fill_price
|
|
order.commission = order.size * fill_price * self.commission
|
|
order.status = OrderStatus.FILLED
|
|
order.filled_at = datetime.now(timezone.utc)
|
|
|
|
# Calculate profitability for sell orders (for fee tracking)
|
|
order.is_profitable = False
|
|
order.realized_pnl = 0.0
|
|
order.entry_price = 0.0
|
|
|
|
# Update balances and positions
|
|
order_value = order.size * fill_price
|
|
|
|
if order.side == OrderSide.BUY:
|
|
# Deduct cost from cash
|
|
total_cost = order_value + order.commission
|
|
self._cash -= total_cost
|
|
|
|
# Update position
|
|
if order.symbol in self._positions:
|
|
existing = self._positions[order.symbol]
|
|
new_size = existing.size + order.size
|
|
new_entry = (existing.entry_price * existing.size + fill_price * order.size) / new_size
|
|
existing.size = new_size
|
|
existing.entry_price = new_entry
|
|
else:
|
|
self._positions[order.symbol] = Position(
|
|
symbol=order.symbol,
|
|
size=order.size,
|
|
entry_price=fill_price,
|
|
current_price=fill_price,
|
|
unrealized_pnl=0.0
|
|
)
|
|
|
|
# Record SL/TP for this position (if set on order)
|
|
if order.stop_loss or order.take_profit:
|
|
self._position_sltp[order.symbol] = {
|
|
'stop_loss': order.stop_loss,
|
|
'take_profit': order.take_profit,
|
|
'side': 'long', # BUY opens a long position
|
|
'entry_price': fill_price
|
|
}
|
|
logger.info(f"SL/TP set for {order.symbol}: SL={order.stop_loss}, TP={order.take_profit}")
|
|
else:
|
|
# Add proceeds to cash
|
|
total_proceeds = order_value - order.commission
|
|
self._cash += total_proceeds
|
|
|
|
# Update position
|
|
if order.symbol in self._positions:
|
|
position = self._positions[order.symbol]
|
|
order.entry_price = position.entry_price # Store for fee calculation
|
|
realized_pnl = (fill_price - position.entry_price) * order.size - order.commission
|
|
position.realized_pnl += realized_pnl
|
|
position.size -= order.size
|
|
|
|
# Track profitability for fee calculation
|
|
order.realized_pnl = realized_pnl
|
|
order.is_profitable = realized_pnl > 0
|
|
|
|
# Remove position if fully closed
|
|
if position.size <= 0:
|
|
del self._positions[order.symbol]
|
|
# Clear SL/TP tracking for this symbol
|
|
if order.symbol in self._position_sltp:
|
|
del self._position_sltp[order.symbol]
|
|
|
|
# Record trade
|
|
self._trade_history.append({
|
|
'order_id': order.order_id,
|
|
'symbol': order.symbol,
|
|
'side': order.side.value,
|
|
'size': order.size,
|
|
'price': fill_price,
|
|
'commission': order.commission,
|
|
'timestamp': order.filled_at.isoformat()
|
|
})
|
|
|
|
# Store in orders dict for reference
|
|
self._orders[order.order_id] = order
|
|
|
|
def cancel_order(self, order_id: str) -> bool:
|
|
"""Cancel a pending order."""
|
|
if order_id not in self._orders:
|
|
return False
|
|
|
|
order = self._orders[order_id]
|
|
if order.status != OrderStatus.OPEN:
|
|
return False
|
|
|
|
# Release locked funds for buy orders
|
|
if order.side == OrderSide.BUY and order.locked_funds > 0:
|
|
self._locked_balance -= order.locked_funds
|
|
self._cash += order.locked_funds
|
|
order.locked_funds = 0
|
|
|
|
order.status = OrderStatus.CANCELLED
|
|
logger.info(f"PaperBroker: Order {order_id} cancelled")
|
|
return True
|
|
|
|
def get_order(self, order_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get order details."""
|
|
if order_id in self._orders:
|
|
return self._orders[order_id].to_dict()
|
|
return None
|
|
|
|
def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
"""Get all open orders."""
|
|
open_orders = [
|
|
order.to_dict() for order in self._orders.values()
|
|
if order.status == OrderStatus.OPEN
|
|
and (symbol is None or order.symbol == symbol)
|
|
]
|
|
return open_orders
|
|
|
|
def get_balance(self, asset: Optional[str] = None) -> float:
|
|
"""Get total equity (cash + locked + position values)."""
|
|
total = self._cash + self._locked_balance
|
|
for position in self._positions.values():
|
|
# Include full position value (size * current_price), not just unrealized PnL
|
|
total += position.size * position.current_price
|
|
return total
|
|
|
|
def get_available_balance(self, asset: Optional[str] = None) -> float:
|
|
"""Get available cash (not locked in orders)."""
|
|
return self._cash
|
|
|
|
def get_position(self, symbol: str) -> Optional[Position]:
|
|
"""Get position for a symbol."""
|
|
return self._positions.get(symbol)
|
|
|
|
def get_all_positions(self) -> List[Position]:
|
|
"""Get all open positions."""
|
|
return list(self._positions.values())
|
|
|
|
def get_current_price(self, symbol: str) -> float:
|
|
"""Get current price for a symbol."""
|
|
# First check cache
|
|
if symbol in self._current_prices:
|
|
return self._current_prices[symbol]
|
|
|
|
# Then try price provider
|
|
if self._price_provider:
|
|
try:
|
|
price = self._price_provider(symbol)
|
|
self._current_prices[symbol] = price
|
|
return price
|
|
except Exception as e:
|
|
logger.warning(f"PaperBroker: Failed to get price for {symbol}: {e}")
|
|
|
|
return 0.0
|
|
|
|
def update(self) -> List[Dict[str, Any]]:
|
|
"""
|
|
Process pending limit orders and update positions.
|
|
|
|
Called on each price update to check for fills.
|
|
"""
|
|
events = []
|
|
|
|
# Update position P&L
|
|
for symbol, position in self._positions.items():
|
|
current_price = self.get_current_price(symbol)
|
|
if current_price > 0:
|
|
position.current_price = current_price
|
|
position.unrealized_pnl = (current_price - position.entry_price) * position.size
|
|
|
|
# Evaluate SL/TP for all tracked positions
|
|
for symbol, sltp in list(self._position_sltp.items()):
|
|
if symbol not in self._positions:
|
|
del self._position_sltp[symbol]
|
|
continue
|
|
|
|
position = self._positions[symbol]
|
|
current_price = self.get_current_price(symbol)
|
|
|
|
if position.size <= 0 or current_price <= 0:
|
|
del self._position_sltp[symbol]
|
|
continue
|
|
|
|
triggered = None
|
|
trigger_price = current_price
|
|
|
|
# Long position: SL triggers when price drops, TP when price rises
|
|
if sltp['side'] == 'long':
|
|
if sltp.get('stop_loss') and current_price <= sltp['stop_loss']:
|
|
triggered = 'stop_loss'
|
|
elif sltp.get('take_profit') and current_price >= sltp['take_profit']:
|
|
triggered = 'take_profit'
|
|
# Short position: SL triggers when price rises, TP when price drops
|
|
else:
|
|
if sltp.get('stop_loss') and current_price >= sltp['stop_loss']:
|
|
triggered = 'stop_loss'
|
|
elif sltp.get('take_profit') and current_price <= sltp['take_profit']:
|
|
triggered = 'take_profit'
|
|
|
|
if triggered:
|
|
# Auto-close position
|
|
close_result = self.close_position(symbol)
|
|
if close_result.success:
|
|
events.append({
|
|
'type': 'sltp_triggered',
|
|
'trigger': triggered,
|
|
'symbol': symbol,
|
|
'trigger_price': trigger_price,
|
|
'size': close_result.filled_qty,
|
|
'pnl': position.unrealized_pnl
|
|
})
|
|
logger.info(f"SL/TP triggered: {triggered} for {symbol} at {trigger_price}")
|
|
# SL/TP tracking cleared in _fill_order when position closes
|
|
|
|
# Check pending limit orders
|
|
for order_id, order in list(self._orders.items()):
|
|
if order.status != OrderStatus.OPEN:
|
|
continue
|
|
|
|
current_price = self.get_current_price(order.symbol)
|
|
if current_price <= 0:
|
|
continue
|
|
|
|
should_fill = False
|
|
|
|
if order.order_type == OrderType.LIMIT:
|
|
if order.side == OrderSide.BUY and current_price <= order.price:
|
|
should_fill = True
|
|
elif order.side == OrderSide.SELL and current_price >= order.price:
|
|
should_fill = True
|
|
|
|
if should_fill:
|
|
# Release locked funds first (for buy orders)
|
|
if order.side == OrderSide.BUY and order.locked_funds > 0:
|
|
self._locked_balance -= order.locked_funds
|
|
self._cash += order.locked_funds
|
|
order.locked_funds = 0
|
|
|
|
fill_price = order.price # Limit orders fill at limit price
|
|
self._fill_order(order, fill_price)
|
|
|
|
events.append({
|
|
'type': 'fill',
|
|
'order_id': order_id,
|
|
'symbol': order.symbol,
|
|
'side': order.side.value,
|
|
'size': order.filled_qty,
|
|
'filled_qty': order.filled_qty,
|
|
'price': order.filled_price,
|
|
'filled_price': order.filled_price,
|
|
'commission': order.commission,
|
|
'is_profitable': getattr(order, 'is_profitable', False),
|
|
'realized_pnl': getattr(order, 'realized_pnl', 0.0),
|
|
'entry_price': getattr(order, 'entry_price', 0.0)
|
|
})
|
|
|
|
logger.info(f"PaperBroker: Limit order filled: {order.side.value} {order.size} {order.symbol} @ {fill_price:.4f}")
|
|
|
|
return events
|
|
|
|
def get_trade_history(self) -> List[Dict[str, Any]]:
|
|
"""Get all executed trades."""
|
|
return self._trade_history.copy()
|
|
|
|
def reset(self):
|
|
"""Reset the broker to initial state."""
|
|
self._cash = self.initial_balance
|
|
self._locked_balance = 0.0
|
|
self._orders.clear()
|
|
self._positions.clear()
|
|
self._trade_history.clear()
|
|
self._current_prices.clear()
|
|
logger.info(f"PaperBroker: Reset with balance {self.initial_balance}")
|
|
|
|
# ==================== State Persistence Methods ====================
|
|
|
|
def _ensure_persistence_cache(self) -> bool:
|
|
"""
|
|
Ensure the persistence table/cache exists.
|
|
"""
|
|
if not self._data_cache:
|
|
return False
|
|
|
|
try:
|
|
# Ensure backing DB table exists for datacache read/write methods.
|
|
if hasattr(self._data_cache, 'db') and hasattr(self._data_cache.db, 'execute_sql'):
|
|
self._data_cache.db.execute_sql(
|
|
'CREATE TABLE IF NOT EXISTS "paper_broker_states" ('
|
|
'id INTEGER PRIMARY KEY AUTOINCREMENT, '
|
|
'tbl_key TEXT UNIQUE, '
|
|
'strategy_instance_id TEXT UNIQUE, '
|
|
'broker_state TEXT, '
|
|
'updated_at TEXT)',
|
|
[]
|
|
)
|
|
|
|
# Migration path for any older local table that was created without tbl_key.
|
|
try:
|
|
existing_df = self._data_cache.db.get_all_rows('paper_broker_states')
|
|
if 'tbl_key' not in existing_df.columns:
|
|
self._data_cache.db.execute_sql(
|
|
'ALTER TABLE "paper_broker_states" ADD COLUMN tbl_key TEXT',
|
|
[]
|
|
)
|
|
except Exception:
|
|
# If schema inspection fails, continue with current schema.
|
|
pass
|
|
|
|
# Keep tbl_key aligned with strategy_instance_id for DataCache overwrite semantics.
|
|
self._data_cache.db.execute_sql(
|
|
'UPDATE "paper_broker_states" '
|
|
'SET tbl_key = strategy_instance_id '
|
|
'WHERE tbl_key IS NULL OR tbl_key = ""',
|
|
[]
|
|
)
|
|
|
|
self._data_cache.create_cache(
|
|
name='paper_broker_states',
|
|
cache_type='table',
|
|
size_limit=5000,
|
|
eviction_policy='deny',
|
|
default_expiration=timedelta(days=7),
|
|
columns=['id', 'tbl_key', 'strategy_instance_id', 'broker_state', 'updated_at']
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"PaperBroker: Error ensuring persistence cache: {e}", exc_info=True)
|
|
return False
|
|
|
|
def to_state_dict(self) -> Dict[str, Any]:
|
|
"""
|
|
Serialize broker state to a dictionary for persistence.
|
|
|
|
Returns dict containing all state needed to restore the broker.
|
|
"""
|
|
# Serialize orders
|
|
orders_data = {}
|
|
for order_id, order in self._orders.items():
|
|
orders_data[order_id] = order.to_dict()
|
|
|
|
# Serialize positions
|
|
positions_data = {}
|
|
for symbol, position in self._positions.items():
|
|
positions_data[symbol] = position.to_dict()
|
|
|
|
return {
|
|
'cash': self._cash,
|
|
'locked_balance': self._locked_balance,
|
|
'initial_balance': self.initial_balance,
|
|
'commission': self.commission,
|
|
'slippage': self.slippage,
|
|
'orders': orders_data,
|
|
'positions': positions_data,
|
|
'trade_history': self._trade_history,
|
|
'current_prices': self._current_prices,
|
|
'position_sltp': self._position_sltp,
|
|
}
|
|
|
|
def from_state_dict(self, state: Dict[str, Any]):
|
|
"""
|
|
Restore broker state from a dictionary.
|
|
|
|
:param state: State dict from to_state_dict().
|
|
"""
|
|
if not state:
|
|
return
|
|
|
|
# Restore balances
|
|
self._cash = state.get('cash', self.initial_balance)
|
|
self._locked_balance = state.get('locked_balance', 0.0)
|
|
|
|
# Restore orders
|
|
self._orders.clear()
|
|
orders_data = state.get('orders', {})
|
|
for order_id, order_dict in orders_data.items():
|
|
order = PaperOrder(
|
|
order_id=order_dict['order_id'],
|
|
symbol=order_dict['symbol'],
|
|
side=OrderSide(order_dict['side']),
|
|
order_type=OrderType(order_dict['order_type']),
|
|
size=order_dict['size'],
|
|
price=order_dict.get('price'),
|
|
stop_loss=order_dict.get('stop_loss'),
|
|
take_profit=order_dict.get('take_profit'),
|
|
)
|
|
order.status = OrderStatus(order_dict['status'])
|
|
order.filled_qty = order_dict.get('filled_qty', 0.0)
|
|
order.filled_price = order_dict.get('filled_price', 0.0)
|
|
order.commission = order_dict.get('commission', 0.0)
|
|
order.locked_funds = order_dict.get('locked_funds', 0.0)
|
|
|
|
if order_dict.get('created_at'):
|
|
order.created_at = datetime.fromisoformat(order_dict['created_at'])
|
|
if order_dict.get('filled_at'):
|
|
order.filled_at = datetime.fromisoformat(order_dict['filled_at'])
|
|
|
|
self._orders[order_id] = order
|
|
|
|
# Restore positions
|
|
self._positions.clear()
|
|
positions_data = state.get('positions', {})
|
|
for symbol, pos_dict in positions_data.items():
|
|
self._positions[symbol] = Position.from_dict(pos_dict)
|
|
|
|
# Restore trade history
|
|
self._trade_history = state.get('trade_history', [])
|
|
|
|
# Restore price cache
|
|
self._current_prices = state.get('current_prices', {})
|
|
|
|
# Restore SL/TP tracking
|
|
self._position_sltp = state.get('position_sltp', {})
|
|
|
|
logger.info(f"PaperBroker: State restored - cash: {self._cash:.2f}, "
|
|
f"positions: {len(self._positions)}, orders: {len(self._orders)}")
|
|
|
|
def save_state(self, strategy_instance_id: str) -> bool:
|
|
"""
|
|
Save broker state to the data cache.
|
|
|
|
:param strategy_instance_id: Unique identifier for the strategy instance.
|
|
:return: True if saved successfully.
|
|
"""
|
|
if not self._data_cache:
|
|
logger.warning("PaperBroker: No data cache available for persistence")
|
|
return False
|
|
|
|
try:
|
|
if not self._ensure_persistence_cache():
|
|
return False
|
|
|
|
state_dict = self.to_state_dict()
|
|
state_json = json.dumps(state_dict)
|
|
|
|
# Check if state already exists
|
|
existing = self._data_cache.get_rows_from_datacache(
|
|
cache_name='paper_broker_states',
|
|
filter_vals=[('strategy_instance_id', strategy_instance_id)]
|
|
)
|
|
|
|
columns = ('tbl_key', 'strategy_instance_id', 'broker_state', 'updated_at')
|
|
values = (strategy_instance_id, strategy_instance_id, state_json, datetime.now(timezone.utc).isoformat())
|
|
|
|
if existing.empty:
|
|
# Insert new state
|
|
self._data_cache.insert_row_into_datacache(
|
|
cache_name='paper_broker_states',
|
|
columns=columns,
|
|
values=values
|
|
)
|
|
else:
|
|
# Update existing state
|
|
self._data_cache.modify_datacache_item(
|
|
cache_name='paper_broker_states',
|
|
filter_vals=[('strategy_instance_id', strategy_instance_id)],
|
|
field_names=columns,
|
|
new_values=values,
|
|
overwrite='strategy_instance_id'
|
|
)
|
|
|
|
logger.debug(f"PaperBroker: State saved for {strategy_instance_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"PaperBroker: Error saving state: {e}", exc_info=True)
|
|
return False
|
|
|
|
def load_state(self, strategy_instance_id: str) -> bool:
|
|
"""
|
|
Load broker state from the data cache.
|
|
|
|
:param strategy_instance_id: Unique identifier for the strategy instance.
|
|
:return: True if state was loaded successfully.
|
|
"""
|
|
if not self._data_cache:
|
|
logger.warning("PaperBroker: No data cache available for persistence")
|
|
return False
|
|
|
|
try:
|
|
if not self._ensure_persistence_cache():
|
|
return False
|
|
|
|
existing = self._data_cache.get_rows_from_datacache(
|
|
cache_name='paper_broker_states',
|
|
filter_vals=[('strategy_instance_id', strategy_instance_id)]
|
|
)
|
|
|
|
if existing.empty:
|
|
logger.debug(f"PaperBroker: No saved state for {strategy_instance_id}")
|
|
return False
|
|
|
|
state_json = existing.iloc[0].get('broker_state', '{}')
|
|
state_dict = json.loads(state_json)
|
|
self.from_state_dict(state_dict)
|
|
|
|
logger.info(f"PaperBroker: State loaded for {strategy_instance_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"PaperBroker: Error loading state: {e}", exc_info=True)
|
|
return False
|