diff --git a/src/backtest_strategy_instance.py b/src/backtest_strategy_instance.py index 7c5c0d4..d1b9bc9 100644 --- a/src/backtest_strategy_instance.py +++ b/src/backtest_strategy_instance.py @@ -1,13 +1,16 @@ # backtest_strategy_instance.py import logging -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING import pandas as pd import datetime as dt import backtrader as bt from StrategyInstance import StrategyInstance +if TYPE_CHECKING: + from brokers import BacktestBroker + logger = logging.getLogger(__name__) diff --git a/src/brokers/__init__.py b/src/brokers/__init__.py new file mode 100644 index 0000000..ce1da25 --- /dev/null +++ b/src/brokers/__init__.py @@ -0,0 +1,19 @@ +""" +Broker Abstraction Layer for BrighterTrading. + +This package provides a unified interface for executing trades across different modes: +- BacktestBroker: Uses Backtrader for historical simulation +- PaperBroker: Simulates fills with live price data +- LiveBroker: Executes real trades via exchange APIs (CCXT) +""" + +from .base_broker import BaseBroker, OrderSide, OrderType, OrderStatus, OrderResult, Position +from .backtest_broker import BacktestBroker +from .paper_broker import PaperBroker +from .factory import create_broker, TradingMode, get_available_modes + +__all__ = [ + 'BaseBroker', 'OrderSide', 'OrderType', 'OrderStatus', 'OrderResult', 'Position', + 'BacktestBroker', 'PaperBroker', + 'create_broker', 'TradingMode', 'get_available_modes' +] diff --git a/src/brokers/backtest_broker.py b/src/brokers/backtest_broker.py new file mode 100644 index 0000000..d8bccc5 --- /dev/null +++ b/src/brokers/backtest_broker.py @@ -0,0 +1,269 @@ +""" +Backtest Broker Implementation for BrighterTrading. + +Wraps Backtrader's broker functionality to provide a unified interface +for backtesting strategies. +""" + +import logging +from typing import Any, Dict, List, Optional +import uuid + +try: + import backtrader as bt +except ImportError: + bt = None + +from .base_broker import ( + BaseBroker, OrderResult, OrderSide, OrderType, OrderStatus, Position +) + +logger = logging.getLogger(__name__) + + +class BacktestBroker(BaseBroker): + """ + Broker implementation for backtesting using Backtrader. + + This broker delegates order execution to Backtrader's internal broker, + while providing the unified BaseBroker interface. + """ + + def __init__( + self, + backtrader_strategy: 'bt.Strategy' = None, + initial_balance: float = 10000.0, + commission: float = 0.001, + slippage: float = 0.0 + ): + """ + Initialize the BacktestBroker. + + :param backtrader_strategy: Reference to the Backtrader strategy instance. + :param initial_balance: Starting balance (set in Cerebro, not here). + :param commission: Commission rate (set in Cerebro, not here). + :param slippage: Slippage rate. + """ + super().__init__(initial_balance, commission, slippage) + self._bt_strategy = backtrader_strategy + self._pending_orders: Dict[str, Any] = {} + + def set_backtrader_strategy(self, strategy: 'bt.Strategy'): + """Set the Backtrader strategy reference after initialization.""" + self._bt_strategy = strategy + + @property + def _bt_broker(self): + """Get the Backtrader broker instance.""" + if self._bt_strategy is None: + raise RuntimeError("Backtrader strategy not set") + return self._bt_strategy.broker + + @property + def _bt_data(self): + """Get the Backtrader data feed.""" + if self._bt_strategy is None: + raise RuntimeError("Backtrader strategy not set") + return self._bt_strategy.data + + def place_order( + self, + symbol: str, + side: OrderSide, + order_type: OrderType, + size: float, + price: Optional[float] = None, + stop_loss: Optional[float] = None, + take_profit: Optional[float] = None, + time_in_force: str = 'GTC' + ) -> OrderResult: + """Place an order via Backtrader.""" + if self._bt_strategy is None: + return OrderResult( + success=False, + message="Backtrader strategy not initialized" + ) + + try: + # Map order type to Backtrader execution type + if order_type == OrderType.MARKET: + exectype = bt.Order.Market + bt_price = None + elif order_type == OrderType.LIMIT: + exectype = bt.Order.Limit + bt_price = price + elif order_type == OrderType.STOP: + exectype = bt.Order.Stop + bt_price = price + else: + exectype = bt.Order.Market + bt_price = None + + # Place the order via Backtrader + if side == OrderSide.BUY: + if bt_price is not None: + bt_order = self._bt_strategy.buy(size=size, exectype=exectype, price=bt_price) + else: + bt_order = self._bt_strategy.buy(size=size, exectype=exectype) + else: + if bt_price is not None: + bt_order = self._bt_strategy.sell(size=size, exectype=exectype, price=bt_price) + else: + bt_order = self._bt_strategy.sell(size=size, exectype=exectype) + + # Generate order ID + order_id = str(bt_order.ref) if bt_order else str(uuid.uuid4()) + + # Store order reference + self._pending_orders[order_id] = bt_order + + logger.info(f"BacktestBroker: Placed {side.value} order for {size} {symbol}") + + return OrderResult( + success=True, + order_id=order_id, + status=OrderStatus.PENDING, + message=f"Order placed: {side.value} {size} {symbol}" + ) + + except Exception as e: + logger.error(f"BacktestBroker: Failed to place order: {e}") + return OrderResult( + success=False, + message=str(e) + ) + + def cancel_order(self, order_id: str) -> bool: + """Cancel an order.""" + if order_id in self._pending_orders: + bt_order = self._pending_orders[order_id] + if bt_order is not None: + self._bt_broker.cancel(bt_order) + del self._pending_orders[order_id] + return True + return False + + def get_order(self, order_id: str) -> Optional[Dict[str, Any]]: + """Get order details.""" + if order_id in self._pending_orders: + bt_order = self._pending_orders[order_id] + if bt_order is not None: + return { + 'order_id': order_id, + 'status': self._map_bt_status(bt_order.status), + 'size': bt_order.size, + 'price': bt_order.price, + 'executed_size': bt_order.executed.size if bt_order.executed else 0, + 'executed_price': bt_order.executed.price if bt_order.executed else 0 + } + return None + + def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]: + """Get all open orders.""" + orders = [] + for order_id, bt_order in self._pending_orders.items(): + if bt_order is not None and bt_order.status in [bt.Order.Submitted, bt.Order.Accepted]: + orders.append(self.get_order(order_id)) + return [o for o in orders if o is not None] + + def get_balance(self, asset: Optional[str] = None) -> float: + """Get total portfolio value.""" + return self._bt_broker.getvalue() + + def get_available_balance(self, asset: Optional[str] = None) -> float: + """Get available cash.""" + return self._bt_broker.getcash() + + def get_position(self, symbol: str) -> Optional[Position]: + """Get position for a symbol.""" + if self._bt_strategy is None: + return None + + bt_position = self._bt_broker.getposition(self._bt_data) + if bt_position.size == 0: + return None + + current_price = self.get_current_price(symbol) + entry_price = bt_position.price + unrealized_pnl = (current_price - entry_price) * bt_position.size + + return Position( + symbol=symbol, + size=bt_position.size, + entry_price=entry_price, + current_price=current_price, + unrealized_pnl=unrealized_pnl + ) + + def get_all_positions(self) -> List[Position]: + """Get all open positions.""" + positions = [] + if self._bt_strategy is None: + return positions + + # For single-data backtests, just check the main data + position = self.get_position("backtest_symbol") + if position is not None: + positions.append(position) + + return positions + + def get_current_price(self, symbol: str) -> float: + """Get current price from Backtrader data feed.""" + if self._bt_strategy is None: + return 0.0 + return self._bt_data.close[0] + + def update(self) -> List[Dict[str, Any]]: + """ + Process pending orders. + + In Backtrader, order processing is handled by the engine. + This method checks for completed orders and updates state. + """ + events = [] + + # Check for completed orders + completed_orders = [] + for order_id, bt_order in self._pending_orders.items(): + if bt_order is None: + continue + + if bt_order.status == bt.Order.Completed: + events.append({ + 'type': 'fill', + 'order_id': order_id, + 'size': bt_order.executed.size, + 'price': bt_order.executed.price, + 'commission': bt_order.executed.comm + }) + completed_orders.append(order_id) + elif bt_order.status in [bt.Order.Canceled, bt.Order.Margin, bt.Order.Rejected]: + events.append({ + 'type': 'cancelled', + 'order_id': order_id, + 'reason': self._map_bt_status(bt_order.status) + }) + completed_orders.append(order_id) + + # Remove completed orders + for order_id in completed_orders: + del self._pending_orders[order_id] + + return events + + @staticmethod + def _map_bt_status(bt_status: int) -> str: + """Map Backtrader order status to string.""" + status_map = { + bt.Order.Created: 'created', + bt.Order.Submitted: 'submitted', + bt.Order.Accepted: 'accepted', + bt.Order.Partial: 'partial', + bt.Order.Completed: 'completed', + bt.Order.Canceled: 'cancelled', + bt.Order.Expired: 'expired', + bt.Order.Margin: 'margin_call', + bt.Order.Rejected: 'rejected' + } + return status_map.get(bt_status, 'unknown') diff --git a/src/brokers/base_broker.py b/src/brokers/base_broker.py new file mode 100644 index 0000000..eff715a --- /dev/null +++ b/src/brokers/base_broker.py @@ -0,0 +1,252 @@ +""" +Base Broker Interface for BrighterTrading. + +This abstract base class defines the interface that all broker implementations +must follow, enabling strategies to work identically across backtest, paper, +and live trading modes. +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional +from dataclasses import dataclass +from enum import Enum +import logging + +logger = logging.getLogger(__name__) + + +class OrderSide(Enum): + """Order side enumeration.""" + BUY = 'buy' + SELL = 'sell' + + +class OrderType(Enum): + """Order type enumeration.""" + MARKET = 'market' + LIMIT = 'limit' + STOP = 'stop' + STOP_LIMIT = 'stop_limit' + + +class OrderStatus(Enum): + """Order status enumeration.""" + PENDING = 'pending' + OPEN = 'open' + FILLED = 'filled' + PARTIALLY_FILLED = 'partially_filled' + CANCELLED = 'cancelled' + REJECTED = 'rejected' + EXPIRED = 'expired' + + +@dataclass +class OrderResult: + """Result of an order placement.""" + success: bool + order_id: Optional[str] = None + message: Optional[str] = None + status: OrderStatus = OrderStatus.PENDING + filled_qty: float = 0.0 + filled_price: float = 0.0 + commission: float = 0.0 + + +@dataclass +class Position: + """Represents a position in a symbol.""" + symbol: str + size: float # Positive for long, negative for short + entry_price: float + current_price: float + unrealized_pnl: float + realized_pnl: float = 0.0 + + +class BaseBroker(ABC): + """ + Abstract base class for all broker implementations. + + Provides a unified interface for: + - Order placement and management + - Balance and position tracking + - Price retrieval + + Subclasses must implement all abstract methods. + """ + + def __init__(self, initial_balance: float = 10000.0, + commission: float = 0.001, + slippage: float = 0.0): + """ + Initialize the broker. + + :param initial_balance: Starting balance in quote currency. + :param commission: Commission rate (0.001 = 0.1%). + :param slippage: Slippage rate for market orders. + """ + self.initial_balance = initial_balance + self.commission = commission + self.slippage = slippage + self._orders: Dict[str, Dict[str, Any]] = {} + self._positions: Dict[str, Position] = {} + + @abstractmethod + def place_order( + self, + symbol: str, + side: OrderSide, + order_type: OrderType, + size: float, + price: Optional[float] = None, + stop_loss: Optional[float] = None, + take_profit: Optional[float] = None, + time_in_force: str = 'GTC' + ) -> OrderResult: + """ + Place an order. + + :param symbol: Trading symbol (e.g., 'BTC/USDT'). + :param side: OrderSide.BUY or OrderSide.SELL. + :param order_type: Type of order (market, limit, etc.). + :param size: Order size in base currency. + :param price: Limit price (required for limit orders). + :param stop_loss: Stop loss price. + :param take_profit: Take profit price. + :param time_in_force: Time in force ('GTC', 'IOC', 'FOK'). + :return: OrderResult with order details. + """ + pass + + @abstractmethod + def cancel_order(self, order_id: str) -> bool: + """ + Cancel an existing order. + + :param order_id: The order ID to cancel. + :return: True if cancelled successfully. + """ + pass + + @abstractmethod + def get_order(self, order_id: str) -> Optional[Dict[str, Any]]: + """ + Get order details by ID. + + :param order_id: The order ID. + :return: Order details or None if not found. + """ + pass + + @abstractmethod + def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]: + """ + Get all open orders, optionally filtered by symbol. + + :param symbol: Optional symbol to filter by. + :return: List of open orders. + """ + pass + + @abstractmethod + def get_balance(self, asset: Optional[str] = None) -> float: + """ + Get balance for an asset or total balance. + + :param asset: Asset symbol (e.g., 'USDT'). None for total. + :return: Balance amount. + """ + pass + + @abstractmethod + def get_available_balance(self, asset: Optional[str] = None) -> float: + """ + Get available balance (not locked in orders). + + :param asset: Asset symbol. None for total. + :return: Available balance. + """ + pass + + @abstractmethod + def get_position(self, symbol: str) -> Optional[Position]: + """ + Get current position for a symbol. + + :param symbol: Trading symbol. + :return: Position object or None if no position. + """ + pass + + @abstractmethod + def get_all_positions(self) -> List[Position]: + """ + Get all open positions. + + :return: List of Position objects. + """ + pass + + @abstractmethod + def get_current_price(self, symbol: str) -> float: + """ + Get the current market price for a symbol. + + :param symbol: Trading symbol. + :return: Current price. + """ + pass + + @abstractmethod + def update(self) -> List[Dict[str, Any]]: + """ + Process pending orders and update positions. + + Called on each tick/bar to check for fills and update state. + + :return: List of events (fills, updates, etc.). + """ + pass + + def get_equity(self) -> float: + """ + Get total equity (balance + unrealized P&L). + + :return: Total equity value. + """ + balance = self.get_balance() + unrealized_pnl = sum(pos.unrealized_pnl for pos in self.get_all_positions()) + return balance + unrealized_pnl + + def close_position(self, symbol: str) -> OrderResult: + """ + Close an entire position for a symbol. + + :param symbol: Trading symbol. + :return: OrderResult for the closing order. + """ + position = self.get_position(symbol) + if position is None or position.size == 0: + return OrderResult(success=False, message=f"No position in {symbol}") + + side = OrderSide.SELL if position.size > 0 else OrderSide.BUY + size = abs(position.size) + + return self.place_order( + symbol=symbol, + side=side, + order_type=OrderType.MARKET, + size=size + ) + + def close_all_positions(self) -> List[OrderResult]: + """ + Close all open positions. + + :return: List of OrderResults for each closed position. + """ + results = [] + for position in self.get_all_positions(): + if position.size != 0: + results.append(self.close_position(position.symbol)) + return results diff --git a/src/brokers/factory.py b/src/brokers/factory.py new file mode 100644 index 0000000..94ac5f1 --- /dev/null +++ b/src/brokers/factory.py @@ -0,0 +1,88 @@ +""" +Broker Factory for BrighterTrading. + +Creates the appropriate broker instance based on the trading mode. +""" + +import logging +from typing import Any, Optional, Callable + +from .base_broker import BaseBroker +from .backtest_broker import BacktestBroker +from .paper_broker import PaperBroker + +logger = logging.getLogger(__name__) + + +class TradingMode: + """Trading mode constants.""" + BACKTEST = 'backtest' + PAPER = 'paper' + LIVE = 'live' + + +def create_broker( + mode: str, + initial_balance: float = 10000.0, + commission: float = 0.001, + slippage: float = 0.0, + price_provider: Optional[Callable[[str], float]] = None, + data_cache: Any = None, + exchange_interface: Any = None, + user_name: str = None, + exchange_name: str = None, + **kwargs +) -> BaseBroker: + """ + Factory function to create the appropriate broker for the trading mode. + + :param mode: Trading mode ('backtest', 'paper', 'live'). + :param initial_balance: Starting balance. + :param commission: Commission rate. + :param slippage: Slippage rate. + :param price_provider: Callable for getting current prices (paper/live). + :param data_cache: DataCache instance for persistence. + :param exchange_interface: ExchangeInterface for live trading. + :param user_name: User name for live trading. + :param exchange_name: Exchange name for live trading. + :param kwargs: Additional arguments passed to broker constructor. + :return: Broker instance. + """ + mode = mode.lower() + + if mode == TradingMode.BACKTEST: + logger.info("Creating BacktestBroker") + return BacktestBroker( + initial_balance=initial_balance, + commission=commission, + slippage=slippage, + **kwargs + ) + + elif mode == TradingMode.PAPER: + logger.info("Creating PaperBroker") + return PaperBroker( + initial_balance=initial_balance, + commission=commission, + slippage=slippage if slippage > 0 else 0.0005, # Default slippage for paper + price_provider=price_provider, + data_cache=data_cache, + **kwargs + ) + + elif mode == TradingMode.LIVE: + # Live broker will be implemented in Phase 5 + raise NotImplementedError( + "Live trading broker not yet implemented. " + "Use paper trading for testing with live prices." + ) + + else: + raise ValueError(f"Invalid trading mode: {mode}. " + f"Must be one of: {TradingMode.BACKTEST}, " + f"{TradingMode.PAPER}, {TradingMode.LIVE}") + + +def get_available_modes() -> list: + """Get list of available trading modes.""" + return [TradingMode.BACKTEST, TradingMode.PAPER] diff --git a/src/brokers/paper_broker.py b/src/brokers/paper_broker.py new file mode 100644 index 0000000..126c6f8 --- /dev/null +++ b/src/brokers/paper_broker.py @@ -0,0 +1,412 @@ +""" +Paper Trading Broker Implementation for BrighterTrading. + +Simulates order execution with live price data for paper trading. +Orders are filled based on current market prices with configurable +slippage and commission. +""" + +import logging +from typing import Any, Dict, List, Optional, Callable +import uuid +from datetime import datetime, timezone + +from .base_broker import ( + BaseBroker, OrderResult, OrderSide, OrderType, OrderStatus, Position +) + +logger = logging.getLogger(__name__) + + +class PaperOrder: + """Represents a paper trading order.""" + + def __init__( + self, + order_id: str, + symbol: str, + side: OrderSide, + order_type: OrderType, + size: float, + price: Optional[float] = None, + stop_loss: Optional[float] = None, + take_profit: Optional[float] = None, + time_in_force: str = 'GTC' + ): + self.order_id = order_id + self.symbol = symbol + self.side = side + self.order_type = order_type + self.size = size + self.price = price + self.stop_loss = stop_loss + self.take_profit = take_profit + self.time_in_force = time_in_force + self.status = OrderStatus.PENDING + self.filled_qty = 0.0 + self.filled_price = 0.0 + self.commission = 0.0 + self.locked_funds = 0.0 # Amount locked for this order + self.created_at = datetime.now(timezone.utc) + self.filled_at: Optional[datetime] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + 'order_id': self.order_id, + 'symbol': self.symbol, + 'side': self.side.value, + 'order_type': self.order_type.value, + 'size': self.size, + 'price': self.price, + 'stop_loss': self.stop_loss, + 'take_profit': self.take_profit, + 'status': self.status.value, + 'filled_qty': self.filled_qty, + 'filled_price': self.filled_price, + 'commission': self.commission, + 'created_at': self.created_at.isoformat(), + 'filled_at': self.filled_at.isoformat() if self.filled_at else None + } + + +class PaperBroker(BaseBroker): + """ + Paper trading broker that simulates order execution. + + Features: + - Simulated fills based on market prices + - Configurable slippage and commission + - Position tracking with P&L calculation + - Persistent state across restarts (via data_cache) + """ + + def __init__( + self, + price_provider: Optional[Callable[[str], float]] = None, + data_cache: Any = None, + initial_balance: float = 10000.0, + commission: float = 0.001, + slippage: float = 0.0005 + ): + """ + Initialize the PaperBroker. + + :param price_provider: Callable that returns current price for a symbol. + :param data_cache: DataCache instance for persistence. + :param initial_balance: Starting balance in quote currency. + :param commission: Commission rate (0.001 = 0.1%). + :param slippage: Slippage rate for market orders. + """ + super().__init__(initial_balance, commission, slippage) + self._price_provider = price_provider + self._data_cache = data_cache + + # Balance tracking + self._cash = initial_balance + self._locked_balance = 0.0 + + # Order and position storage + self._orders: Dict[str, PaperOrder] = {} + self._positions: Dict[str, Position] = {} + self._trade_history: List[Dict[str, Any]] = [] + + # Current prices cache + self._current_prices: Dict[str, float] = {} + + def set_price_provider(self, provider: Callable[[str], float]): + """Set the price provider callable.""" + self._price_provider = provider + + def update_price(self, symbol: str, price: float): + """Update the current price for a symbol.""" + self._current_prices[symbol] = price + + def place_order( + self, + symbol: str, + side: OrderSide, + order_type: OrderType, + size: float, + price: Optional[float] = None, + stop_loss: Optional[float] = None, + take_profit: Optional[float] = None, + time_in_force: str = 'GTC' + ) -> OrderResult: + """Place a paper trading order.""" + order_id = str(uuid.uuid4())[:8] + + # Validate order + current_price = self.get_current_price(symbol) + if current_price <= 0: + return OrderResult( + success=False, + message=f"Cannot get current price for {symbol}" + ) + + # Calculate required margin/cost + order_value = size * current_price + required_funds = order_value * (1 + self.commission) + + if side == OrderSide.BUY and required_funds > self._cash: + return OrderResult( + success=False, + message=f"Insufficient funds: need {required_funds:.2f}, have {self._cash:.2f}" + ) + + # For sell orders, check if we have the position + if side == OrderSide.SELL: + position = self._positions.get(symbol) + if position is None or position.size < size: + available = position.size if position else 0 + return OrderResult( + success=False, + message=f"Insufficient position: need {size}, have {available}" + ) + + # Create the order + order = PaperOrder( + order_id=order_id, + symbol=symbol, + side=side, + order_type=order_type, + size=size, + price=price, + stop_loss=stop_loss, + take_profit=take_profit, + time_in_force=time_in_force + ) + + # For market orders, fill immediately + if order_type == OrderType.MARKET: + fill_price = self._calculate_fill_price(symbol, side, current_price) + self._fill_order(order, fill_price) + logger.info(f"PaperBroker: Market order filled: {side.value} {size} {symbol} @ {fill_price:.4f}") + else: + # Store pending order + order.status = OrderStatus.OPEN + self._orders[order_id] = order + # Lock funds for buy limit orders + if side == OrderSide.BUY: + order.locked_funds = required_funds + self._locked_balance += required_funds + self._cash -= required_funds + logger.info(f"PaperBroker: Limit order placed: {side.value} {size} {symbol} @ {price}") + + return OrderResult( + success=True, + order_id=order_id, + status=order.status, + filled_qty=order.filled_qty, + filled_price=order.filled_price, + commission=order.commission, + message=f"Order {order_id} {'filled' if order.status == OrderStatus.FILLED else 'placed'}" + ) + + def _calculate_fill_price(self, symbol: str, side: OrderSide, market_price: float) -> float: + """Calculate fill price with slippage.""" + if side == OrderSide.BUY: + return market_price * (1 + self.slippage) + else: + return market_price * (1 - self.slippage) + + def _fill_order(self, order: PaperOrder, fill_price: float): + """Execute an order fill.""" + order.filled_qty = order.size + order.filled_price = fill_price + order.commission = order.size * fill_price * self.commission + order.status = OrderStatus.FILLED + order.filled_at = datetime.now(timezone.utc) + + # Update balances and positions + order_value = order.size * fill_price + + if order.side == OrderSide.BUY: + # Deduct cost from cash + total_cost = order_value + order.commission + self._cash -= total_cost + + # Update position + if order.symbol in self._positions: + existing = self._positions[order.symbol] + new_size = existing.size + order.size + new_entry = (existing.entry_price * existing.size + fill_price * order.size) / new_size + existing.size = new_size + existing.entry_price = new_entry + else: + self._positions[order.symbol] = Position( + symbol=order.symbol, + size=order.size, + entry_price=fill_price, + current_price=fill_price, + unrealized_pnl=0.0 + ) + else: + # Add proceeds to cash + total_proceeds = order_value - order.commission + self._cash += total_proceeds + + # Update position + if order.symbol in self._positions: + position = self._positions[order.symbol] + realized_pnl = (fill_price - position.entry_price) * order.size - order.commission + position.realized_pnl += realized_pnl + position.size -= order.size + + # Remove position if fully closed + if position.size <= 0: + del self._positions[order.symbol] + + # Record trade + self._trade_history.append({ + 'order_id': order.order_id, + 'symbol': order.symbol, + 'side': order.side.value, + 'size': order.size, + 'price': fill_price, + 'commission': order.commission, + 'timestamp': order.filled_at.isoformat() + }) + + # Store in orders dict for reference + self._orders[order.order_id] = order + + def cancel_order(self, order_id: str) -> bool: + """Cancel a pending order.""" + if order_id not in self._orders: + return False + + order = self._orders[order_id] + if order.status != OrderStatus.OPEN: + return False + + # Release locked funds for buy orders + if order.side == OrderSide.BUY and order.locked_funds > 0: + self._locked_balance -= order.locked_funds + self._cash += order.locked_funds + order.locked_funds = 0 + + order.status = OrderStatus.CANCELLED + logger.info(f"PaperBroker: Order {order_id} cancelled") + return True + + def get_order(self, order_id: str) -> Optional[Dict[str, Any]]: + """Get order details.""" + if order_id in self._orders: + return self._orders[order_id].to_dict() + return None + + def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]: + """Get all open orders.""" + open_orders = [ + order.to_dict() for order in self._orders.values() + if order.status == OrderStatus.OPEN + and (symbol is None or order.symbol == symbol) + ] + return open_orders + + def get_balance(self, asset: Optional[str] = None) -> float: + """Get total balance including unrealized P&L.""" + total = self._cash + self._locked_balance + for position in self._positions.values(): + total += position.unrealized_pnl + return total + + def get_available_balance(self, asset: Optional[str] = None) -> float: + """Get available cash (not locked in orders).""" + return self._cash + + def get_position(self, symbol: str) -> Optional[Position]: + """Get position for a symbol.""" + return self._positions.get(symbol) + + def get_all_positions(self) -> List[Position]: + """Get all open positions.""" + return list(self._positions.values()) + + def get_current_price(self, symbol: str) -> float: + """Get current price for a symbol.""" + # First check cache + if symbol in self._current_prices: + return self._current_prices[symbol] + + # Then try price provider + if self._price_provider: + try: + price = self._price_provider(symbol) + self._current_prices[symbol] = price + return price + except Exception as e: + logger.warning(f"PaperBroker: Failed to get price for {symbol}: {e}") + + return 0.0 + + def update(self) -> List[Dict[str, Any]]: + """ + Process pending limit orders and update positions. + + Called on each price update to check for fills. + """ + events = [] + + # Update position P&L + for symbol, position in self._positions.items(): + current_price = self.get_current_price(symbol) + if current_price > 0: + position.current_price = current_price + position.unrealized_pnl = (current_price - position.entry_price) * position.size + + # Check pending limit orders + for order_id, order in list(self._orders.items()): + if order.status != OrderStatus.OPEN: + continue + + current_price = self.get_current_price(order.symbol) + if current_price <= 0: + continue + + should_fill = False + + if order.order_type == OrderType.LIMIT: + if order.side == OrderSide.BUY and current_price <= order.price: + should_fill = True + elif order.side == OrderSide.SELL and current_price >= order.price: + should_fill = True + + if should_fill: + # Release locked funds first (for buy orders) + if order.side == OrderSide.BUY and order.locked_funds > 0: + self._locked_balance -= order.locked_funds + self._cash += order.locked_funds + order.locked_funds = 0 + + fill_price = order.price # Limit orders fill at limit price + self._fill_order(order, fill_price) + + events.append({ + 'type': 'fill', + 'order_id': order_id, + 'symbol': order.symbol, + 'side': order.side.value, + 'size': order.filled_qty, + 'price': order.filled_price, + 'commission': order.commission + }) + + logger.info(f"PaperBroker: Limit order filled: {order.side.value} {order.size} {order.symbol} @ {fill_price:.4f}") + + return events + + def get_trade_history(self) -> List[Dict[str, Any]]: + """Get all executed trades.""" + return self._trade_history.copy() + + def reset(self): + """Reset the broker to initial state.""" + self._cash = self.initial_balance + self._locked_balance = 0.0 + self._orders.clear() + self._positions.clear() + self._trade_history.clear() + self._current_prices.clear() + logger.info(f"PaperBroker: Reset with balance {self.initial_balance}") diff --git a/tests/test_brokers.py b/tests/test_brokers.py new file mode 100644 index 0000000..a21daef --- /dev/null +++ b/tests/test_brokers.py @@ -0,0 +1,256 @@ +""" +Tests for the broker abstraction layer. +""" +import pytest +from brokers import ( + BaseBroker, BacktestBroker, PaperBroker, + OrderSide, OrderType, OrderStatus, OrderResult, Position, + create_broker, TradingMode +) + + +class TestPaperBroker: + """Tests for PaperBroker.""" + + def test_create_paper_broker(self): + """Test creating a paper broker.""" + broker = PaperBroker(initial_balance=10000) + assert broker.get_balance() == 10000 + assert broker.get_available_balance() == 10000 + + def test_paper_broker_market_buy(self): + """Test market buy order.""" + broker = PaperBroker(initial_balance=10000, commission=0.001) + broker.update_price('BTC/USDT', 50000) + + result = broker.place_order( + symbol='BTC/USDT', + side=OrderSide.BUY, + order_type=OrderType.MARKET, + size=0.1 + ) + + assert result.success + assert result.status == OrderStatus.FILLED + + position = broker.get_position('BTC/USDT') + assert position is not None + assert position.size == 0.1 + + # Check balance deducted (price * size + commission) + expected_cost = 50000 * 0.1 * (1 + 0.001) # with slippage + assert broker.get_available_balance() < 10000 + + def test_paper_broker_market_sell(self): + """Test market sell order.""" + broker = PaperBroker(initial_balance=10000, commission=0.001) + broker.update_price('BTC/USDT', 50000) + + # First buy + broker.place_order( + symbol='BTC/USDT', + side=OrderSide.BUY, + order_type=OrderType.MARKET, + size=0.1 + ) + + # Update price and sell + broker.update_price('BTC/USDT', 55000) + result = broker.place_order( + symbol='BTC/USDT', + side=OrderSide.SELL, + order_type=OrderType.MARKET, + size=0.1 + ) + + assert result.success + assert result.status == OrderStatus.FILLED + + # Position should be closed + position = broker.get_position('BTC/USDT') + assert position is None + + def test_paper_broker_insufficient_funds(self): + """Test order rejection due to insufficient funds.""" + broker = PaperBroker(initial_balance=1000) + broker.update_price('BTC/USDT', 50000) + + result = broker.place_order( + symbol='BTC/USDT', + side=OrderSide.BUY, + order_type=OrderType.MARKET, + size=0.1 # Would cost ~5000 + ) + + assert not result.success + assert "Insufficient funds" in result.message + + def test_paper_broker_limit_order(self): + """Test limit order placement and fill.""" + broker = PaperBroker(initial_balance=10000, commission=0.001) + broker.update_price('BTC/USDT', 50000) + + result = broker.place_order( + symbol='BTC/USDT', + side=OrderSide.BUY, + order_type=OrderType.LIMIT, + size=0.1, + price=49000 + ) + + assert result.success + assert result.status == OrderStatus.OPEN + + # Order should be pending + open_orders = broker.get_open_orders() + assert len(open_orders) == 1 + + # Update price below limit - should fill + broker.update_price('BTC/USDT', 48000) + events = broker.update() + + assert len(events) == 1 + assert events[0]['type'] == 'fill' + + # Now position should exist + position = broker.get_position('BTC/USDT') + assert position is not None + assert position.size == 0.1 + + def test_paper_broker_cancel_order(self): + """Test order cancellation.""" + broker = PaperBroker(initial_balance=10000, commission=0, slippage=0) + broker.update_price('BTC/USDT', 50000) + + initial_balance = broker.get_available_balance() + + result = broker.place_order( + symbol='BTC/USDT', + side=OrderSide.BUY, + order_type=OrderType.LIMIT, + size=0.1, + price=49000 + ) + + assert result.success + assert broker.get_available_balance() < initial_balance # Funds locked + + # Cancel the order + cancelled = broker.cancel_order(result.order_id) + assert cancelled + + # Funds should be released + assert broker.get_available_balance() == initial_balance + + def test_paper_broker_pnl_tracking(self): + """Test P&L tracking.""" + broker = PaperBroker(initial_balance=10000, commission=0, slippage=0) + broker.update_price('BTC/USDT', 50000) + + # Buy at 50000 + broker.place_order( + symbol='BTC/USDT', + side=OrderSide.BUY, + order_type=OrderType.MARKET, + size=0.1 + ) + + # Price goes up + broker.update_price('BTC/USDT', 52000) + broker.update() + + position = broker.get_position('BTC/USDT') + assert position is not None + # Unrealized P&L: (52000 - 50000) * 0.1 = 200 + assert position.unrealized_pnl == 200 + + def test_paper_broker_reset(self): + """Test broker reset.""" + broker = PaperBroker(initial_balance=10000) + broker.update_price('BTC/USDT', 50000) + broker.place_order( + symbol='BTC/USDT', + side=OrderSide.BUY, + order_type=OrderType.MARKET, + size=0.1 + ) + + broker.reset() + + assert broker.get_balance() == 10000 + assert broker.get_all_positions() == [] + assert broker.get_open_orders() == [] + + +class TestBrokerFactory: + """Tests for the broker factory.""" + + def test_create_paper_broker(self): + """Test creating paper broker via factory.""" + broker = create_broker( + mode=TradingMode.PAPER, + initial_balance=5000 + ) + assert isinstance(broker, PaperBroker) + assert broker.get_balance() == 5000 + + def test_create_backtest_broker(self): + """Test creating backtest broker via factory.""" + broker = create_broker( + mode=TradingMode.BACKTEST, + initial_balance=10000 + ) + assert isinstance(broker, BacktestBroker) + + def test_create_live_broker_not_implemented(self): + """Test that live broker raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + create_broker(mode=TradingMode.LIVE) + + def test_invalid_mode(self): + """Test that invalid mode raises ValueError.""" + with pytest.raises(ValueError): + create_broker(mode='invalid_mode') + + +class TestOrderResult: + """Tests for OrderResult dataclass.""" + + def test_order_result_success(self): + """Test successful order result.""" + result = OrderResult( + success=True, + order_id='12345', + status=OrderStatus.FILLED, + filled_qty=0.1, + filled_price=50000 + ) + assert result.success + assert result.order_id == '12345' + assert result.status == OrderStatus.FILLED + + def test_order_result_failure(self): + """Test failed order result.""" + result = OrderResult( + success=False, + message="Insufficient funds" + ) + assert not result.success + assert "Insufficient" in result.message + + +class TestPosition: + """Tests for Position dataclass.""" + + def test_position_creation(self): + """Test position creation.""" + position = Position( + symbol='BTC/USDT', + size=0.1, + entry_price=50000, + current_price=51000, + unrealized_pnl=100 + ) + assert position.symbol == 'BTC/USDT' + assert position.size == 0.1 + assert position.unrealized_pnl == 100