Phase 2: Broker abstraction foundation

- Create brokers package with unified trading interface
- BaseBroker: Abstract base class defining broker contract
- BacktestBroker: Wraps Backtrader for backtesting mode
- PaperBroker: Simulated order execution for paper trading
- Factory function to create broker based on trading mode
- Comprehensive test suite for broker functionality

The broker abstraction enables strategies to work identically
across backtest, paper, and live trading modes.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
rob 2026-02-28 17:00:43 -04:00
parent e245f80e18
commit f1182d4e0c
7 changed files with 1300 additions and 1 deletions

View File

@ -1,13 +1,16 @@
# backtest_strategy_instance.py # backtest_strategy_instance.py
import logging import logging
from typing import Any, Optional from typing import Any, Optional, TYPE_CHECKING
import pandas as pd import pandas as pd
import datetime as dt import datetime as dt
import backtrader as bt import backtrader as bt
from StrategyInstance import StrategyInstance from StrategyInstance import StrategyInstance
if TYPE_CHECKING:
from brokers import BacktestBroker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

19
src/brokers/__init__.py Normal file
View File

@ -0,0 +1,19 @@
"""
Broker Abstraction Layer for BrighterTrading.
This package provides a unified interface for executing trades across different modes:
- BacktestBroker: Uses Backtrader for historical simulation
- PaperBroker: Simulates fills with live price data
- LiveBroker: Executes real trades via exchange APIs (CCXT)
"""
from .base_broker import BaseBroker, OrderSide, OrderType, OrderStatus, OrderResult, Position
from .backtest_broker import BacktestBroker
from .paper_broker import PaperBroker
from .factory import create_broker, TradingMode, get_available_modes
__all__ = [
'BaseBroker', 'OrderSide', 'OrderType', 'OrderStatus', 'OrderResult', 'Position',
'BacktestBroker', 'PaperBroker',
'create_broker', 'TradingMode', 'get_available_modes'
]

View File

