""" Base Broker Interface for BrighterTrading. This abstract base class defines the interface that all broker implementations must follow, enabling strategies to work identically across backtest, paper, and live trading modes. """ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional from dataclasses import dataclass from enum import Enum import logging logger = logging.getLogger(__name__) class OrderSide(Enum): """Order side enumeration.""" BUY = 'buy' SELL = 'sell' class OrderType(Enum): """Order type enumeration.""" MARKET = 'market' LIMIT = 'limit' STOP = 'stop' STOP_LIMIT = 'stop_limit' class OrderStatus(Enum): """Order status enumeration.""" PENDING = 'pending' OPEN = 'open' FILLED = 'filled' PARTIALLY_FILLED = 'partially_filled' CANCELLED = 'cancelled' REJECTED = 'rejected' EXPIRED = 'expired' @dataclass class OrderResult: """Result of an order placement.""" success: bool order_id: Optional[str] = None message: Optional[str] = None status: OrderStatus = OrderStatus.PENDING filled_qty: float = 0.0 filled_price: float = 0.0 commission: float = 0.0 @dataclass class Position: """Represents a position in a symbol.""" symbol: str size: float # Positive for long, negative for short entry_price: float current_price: float unrealized_pnl: float realized_pnl: float = 0.0 def to_dict(self) -> Dict[str, Any]: """Convert position to dictionary for persistence.""" return { 'symbol': self.symbol, 'size': self.size, 'entry_price': self.entry_price, 'current_price': self.current_price, 'unrealized_pnl': self.unrealized_pnl, 'realized_pnl': self.realized_pnl, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'Position': """Create Position from dictionary.""" return cls( symbol=data['symbol'], size=data['size'], entry_price=data['entry_price'], current_price=data['current_price'], unrealized_pnl=data['unrealized_pnl'], realized_pnl=data.get('realized_pnl', 0.0), ) class BaseBroker(ABC): """ Abstract base class for all broker implementations. Provides a unified interface for: - Order placement and management - Balance and position tracking - Price retrieval Subclasses must implement all abstract methods. """ def __init__(self, initial_balance: float = 10000.0, commission: float = 0.001, slippage: float = 0.0): """ Initialize the broker. :param initial_balance: Starting balance in quote currency. :param commission: Commission rate (0.001 = 0.1%). :param slippage: Slippage rate for market orders. """ self.initial_balance = initial_balance self.commission = commission self.slippage = slippage self._orders: Dict[str, Dict[str, Any]] = {} self._positions: Dict[str, Position] = {} @abstractmethod def place_order( self, symbol: str, side: OrderSide, order_type: OrderType, size: float, price: Optional[float] = None, stop_loss: Optional[float] = None, take_profit: Optional[float] = None, time_in_force: str = 'GTC' ) -> OrderResult: """ Place an order. :param symbol: Trading symbol (e.g., 'BTC/USDT'). :param side: OrderSide.BUY or OrderSide.SELL. :param order_type: Type of order (market, limit, etc.). :param size: Order size in base currency. :param price: Limit price (required for limit orders). :param stop_loss: Stop loss price. :param take_profit: Take profit price. :param time_in_force: Time in force ('GTC', 'IOC', 'FOK'). :return: OrderResult with order details. """ pass @abstractmethod def cancel_order(self, order_id: str) -> bool: """ Cancel an existing order. :param order_id: The order ID to cancel. :return: True if cancelled successfully. """ pass @abstractmethod def get_order(self, order_id: str) -> Optional[Dict[str, Any]]: """ Get order details by ID. :param order_id: The order ID. :return: Order details or None if not found. """ pass @abstractmethod def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]: """ Get all open orders, optionally filtered by symbol. :param symbol: Optional symbol to filter by. :return: List of open orders. """ pass @abstractmethod def get_balance(self, asset: Optional[str] = None) -> float: """ Get balance for an asset or total balance. :param asset: Asset symbol (e.g., 'USDT'). None for total. :return: Balance amount. """ pass @abstractmethod def get_available_balance(self, asset: Optional[str] = None) -> float: """ Get available balance (not locked in orders). :param asset: Asset symbol. None for total. :return: Available balance. """ pass @abstractmethod def get_position(self, symbol: str) -> Optional[Position]: """ Get current position for a symbol. :param symbol: Trading symbol. :return: Position object or None if no position. """ pass @abstractmethod def get_all_positions(self) -> List[Position]: """ Get all open positions. :return: List of Position objects. """ pass @abstractmethod def get_current_price(self, symbol: str) -> float: """ Get the current market price for a symbol. :param symbol: Trading symbol. :return: Current price. """ pass @abstractmethod def update(self) -> List[Dict[str, Any]]: """ Process pending orders and update positions. Called on each tick/bar to check for fills and update state. :return: List of events (fills, updates, etc.). """ pass def get_equity(self) -> float: """ Get total equity (balance + unrealized P&L). :return: Total equity value. """ balance = self.get_balance() unrealized_pnl = sum(pos.unrealized_pnl for pos in self.get_all_positions()) return balance + unrealized_pnl def close_position(self, symbol: str) -> OrderResult: """ Close an entire position for a symbol. :param symbol: Trading symbol. :return: OrderResult for the closing order. """ position = self.get_position(symbol) if position is None or position.size == 0: return OrderResult(success=False, message=f"No position in {symbol}") side = OrderSide.SELL if position.size > 0 else OrderSide.BUY size = abs(position.size) return self.place_order( symbol=symbol, side=side, order_type=OrderType.MARKET, size=size ) def close_all_positions(self) -> List[OrderResult]: """ Close all open positions. :return: List of OrderResults for each closed position. """ results = [] for position in self.get_all_positions(): if position.size != 0: results.append(self.close_position(position.symbol)) return results