brighter-trading/src/brokers/base_broker.py

276 lines
7.5 KiB
Python

"""
Base Broker Interface for BrighterTrading.
This abstract base class defines the interface that all broker implementations
must follow, enabling strategies to work identically across backtest, paper,
and live trading modes.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from enum import Enum
import logging
logger = logging.getLogger(__name__)
class OrderSide(Enum):
"""Order side enumeration."""
BUY = 'buy'
SELL = 'sell'
class OrderType(Enum):
"""Order type enumeration."""
MARKET = 'market'
LIMIT = 'limit'
STOP = 'stop'
STOP_LIMIT = 'stop_limit'
class OrderStatus(Enum):
"""Order status enumeration."""
PENDING = 'pending'
OPEN = 'open'
FILLED = 'filled'
PARTIALLY_FILLED = 'partially_filled'
CANCELLED = 'cancelled'
REJECTED = 'rejected'
EXPIRED = 'expired'
@dataclass
class OrderResult:
"""Result of an order placement."""
success: bool
order_id: Optional[str] = None
message: Optional[str] = None
status: OrderStatus = OrderStatus.PENDING
filled_qty: float = 0.0
filled_price: float = 0.0
commission: float = 0.0
@dataclass
class Position:
"""Represents a position in a symbol."""
symbol: str
size: float # Positive for long, negative for short
entry_price: float
current_price: float
unrealized_pnl: float
realized_pnl: float = 0.0
def to_dict(self) -> Dict[str, Any]:
"""Convert position to dictionary for persistence."""
return {
'symbol': self.symbol,
'size': self.size,
'entry_price': self.entry_price,
'current_price': self.current_price,
'unrealized_pnl': self.unrealized_pnl,
'realized_pnl': self.realized_pnl,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Position':
"""Create Position from dictionary."""
return cls(
symbol=data['symbol'],
size=data['size'],
entry_price=data['entry_price'],
current_price=data['current_price'],
unrealized_pnl=data['unrealized_pnl'],
realized_pnl=data.get('realized_pnl', 0.0),
)
class BaseBroker(ABC):
"""
Abstract base class for all broker implementations.
Provides a unified interface for:
- Order placement and management
- Balance and position tracking
- Price retrieval
Subclasses must implement all abstract methods.
"""
def __init__(self, initial_balance: float = 10000.0,
commission: float = 0.001,
slippage: float = 0.0):
"""
Initialize the broker.
:param initial_balance: Starting balance in quote currency.
:param commission: Commission rate (0.001 = 0.1%).
:param slippage: Slippage rate for market orders.
"""
self.initial_balance = initial_balance
self.commission = commission
self.slippage = slippage
self._orders: Dict[str, Dict[str, Any]] = {}
self._positions: Dict[str, Position] = {}
@abstractmethod
def place_order(
self,
symbol: str,
side: OrderSide,
order_type: OrderType,
size: float,
price: Optional[float] = None,
stop_loss: Optional[float] = None,
take_profit: Optional[float] = None,
time_in_force: str = 'GTC'
) -> OrderResult:
"""
Place an order.
:param symbol: Trading symbol (e.g., 'BTC/USDT').
:param side: OrderSide.BUY or OrderSide.SELL.
:param order_type: Type of order (market, limit, etc.).
:param size: Order size in base currency.
:param price: Limit price (required for limit orders).
:param stop_loss: Stop loss price.
:param take_profit: Take profit price.
:param time_in_force: Time in force ('GTC', 'IOC', 'FOK').
:return: OrderResult with order details.
"""
pass
@abstractmethod
def cancel_order(self, order_id: str) -> bool:
"""
Cancel an existing order.
:param order_id: The order ID to cancel.
:return: True if cancelled successfully.
"""
pass
@abstractmethod
def get_order(self, order_id: str) -> Optional[Dict[str, Any]]:
"""
Get order details by ID.
:param order_id: The order ID.
:return: Order details or None if not found.
"""
pass
@abstractmethod
def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Get all open orders, optionally filtered by symbol.
:param symbol: Optional symbol to filter by.
:return: List of open orders.
"""
pass
@abstractmethod
def get_balance(self, asset: Optional[str] = None) -> float:
"""
Get balance for an asset or total balance.
:param asset: Asset symbol (e.g., 'USDT'). None for total.
:return: Balance amount.
"""
pass
@abstractmethod
def get_available_balance(self, asset: Optional[str] = None) -> float:
"""
Get available balance (not locked in orders).
:param asset: Asset symbol. None for total.
:return: Available balance.
"""
pass
@abstractmethod
def get_position(self, symbol: str) -> Optional[Position]:
"""
Get current position for a symbol.
:param symbol: Trading symbol.
:return: Position object or None if no position.
"""
pass
@abstractmethod
def get_all_positions(self) -> List[Position]:
"""
Get all open positions.
:return: List of Position objects.
"""
pass
@abstractmethod
def get_current_price(self, symbol: str) -> float:
"""
Get the current market price for a symbol.
:param symbol: Trading symbol.
:return: Current price.
"""
pass
@abstractmethod
def update(self) -> List[Dict[str, Any]]:
"""
Process pending orders and update positions.
Called on each tick/bar to check for fills and update state.
:return: List of events (fills, updates, etc.).
"""
pass
def get_equity(self) -> float:
"""
Get total equity (balance + unrealized P&L).
:return: Total equity value.
"""
balance = self.get_balance()
unrealized_pnl = sum(pos.unrealized_pnl for pos in self.get_all_positions())
return balance + unrealized_pnl
def close_position(self, symbol: str) -> OrderResult:
"""
Close an entire position for a symbol.
:param symbol: Trading symbol.
:return: OrderResult for the closing order.
"""
position = self.get_position(symbol)
if position is None or position.size == 0:
return OrderResult(success=False, message=f"No position in {symbol}")
side = OrderSide.SELL if position.size > 0 else OrderSide.BUY
size = abs(position.size)
return self.place_order(
symbol=symbol,
side=side,
order_type=OrderType.MARKET,
size=size
)
def close_all_positions(self) -> List[OrderResult]:
"""
Close all open positions.
:return: List of OrderResults for each closed position.
"""
results = []
for position in self.get_all_positions():
if position.size != 0:
results.append(self.close_position(position.symbol))
return results