@ -0,0 +1,269 @@
"""
Backtest Broker Implementation for BrighterTrading.
Wraps Backtrader's broker functionality to provide a unified interface
for backtesting strategies.
"""
import logging
from typing import Any, Dict, List, Optional
import uuid
try:
import backtrader as bt
except ImportError:
bt = None
from .base_broker import (
BaseBroker, OrderResult, OrderSide, OrderType, OrderStatus, Position
)
logger = logging.getLogger(__name__)
class BacktestBroker(BaseBroker):
"""
Broker implementation for backtesting using Backtrader.
This broker delegates order execution to Backtrader's internal broker,
while providing the unified BaseBroker interface.
"""
def __init__(
self,
backtrader_strategy: 'bt.Strategy' = None,
initial_balance: float = 10000.0,
commission: float = 0.001,
slippage: float = 0.0
):
"""
Initialize the BacktestBroker.
:param backtrader_strategy: Reference to the Backtrader strategy instance.
:param initial_balance: Starting balance (set in Cerebro, not here).
:param commission: Commission rate (set in Cerebro, not here).
:param slippage: Slippage rate.
"""
super().__init__(initial_balance, commission, slippage)
self._bt_strategy = backtrader_strategy
self._pending_orders: Dict[str, Any] = {}
def set_backtrader_strategy(self, strategy: 'bt.Strategy'):
"""Set the Backtrader strategy reference after initialization."""
self._bt_strategy = strategy
@property
def _bt_broker(self):
"""Get the Backtrader broker instance."""
if self._bt_strategy is None:
raise RuntimeError("Backtrader strategy not set")
return self._bt_strategy.broker
@property
def _bt_data(self):
"""Get the Backtrader data feed."""
if self._bt_strategy is None:
raise RuntimeError("Backtrader strategy not set")
return self._bt_strategy.data
def place_order(
self,
symbol: str,
side: OrderSide,
order_type: OrderType,
size: float,
price: Optional[float] = None,
stop_loss: Optional[float] = None,
take_profit: Optional[float] = None,
time_in_force: str = 'GTC'
) -> OrderResult:
"""Place an order via Backtrader."""
if self._bt_strategy is None:
return OrderResult(
success=False,
message="Backtrader strategy not initialized"
)
try:
# Map order type to Backtrader execution type
if order_type == OrderType.MARKET:
exectype = bt.Order.Market
bt_price = None
elif order_type == OrderType.LIMIT:
exectype = bt.Order.Limit
bt_price = price
elif order_type == OrderType.STOP:
exectype = bt.Order.Stop
bt_price = price
else:
exectype = bt.Order.Market
bt_price = None
# Place the order via Backtrader
if side == OrderSide.BUY:
if bt_price is not None:
bt_order = self._bt_strategy.buy(size=size, exectype=exectype, price=bt_price)
else:
bt_order = self._bt_strategy.buy(size=size, exectype=exectype)
else:
if bt_price is not None:
bt_order = self._bt_strategy.sell(size=size, exectype=exectype, price=bt_price)
else:
bt_order = self._bt_strategy.sell(size=size, exectype=exectype)
# Generate order ID
order_id = str(bt_order.ref) if bt_order else str(uuid.uuid4())
# Store order reference
self._pending_orders[order_id] = bt_order
logger.info(f"BacktestBroker: Placed {side.value} order for {size} {symbol}")
return OrderResult(
success=True,
order_id=order_id,
status=OrderStatus.PENDING,
message=f"Order placed: {side.value} {size} {symbol}"
)
except Exception as e:
logger.error(f"BacktestBroker: Failed to place order: {e}")
return OrderResult(
success=False,
message=str(e)
)
def cancel_order(self, order_id: str) -> bool:
"""Cancel an order."""
if order_id in self._pending_orders:
bt_order = self._pending_orders[order_id]
if bt_order is not None:
self._bt_broker.cancel(bt_order)
del self._pending_orders[order_id]
return True
return False
def get_order(self, order_id: str) -> Optional[Dict[str, Any]]:
"""Get order details."""
if order_id in self._pending_orders:
bt_order = self._pending_orders[order_id]
if bt_order is not None:
return {
'order_id': order_id,
'status': self._map_bt_status(bt_order.status),
'size': bt_order.size,
'price': bt_order.price,
'executed_size': bt_order.executed.size if bt_order.executed else 0,
'executed_price': bt_order.executed.price if bt_order.executed else 0
}
return None
def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get all open orders."""
orders = []
for order_id, bt_order in self._pending_orders.items():
if bt_order is not None and bt_order.status in [bt.Order.Submitted, bt.Order.Accepted]:
orders.append(self.get_order(order_id))
return [o for o in orders if o is not None]
def get_balance(self, asset: Optional[str] = None) -> float:
"""Get total portfolio value."""
return self._bt_broker.getvalue()
def get_available_balance(self, asset: Optional[str] = None) -> float:
"""Get available cash."""
return self._bt_broker.getcash()
def get_position(self, symbol: str) -> Optional[Position]:
"""Get position for a symbol."""
if self._bt_strategy is None:
return None
bt_position = self._bt_broker.getposition(self._bt_data)
if bt_position.size == 0:
return None
current_price = self.get_current_price(symbol)
entry_price = bt_position.price
unrealized_pnl = (current_price - entry_price) * bt_position.size
return Position(
symbol=symbol,
size=bt_position.size,
entry_price=entry_price,
current_price=current_price,
unrealized_pnl=unrealized_pnl
)
def get_all_positions(self) -> List[Position]:
"""Get all open positions."""
positions = []
if self._bt_strategy is None:
return positions
# For single-data backtests, just check the main data
position = self.get_position("backtest_symbol")
if position is not None:
positions.append(position)
return positions
def get_current_price(self, symbol: str) -> float:
"""Get current price from Backtrader data feed."""
if self._bt_strategy is None:
return 0.0
return self._bt_data.close[0]
def update(self) -> List[Dict[str, Any]]:
"""
Process pending orders.
In Backtrader, order processing is handled by the engine.
This method checks for completed orders and updates state.
"""
events = []
# Check for completed orders
completed_orders = []
for order_id, bt_order in self._pending_orders.items():
if bt_order is None:
continue
if bt_order.status == bt.Order.Completed:
events.append({
'type': 'fill',
'order_id': order_id,
'size': bt_order.executed.size,
'price': bt_order.executed.price,
'commission': bt_order.executed.comm
})
completed_orders.append(order_id)
elif bt_order.status in [bt.Order.Canceled, bt.Order.Margin, bt.Order.Rejected]:
events.append({
'type': 'cancelled',
'order_id': order_id,
'reason': self._map_bt_status(bt_order.status)
})
completed_orders.append(order_id)
# Remove completed orders
for order_id in completed_orders:
del self._pending_orders[order_id]
return events
@staticmethod
def _map_bt_status(bt_status: int) -> str:
"""Map Backtrader order status to string."""
status_map = {
bt.Order.Created: 'created',
bt.Order.Submitted: 'submitted',
bt.Order.Accepted: 'accepted',
bt.Order.Partial: 'partial',
bt.Order.Completed: 'completed',
bt.Order.Canceled: 'cancelled',
bt.Order.Expired: 'expired',
bt.Order.Margin: 'margin_call',
bt.Order.Rejected: 'rejected'
}
return status_map.get(bt_status, 'unknown')

252
src/brokers/base_broker.py Normal file
View File

@ -0,0 +1,252 @@
"""
Base Broker Interface for BrighterTrading.
This abstract base class defines the interface that all broker implementations
must follow, enabling strategies to work identically across backtest, paper,
and live trading modes.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from enum import Enum
import logging
logger = logging.getLogger(__name__)
class OrderSide(Enum):
"""Order side enumeration."""
BUY = 'buy'
SELL = 'sell'
class OrderType(Enum):
"""Order type enumeration."""
MARKET = 'market'
LIMIT = 'limit'
STOP = 'stop'
STOP_LIMIT = 'stop_limit'
class OrderStatus(Enum):
"""Order status enumeration."""
PENDING = 'pending'
OPEN = 'open'
FILLED = 'filled'
PARTIALLY_FILLED = 'partially_filled'
CANCELLED = 'cancelled'
REJECTED = 'rejected'
EXPIRED = 'expired'
@dataclass
class OrderResult:
"""Result of an order placement."""
success: bool
order_id: Optional[str] = None
message: Optional[str] = None
status: OrderStatus = OrderStatus.PENDING
filled_qty: float = 0.0
filled_price: float = 0.0
commission: float = 0.0
@dataclass
class Position:
"""Represents a position in a symbol."""
symbol: str
size: float # Positive for long, negative for short
entry_price: float
current_price: float
unrealized_pnl: float
realized_pnl: float = 0.0
class BaseBroker(ABC):
"""
Abstract base class for all broker implementations.
Provides a unified interface for:
- Order placement and management
- Balance and position tracking
- Price retrieval
Subclasses must implement all abstract methods.
"""
def __init__(self, initial_balance: float = 10000.0,
commission: float = 0.001,
slippage: float = 0.0):
"""
Initialize the broker.
:param initial_balance: Starting balance in quote currency.
:param commission: Commission rate (0.001 = 0.1%).
:param slippage: Slippage rate for market orders.
"""
self.initial_balance = initial_balance
self.commission = commission
self.slippage = slippage
self._orders: Dict[str, Dict[str, Any]] = {}
self._positions: Dict[str, Position] = {}
@abstractmethod
def place_order(
self,
symbol: str,
side: OrderSide,
order_type: OrderType,
size: float,
price: Optional[float] = None,
stop_loss: Optional[float] = None,
take_profit: Optional[float] = None,
time_in_force: str = 'GTC'
) -> OrderResult:
"""
Place an order.
:param symbol: Trading symbol (e.g., 'BTC/USDT').
:param side: OrderSide.BUY or OrderSide.SELL.
:param order_type: Type of order (market, limit, etc.).
:param size: Order size in base currency.
:param price: Limit price (required for limit orders).
:param stop_loss: Stop loss price.
:param take_profit: Take profit price.
:param time_in_force: Time in force ('GTC', 'IOC', 'FOK').
:return: OrderResult with order details.
"""
pass
@abstractmethod
def cancel_order(self, order_id: str) -> bool:
"""
Cancel an existing order.
:param order_id: The order ID to cancel.
:return: True if cancelled successfully.
"""
pass
@abstractmethod
def get_order(self, order_id: str) -> Optional[Dict[str, Any]]:
"""
Get order details by ID.
:param order_id: The order ID.
:return: Order details or None if not found.
"""
pass
@abstractmethod
def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Get all open orders, optionally filtered by symbol.
:param symbol: Optional symbol to filter by.
:return: List of open orders.
"""
pass
@abstractmethod
def get_balance(self, asset: Optional[str] = None) -> float:
"""
Get balance for an asset or total balance.
:param asset: Asset symbol (e.g., 'USDT'). None for total.
:return: Balance amount.
"""
pass
@abstractmethod
def get_available_balance(self, asset: Optional[str] = None) -> float:
"""
Get available balance (not locked in orders).
:param asset: Asset symbol. None for total.
:return: Available balance.
"""
pass
@abstractmethod
def get_position(self, symbol: str) -> Optional[Position]:
"""
Get current position for a symbol.
:param symbol: Trading symbol.
:return: Position object or None if no position.
"""
pass
@abstractmethod
def get_all_positions(self) -> List[Position]:
"""
Get all open positions.
:return: List of Position objects.
"""
pass
@abstractmethod
def get_current_price(self, symbol: str) -> float:
"""
Get the current market price for a symbol.
:param symbol: Trading symbol.
:return: Current price.
"""
pass
@abstractmethod
def update(self) -> List[Dict[str, Any]]:
"""
Process pending orders and update positions.
Called on each tick/bar to check for fills and update state.
:return: List of events (fills, updates, etc.).
"""
pass
def get_equity(self) -> float:
"""
Get total equity (balance + unrealized P&L).
:return: Total equity value.
"""
balance = self.get_balance()
unrealized_pnl = sum(pos.unrealized_pnl for pos in self.get_all_positions())
return balance + unrealized_pnl
def close_position(self, symbol: str) -> OrderResult:
"""
Close an entire position for a symbol.
:param symbol: Trading symbol.
:return: OrderResult for the closing order.
"""
position = self.get_position(symbol)
if position is None or position.size == 0:
return OrderResult(success=False, message=f"No position in {symbol}")
side = OrderSide.SELL if position.size > 0 else OrderSide.BUY
size = abs(position.size)
return self.place_order(
symbol=symbol,
side=side,
order_type=OrderType.MARKET,
size=size
)
def close_all_positions(self) -> List[OrderResult]:
"""
Close all open positions.
:return: List of OrderResults for each closed position.
"""
results = []
for position in self.get_all_positions():
if position.size != 0:
results.append(self.close_position(position.symbol))
return results

88
src/brokers/factory.py Normal file
View File

@ -0,0 +1,88 @@
"""
Broker Factory for BrighterTrading.
Creates the appropriate broker instance based on the trading mode.
"""
import logging
from typing import Any, Optional, Callable
from .base_broker import BaseBroker
from .backtest_broker import BacktestBroker
from .paper_broker import PaperBroker
logger = logging.getLogger(__name__)
class TradingMode:
"""Trading mode constants."""
BACKTEST = 'backtest'
PAPER = 'paper'
LIVE = 'live'
def create_broker(
mode: str,
initial_balance: float = 10000.0,
commission: float = 0.001,
slippage: float = 0.0,
price_provider: Optional[Callable[[str], float]] = None,
data_cache: Any = None,
exchange_interface: Any = None,
user_name: str = None,
exchange_name: str = None,
**kwargs
) -> BaseBroker:
"""
Factory function to create the appropriate broker for the trading mode.
:param mode: Trading mode ('backtest', 'paper', 'live').
:param initial_balance: Starting balance.
:param commission: Commission rate.
:param slippage: Slippage rate.
:param price_provider: Callable for getting current prices (paper/live).
:param data_cache: DataCache instance for persistence.
:param exchange_interface: ExchangeInterface for live trading.
:param user_name: User name for live trading.
:param exchange_name: Exchange name for live trading.
:param kwargs: Additional arguments passed to broker constructor.
:return: Broker instance.
"""
mode = mode.lower()
if mode == TradingMode.BACKTEST:
logger.info("Creating BacktestBroker")
return BacktestBroker(
initial_balance=initial_balance,
commission=commission,
slippage=slippage,
**kwargs
)
elif mode == TradingMode.PAPER:
logger.info("Creating PaperBroker")
return PaperBroker(
initial_balance=initial_balance,
commission=commission,
slippage=slippage if slippage > 0 else 0.0005, # Default slippage for paper
price_provider=price_provider,
data_cache=data_cache,
**kwargs
)
elif mode == TradingMode.LIVE:
# Live broker will be implemented in Phase 5
raise NotImplementedError(
"Live trading broker not yet implemented. "
"Use paper trading for testing with live prices."
)
else:
raise ValueError(f"Invalid trading mode: {mode}. "
f"Must be one of: {TradingMode.BACKTEST}, "
f"{TradingMode.PAPER}, {TradingMode.LIVE}")
def get_available_modes() -> list:
"""Get list of available trading modes."""
return [TradingMode.BACKTEST, TradingMode.PAPER]

412
src/brokers/paper_broker.py Normal file
View File

@ -0,0 +1,412 @@
"""
Paper Trading Broker Implementation for BrighterTrading.
Simulates order execution with live price data for paper trading.
Orders are filled based on current market prices with configurable
slippage and commission.
"""
import logging
from typing import Any, Dict, List, Optional, Callable
import uuid
from datetime import datetime, timezone
from .base_broker import (
BaseBroker, OrderResult, OrderSide, OrderType, OrderStatus, Position
)
logger = logging.getLogger(__name__)
class PaperOrder:
"""Represents a paper trading order."""
def __init__(
self,
order_id: str,
symbol: str,
side: OrderSide,
order_type: OrderType,
size: float,
price: Optional[float] = None,
stop_loss: Optional[float] = None,
take_profit: Optional[float] = None,
time_in_force: str = 'GTC'
):
self.order_id = order_id
self.symbol = symbol
self.side = side
self.order_type = order_type
self.size = size
self.price = price
self.stop_loss = stop_loss
self.take_profit = take_profit
self.time_in_force = time_in_force
self.status = OrderStatus.PENDING
self.filled_qty = 0.0
self.filled_price = 0.0
self.commission = 0.0
self.locked_funds = 0.0 # Amount locked for this order
self.created_at = datetime.now(timezone.utc)
self.filled_at: Optional[datetime] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'order_id': self.order_id,
'symbol': self.symbol,
'side': self.side.value,
'order_type': self.order_type.value,
'size': self.size,
'price': self.price,
'stop_loss': self.stop_loss,
'take_profit': self.take_profit,
'status': self.status.value,
'filled_qty': self.filled_qty,
'filled_price': self.filled_price,
'commission': self.commission,
'created_at': self.created_at.isoformat(),
'filled_at': self.filled_at.isoformat() if self.filled_at else None
}
class PaperBroker(BaseBroker):
"""
Paper trading broker that simulates order execution.
Features:
- Simulated fills based on market prices
- Configurable slippage and commission
- Position tracking with P&L calculation
- Persistent state across restarts (via data_cache)
"""
def __init__(
self,
price_provider: Optional[Callable[[str], float]] = None,
data_cache: Any = None,
initial_balance: float = 10000.0,
commission: float = 0.001,
slippage: float = 0.0005
):
"""
Initialize the PaperBroker.
:param price_provider: Callable that returns current price for a symbol.
:param data_cache: DataCache instance for persistence.
:param initial_balance: Starting balance in quote currency.
:param commission: Commission rate (0.001 = 0.1%).
:param slippage: Slippage rate for market orders.
"""
super().__init__(initial_balance, commission, slippage)
self._price_provider = price_provider
self._data_cache = data_cache
# Balance tracking
self._cash = initial_balance
self._locked_balance = 0.0
# Order and position storage
self._orders: Dict[str, PaperOrder] = {}
self._positions: Dict[str, Position] = {}
self._trade_history: List[Dict[str, Any]] = []
# Current prices cache
self._current_prices: Dict[str, float] = {}
def set_price_provider(self, provider: Callable[[str], float]):
"""Set the price provider callable."""
self._price_provider = provider
def update_price(self, symbol: str, price: float):
"""Update the current price for a symbol."""
self._current_prices[symbol] = price
def place_order(
self,
symbol: str,
side: OrderSide,
order_type: OrderType,
size: float,
price: Optional[float] = None,
stop_loss: Optional[float] = None,
take_profit: Optional[float] = None,
time_in_force: str = 'GTC'
) -> OrderResult:
"""Place a paper trading order."""
order_id = str(uuid.uuid4())[:8]
# Validate order
current_price = self.get_current_price(symbol)
if current_price <= 0:
return OrderResult(
success=False,
message=f"Cannot get current price for {symbol}"
)
# Calculate required margin/cost
order_value = size * current_price
required_funds = order_value * (1 + self.commission)
if side == OrderSide.BUY and required_funds > self._cash:
return OrderResult(
success=False,
message=f"Insufficient funds: need {required_funds:.2f}, have {self._cash:.2f}"
)
# For sell orders, check if we have the position
if side == OrderSide.SELL:
position = self._positions.get(symbol)
if position is None or position.size < size:
available = position.size if position else 0
return OrderResult(
success=False,
message=f"Insufficient position: need {size}, have {available}"
)
# Create the order
order = PaperOrder(
order_id=order_id,
symbol=symbol,
side=side,
order_type=order_type,
size=size,
price=price,
stop_loss=stop_loss,
take_profit=take_profit,
time_in_force=time_in_force
)
# For market orders, fill immediately
if order_type == OrderType.MARKET:
fill_price = self._calculate_fill_price(symbol, side, current_price)
self._fill_order(order, fill_price)
logger.info(f"PaperBroker: Market order filled: {side.value} {size} {symbol} @ {fill_price:.4f}")
else:
# Store pending order
order.status = OrderStatus.OPEN
self._orders[order_id] = order
# Lock funds for buy limit orders
if side == OrderSide.BUY:
order.locked_funds = required_funds
self._locked_balance += required_funds
self._cash -= required_funds
logger.info(f"PaperBroker: Limit order placed: {side.value} {size} {symbol} @ {price}")
return OrderResult(
success=True,
order_id=order_id,
status=order.status,
filled_qty=order.filled_qty,
filled_price=order.filled_price,
commission=order.commission,
message=f"Order {order_id} {'filled' if order.status == OrderStatus.FILLED else 'placed'}"
)
def _calculate_fill_price(self, symbol: str, side: OrderSide, market_price: float) -> float:
"""Calculate fill price with slippage."""
if side == OrderSide.BUY:
return market_price * (1 + self.slippage)
else:
return market_price * (1 - self.slippage)
def _fill_order(self, order: PaperOrder, fill_price: float):
"""Execute an order fill."""
order.filled_qty = order.size
order.filled_price = fill_price
order.commission = order.size * fill_price * self.commission
order.status = OrderStatus.FILLED
order.filled_at = datetime.now(timezone.utc)
# Update balances and positions
order_value = order.size * fill_price
if order.side == OrderSide.BUY:
# Deduct cost from cash
total_cost = order_value + order.commission
self._cash -= total_cost
# Update position
if order.symbol in self._positions:
existing = self._positions[order.symbol]
new_size = existing.size + order.size
new_entry = (existing.entry_price * existing.size + fill_price * order.size) / new_size
existing.size = new_size
existing.entry_price = new_entry
else:
self._positions[order.symbol] = Position(
symbol=order.symbol,
size=order.size,
entry_price=fill_price,
current_price=fill_price,
unrealized_pnl=0.0
)
else:
# Add proceeds to cash
total_proceeds = order_value - order.commission
self._cash += total_proceeds
# Update position
if order.symbol in self._positions:
position = self._positions[order.symbol]
realized_pnl = (fill_price - position.entry_price) * order.size - order.commission
position.realized_pnl += realized_pnl
position.size -= order.size
# Remove position if fully closed
if position.size <= 0:
del self._positions[order.symbol]
# Record trade
self._trade_history.append({
'order_id': order.order_id,
'symbol': order.symbol,
'side': order.side.value,
'size': order.size,
'price': fill_price,
'commission': order.commission,
'timestamp': order.filled_at.isoformat()
})
# Store in orders dict for reference
self._orders[order.order_id] = order
def cancel_order(self, order_id: str) -> bool:
"""Cancel a pending order."""
if order_id not in self._orders:
return False
order = self._orders[order_id]
if order.status != OrderStatus.OPEN:
return False
# Release locked funds for buy orders
if order.side == OrderSide.BUY and order.locked_funds > 0:
self._locked_balance -= order.locked_funds
self._cash += order.locked_funds
order.locked_funds = 0
order.status = OrderStatus.CANCELLED
logger.info(f"PaperBroker: Order {order_id} cancelled")
return True
def get_order(self, order_id: str) -> Optional[Dict[str, Any]]:
"""Get order details."""
if order_id in self._orders:
return self._orders[order_id].to_dict()
return None
def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get all open orders."""
open_orders = [
order.to_dict() for order in self._orders.values()
if order.status == OrderStatus.OPEN
and (symbol is None or order.symbol == symbol)
]
return open_orders
def get_balance(self, asset: Optional[str] = None) -> float:
"""Get total balance including unrealized P&L."""
total = self._cash + self._locked_balance
for position in self._positions.values():
total += position.unrealized_pnl
return total
def get_available_balance(self, asset: Optional[str] = None) -> float:
"""Get available cash (not locked in orders)."""
return self._cash
def get_position(self, symbol: str) -> Optional[Position]:
"""Get position for a symbol."""
return self._positions.get(symbol)
def get_all_positions(self) -> List[Position]:
"""Get all open positions."""
return list(self._positions.values())
def get_current_price(self, symbol: str) -> float:
"""Get current price for a symbol."""
# First check cache
if symbol in self._current_prices:
return self._current_prices[symbol]
# Then try price provider
if self._price_provider:
try:
price = self._price_provider(symbol)
self._current_prices[symbol] = price
return price
except Exception as e:
logger.warning(f"PaperBroker: Failed to get price for {symbol}: {e}")
return 0.0
def update(self) -> List[Dict[str, Any]]:
"""
Process pending limit orders and update positions.
Called on each price update to check for fills.
"""
events = []
# Update position P&L
for symbol, position in self._positions.items():
current_price = self.get_current_price(symbol)
if current_price > 0:
position.current_price = current_price
position.unrealized_pnl = (current_price - position.entry_price) * position.size
# Check pending limit orders
for order_id, order in list(self._orders.items()):
if order.status != OrderStatus.OPEN:
continue
current_price = self.get_current_price(order.symbol)
if current_price <= 0:
continue
should_fill = False
if order.order_type == OrderType.LIMIT:
if order.side == OrderSide.BUY and current_price <= order.price:
should_fill = True
elif order.side == OrderSide.SELL and current_price >= order.price:
should_fill = True
if should_fill:
# Release locked funds first (for buy orders)
if order.side == OrderSide.BUY and order.locked_funds > 0:
self._locked_balance -= order.locked_funds
self._cash += order.locked_funds
order.locked_funds = 0
fill_price = order.price # Limit orders fill at limit price
self._fill_order(order, fill_price)
events.append({
'type': 'fill',
'order_id': order_id,
'symbol': order.symbol,
'side': order.side.value,
'size': order.filled_qty,
'price': order.filled_price,
'commission': order.commission
})
logger.info(f"PaperBroker: Limit order filled: {order.side.value} {order.size} {order.symbol} @ {fill_price:.4f}")
return events
def get_trade_history(self) -> List[Dict[str, Any]]:
"""Get all executed trades."""
return self._trade_history.copy()
def reset(self):
"""Reset the broker to initial state."""
self._cash = self.initial_balance
self._locked_balance = 0.0
self._orders.clear()
self._positions.clear()
self._trade_history.clear()
self._current_prices.clear()
logger.info(f"PaperBroker: Reset with balance {self.initial_balance}")

256
tests/test_brokers.py Normal file
View File

@ -0,0 +1,256 @@
"""
Tests for the broker abstraction layer.
"""
import pytest
from brokers import (
BaseBroker, BacktestBroker, PaperBroker,
OrderSide, OrderType, OrderStatus, OrderResult, Position,
create_broker, TradingMode
)
class TestPaperBroker:
"""Tests for PaperBroker."""
def test_create_paper_broker(self):
"""Test creating a paper broker."""
broker = PaperBroker(initial_balance=10000)
assert broker.get_balance() == 10000
assert broker.get_available_balance() == 10000
def test_paper_broker_market_buy(self):
"""Test market buy order."""
broker = PaperBroker(initial_balance=10000, commission=0.001)
broker.update_price('BTC/USDT', 50000)
result = broker.place_order(
symbol='BTC/USDT',
side=OrderSide.BUY,
order_type=OrderType.MARKET,
size=0.1
)
assert result.success
assert result.status == OrderStatus.FILLED
position = broker.get_position('BTC/USDT')
assert position is not None
assert position.size == 0.1
# Check balance deducted (price * size + commission)
expected_cost = 50000 * 0.1 * (1 + 0.001) # with slippage
assert broker.get_available_balance() < 10000
def test_paper_broker_market_sell(self):
"""Test market sell order."""
broker = PaperBroker(initial_balance=10000, commission=0.001)
broker.update_price('BTC/USDT', 50000)
# First buy
broker.place_order(
symbol='BTC/USDT',
side=OrderSide.BUY,
order_type=OrderType.MARKET,
size=0.1
)
# Update price and sell
broker.update_price('BTC/USDT', 55000)
result = broker.place_order(
symbol='BTC/USDT',
side=OrderSide.SELL,
order_type=OrderType.MARKET,
size=0.1
)
assert result.success
assert result.status == OrderStatus.FILLED
# Position should be closed
position = broker.get_position('BTC/USDT')
assert position is None
def test_paper_broker_insufficient_funds(self):
"""Test order rejection due to insufficient funds."""
broker = PaperBroker(initial_balance=1000)
broker.update_price('BTC/USDT', 50000)
result = broker.place_order(
symbol='BTC/USDT',
side=OrderSide.BUY,
order_type=OrderType.MARKET,
size=0.1 # Would cost ~5000
)
assert not result.success
assert "Insufficient funds" in result.message
def test_paper_broker_limit_order(self):
"""Test limit order placement and fill."""
broker = PaperBroker(initial_balance=10000, commission=0.001)
broker.update_price('BTC/USDT', 50000)
result = broker.place_order(
symbol='BTC/USDT',
side=OrderSide.BUY,
order_type=OrderType.LIMIT,
size=0.1,
price=49000
)
assert result.success
assert result.status == OrderStatus.OPEN
# Order should be pending
open_orders = broker.get_open_orders()
assert len(open_orders) == 1
# Update price below limit - should fill
broker.update_price('BTC/USDT', 48000)
events = broker.update()
assert len(events) == 1
assert events[0]['type'] == 'fill'
# Now position should exist
position = broker.get_position('BTC/USDT')
assert position is not None
assert position.size == 0.1
def test_paper_broker_cancel_order(self):
"""Test order cancellation."""
broker = PaperBroker(initial_balance=10000, commission=0, slippage=0)
broker.update_price('BTC/USDT', 50000)
initial_balance = broker.get_available_balance()
result = broker.place_order(
symbol='BTC/USDT',
side=OrderSide.BUY,
order_type=OrderType.LIMIT,
size=0.1,
price=49000
)
assert result.success
assert broker.get_available_balance() < initial_balance # Funds locked
# Cancel the order
cancelled = broker.cancel_order(result.order_id)
assert cancelled
# Funds should be released
assert broker.get_available_balance() == initial_balance
def test_paper_broker_pnl_tracking(self):
"""Test P&L tracking."""
broker = PaperBroker(initial_balance=10000, commission=0, slippage=0)
broker.update_price('BTC/USDT', 50000)
# Buy at 50000
broker.place_order(
symbol='BTC/USDT',
side=OrderSide.BUY,
order_type=OrderType.MARKET,
size=0.1
)
# Price goes up
broker.update_price('BTC/USDT', 52000)
broker.update()
position = broker.get_position('BTC/USDT')
assert position is not None
# Unrealized P&L: (52000 - 50000) * 0.1 = 200
assert position.unrealized_pnl == 200
def test_paper_broker_reset(self):
"""Test broker reset."""
broker = PaperBroker(initial_balance=10000)
broker.update_price('BTC/USDT', 50000)
broker.place_order(
symbol='BTC/USDT',
side=OrderSide.BUY,
order_type=OrderType.MARKET,
size=0.1
)
broker.reset()
assert broker.get_balance() == 10000
assert broker.get_all_positions() == []
assert broker.get_open_orders() == []
class TestBrokerFactory:
"""Tests for the broker factory."""
def test_create_paper_broker(self):
"""Test creating paper broker via factory."""
broker = create_broker(
mode=TradingMode.PAPER,
initial_balance=5000
)
assert isinstance(broker, PaperBroker)
assert broker.get_balance() == 5000
def test_create_backtest_broker(self):
"""Test creating backtest broker via factory."""
broker = create_broker(
mode=TradingMode.BACKTEST,
initial_balance=10000
)
assert isinstance(broker, BacktestBroker)
def test_create_live_broker_not_implemented(self):
"""Test that live broker raises NotImplementedError."""
with pytest.raises(NotImplementedError):
create_broker(mode=TradingMode.LIVE)
def test_invalid_mode(self):
"""Test that invalid mode raises ValueError."""
with pytest.raises(ValueError):
create_broker(mode='invalid_mode')
class TestOrderResult:
"""Tests for OrderResult dataclass."""
def test_order_result_success(self):
"""Test successful order result."""
result = OrderResult(
success=True,
order_id='12345',
status=OrderStatus.FILLED,
filled_qty=0.1,
filled_price=50000
)
assert result.success
assert result.order_id == '12345'
assert result.status == OrderStatus.FILLED
def test_order_result_failure(self):
"""Test failed order result."""
result = OrderResult(
success=False,
message="Insufficient funds"
)
assert not result.success
assert "Insufficient" in result.message
class TestPosition:
"""Tests for Position dataclass."""
def test_position_creation(self):
"""Test position creation."""
position = Position(
symbol='BTC/USDT',
size=0.1,
entry_price=50000,
current_price=51000,
unrealized_pnl=100
)
assert position.symbol == 'BTC/USDT'
assert position.size == 0.1
assert position.unrealized_pnl == 100