diff --git a/pytest.ini b/pytest.ini
index fe048b6..0488659 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -5,6 +5,8 @@ python_classes = Test*
python_functions = test_*
# Default: exclude integration tests (run with: pytest -m integration)
addopts = -v --tb=short -m "not integration"
+filterwarnings =
+ ignore:'crypt' is deprecated and slated for removal in Python 3\.13:DeprecationWarning
markers =
live_testnet: marks tests as requiring live testnet API keys (deselect with '-m "not live_testnet"')
diff --git a/requirements.txt b/requirements.txt
index d834492..70c4415 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,4 +13,7 @@ Flask-Cors~=3.0.10
email_validator~=2.2.0
aiohttp>=3.9.0
websockets>=12.0
-requests>=2.31.0
\ No newline at end of file
+requests>=2.31.0
+# Bitcoin wallet and encryption
+bit>=0.8.0
+cryptography>=41.0.0
\ No newline at end of file
diff --git a/src/BrighterTrades.py b/src/BrighterTrades.py
index 7eeacc8..79bd709 100644
--- a/src/BrighterTrades.py
+++ b/src/BrighterTrades.py
@@ -13,6 +13,7 @@ from indicators import Indicators
from Signals import Signals
from trade import Trades
from edm_client import EdmClient, EdmWebSocketClient
+from wallet import WalletManager
# Configure logging
logger = logging.getLogger(__name__)
@@ -71,6 +72,18 @@ class BrighterTrades:
edm_client=self.edm_client)
self.backtests = {} # In-memory storage for backtests (replace with DB access in production)
+ # Wallet manager for Bitcoin wallets and credits ledger
+ wallet_config = self.config.get_setting('wallet') or {}
+ wallet_keys = wallet_config.get('encryption_keys', {1: 'default_dev_key_change_in_production'})
+ # Ensure keys are int -> str mapping
+ wallet_keys = {int(k): v for k, v in wallet_keys.items()}
+ self.wallet_manager = WalletManager(
+ database=self.data.db,
+ encryption_keys=wallet_keys,
+ default_network=wallet_config.get('bitcoin_network', 'testnet')
+ )
+ logger.info(f"Wallet manager initialized (network: {wallet_config.get('bitcoin_network', 'testnet')})")
+
@staticmethod
def _coerce_user_id(user_id: Any) -> int | None:
if user_id is None or user_id == '':
@@ -752,7 +765,12 @@ class BrighterTrades:
# This ensures subscribers run with the creator's indicator definitions
indicator_owner_id = creator_id if is_subscribed and not is_owner else None
- # Early exchange requirements validation
+ # Check for strategy fees (only for non-owners running in paper/live mode)
+ strategy_fee = float(strategy_row.get('fee', 0.0))
+ has_fee = strategy_fee > 0 and not is_owner and mode in ['paper', 'live']
+ strategy_run_id = None # Will be set if fee accumulation is started
+
+ # Early exchange requirements validation (BEFORE fee accumulation to avoid orphaned fees)
from exchange_validation import extract_required_exchanges, validate_exchange_requirements
strategy_full = self.strategies.get_strategy_by_tbl_key(strategy_id)
required_exchanges = extract_required_exchanges(strategy_full)
@@ -933,6 +951,28 @@ class BrighterTrades:
# Paper mode: random UUID since paper state is ephemeral
strategy_instance_id = str(uuid.uuid4())
+ # Start fee accumulation only after all startup validation has passed.
+ if has_fee and creator_id is not None:
+ strategy_run_id = f"{strategy_id}_{user_id}_{uuid.uuid4().hex[:8]}"
+ fee_result = self.wallet_manager.start_fee_accumulation(
+ strategy_run_id=strategy_run_id,
+ user_id=user_id,
+ creator_user_id=creator_id,
+ fee_percent=int(strategy_fee), # 1-100% of exchange commission
+ estimated_trades=10 # Check for ~10 trades worth of credits
+ )
+
+ if not fee_result['success']:
+ return {
+ 'success': False,
+ 'message': fee_result.get('error', 'Failed to start fee accumulation'),
+ 'balance_available': fee_result.get('available', 0),
+ 'recommended_minimum': fee_result.get('recommended_minimum', 10000),
+ 'need_deposit': True
+ }
+
+ logger.info(f"Started fee accumulation for strategy {strategy_id}: run_id={strategy_run_id}")
+
instance = self.strategies.create_strategy_instance(
mode=mode,
strategy_instance_id=strategy_instance_id,
@@ -950,6 +990,15 @@ class BrighterTrades:
indicator_owner_id=indicator_owner_id, # For subscribed strategies, use creator's indicators
)
+ # Store fee tracking info on the instance
+ if strategy_run_id:
+ instance.strategy_run_id = strategy_run_id
+ instance.has_fee = True
+ instance.wallet_manager = self.wallet_manager
+ else:
+ instance.strategy_run_id = None
+ instance.has_fee = False
+
# Store the active instance
self.strategies.active_instances[instance_key] = instance
@@ -980,6 +1029,8 @@ class BrighterTrades:
return result
except Exception as e:
+ if strategy_run_id:
+ self.wallet_manager.cancel_fee_accumulation(strategy_run_id)
logger.error(f"Failed to create strategy instance: {e}", exc_info=True)
return {"success": False, "message": f"Failed to start strategy: {str(e)}"}
@@ -1028,9 +1079,20 @@ class BrighterTrades:
if hasattr(instance, 'trade_history'):
final_stats['total_trades'] = len(instance.trade_history)
+ # Settle accumulated fees if this was a paid strategy
+ fee_settlement = None
+ if hasattr(instance, 'strategy_run_id') and instance.strategy_run_id:
+ settle_result = self.wallet_manager.settle_accumulated_fees(instance.strategy_run_id)
+ fee_settlement = {
+ 'fees_settled': settle_result.get('settled', 0),
+ 'trades_charged': settle_result.get('trades', 0)
+ }
+ logger.info(f"Settled {settle_result.get('settled', 0)} sats for "
+ f"{settle_result.get('trades', 0)} trades on strategy {strategy_id}")
+
logger.info(f"Stopped strategy '{strategy_name}' for user {user_id} in {mode} mode")
- return {
+ result = {
"success": True,
"message": f"Strategy '{strategy_name}' stopped.",
"strategy_id": strategy_id,
@@ -1040,6 +1102,11 @@ class BrighterTrades:
"final_stats": final_stats,
}
+ if fee_settlement:
+ result["fee_settlement"] = fee_settlement
+
+ return result
+
def get_strategy_status(
self,
user_id: int,
@@ -1809,3 +1876,81 @@ class BrighterTrades:
if msg_type == 'reply':
# If the message is a reply log the response to the terminal.
print(f"\napp.py:Received reply: {msg_data}")
+
+ # ===== Wallet Methods =====
+
+ def create_user_wallet(self, user_id: int, user_sweep_address: str = None) -> dict:
+ """
+ Create BTC wallet for registered user (two-address model).
+
+ Args:
+ user_id: User ID to create wallet for.
+ user_sweep_address: Optional user-provided sweep address.
+
+ Returns:
+ Dict with success status and wallet details.
+ """
+ user = self.users.get_user_by_id(user_id)
+ if not user:
+ return {'success': False, 'error': 'User not found'}
+
+ # Check if user is a guest (guests can't have wallets)
+ username = user.get('user_name', '')
+ if username == 'guest' or user.get('is_guest', False):
+ return {'success': False, 'error': 'Only registered users can create wallets'}
+
+ return self.wallet_manager.create_wallet(user_id, user_sweep_address=user_sweep_address)
+
+ def get_user_wallet(self, user_id: int) -> dict:
+ """
+ Get user's wallet info (without private key).
+
+ Args:
+ user_id: User ID to look up.
+
+ Returns:
+ Dict with wallet info or None.
+ """
+ return self.wallet_manager.get_wallet(user_id)
+
+ def get_credits_balance(self, user_id: int) -> int:
+ """
+ Get user's spendable credits balance.
+
+ Args:
+ user_id: User ID to look up.
+
+ Returns:
+ Balance in satoshis.
+ """
+ return self.wallet_manager.get_credits_balance(user_id)
+
+ def request_withdrawal(self, user_id: int, amount_satoshis: int,
+ destination_address: str) -> dict:
+ """
+ Request BTC withdrawal.
+
+ Args:
+ user_id: User ID requesting withdrawal.
+ amount_satoshis: Amount to withdraw.
+ destination_address: Bitcoin address to send to.
+
+ Returns:
+ Dict with success status.
+ """
+ return self.wallet_manager.request_withdrawal(
+ user_id, amount_satoshis, destination_address
+ )
+
+ def get_transaction_history(self, user_id: int, limit: int = 20) -> list:
+ """
+ Get user's recent ledger transactions.
+
+ Args:
+ user_id: User ID to look up.
+ limit: Maximum number of transactions.
+
+ Returns:
+ List of transaction dicts.
+ """
+ return self.wallet_manager.get_transaction_history(user_id, limit)
diff --git a/src/Database.py b/src/Database.py
index 9008c27..41099cc 100644
--- a/src/Database.py
+++ b/src/Database.py
@@ -89,12 +89,16 @@ class Database:
def __init__(self, db_file: str = None):
self.db_file = db_file
- def execute_sql(self, sql: str, params: list = None) -> None:
+ def execute_sql(self, sql: str, params: tuple = None,
+ fetch_one: bool = False, fetch_all: bool = False) -> Any:
"""
- Executes a raw SQL statement with optional parameters.
+ Executes a raw SQL statement with optional parameters and fetch modes.
:param sql: SQL statement to execute.
:param params: Optional tuple of parameters to pass with the SQL statement.
+ :param fetch_one: If True, returns a single row (or None).
+ :param fetch_all: If True, returns all rows as a list (or empty list).
+ :return: None, single row, or list of rows depending on fetch parameters.
"""
with SQLite(self.db_file) as con:
cur = con.cursor()
@@ -103,6 +107,37 @@ class Database:
else:
cur.execute(sql, params)
+ if fetch_one:
+ return cur.fetchone()
+ elif fetch_all:
+ return cur.fetchall()
+ return None
+
+ def execute_in_transaction(self, statements: list) -> bool:
+ """
+ Executes multiple SQL statements within a single transaction.
+ All statements succeed or all fail (atomic).
+
+ :param statements: List of (sql, params) tuples to execute.
+ :return: True if all statements succeeded.
+ :raises: Exception if any statement fails (transaction rolled back).
+ """
+ conn = sqlite3.connect(self.db_file)
+ try:
+ cur = conn.cursor()
+ for sql, params in statements:
+ if params is None:
+ cur.execute(sql)
+ else:
+ cur.execute(sql, params)
+ conn.commit()
+ return True
+ except Exception:
+ conn.rollback()
+ raise
+ finally:
+ conn.close()
+
def get_all_rows(self, table_name: str) -> pd.DataFrame:
"""
Retrieves all rows from a table.
diff --git a/src/StrategyInstance.py b/src/StrategyInstance.py
index 16be176..af0d381 100644
--- a/src/StrategyInstance.py
+++ b/src/StrategyInstance.py
@@ -839,3 +839,45 @@ class StrategyInstance:
Retrieves the available balance for the strategy.
"""
return self.trades.get_available_balance(self.strategy_id)
+
+ def accumulate_trade_fee(self, trade_value_usd: float, commission_rate: float,
+ is_profitable: bool) -> dict:
+ """
+ Accumulate fee for a completed trade (only called for paid strategies).
+
+ This method is called when a trade fills to accumulate fees that will
+ be settled when the strategy stops.
+
+ :param trade_value_usd: The USD value of the trade.
+ :param commission_rate: The exchange commission rate (e.g., 0.001 for 0.1%).
+ :param is_profitable: Whether the trade was profitable.
+ :return: Dict with fee charged amount.
+ """
+ # Check if this strategy has fee tracking enabled
+ if not getattr(self, 'has_fee', False) or not getattr(self, 'strategy_run_id', None):
+ return {'success': True, 'fee_charged': 0, 'reason': 'no_fees_enabled'}
+
+ wallet_manager = getattr(self, 'wallet_manager', None)
+ if not wallet_manager:
+ return {'success': False, 'error': 'No wallet manager available'}
+
+ # Calculate exchange fee in USD, then convert to satoshis
+ # Assume 1 BTC = $50,000 for conversion (rough estimate, could be made dynamic)
+ btc_price_usd = 50000 # This could be fetched from exchange
+ exchange_fee_usd = trade_value_usd * commission_rate
+
+ # Convert to satoshis: 1 BTC = 100,000,000 satoshis
+ exchange_fee_btc = exchange_fee_usd / btc_price_usd
+ exchange_fee_satoshis = int(exchange_fee_btc * 100_000_000)
+
+ # Accumulate the fee
+ result = wallet_manager.accumulate_trade_fee(
+ strategy_run_id=self.strategy_run_id,
+ exchange_fee_satoshis=exchange_fee_satoshis,
+ is_profitable=is_profitable
+ )
+
+ if result.get('fee_charged', 0) > 0:
+ logger.debug(f"Accumulated fee: {result['fee_charged']} sats for trade worth ${trade_value_usd:.2f}")
+
+ return result
diff --git a/src/Users.py b/src/Users.py
index 94db5d2..331c3ca 100644
--- a/src/Users.py
+++ b/src/Users.py
@@ -60,6 +60,25 @@ class BaseUser:
filter_vals=('id', user_id)
)
+ def get_user_by_id(self, user_id: int) -> dict | None:
+ """
+ Retrieves user data as a dict based on the user ID.
+
+ :param user_id: The ID of the user.
+ :return: A dict containing the user's data, or None if not found.
+ """
+ try:
+ user_df = self.data.get_rows_from_datacache(
+ cache_name='users', filter_vals=[('id', user_id)])
+
+ if user_df is None or user_df.empty:
+ return None
+
+ # Convert first row to dict
+ return user_df.iloc[0].to_dict()
+ except Exception:
+ return None
+
def _remove_user_from_memory(self, user_name: str) -> None:
"""
Private method to remove a user's data from the cache (memory).
@@ -202,8 +221,11 @@ class UserAccountManagement(BaseUser):
"""
if self.validate_password(username=username, password=password):
self.modify_user_data(username=username, field_name="status", new_data="logged_in")
- self.modify_user_data(username=username, field_name="signin_time",
- new_data=dt.datetime.utcnow().timestamp())
+ self.modify_user_data(
+ username=username,
+ field_name="signin_time",
+ new_data=dt.datetime.now(dt.timezone.utc).timestamp()
+ )
return True
return False
@@ -216,13 +238,60 @@ class UserAccountManagement(BaseUser):
"""
# Update the user's status and sign-in time in both cache and database
self.modify_user_data(username=username, field_name='status', new_data='logged_out')
- self.modify_user_data(username=username, field_name='signin_time', new_data=dt.datetime.utcnow().timestamp())
+ self.modify_user_data(
+ username=username,
+ field_name='signin_time',
+ new_data=dt.datetime.now(dt.timezone.utc).timestamp()
+ )
# Remove the user's data from the cache
self._remove_user_from_memory(user_name=username)
return True
+ def update_email(self, user_id: int, email: str) -> bool:
+ """
+ Updates the user's email address.
+
+ :param user_id: The ID of the user.
+ :param email: The new email address.
+ :return: True on success, False on failure.
+ """
+ try:
+ username = self.get_username(user_id)
+ if not username:
+ return False
+
+ self.modify_user_data(username=username, field_name='email', new_data=email)
+ return True
+ except Exception:
+ return False
+
+ def update_password(self, user_id: int, current_password: str, new_password: str) -> dict:
+ """
+ Updates the user's password after validating the current password.
+
+ :param user_id: The ID of the user.
+ :param current_password: The current password for verification.
+ :param new_password: The new password to set.
+ :return: Dict with 'success' key and optionally 'error' key.
+ """
+ try:
+ username = self.get_username(user_id)
+ if not username:
+ return {'success': False, 'error': 'User not found'}
+
+ # Validate current password
+ if not self.validate_password(username=username, password=current_password):
+ return {'success': False, 'error': 'Current password is incorrect'}
+
+ # Hash and store new password
+ encrypted_password = self.scramble_text(new_password)
+ self.modify_user_data(username=username, field_name='password', new_data=encrypted_password)
+ return {'success': True}
+ except Exception as e:
+ return {'success': False, 'error': str(e)}
+
def log_out_all_users(self, enforcement: str = 'hard') -> None:
"""
Logs out all users by updating their status in the database and clearing the data.
@@ -378,7 +447,7 @@ class UserAccountManagement(BaseUser):
:param some_text: The text to encrypt.
:return: The hashed text.
"""
- return bcrypt.hash(some_text, rounds=13)
+ return bcrypt.using(rounds=13).hash(some_text)
class UserExchangeManagement(UserAccountManagement):
diff --git a/src/app.py b/src/app.py
index 374d343..1bd8f19 100644
--- a/src/app.py
+++ b/src/app.py
@@ -20,6 +20,7 @@ from email_validator import validate_email, EmailNotValidError # noqa: E402
# Local application imports
from BrighterTrades import BrighterTrades # noqa: E402
from utils import sanitize_for_json # noqa: E402
+from wallet import WalletBackgroundJobs # noqa: E402
# Set up logging
log_level_name = os.getenv('BRIGHTER_LOG_LEVEL', 'INFO').upper()
@@ -72,6 +73,8 @@ CORS_HEADERS = 'Content-Type'
# Socket ID to authenticated user_id mapping
# This is the source of truth for WebSocket authentication - never trust client payloads
socket_user_mapping = {} # request.sid -> user_id
+wallet_jobs = None
+_wallet_jobs_started = False
# Set the app directly with the globals.
app.config.from_object(__name__)
@@ -247,6 +250,41 @@ def start_strategy_loop():
# Start the loop when the app starts (will be called from main block)
+def start_wallet_background_jobs():
+ """
+ Start wallet background jobs once per process.
+
+ This supports both `python src/app.py` and WSGI imports.
+ Jobs are skipped in test contexts.
+ """
+ global wallet_jobs, _wallet_jobs_started
+
+ if _wallet_jobs_started:
+ return
+
+ if app.config.get('TESTING') or os.getenv('PYTEST_CURRENT_TEST'):
+ return
+
+ if os.getenv('BRIGHTER_DISABLE_WALLET_JOBS', '').lower() in ('1', 'true', 'yes'):
+ logging.info("Wallet background jobs disabled by BRIGHTER_DISABLE_WALLET_JOBS")
+ return
+
+ if not getattr(brighter_trades, 'wallet_manager', None):
+ logging.warning("Wallet manager not initialized - background jobs disabled")
+ return
+
+ wallet_jobs = WalletBackgroundJobs(brighter_trades.wallet_manager, socketio)
+ wallet_jobs.start_all_jobs()
+ _wallet_jobs_started = True
+ logging.info("Wallet background jobs started")
+
+
+@app.before_request
+def ensure_background_jobs_started():
+ """Ensure wallet background jobs are running in non-`__main__` deployments."""
+ start_wallet_background_jobs()
+
+
def _coerce_user_id(user_id):
if user_id is None or user_id == '':
return None
@@ -764,6 +802,240 @@ def _validate_blockly_xml(xml_string: str) -> bool:
return False
+# =============================================================================
+# Wallet API Routes
+# =============================================================================
+
+def _get_current_user_id():
+ """Get user_id from session. Returns None if not logged in."""
+ user_name = session.get('user')
+ if not user_name:
+ return None
+ return brighter_trades.users.get_id(user_name)
+
+
+# === User Profile APIs ===
+
+@app.route('/api/user/profile', methods=['GET'])
+def get_user_profile():
+ """Get current user's profile info."""
+ user_id = _get_current_user_id()
+ if not user_id:
+ return jsonify({'success': False, 'error': 'Not logged in'}), 401
+
+ user = brighter_trades.users.get_user_by_id(user_id)
+ if user:
+ return jsonify({
+ 'success': True,
+ 'profile': {
+ 'username': user.get('user_name'),
+ 'email': user.get('email', '')
+ }
+ })
+ return jsonify({'success': False, 'error': 'User not found'}), 404
+
+
+@app.route('/api/user/email', methods=['POST'])
+def update_user_email():
+ """Update current user's email address."""
+ user_id = _get_current_user_id()
+ if not user_id:
+ return jsonify({'success': False, 'error': 'Not logged in'}), 401
+
+ data = request.get_json() or {}
+ email = data.get('email', '').strip()
+
+ if not email:
+ return jsonify({'success': False, 'error': 'Email is required'}), 400
+
+ # Basic email validation
+ if '@' not in email or '.' not in email:
+ return jsonify({'success': False, 'error': 'Invalid email format'}), 400
+
+ result = brighter_trades.users.update_email(user_id, email)
+ if result:
+ return jsonify({'success': True})
+ return jsonify({'success': False, 'error': 'Failed to update email'}), 500
+
+
+@app.route('/api/user/password', methods=['POST'])
+def update_user_password():
+ """Update current user's password."""
+ user_id = _get_current_user_id()
+ if not user_id:
+ return jsonify({'success': False, 'error': 'Not logged in'}), 401
+
+ data = request.get_json() or {}
+ current_password = data.get('current_password', '')
+ new_password = data.get('new_password', '')
+
+ if not current_password or not new_password:
+ return jsonify({'success': False, 'error': 'Both current and new password are required'}), 400
+
+ if len(new_password) < 6:
+ return jsonify({'success': False, 'error': 'Password must be at least 6 characters'}), 400
+
+ result = brighter_trades.users.update_password(user_id, current_password, new_password)
+ if result.get('success'):
+ return jsonify({'success': True})
+ return jsonify({'success': False, 'error': result.get('error', 'Failed to update password')}), 400
+
+
+# === Wallet APIs ===
+
+@app.route('/api/wallet', methods=['GET'])
+def get_wallet():
+ """Get current user's wallet info (without private key)."""
+ user_id = _get_current_user_id()
+ if not user_id:
+ return jsonify({'success': False, 'error': 'Not logged in'}), 401
+
+ wallet = brighter_trades.get_user_wallet(user_id)
+ if wallet:
+ return jsonify({'success': True, 'wallet': wallet})
+ return jsonify({'success': True, 'wallet': None})
+
+
+@app.route('/api/wallet/create', methods=['POST'])
+def create_wallet():
+ """Create a new wallet for the current user (two-address model)."""
+ user_id = _get_current_user_id()
+ if not user_id:
+ return jsonify({'success': False, 'error': 'Not logged in'}), 401
+
+ data = request.get_json() or {}
+ user_sweep_address = data.get('user_sweep_address') # Optional user-provided sweep address
+
+ result = brighter_trades.create_user_wallet(user_id, user_sweep_address=user_sweep_address)
+ return jsonify(result)
+
+
+@app.route('/api/wallet/credits', methods=['GET'])
+def get_credits():
+ """Get user's spendable credits balance (from ledger)."""
+ user_id = _get_current_user_id()
+ if not user_id:
+ return jsonify({'success': False, 'error': 'Not logged in'}), 401
+
+ balance = brighter_trades.get_credits_balance(user_id)
+ return jsonify({'success': True, 'balance': balance})
+
+
+@app.route('/api/wallet/withdraw', methods=['POST'])
+def request_withdrawal():
+ """Request BTC withdrawal (processed async)."""
+ user_id = _get_current_user_id()
+ if not user_id:
+ return jsonify({'success': False, 'error': 'Not logged in'}), 401
+
+ data = request.get_json() or {}
+ amount = data.get('amount_satoshis')
+ destination = data.get('destination_address')
+
+ if not amount or not destination:
+ return jsonify({'success': False, 'error': 'Missing amount or destination'}), 400
+
+ try:
+ amount = int(amount)
+ except (TypeError, ValueError):
+ return jsonify({'success': False, 'error': 'Invalid amount'}), 400
+
+ # SECURITY: Prevent negative withdrawal (would credit balance)
+ if amount <= 0:
+ return jsonify({'success': False, 'error': 'Amount must be positive'}), 400
+
+ result = brighter_trades.request_withdrawal(user_id, amount, destination)
+ return jsonify(result)
+
+
+@app.route('/api/wallet/transactions', methods=['GET'])
+def get_transactions():
+ """Get user's recent transaction history."""
+ user_id = _get_current_user_id()
+ if not user_id:
+ return jsonify({'success': False, 'error': 'Not logged in'}), 401
+
+ limit = request.args.get('limit', 20, type=int)
+ transactions = brighter_trades.get_transaction_history(user_id, limit)
+ return jsonify({'success': True, 'transactions': transactions})
+
+
+@app.route('/api/wallet/keys', methods=['GET'])
+def get_wallet_keys():
+ """
+ Get fee address keys (for viewing after wallet creation).
+ Note: Sweep address private key is NOT stored and cannot be retrieved.
+ """
+ user_id = _get_current_user_id()
+ if not user_id:
+ return jsonify({'success': False, 'error': 'Not logged in'}), 401
+
+ # Keep key export disabled by default; can be opt-in for controlled POC sessions.
+ wallet_cfg = brighter_trades.config.get_setting('wallet') or {}
+ if not wallet_cfg.get('allow_key_export', False):
+ return jsonify({'success': False, 'error': 'Key export is disabled'}), 403
+
+ keys = brighter_trades.wallet_manager.get_wallet_keys(user_id)
+ if keys:
+ return jsonify({'success': True, **keys})
+ return jsonify({'success': False, 'error': 'Wallet not found or keys unavailable'})
+
+
+@app.route('/api/admin/credit', methods=['POST'])
+def admin_credit_user():
+ """
+ Admin endpoint to manually credit a user's wallet.
+ For POC testing until deposit detection is implemented.
+
+ SECURITY: In POC mode, users can only credit themselves.
+ In production, add proper admin role checking.
+ """
+ user_id = _get_current_user_id()
+ if not user_id:
+ return jsonify({'success': False, 'error': 'Not logged in'}), 401
+
+ data = request.get_json() or {}
+ target_user_id = data.get('user_id')
+ amount = data.get('amount_satoshis', 10000) # Default 10k sats
+ reason = data.get('reason', 'POC test credit')
+
+ # SECURITY: In POC mode, only allow crediting yourself
+ # Remove this restriction when proper admin auth is implemented
+ if target_user_id is not None:
+ try:
+ if int(target_user_id) != user_id:
+ return jsonify({'success': False, 'error': 'Can only credit your own account in POC mode'}), 403
+ except (TypeError, ValueError):
+ return jsonify({'success': False, 'error': 'Invalid user_id'}), 400
+
+ target_user_id = user_id # Force to self
+
+ try:
+ amount = int(amount)
+ except (TypeError, ValueError):
+ return jsonify({'success': False, 'error': 'Invalid amount'}), 400
+
+ # SECURITY: Prevent negative credits
+ if amount <= 0:
+ return jsonify({'success': False, 'error': 'Amount must be positive'}), 400
+
+ # SECURITY: Limit POC credits to prevent abuse
+ MAX_POC_CREDIT = 100000 # 0.001 BTC max per credit
+ if amount > MAX_POC_CREDIT:
+ return jsonify({'success': False, 'error': f'POC credit limited to {MAX_POC_CREDIT} satoshis'}), 400
+
+ result = brighter_trades.wallet_manager.admin_credit(
+ user_id=target_user_id,
+ amount_satoshis=amount,
+ reason=reason
+ )
+ return jsonify(result)
+
+
+# =============================================================================
+# Health Check Routes
+# =============================================================================
+
@app.route('/health/edm', methods=['GET'])
def edm_health():
"""
@@ -787,4 +1059,7 @@ if __name__ == '__main__':
start_strategy_loop()
logging.info("Strategy execution loop started in background")
+ # Start wallet background jobs (auto-sweep, deposit detection, withdrawal processing)
+ start_wallet_background_jobs()
+
socketio.run(app, host='127.0.0.1', port=5002, debug=False, use_reloader=False)
diff --git a/src/brokers/live_broker.py b/src/brokers/live_broker.py
index 50a4a13..74dd306 100644
--- a/src/brokers/live_broker.py
+++ b/src/brokers/live_broker.py
@@ -955,6 +955,16 @@ class LiveBroker(BaseBroker):
# Emit fill event if order was filled
if order.status == OrderStatus.FILLED and old_status != OrderStatus.FILLED:
+ # Calculate profitability for sell orders (before position update)
+ is_profitable = False
+ realized_pnl = 0.0
+ entry_price = 0.0
+ if order.side == OrderSide.SELL and order.symbol in self._positions:
+ pos = self._positions[order.symbol]
+ entry_price = pos.entry_price
+ realized_pnl = (order.filled_price - pos.entry_price) * order.filled_qty - order.commission
+ is_profitable = realized_pnl > 0
+
events.append({
'type': 'fill',
'order_id': order_id,
@@ -965,7 +975,10 @@ class LiveBroker(BaseBroker):
'filled_qty': order.filled_qty,
'price': order.filled_price,
'filled_price': order.filled_price,
- 'commission': order.commission
+ 'commission': order.commission,
+ 'is_profitable': is_profitable,
+ 'realized_pnl': realized_pnl,
+ 'entry_price': entry_price
})
logger.info(f"Order filled: {order_id} - {order.side.value} {order.filled_qty} {order.symbol} @ {order.filled_price}")
diff --git a/src/brokers/paper_broker.py b/src/brokers/paper_broker.py
index f0c45c3..165cd05 100644
--- a/src/brokers/paper_broker.py
+++ b/src/brokers/paper_broker.py
@@ -230,6 +230,11 @@ class PaperBroker(BaseBroker):
order.status = OrderStatus.FILLED
order.filled_at = datetime.now(timezone.utc)
+ # Calculate profitability for sell orders (for fee tracking)
+ order.is_profitable = False
+ order.realized_pnl = 0.0
+ order.entry_price = 0.0
+
# Update balances and positions
order_value = order.size * fill_price
@@ -261,10 +266,15 @@ class PaperBroker(BaseBroker):
# Update position
if order.symbol in self._positions:
position = self._positions[order.symbol]
+ order.entry_price = position.entry_price # Store for fee calculation
realized_pnl = (fill_price - position.entry_price) * order.size - order.commission
position.realized_pnl += realized_pnl
position.size -= order.size
+ # Track profitability for fee calculation
+ order.realized_pnl = realized_pnl
+ order.is_profitable = realized_pnl > 0
+
# Remove position if fully closed
if position.size <= 0:
del self._positions[order.symbol]
@@ -405,7 +415,10 @@ class PaperBroker(BaseBroker):
'filled_qty': order.filled_qty,
'price': order.filled_price,
'filled_price': order.filled_price,
- 'commission': order.commission
+ 'commission': order.commission,
+ 'is_profitable': getattr(order, 'is_profitable', False),
+ 'realized_pnl': getattr(order, 'realized_pnl', 0.0),
+ 'entry_price': getattr(order, 'entry_price', 0.0)
})
logger.info(f"PaperBroker: Limit order filled: {order.side.value} {order.size} {order.symbol} @ {fill_price:.4f}")
diff --git a/src/live_strategy_instance.py b/src/live_strategy_instance.py
index f3b1100..aedf4a1 100644
--- a/src/live_strategy_instance.py
+++ b/src/live_strategy_instance.py
@@ -375,6 +375,21 @@ class LiveStrategyInstance(StrategyInstance):
})
logger.info(f"Order filled: {event}")
+ # Accumulate strategy fees for paid strategies
+ # Fee is charged on profitable sell orders (closing positions)
+ if event.get('side') == 'sell':
+ filled_qty = event.get('filled_qty', 0)
+ filled_price = event.get('filled_price', 0)
+ trade_value = filled_qty * filled_price
+ commission_rate = self.live_broker._commission
+ # Use actual profitability from broker (based on realized PnL)
+ is_profitable = event.get('is_profitable', False)
+ self.accumulate_trade_fee(
+ trade_value_usd=trade_value,
+ commission_rate=commission_rate,
+ is_profitable=is_profitable
+ )
+
# Update balance attributes
self._update_balances()
diff --git a/src/paper_strategy_instance.py b/src/paper_strategy_instance.py
index 48b7a28..ea70741 100644
--- a/src/paper_strategy_instance.py
+++ b/src/paper_strategy_instance.py
@@ -240,6 +240,19 @@ class PaperStrategyInstance(StrategyInstance):
'filled_price': event.get('filled_price', event.get('price')),
})
+ # Accumulate strategy fees for paid strategies
+ # Fee is charged on profitable sells (closing a position)
+ if event.get('side') == 'sell':
+ trade_value = event.get('filled_qty', 0) * event.get('filled_price', 0)
+ commission_rate = self.paper_broker._commission
+ # Use actual profitability from broker (based on realized PnL)
+ is_profitable = event.get('is_profitable', False)
+ self.accumulate_trade_fee(
+ trade_value_usd=trade_value,
+ commission_rate=commission_rate,
+ is_profitable=is_profitable
+ )
+
# Update exec context with current data
self.exec_context['current_candle'] = candle_data
self.exec_context['current_price'] = price
diff --git a/src/static/Account.js b/src/static/Account.js
new file mode 100644
index 0000000..2cb281a
--- /dev/null
+++ b/src/static/Account.js
@@ -0,0 +1,695 @@
+/**
+ * Account - Manages user wallet and credits with two-address custody model
+ *
+ * Two Address Model:
+ * - Fee Address: We store private key, used for strategy fees (max ~$50)
+ * - Sweep Address: We don't store private key, user controls, receives excess
+ */
+class Account {
+ constructor() {
+ this.walletInfo = null;
+ this.creditsBalance = 0;
+ this.transactions = [];
+ this.userProfile = null;
+ this.initialized = false;
+ }
+
+ /**
+ * Initialize the account panel
+ */
+ async initialize() {
+ if (this.initialized) {
+ return;
+ }
+ this.setupEventListeners();
+ this.initialized = true;
+ }
+
+ /**
+ * Show the account settings dialog
+ */
+ async showAccountSettings() {
+ // Check if user is logged in (not a guest)
+ const username = document.getElementById('username_display')?.textContent || '';
+ if (username.startsWith('guest') || !username) {
+ alert('Please sign in to access account settings.');
+ return;
+ }
+
+ // Show dialog
+ document.getElementById('account_settings_form').style.display = 'block';
+
+ // Load profile and wallet data
+ await Promise.all([
+ this.loadProfile(),
+ this.refresh()
+ ]);
+
+ // Default to profile tab
+ this.switchTab('profile');
+ }
+
+ /**
+ * Close the account settings dialog
+ */
+ closeAccountSettings() {
+ document.getElementById('account_settings_form').style.display = 'none';
+ }
+
+ /**
+ * Switch between tabs in the account settings
+ */
+ switchTab(tabName) {
+ // Update tab buttons
+ document.querySelectorAll('.account-tab').forEach(tab => {
+ tab.classList.toggle('active', tab.dataset.tab === tabName);
+ });
+
+ // Update tab panels
+ document.querySelectorAll('.tab-panel').forEach(panel => {
+ panel.classList.toggle('active', panel.id === `${tabName}_tab`);
+ });
+ }
+
+ /**
+ * Load user profile data
+ */
+ async loadProfile() {
+ try {
+ const response = await fetch('/api/user/profile');
+ const data = await response.json();
+
+ if (data.success) {
+ this.userProfile = data.profile;
+ const emailEl = document.getElementById('current_email');
+ if (emailEl) {
+ emailEl.textContent = data.profile.email || 'Not set';
+ }
+ }
+ } catch (error) {
+ console.error('Error loading profile:', error);
+ }
+ }
+
+ /**
+ * Update user email
+ */
+ async updateEmail() {
+ const newEmail = document.getElementById('new_email').value.trim();
+ if (!newEmail) {
+ alert('Please enter a new email address.');
+ return;
+ }
+
+ // Basic email validation
+ if (!newEmail.includes('@') || !newEmail.includes('.')) {
+ alert('Please enter a valid email address.');
+ return;
+ }
+
+ try {
+ const response = await fetch('/api/user/email', {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({ email: newEmail })
+ });
+
+ const data = await response.json();
+
+ if (data.success) {
+ alert('Email updated successfully.');
+ document.getElementById('current_email').textContent = newEmail;
+ document.getElementById('new_email').value = '';
+ } else {
+ alert('Failed to update email: ' + (data.error || 'Unknown error'));
+ }
+ } catch (error) {
+ console.error('Error updating email:', error);
+ alert('Failed to update email: ' + error.message);
+ }
+ }
+
+ /**
+ * Update user password
+ */
+ async updatePassword() {
+ const currentPassword = document.getElementById('current_password').value;
+ const newPassword = document.getElementById('new_password').value;
+ const confirmPassword = document.getElementById('confirm_password').value;
+
+ if (!currentPassword || !newPassword || !confirmPassword) {
+ alert('Please fill in all password fields.');
+ return;
+ }
+
+ if (newPassword !== confirmPassword) {
+ alert('New passwords do not match.');
+ return;
+ }
+
+ if (newPassword.length < 6) {
+ alert('Password must be at least 6 characters.');
+ return;
+ }
+
+ try {
+ const response = await fetch('/api/user/password', {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({
+ current_password: currentPassword,
+ new_password: newPassword
+ })
+ });
+
+ const data = await response.json();
+
+ if (data.success) {
+ alert('Password updated successfully.');
+ document.getElementById('current_password').value = '';
+ document.getElementById('new_password').value = '';
+ document.getElementById('confirm_password').value = '';
+ } else {
+ alert('Failed to update password: ' + (data.error || 'Unknown error'));
+ }
+ } catch (error) {
+ console.error('Error updating password:', error);
+ alert('Failed to update password: ' + error.message);
+ }
+ }
+
+ /**
+ * Setup checkbox event listeners for modals
+ */
+ setupEventListeners() {
+ // Terms checkbox enables create button
+ const termsCheckbox = document.getElementById('terms_checkbox');
+ const createBtn = document.getElementById('create_wallet_btn');
+ if (termsCheckbox && createBtn) {
+ termsCheckbox.addEventListener('change', () => {
+ createBtn.disabled = !termsCheckbox.checked;
+ });
+ }
+
+ // Keys saved checkbox enables close button
+ const keysSavedCheckbox = document.getElementById('keys_saved_checkbox');
+ const closeKeysBtn = document.getElementById('close_keys_btn');
+ if (keysSavedCheckbox && closeKeysBtn) {
+ keysSavedCheckbox.addEventListener('change', () => {
+ closeKeysBtn.disabled = !keysSavedCheckbox.checked;
+ });
+ }
+ }
+
+ /**
+ * Refresh wallet info, credits balance, and transactions
+ */
+ async refresh() {
+ try {
+ // Load wallet and credits in parallel
+ const [walletRes, creditsRes, txRes] = await Promise.all([
+ fetch('/api/wallet'),
+ fetch('/api/wallet/credits'),
+ fetch('/api/wallet/transactions?limit=20')
+ ]);
+
+ const walletData = await walletRes.json();
+ const creditsData = await creditsRes.json();
+ const txData = await txRes.json();
+
+ if (walletData.success && walletData.wallet) {
+ this.walletInfo = walletData.wallet;
+ this.displayWallet();
+ } else {
+ this.walletInfo = null;
+ this.displayNoWallet();
+ }
+
+ if (creditsData.success) {
+ this.creditsBalance = creditsData.balance || 0;
+ this.updateCreditsDisplay();
+ }
+
+ if (txData.success) {
+ this.transactions = txData.transactions || [];
+ this.displayTransactions();
+ }
+ } catch (error) {
+ console.error('Error refreshing account:', error);
+ }
+ }
+
+ /**
+ * Display wallet info
+ */
+ displayWallet() {
+ const noWallet = document.getElementById('no_wallet');
+ const walletInfo = document.getElementById('wallet_info');
+
+ if (noWallet) noWallet.style.display = 'none';
+ if (walletInfo) walletInfo.style.display = 'block';
+
+ const addressEl = document.getElementById('btc_address');
+ const sweepAddressEl = document.getElementById('sweep_address');
+ const networkEl = document.getElementById('wallet_network');
+ const warningEl = document.getElementById('wallet_disabled_warning');
+
+ if (addressEl) {
+ addressEl.textContent = this.walletInfo.fee_address || this.walletInfo.address;
+ }
+ if (sweepAddressEl) {
+ sweepAddressEl.textContent = this.walletInfo.sweep_address || 'Not configured';
+ }
+ if (networkEl) {
+ networkEl.textContent = this.walletInfo.network === 'testnet' ? '(Testnet)' : '(Mainnet)';
+ }
+ if (warningEl) {
+ warningEl.style.display = this.walletInfo.is_disabled ? 'block' : 'none';
+ }
+ }
+
+ /**
+ * Display no wallet state
+ */
+ displayNoWallet() {
+ const noWallet = document.getElementById('no_wallet');
+ const walletInfo = document.getElementById('wallet_info');
+
+ if (noWallet) noWallet.style.display = 'block';
+ if (walletInfo) walletInfo.style.display = 'none';
+ }
+
+ /**
+ * Update credits balance display
+ */
+ updateCreditsDisplay() {
+ const balanceEl = document.getElementById('credits_balance');
+ if (balanceEl) {
+ balanceEl.textContent = this.formatSatoshis(this.creditsBalance);
+ }
+ }
+
+ /**
+ * Format satoshis as BTC with appropriate precision
+ */
+ formatSatoshis(satoshis) {
+ const btc = satoshis / 100000000;
+ if (btc === 0) return '0';
+ if (btc < 0.0001) return btc.toFixed(8);
+ if (btc < 1) return btc.toFixed(6);
+ return btc.toFixed(4);
+ }
+
+ /**
+ * Display transaction history
+ */
+ displayTransactions() {
+ const tbody = document.getElementById('transactions_body');
+ if (!tbody) return;
+
+ if (!this.transactions || this.transactions.length === 0) {
+ tbody.innerHTML = '
-
Fee(%):
-
+
+ Fee:
+
+ %
+ ⓘ
@@ -154,6 +159,31 @@
z-index: 10;
border-radius: 3px;
}
+
+/* Fee info icon tooltip */
+.fee-info-icon {
+ display: inline-block;
+ width: 16px;
+ height: 16px;
+ background: #3E3AF2;
+ color: white;
+ border-radius: 50%;
+ text-align: center;
+ font-size: 11px;
+ line-height: 16px;
+ cursor: help;
+ margin-left: 5px;
+ vertical-align: middle;
+}
+
+.fee-info-icon:hover {
+ background: #0A07DF;
+}
+
+/* Native tooltip styling via title attribute - multiline support */
+.fee-info-icon[title] {
+ position: relative;
+}
diff --git a/src/wallet/__init__.py b/src/wallet/__init__.py
new file mode 100644
index 0000000..45df518
--- /dev/null
+++ b/src/wallet/__init__.py
@@ -0,0 +1,8 @@
+"""Bitcoin wallet and credits ledger module for strategy fees."""
+
+from .wallet_manager import WalletManager
+from .bitcoin_service import BitcoinService
+from .encryption import KeyEncryption
+from .background_jobs import WalletBackgroundJobs
+
+__all__ = ['WalletManager', 'BitcoinService', 'KeyEncryption', 'WalletBackgroundJobs']
diff --git a/src/wallet/background_jobs.py b/src/wallet/background_jobs.py
new file mode 100644
index 0000000..020787d
--- /dev/null
+++ b/src/wallet/background_jobs.py
@@ -0,0 +1,194 @@
+"""
+Background jobs for wallet operations.
+
+- Auto-sweep: Transfer excess funds (over $50 cap) to user's sweep address
+- Deposit detection: Monitor for incoming deposits and credit user ledgers
+- Withdrawal processing: Process pending withdrawal requests
+"""
+import logging
+import time
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from .wallet_manager import WalletManager
+
+logger = logging.getLogger(__name__)
+
+# Job intervals in seconds
+AUTO_SWEEP_INTERVAL = 300 # 5 minutes
+DEPOSIT_CHECK_INTERVAL = 60 # 1 minute
+WITHDRAWAL_PROCESS_INTERVAL = 30 # 30 seconds
+
+
+class WalletBackgroundJobs:
+ """
+ Manages background jobs for wallet operations.
+ Uses eventlet for async execution compatible with Flask-SocketIO.
+ """
+
+ def __init__(self, wallet_manager: 'WalletManager', socketio=None):
+ self.wallet_manager = wallet_manager
+ self.socketio = socketio
+ self._running = False
+ self._jobs_started = False
+
+ def start_all_jobs(self):
+ """Start all background jobs."""
+ if self._jobs_started:
+ logger.warning("Wallet background jobs already started")
+ return
+
+ self._running = True
+ self._jobs_started = True
+
+ if self.socketio:
+ # Use SocketIO's background task mechanism
+ self.socketio.start_background_task(self._auto_sweep_loop)
+ self.socketio.start_background_task(self._deposit_detection_loop)
+ self.socketio.start_background_task(self._withdrawal_processing_loop)
+ logger.info("Wallet background jobs started via SocketIO")
+ else:
+ # Fallback to eventlet directly
+ import eventlet
+ eventlet.spawn(self._auto_sweep_loop)
+ eventlet.spawn(self._deposit_detection_loop)
+ eventlet.spawn(self._withdrawal_processing_loop)
+ logger.info("Wallet background jobs started via eventlet")
+
+ def stop_all_jobs(self):
+ """Stop all background jobs."""
+ self._running = False
+ logger.info("Wallet background jobs stopping...")
+
+ # === Auto-Sweep Job ===
+
+ def _auto_sweep_loop(self):
+ """Background loop for auto-sweep functionality."""
+ logger.info("Auto-sweep job started")
+ while self._running:
+ try:
+ self._process_auto_sweeps()
+ except Exception as e:
+ logger.error(f"Auto-sweep error: {e}")
+ self._sleep(AUTO_SWEEP_INTERVAL)
+
+ def _process_auto_sweeps(self):
+ """Check all wallets and sweep excess funds."""
+ try:
+ # Get all wallets that might need sweeping
+ wallets = self.wallet_manager.get_wallets_over_cap()
+
+ for wallet in wallets:
+ user_id = wallet['user_id']
+ balance = wallet['balance']
+ sweep_address = wallet['sweep_address']
+ cap = self.wallet_manager.BALANCE_CAP_SATOSHIS
+
+ if not sweep_address:
+ logger.warning(f"User {user_id} has no sweep address configured")
+ continue
+
+ excess = balance - cap
+ if excess <= 0:
+ continue
+
+ # Keep a small buffer (1000 sats) to avoid sweeping tiny amounts
+ if excess < 1000:
+ continue
+
+ logger.info(f"Auto-sweeping {excess} satoshis for user {user_id}")
+
+ # Perform the sweep
+ result = self.wallet_manager.auto_sweep(user_id, excess)
+ if result.get('success'):
+ logger.info(f"Auto-sweep successful for user {user_id}: {result.get('tx_hash', 'simulated')}")
+ else:
+ logger.error(f"Auto-sweep failed for user {user_id}: {result.get('error')}")
+
+ except Exception as e:
+ logger.error(f"Error processing auto-sweeps: {e}")
+
+ # === Deposit Detection Job ===
+
+ def _deposit_detection_loop(self):
+ """Background loop for deposit detection."""
+ logger.info("Deposit detection job started")
+ while self._running:
+ try:
+ self._check_for_deposits()
+ except Exception as e:
+ logger.error(f"Deposit detection error: {e}")
+ self._sleep(DEPOSIT_CHECK_INTERVAL)
+
+ def _check_for_deposits(self):
+ """Check all wallet addresses for new deposits."""
+ try:
+ # Include disabled wallets so incoming deposits are still credited.
+ wallets = self.wallet_manager.get_wallets_for_deposit_monitoring()
+
+ for wallet in wallets:
+ user_id = wallet['user_id']
+ fee_address = wallet['fee_address']
+ network = wallet['network']
+
+ # Check for new deposits at the fee address
+ new_deposits = self.wallet_manager.check_address_for_deposits(
+ user_id=user_id,
+ address=fee_address,
+ network=network
+ )
+
+ for deposit in new_deposits:
+ logger.info(f"New deposit detected for user {user_id}: "
+ f"{deposit['amount_satoshis']} sats, tx: {deposit['tx_hash']}")
+
+ except Exception as e:
+ logger.error(f"Error checking for deposits: {e}")
+
+ # === Withdrawal Processing Job ===
+
+ def _withdrawal_processing_loop(self):
+ """Background loop for processing pending withdrawals."""
+ logger.info("Withdrawal processing job started")
+ while self._running:
+ try:
+ self._process_pending_withdrawals()
+ except Exception as e:
+ logger.error(f"Withdrawal processing error: {e}")
+ self._sleep(WITHDRAWAL_PROCESS_INTERVAL)
+
+ def _process_pending_withdrawals(self):
+ """Process all pending withdrawal requests."""
+ try:
+ # Get pending withdrawals (status = 'reserved')
+ pending = self.wallet_manager.get_pending_withdrawals()
+
+ for withdrawal in pending:
+ withdrawal_id = withdrawal['id']
+ user_id = withdrawal['user_id']
+ amount = withdrawal['amount_satoshis']
+ destination = withdrawal['destination_address']
+
+ logger.info(f"Processing withdrawal {withdrawal_id} for user {user_id}: "
+ f"{amount} sats to {destination}")
+
+ # Process the withdrawal
+ result = self.wallet_manager.process_withdrawal(withdrawal_id)
+
+ if result.get('success'):
+ logger.info(f"Withdrawal {withdrawal_id} completed: {result.get('tx_hash', 'simulated')}")
+ else:
+ logger.error(f"Withdrawal {withdrawal_id} failed: {result.get('error')}")
+
+ except Exception as e:
+ logger.error(f"Error processing withdrawals: {e}")
+
+ def _sleep(self, seconds):
+ """Sleep that respects the running flag and works with eventlet."""
+ import eventlet
+ # Sleep in small increments to allow quick shutdown
+ remaining = seconds
+ while remaining > 0 and self._running:
+ sleep_time = min(remaining, 5)
+ eventlet.sleep(sleep_time)
+ remaining -= sleep_time
diff --git a/src/wallet/bitcoin_service.py b/src/wallet/bitcoin_service.py
new file mode 100644
index 0000000..2977e5c
--- /dev/null
+++ b/src/wallet/bitcoin_service.py
@@ -0,0 +1,160 @@
+"""Bitcoin network operations using bit library."""
+import logging
+from typing import Optional
+
+logger = logging.getLogger(__name__)
+
+
+class BitcoinService:
+ """
+ Service for Bitcoin network operations.
+ Handles address generation, balance checking, and transactions.
+ """
+
+ def __init__(self, testnet: bool = True):
+ """
+ Initialize Bitcoin service.
+
+ Args:
+ testnet: If True, use testnet; otherwise use mainnet.
+ """
+ self.testnet = testnet
+
+ def _get_key_class(self):
+ """Get the appropriate key class based on network."""
+ from bit import Key, PrivateKeyTestnet
+ return PrivateKeyTestnet if self.testnet else Key
+
+ def generate_keypair(self) -> dict:
+ """
+ Generate new Bitcoin keypair.
+
+ Returns:
+ Dict with 'address', 'public_key', and 'private_key' (WIF format).
+ """
+ key_class = self._get_key_class()
+ key = key_class()
+ return {
+ 'address': key.address,
+ 'public_key': key.public_key.hex(),
+ 'private_key': key.to_wif() # Wallet Import Format
+ }
+
+ def get_balance(self, address: str) -> int:
+ """
+ Get balance in satoshis from blockchain.
+
+ Args:
+ address: Bitcoin address to check.
+
+ Returns:
+ Balance in satoshis, or 0 on error.
+ """
+ try:
+ from bit.network import NetworkAPI
+ if self.testnet:
+ balance = NetworkAPI.get_balance_testnet(address)
+ else:
+ balance = NetworkAPI.get_balance(address)
+ return balance
+ except Exception as e:
+ logger.error(f"Failed to fetch balance for {address}: {e}")
+ return 0
+
+ def get_unspent(self, address: str) -> list:
+ """
+ Get unspent transaction outputs for address.
+
+ Args:
+ address: Bitcoin address to check.
+
+ Returns:
+ List of UTXOs, or empty list on error.
+ """
+ try:
+ from bit.network import NetworkAPI
+ if self.testnet:
+ return NetworkAPI.get_unspent_testnet(address)
+ return NetworkAPI.get_unspent(address)
+ except Exception as e:
+ logger.error(f"Failed to fetch UTXOs for {address}: {e}")
+ return []
+
+ def send_transaction(self, private_key_wif: str, to_address: str,
+ amount_satoshis: int) -> Optional[str]:
+ """
+ Send BTC transaction.
+
+ Args:
+ private_key_wif: Sender's private key in WIF format.
+ to_address: Recipient's Bitcoin address.
+ amount_satoshis: Amount to send in satoshis.
+
+ Returns:
+ Transaction hash on success, None on failure.
+ """
+ try:
+ key_class = self._get_key_class()
+ key = key_class(private_key_wif)
+ tx_hash = key.send([(to_address, amount_satoshis, 'satoshi')])
+ logger.info(f"Sent {amount_satoshis} satoshis to {to_address}: {tx_hash}")
+ return tx_hash
+ except Exception as e:
+ logger.error(f"Failed to send transaction: {e}")
+ return None
+
+ def validate_address(self, address: str) -> bool:
+ """
+ Validate Bitcoin address format and checksum.
+
+ Args:
+ address: Address string to validate.
+
+ Returns:
+ True if valid for this network, False otherwise.
+ """
+ if not address or not isinstance(address, str):
+ return False
+
+ try:
+ address = address.strip()
+
+ if self.testnet:
+ # Testnet addresses: m/n (legacy), 2 (P2SH), tb1 (bech32)
+ if address.startswith(('m', 'n', '2')):
+ # Base58 testnet address
+ return 26 <= len(address) <= 35
+ elif address.startswith('tb1'):
+ # Bech32 testnet address
+ return 42 <= len(address) <= 62
+ return False
+ else:
+ # Mainnet addresses: 1 (legacy), 3 (P2SH), bc1 (bech32)
+ if address.startswith(('1', '3')):
+ # Base58 mainnet address
+ return 26 <= len(address) <= 35
+ elif address.startswith('bc1'):
+ # Bech32 mainnet address
+ return 42 <= len(address) <= 62
+ return False
+
+ except Exception:
+ return False
+
+ def estimate_fee(self, num_inputs: int = 1, num_outputs: int = 2) -> int:
+ """
+ Estimate transaction fee in satoshis.
+
+ Args:
+ num_inputs: Number of transaction inputs.
+ num_outputs: Number of transaction outputs.
+
+ Returns:
+ Estimated fee in satoshis.
+ """
+ # Rough estimate: ~148 bytes per input, ~34 bytes per output, ~10 overhead
+ tx_size = (num_inputs * 148) + (num_outputs * 34) + 10
+ # Use a reasonable fee rate (satoshis per byte)
+ # Could be made dynamic based on network conditions
+ fee_rate = 10
+ return tx_size * fee_rate
diff --git a/src/wallet/encryption.py b/src/wallet/encryption.py
new file mode 100644
index 0000000..d755bb1
--- /dev/null
+++ b/src/wallet/encryption.py
@@ -0,0 +1,87 @@
+"""Symmetric encryption for wallet keys with versioning."""
+import base64
+import hashlib
+from cryptography.fernet import Fernet
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
+
+
+class KeyEncryption:
+ """
+ Versioned encryption for wallet keys.
+ Supports key rotation by maintaining multiple key versions.
+ """
+ CURRENT_VERSION = 1
+
+ def __init__(self, master_keys: dict):
+ """
+ Initialize with versioned master keys.
+
+ Args:
+ master_keys: Dict mapping version numbers to master key strings.
+ Example: {1: 'key_v1', 2: 'key_v2'}
+ """
+ self.fernets = {}
+ for version, key in master_keys.items():
+ # Derive unique salt per version (deterministic but version-specific)
+ salt = hashlib.sha256(f"brighter_wallet_v{version}".encode()).digest()[:16]
+ kdf = PBKDF2HMAC(
+ algorithm=hashes.SHA256(),
+ length=32,
+ salt=salt,
+ iterations=100000,
+ )
+ derived_key = base64.urlsafe_b64encode(kdf.derive(key.encode()))
+ self.fernets[int(version)] = Fernet(derived_key)
+
+ def encrypt(self, data: str, version: int = None) -> str:
+ """
+ Encrypt data with current (or specified) key version.
+
+ Args:
+ data: Plaintext string to encrypt.
+ version: Key version to use (defaults to CURRENT_VERSION).
+
+ Returns:
+ Encrypted data as string.
+ """
+ version = version or self.CURRENT_VERSION
+ if version not in self.fernets:
+ raise ValueError(f"Unknown encryption key version: {version}")
+ return self.fernets[version].encrypt(data.encode()).decode()
+
+ def decrypt(self, encrypted_data: str, version: int) -> str:
+ """
+ Decrypt data using the specified key version.
+
+ Args:
+ encrypted_data: Encrypted string to decrypt.
+ version: Key version that was used to encrypt.
+
+ Returns:
+ Decrypted plaintext string.
+
+ Raises:
+ ValueError: If the key version is unknown.
+ """
+ if version not in self.fernets:
+ raise ValueError(f"Unknown encryption key version: {version}")
+ return self.fernets[version].decrypt(encrypted_data.encode()).decode()
+
+ def re_encrypt(self, encrypted_data: str, old_version: int,
+ new_version: int = None) -> tuple[str, int]:
+ """
+ Re-encrypt data from old version to new version.
+ Used during key rotation.
+
+ Args:
+ encrypted_data: Data encrypted with old_version.
+ old_version: Version the data is currently encrypted with.
+ new_version: Target version (defaults to CURRENT_VERSION).
+
+ Returns:
+ Tuple of (new_encrypted_data, new_version).
+ """
+ new_version = new_version or self.CURRENT_VERSION
+ plaintext = self.decrypt(encrypted_data, old_version)
+ return self.encrypt(plaintext, new_version), new_version
diff --git a/src/wallet/wallet_manager.py b/src/wallet/wallet_manager.py
new file mode 100644
index 0000000..3000250
--- /dev/null
+++ b/src/wallet/wallet_manager.py
@@ -0,0 +1,1005 @@
+"""Main wallet management operations using credits ledger with balance cap."""
+import logging
+import uuid
+from typing import Optional
+
+from .bitcoin_service import BitcoinService
+from .encryption import KeyEncryption
+
+logger = logging.getLogger(__name__)
+
+# Balance cap in satoshis (~$50 at typical BTC prices)
+# 0.001 BTC ~ $50-100 range depending on BTC price
+BALANCE_CAP_SATOSHIS = 100000 # 0.001 BTC
+
+
+class WalletManager:
+ """
+ Manages user Bitcoin wallets and internal credits ledger.
+
+ Two-Address Custody Model:
+ 1. Fee Address: We store private key, used for strategy fees (max ~$50)
+ 2. Sweep Address: We only store public address, user controls private key
+
+ Credits Ledger: Internal balance for instant, reliable strategy fees
+ """
+
+ # Balance cap as class attribute so background jobs can access it
+ BALANCE_CAP_SATOSHIS = BALANCE_CAP_SATOSHIS
+
+ def __init__(self, database, encryption_keys: dict, default_network: str = 'testnet'):
+ """
+ Initialize wallet manager.
+
+ Args:
+ database: Database instance for persistence.
+ encryption_keys: Dict mapping version numbers to encryption keys.
+ default_network: 'testnet' or 'mainnet' for new wallets.
+ """
+ self.db = database
+ self.encryption = KeyEncryption(encryption_keys)
+ self.default_network = default_network
+
+ # Ensure wallet tables exist
+ self._ensure_tables()
+
+ def _ensure_tables(self):
+ """Create wallet-related tables if they don't exist."""
+ # Wallets table - stores Fee Address (with keys) and Sweep Address (public only)
+ self.db.execute_sql("""
+ CREATE TABLE IF NOT EXISTS wallets (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL UNIQUE,
+ btc_address TEXT NOT NULL,
+ public_key_encrypted TEXT NOT NULL,
+ private_key_encrypted TEXT NOT NULL,
+ sweep_address TEXT,
+ encryption_key_version INTEGER NOT NULL DEFAULT 1,
+ network TEXT NOT NULL DEFAULT 'testnet' CHECK (network IN ('testnet', 'mainnet')),
+ is_disabled INTEGER NOT NULL DEFAULT 0,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (user_id) REFERENCES users(id)
+ )
+ """)
+
+ # Migration: Add sweep_address column if missing (for existing wallets)
+ try:
+ self.db.execute_sql("ALTER TABLE wallets ADD COLUMN sweep_address TEXT")
+ except Exception:
+ pass # Column already exists
+
+ # Credits ledger - source of truth for spendable balance
+ self.db.execute_sql("""
+ CREATE TABLE IF NOT EXISTS credits_ledger (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ amount_satoshis INTEGER NOT NULL CHECK (amount_satoshis != 0),
+ tx_type TEXT NOT NULL CHECK (tx_type IN ('deposit', 'withdrawal', 'withdrawal_reversal', 'fee_paid', 'fee_received', 'admin_credit')),
+ reference_id TEXT,
+ counterparty_user_id INTEGER,
+ idempotency_key TEXT NOT NULL UNIQUE,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (user_id) REFERENCES users(id),
+ FOREIGN KEY (counterparty_user_id) REFERENCES users(id)
+ )
+ """)
+
+ # Pending fees during strategy run
+ self.db.execute_sql("""
+ CREATE TABLE IF NOT EXISTS pending_strategy_fees (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ strategy_run_id TEXT NOT NULL UNIQUE,
+ user_id INTEGER NOT NULL,
+ creator_user_id INTEGER NOT NULL,
+ fee_percent INTEGER NOT NULL DEFAULT 10,
+ accumulated_satoshis INTEGER NOT NULL DEFAULT 0,
+ trade_count INTEGER NOT NULL DEFAULT 0,
+ started_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (user_id) REFERENCES users(id),
+ FOREIGN KEY (creator_user_id) REFERENCES users(id)
+ )
+ """)
+
+ # Migration: Add fee_percent column if missing
+ try:
+ self.db.execute_sql("ALTER TABLE pending_strategy_fees ADD COLUMN fee_percent INTEGER NOT NULL DEFAULT 10")
+ except Exception:
+ pass # Column already exists
+
+ # Withdrawal requests (async processing)
+ self.db.execute_sql("""
+ CREATE TABLE IF NOT EXISTS withdrawal_requests (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ amount_satoshis INTEGER NOT NULL CHECK (amount_satoshis > 0),
+ destination_address TEXT NOT NULL,
+ status TEXT NOT NULL DEFAULT 'pending' CHECK (status IN ('pending', 'reserved', 'processing', 'completed', 'failed', 'reversed')),
+ btc_txhash TEXT,
+ error_message TEXT,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ processed_at TIMESTAMP,
+ FOREIGN KEY (user_id) REFERENCES users(id)
+ )
+ """)
+
+ # Deposit tracking (for dedupe)
+ self.db.execute_sql("""
+ CREATE TABLE IF NOT EXISTS deposits (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ network TEXT NOT NULL,
+ tx_hash TEXT NOT NULL,
+ vout INTEGER NOT NULL,
+ amount_satoshis INTEGER NOT NULL,
+ confirmations INTEGER NOT NULL DEFAULT 0,
+ credited INTEGER NOT NULL DEFAULT 0,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (user_id) REFERENCES users(id)
+ )
+ """)
+
+ # Create indexes for performance
+ self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_wallets_user_id ON wallets(user_id)")
+ self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_ledger_user_id ON credits_ledger(user_id)")
+ self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_ledger_idempotency ON credits_ledger(idempotency_key)")
+ self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_pending_fees_run ON pending_strategy_fees(strategy_run_id)")
+ self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_withdrawals_status ON withdrawal_requests(status)")
+ self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_deposits_credited ON deposits(credited)")
+ # Unique constraint on deposits (network, tx_hash, vout)
+ self.db.execute_sql("CREATE UNIQUE INDEX IF NOT EXISTS idx_deposits_unique ON deposits(network, tx_hash, vout)")
+
+ logger.info("Wallet tables ensured")
+
+ def create_wallet(self, user_id: int, user_sweep_address: str = None) -> dict:
+ """
+ Generate and store new wallet for user (two-address model).
+
+ Creates:
+ 1. Fee Address - We store private key, used for strategy fees
+ 2. Sweep Address - Either user-provided or we generate one (we don't store private key)
+
+ Args:
+ user_id: User ID to create wallet for.
+ user_sweep_address: Optional user-provided sweep address (e.g., from their own wallet).
+
+ Returns:
+ Dict with success status and both addresses' details.
+ The sweep address private key is only returned if we generated it.
+ """
+ existing = self.get_wallet(user_id)
+ if existing:
+ return {'success': False, 'error': 'Wallet already exists'}
+
+ bitcoin = BitcoinService(testnet=(self.default_network == 'testnet'))
+
+ # Validate user-provided sweep address if given
+ if user_sweep_address:
+ if not bitcoin.validate_address(user_sweep_address):
+ return {'success': False, 'error': 'Invalid Bitcoin address for sweep address'}
+
+ # Generate Fee Address (we store and manage)
+ fee_keypair = bitcoin.generate_keypair()
+
+ # Sweep Address: use user-provided or generate new
+ if user_sweep_address:
+ sweep_address = user_sweep_address
+ sweep_keypair = None # User controls this, we don't have the keys
+ else:
+ # Generate Sweep Address (user controls - we DON'T store private key)
+ sweep_keypair = bitcoin.generate_keypair()
+ sweep_address = sweep_keypair['address']
+
+ version = KeyEncryption.CURRENT_VERSION
+ encrypted_public = self.encryption.encrypt(fee_keypair['public_key'], version)
+ encrypted_private = self.encryption.encrypt(fee_keypair['private_key'], version)
+
+ self.db.execute_sql(
+ """INSERT INTO wallets (user_id, btc_address, public_key_encrypted,
+ private_key_encrypted, sweep_address, encryption_key_version, network)
+ VALUES (?, ?, ?, ?, ?, ?, ?)""",
+ (user_id, fee_keypair['address'], encrypted_public, encrypted_private,
+ sweep_address, version, self.default_network)
+ )
+
+ logger.info(f"Created two-address wallet for user {user_id}: "
+ f"fee={fee_keypair['address']}, sweep={sweep_address} "
+ f"(user_provided={bool(user_sweep_address)})")
+
+ result = {
+ 'success': True,
+ # Fee Address (we manage)
+ 'fee_address': fee_keypair['address'],
+ 'fee_public_key': fee_keypair['public_key'],
+ 'fee_private_key': fee_keypair['private_key'],
+ # Sweep Address
+ 'sweep_address': sweep_address,
+ 'network': self.default_network,
+ }
+
+ # Only include sweep private key if we generated it
+ if sweep_keypair:
+ result['sweep_public_key'] = sweep_keypair['public_key']
+ result['sweep_private_key'] = sweep_keypair['private_key']
+ result['warning'] = 'SAVE YOUR SWEEP ADDRESS PRIVATE KEY. It is NOT stored and CANNOT be recovered.'
+ else:
+ result['message'] = 'Using your provided sweep address. Excess funds will be sent there.'
+
+ return result
+
+ def get_wallet(self, user_id: int) -> Optional[dict]:
+ """
+ Get wallet info (without decrypted private keys).
+
+ Args:
+ user_id: User ID to look up.
+
+ Returns:
+ Dict with wallet info, or None if no wallet exists.
+ """
+ result = self.db.execute_sql(
+ "SELECT btc_address, network, is_disabled, created_at, sweep_address FROM wallets WHERE user_id = ?",
+ (user_id,), fetch_one=True
+ )
+ if not result:
+ return None
+ return {
+ 'fee_address': result[0],
+ 'address': result[0], # Backwards compatibility
+ 'network': result[1],
+ 'is_disabled': bool(result[2]),
+ 'created_at': result[3],
+ 'sweep_address': result[4]
+ }
+
+ def get_wallet_keys(self, user_id: int) -> Optional[dict]:
+ """
+ Get fee address keys (for "View Keys" feature).
+ Note: Sweep address private key is NOT stored.
+
+ Args:
+ user_id: User ID to look up.
+
+ Returns:
+ Dict with fee address and decrypted private key, or None.
+ """
+ result = self.db.execute_sql(
+ """SELECT btc_address, public_key_encrypted, private_key_encrypted,
+ encryption_key_version, network FROM wallets WHERE user_id = ?""",
+ (user_id,), fetch_one=True
+ )
+ if not result:
+ return None
+
+ address, encrypted_public, encrypted_private, version, network = result
+
+ try:
+ private_key = self.encryption.decrypt(encrypted_private, version)
+ return {
+ 'fee_address': address,
+ 'fee_private_key': private_key,
+ 'network': network
+ }
+ except Exception as e:
+ logger.error(f"Failed to decrypt keys for user {user_id}: {e}")
+ return None
+
+ def get_credits_balance(self, user_id: int) -> int:
+ """
+ Get user's spendable credits balance (sum of ledger entries).
+
+ Args:
+ user_id: User ID to look up.
+
+ Returns:
+ Balance in satoshis.
+ """
+ result = self.db.execute_sql(
+ "SELECT COALESCE(SUM(amount_satoshis), 0) FROM credits_ledger WHERE user_id = ?",
+ (user_id,), fetch_one=True
+ )
+ return result[0] if result else 0
+
+ def check_balance_cap(self, user_id: int) -> dict:
+ """
+ Check if balance exceeds cap and disable/enable wallet accordingly.
+
+ Args:
+ user_id: User ID to check.
+
+ Returns:
+ Dict with 'over_cap' status and current balance.
+ """
+ balance = self.get_credits_balance(user_id)
+ if balance > BALANCE_CAP_SATOSHIS:
+ self.db.execute_sql(
+ "UPDATE wallets SET is_disabled = 1 WHERE user_id = ?", (user_id,)
+ )
+ logger.warning(f"User {user_id} wallet disabled - over balance cap "
+ f"({balance} > {BALANCE_CAP_SATOSHIS})")
+ return {'over_cap': True, 'balance': balance, 'cap': BALANCE_CAP_SATOSHIS}
+ else:
+ # Re-enable if under cap
+ self.db.execute_sql(
+ "UPDATE wallets SET is_disabled = 0 WHERE user_id = ?", (user_id,)
+ )
+ return {'over_cap': False, 'balance': balance}
+
+ # === Strategy Fee Accumulation ===
+
+ def start_fee_accumulation(self, strategy_run_id: str, user_id: int,
+ creator_user_id: int, fee_percent: int = 10,
+ estimated_trades: int = 10) -> dict:
+ """
+ Start fee accumulation for a strategy run.
+ Checks if user has enough credits for estimated trades.
+
+ Args:
+ strategy_run_id: Unique identifier for this strategy run.
+ user_id: User running the strategy.
+ creator_user_id: Strategy creator who will receive fees.
+ fee_percent: Percentage of exchange commission to charge (1-100).
+ estimated_trades: Estimated number of trades for credit check.
+
+ Returns:
+ Dict with success status.
+ """
+ # If user is running their own strategy, no fees apply
+ if user_id == creator_user_id:
+ return {'success': True, 'message': 'No fees for own strategy'}
+
+ # Validate fee_percent
+ fee_percent = max(1, min(100, int(fee_percent)))
+
+ # Check wallet is enabled
+ wallet = self.get_wallet(user_id)
+ if wallet and wallet['is_disabled']:
+ return {'success': False, 'error': 'Wallet disabled - withdraw funds to re-enable'}
+
+ # Estimate fee per trade (rough estimate, actual depends on trade size)
+ # Scale estimate by fee_percent (higher fee = more credits needed)
+ base_estimate = 1000 # ~1000 sats per trade at 100%
+ min_credits_needed = (base_estimate * fee_percent // 100) * estimated_trades
+
+ balance = self.get_credits_balance(user_id)
+ if balance < min_credits_needed:
+ return {
+ 'success': False,
+ 'error': 'Insufficient credits for strategy fees',
+ 'available': balance,
+ 'recommended_minimum': min_credits_needed
+ }
+
+ # Create pending fees record with fee_percent
+ self.db.execute_sql(
+ """INSERT OR REPLACE INTO pending_strategy_fees
+ (strategy_run_id, user_id, creator_user_id, fee_percent, accumulated_satoshis, trade_count)
+ VALUES (?, ?, ?, ?, 0, 0)""",
+ (strategy_run_id, user_id, creator_user_id, fee_percent)
+ )
+
+ logger.info(f"Started fee accumulation for run {strategy_run_id}: "
+ f"user={user_id}, creator={creator_user_id}, fee_percent={fee_percent}%")
+
+ return {'success': True, 'strategy_run_id': strategy_run_id}
+
+ def accumulate_trade_fee(self, strategy_run_id: str, exchange_fee_satoshis: int,
+ is_profitable: bool) -> dict:
+ """
+ Accumulate fee for a single trade (called after each trade).
+ Only charges on profitable trades.
+
+ Args:
+ strategy_run_id: Strategy run identifier.
+ exchange_fee_satoshis: Exchange commission in satoshis.
+ is_profitable: Whether the trade was profitable.
+
+ Returns:
+ Dict with fee charged amount.
+ """
+ if not is_profitable:
+ return {'success': True, 'fee_charged': 0, 'reason': 'unprofitable_trade'}
+
+ # Get the fee_percent for this strategy run
+ result = self.db.execute_sql(
+ "SELECT fee_percent FROM pending_strategy_fees WHERE strategy_run_id = ?",
+ (strategy_run_id,), fetch_one=True
+ )
+
+ if not result:
+ return {'success': False, 'fee_charged': 0, 'reason': 'no_pending_fees_record'}
+
+ fee_percent = result[0]
+
+ # Fee is fee_percent% of exchange commission
+ strategy_fee = (exchange_fee_satoshis * fee_percent) // 100
+
+ if strategy_fee <= 0:
+ return {'success': True, 'fee_charged': 0, 'reason': 'fee_too_small'}
+
+ # Add to accumulated fees
+ self.db.execute_sql(
+ """UPDATE pending_strategy_fees
+ SET accumulated_satoshis = accumulated_satoshis + ?,
+ trade_count = trade_count + 1
+ WHERE strategy_run_id = ?""",
+ (strategy_fee, strategy_run_id)
+ )
+
+ logger.debug(f"Accumulated {strategy_fee} sats fee ({fee_percent}% of {exchange_fee_satoshis}) for run {strategy_run_id}")
+
+ return {'success': True, 'fee_charged': strategy_fee}
+
+ def settle_accumulated_fees(self, strategy_run_id: str) -> dict:
+ """
+ Settle all accumulated fees when strategy stops.
+ Transfers total from user to creator.
+
+ Args:
+ strategy_run_id: Strategy run identifier.
+
+ Returns:
+ Dict with settlement details.
+ """
+ # Get pending fees
+ pending = self.db.execute_sql(
+ """SELECT user_id, creator_user_id, accumulated_satoshis, trade_count
+ FROM pending_strategy_fees WHERE strategy_run_id = ?""",
+ (strategy_run_id,), fetch_one=True
+ )
+
+ if not pending:
+ return {'success': True, 'message': 'No pending fees found', 'settled': 0}
+
+ user_id, creator_user_id, total_fees, trade_count = pending
+
+ if total_fees <= 0:
+ # Clean up record, no fees to settle
+ self.db.execute_sql(
+ "DELETE FROM pending_strategy_fees WHERE strategy_run_id = ?",
+ (strategy_run_id,)
+ )
+ return {'success': True, 'settled': 0, 'trades': trade_count}
+
+ # Check user still has sufficient balance
+ balance = self.get_credits_balance(user_id)
+ if balance < total_fees:
+ # Settle what we can, log the shortfall
+ logger.warning(f"User {user_id} has insufficient balance for full fee settlement. "
+ f"Owed: {total_fees}, Available: {balance}")
+ total_fees = max(0, balance) # Settle partial
+
+ if total_fees > 0:
+ idempotency_key = f"fee_settle_{strategy_run_id}"
+
+ # Check if already settled (idempotency)
+ existing = self.db.execute_sql(
+ "SELECT id FROM credits_ledger WHERE idempotency_key = ?",
+ (idempotency_key,), fetch_one=True
+ )
+ if existing:
+ logger.info(f"Fees already settled for run {strategy_run_id}")
+ else:
+ # Atomic fee settlement - both debit and credit in one transaction
+ # Prevents partial state if one insert fails
+ try:
+ self.db.execute_in_transaction([
+ # Debit user
+ ("""INSERT INTO credits_ledger
+ (user_id, amount_satoshis, tx_type, reference_id,
+ counterparty_user_id, idempotency_key)
+ VALUES (?, ?, 'fee_paid', ?, ?, ?)""",
+ (user_id, -total_fees, strategy_run_id, creator_user_id, idempotency_key)),
+ # Credit creator
+ ("""INSERT INTO credits_ledger
+ (user_id, amount_satoshis, tx_type, reference_id,
+ counterparty_user_id, idempotency_key)
+ VALUES (?, ?, 'fee_received', ?, ?, ?)""",
+ (creator_user_id, total_fees, strategy_run_id, user_id,
+ f"{idempotency_key}_creator"))
+ ])
+ logger.info(f"Settled {total_fees} sats from user {user_id} to creator {creator_user_id}")
+ except Exception as e:
+ logger.error(f"Failed to settle fees for run {strategy_run_id}: {e}")
+ # Don't delete pending record - leave for retry
+ return {'success': False, 'error': f'Fee settlement failed: {e}'}
+
+ # Clean up pending record
+ self.db.execute_sql(
+ "DELETE FROM pending_strategy_fees WHERE strategy_run_id = ?",
+ (strategy_run_id,)
+ )
+
+ # Check balance cap after settlement
+ self.check_balance_cap(user_id)
+
+ return {'success': True, 'settled': total_fees, 'trades': trade_count}
+
+ def get_pending_fees(self, strategy_run_id: str) -> dict:
+ """
+ Get current accumulated fees for a running strategy.
+
+ Args:
+ strategy_run_id: Strategy run identifier.
+
+ Returns:
+ Dict with accumulated fees and trade count.
+ """
+ result = self.db.execute_sql(
+ """SELECT accumulated_satoshis, trade_count
+ FROM pending_strategy_fees WHERE strategy_run_id = ?""",
+ (strategy_run_id,), fetch_one=True
+ )
+ if not result:
+ return {'accumulated_satoshis': 0, 'trade_count': 0}
+ return {'accumulated_satoshis': result[0], 'trade_count': result[1]}
+
+ def cancel_fee_accumulation(self, strategy_run_id: str) -> dict:
+ """
+ Cancel fee accumulation for a strategy run without charging fees.
+
+ Used when strategy startup fails after a pending fee record has been created.
+
+ Args:
+ strategy_run_id: Strategy run identifier.
+
+ Returns:
+ Dict with cancellation status.
+ """
+ try:
+ self.db.execute_sql(
+ "DELETE FROM pending_strategy_fees WHERE strategy_run_id = ?",
+ (strategy_run_id,)
+ )
+ return {'success': True}
+ except Exception as e:
+ logger.error(f"Failed to cancel fee accumulation for run {strategy_run_id}: {e}")
+ return {'success': False, 'error': str(e)}
+
+ # === Deposits and Withdrawals ===
+
+ def credit_deposit(self, user_id: int, amount_satoshis: int,
+ network: str, tx_hash: str, vout: int) -> dict:
+ """
+ Credit user's ledger when BTC deposit is confirmed.
+
+ Args:
+ user_id: User ID to credit.
+ amount_satoshis: Amount to credit.
+ network: Bitcoin network ('testnet' or 'mainnet').
+ tx_hash: Transaction hash.
+ vout: Output index within transaction.
+
+ Returns:
+ Dict with success status.
+ """
+ idempotency_key = f"deposit_{network}_{tx_hash}_{vout}"
+
+ existing = self.db.execute_sql(
+ "SELECT id FROM credits_ledger WHERE idempotency_key = ?",
+ (idempotency_key,), fetch_one=True
+ )
+ if existing:
+ return {'success': True, 'already_processed': True}
+
+ self.db.execute_sql(
+ """INSERT INTO credits_ledger
+ (user_id, amount_satoshis, tx_type, reference_id, idempotency_key)
+ VALUES (?, ?, 'deposit', ?, ?)""",
+ (user_id, amount_satoshis, tx_hash, idempotency_key)
+ )
+
+ logger.info(f"Credited {amount_satoshis} sats to user {user_id} from deposit {tx_hash}")
+
+ # Check balance cap after deposit
+ cap_check = self.check_balance_cap(user_id)
+ if cap_check['over_cap']:
+ logger.warning(f"User {user_id} wallet disabled - over balance cap after deposit")
+
+ return {'success': True, 'over_cap': cap_check.get('over_cap', False)}
+
+ def request_withdrawal(self, user_id: int, amount_satoshis: int,
+ destination_address: str) -> dict:
+ """
+ Queue a withdrawal request with immediate reservation.
+
+ Args:
+ user_id: User ID requesting withdrawal.
+ amount_satoshis: Amount to withdraw.
+ destination_address: Bitcoin address to send to.
+
+ Returns:
+ Dict with success status.
+ """
+ wallet = self.get_wallet(user_id)
+ if not wallet:
+ return {'success': False, 'error': 'No wallet'}
+
+ # Validate address for the wallet's network
+ bitcoin = BitcoinService(testnet=(wallet['network'] == 'testnet'))
+ if not bitcoin.validate_address(destination_address):
+ return {'success': False, 'error': 'Invalid Bitcoin address for this network'}
+
+ balance = self.get_credits_balance(user_id)
+ if balance < amount_satoshis:
+ return {'success': False, 'error': 'Insufficient balance', 'available': balance}
+
+ # Reserve immediately (debit from ledger)
+ idempotency_key = f"withdrawal_reserve_{uuid.uuid4().hex}"
+ try:
+ self.db.execute_in_transaction([
+ ("""INSERT INTO credits_ledger
+ (user_id, amount_satoshis, tx_type, reference_id, idempotency_key)
+ VALUES (?, ?, 'withdrawal', ?, ?)""",
+ (user_id, -amount_satoshis, destination_address, idempotency_key)),
+ ("""INSERT INTO withdrawal_requests
+ (user_id, amount_satoshis, destination_address, status)
+ VALUES (?, ?, ?, 'reserved')""",
+ (user_id, amount_satoshis, destination_address))
+ ])
+ except Exception as e:
+ logger.error(f"Failed to reserve withdrawal for user {user_id}: {e}")
+ return {'success': False, 'error': 'Failed to queue withdrawal request'}
+
+ logger.info(f"Withdrawal requested: {amount_satoshis} sats to {destination_address} for user {user_id}")
+
+ # Re-check balance cap (may re-enable wallet after withdrawal)
+ self.check_balance_cap(user_id)
+
+ return {'success': True, 'message': 'Withdrawal reserved and queued'}
+
+ def admin_credit(self, user_id: int, amount_satoshis: int, reason: str) -> dict:
+ """
+ Admin function to manually credit a user (for POC testing).
+
+ Args:
+ user_id: User ID to credit.
+ amount_satoshis: Amount to credit.
+ reason: Reason for the credit.
+
+ Returns:
+ Dict with success status.
+ """
+ idempotency_key = f"admin_{user_id}_{uuid.uuid4().hex}"
+ self.db.execute_sql(
+ """INSERT INTO credits_ledger
+ (user_id, amount_satoshis, tx_type, reference_id, idempotency_key)
+ VALUES (?, ?, 'admin_credit', ?, ?)""",
+ (user_id, amount_satoshis, reason, idempotency_key)
+ )
+
+ logger.info(f"Admin credited {amount_satoshis} sats to user {user_id}: {reason}")
+
+ self.check_balance_cap(user_id)
+ return {'success': True}
+
+ def get_transaction_history(self, user_id: int, limit: int = 20) -> list:
+ """
+ Get user's recent ledger transactions.
+
+ Args:
+ user_id: User ID to look up.
+ limit: Maximum number of transactions to return.
+
+ Returns:
+ List of transaction dicts.
+ """
+ results = self.db.execute_sql(
+ """SELECT tx_type, amount_satoshis, reference_id, created_at
+ FROM credits_ledger WHERE user_id = ?
+ ORDER BY created_at DESC LIMIT ?""",
+ (user_id, limit), fetch_all=True
+ )
+ return [{'type': r[0], 'amount': r[1], 'reference': r[2], 'date': r[3]}
+ for r in (results or [])]
+
+ # === Background Job Support Methods ===
+
+ def get_wallets_over_cap(self) -> list:
+ """
+ Get all wallets with balance exceeding the cap (for auto-sweep).
+
+ Returns:
+ List of dicts with user_id, balance, and sweep_address.
+ """
+ # Get all users with wallets and compute their balances
+ results = self.db.execute_sql(
+ """SELECT w.user_id, w.sweep_address,
+ COALESCE(SUM(cl.amount_satoshis), 0) as balance
+ FROM wallets w
+ LEFT JOIN credits_ledger cl ON w.user_id = cl.user_id
+ WHERE w.sweep_address IS NOT NULL
+ GROUP BY w.user_id
+ HAVING balance > ?""",
+ (BALANCE_CAP_SATOSHIS,), fetch_all=True
+ )
+ return [{'user_id': r[0], 'sweep_address': r[1], 'balance': r[2]}
+ for r in (results or [])]
+
+ def auto_sweep(self, user_id: int, amount_satoshis: int) -> dict:
+ """
+ Sweep excess funds to user's sweep address.
+
+ Args:
+ user_id: User ID to sweep for.
+ amount_satoshis: Amount to sweep.
+
+ Returns:
+ Dict with success status and tx_hash.
+ """
+ wallet = self.get_wallet(user_id)
+ if not wallet:
+ return {'success': False, 'error': 'No wallet'}
+
+ sweep_address = wallet.get('sweep_address')
+ if not sweep_address:
+ return {'success': False, 'error': 'No sweep address configured'}
+
+ # Get fee address private key
+ keys = self.get_wallet_keys(user_id)
+ if not keys:
+ return {'success': False, 'error': 'Could not retrieve wallet keys'}
+
+ bitcoin = BitcoinService(testnet=(wallet['network'] == 'testnet'))
+
+ try:
+ # Execute the sweep transaction
+ tx_hash = bitcoin.send_transaction(
+ private_key_wif=keys['fee_private_key'],
+ to_address=sweep_address,
+ amount_satoshis=amount_satoshis
+ )
+ if not tx_hash:
+ return {'success': False, 'error': 'Sweep transaction failed: no tx hash returned'}
+
+ # Record the sweep as a withdrawal in the ledger
+ idempotency_key = f"auto_sweep_{user_id}_{tx_hash}"
+ self.db.execute_sql(
+ """INSERT INTO credits_ledger
+ (user_id, amount_satoshis, tx_type, reference_id, idempotency_key)
+ VALUES (?, ?, 'withdrawal', ?, ?)""",
+ (user_id, -amount_satoshis, f"auto_sweep:{tx_hash}", idempotency_key)
+ )
+
+ logger.info(f"Auto-swept {amount_satoshis} sats for user {user_id} to {sweep_address}: {tx_hash}")
+
+ # Re-check balance cap
+ self.check_balance_cap(user_id)
+
+ return {'success': True, 'tx_hash': tx_hash}
+
+ except Exception as e:
+ logger.error(f"Auto-sweep failed for user {user_id}: {e}")
+ return {'success': False, 'error': str(e)}
+
+ def get_all_active_wallets(self) -> list:
+ """
+ Get all active (non-disabled) wallets for deposit checking.
+
+ Returns:
+ List of dicts with user_id, fee_address, and network.
+ """
+ results = self.db.execute_sql(
+ """SELECT user_id, btc_address, network
+ FROM wallets WHERE is_disabled = 0""",
+ fetch_all=True
+ )
+ return [{'user_id': r[0], 'fee_address': r[1], 'network': r[2]}
+ for r in (results or [])]
+
+ def get_wallets_for_deposit_monitoring(self) -> list:
+ """
+ Get all wallets for deposit checking, including disabled wallets.
+
+ Disabled wallets must still receive deposit credits so users can move
+ back under the cap via sweeping/withdrawal flows.
+
+ Returns:
+ List of dicts with user_id, fee_address, and network.
+ """
+ results = self.db.execute_sql(
+ """SELECT user_id, btc_address, network
+ FROM wallets""",
+ fetch_all=True
+ )
+ return [{'user_id': r[0], 'fee_address': r[1], 'network': r[2]}
+ for r in (results or [])]
+
+ def check_address_for_deposits(self, user_id: int, address: str, network: str) -> list:
+ """
+ Check an address for new deposits and credit the user's ledger.
+
+ Args:
+ user_id: User ID who owns the address.
+ address: Bitcoin address to check.
+ network: Bitcoin network ('testnet' or 'mainnet').
+
+ Returns:
+ List of new deposits found and credited.
+ """
+ bitcoin = BitcoinService(testnet=(network == 'testnet'))
+ new_deposits = []
+
+ try:
+ # Get unspent outputs (UTXOs) for the address
+ utxos = bitcoin.get_unspent(address)
+
+ for utxo in utxos:
+ tx_hash = utxo.txid
+ vout = utxo.txindex
+ amount_satoshis = utxo.amount # bit library returns in satoshis
+
+ # Check if we've already processed this deposit
+ existing = self.db.execute_sql(
+ """SELECT id, credited FROM deposits
+ WHERE network = ? AND tx_hash = ? AND vout = ?""",
+ (network, tx_hash, vout), fetch_one=True
+ )
+
+ if existing and int(existing[1]) == 1:
+ continue # Already processed
+
+ # Ensure tracking row exists. Keep credited=0 until ledger credit succeeds.
+ if not existing:
+ self.db.execute_sql(
+ """INSERT OR IGNORE INTO deposits
+ (user_id, network, tx_hash, vout, amount_satoshis, credited)
+ VALUES (?, ?, ?, ?, ?, 0)""",
+ (user_id, network, tx_hash, vout, amount_satoshis)
+ )
+
+ # Credit the user's ledger
+ result = self.credit_deposit(user_id, amount_satoshis, network, tx_hash, vout)
+
+ if result.get('success'):
+ self.db.execute_sql(
+ """UPDATE deposits
+ SET credited = 1
+ WHERE network = ? AND tx_hash = ? AND vout = ?""",
+ (network, tx_hash, vout)
+ )
+
+ if not result.get('already_processed'):
+ new_deposits.append({
+ 'tx_hash': tx_hash,
+ 'vout': vout,
+ 'amount_satoshis': amount_satoshis
+ })
+
+ except Exception as e:
+ logger.error(f"Error checking deposits for {address}: {e}")
+
+ return new_deposits
+
+ def get_pending_withdrawals(self) -> list:
+ """
+ Get all pending withdrawal requests (status = 'reserved').
+
+ Returns:
+ List of pending withdrawal dicts.
+ """
+ results = self.db.execute_sql(
+ """SELECT id, user_id, amount_satoshis, destination_address
+ FROM withdrawal_requests
+ WHERE status = 'reserved'
+ ORDER BY created_at ASC""",
+ fetch_all=True
+ )
+ return [{'id': r[0], 'user_id': r[1], 'amount_satoshis': r[2], 'destination_address': r[3]}
+ for r in (results or [])]
+
+ def process_withdrawal(self, withdrawal_id: int) -> dict:
+ """
+ Process a pending withdrawal request.
+
+ Args:
+ withdrawal_id: ID of the withdrawal request.
+
+ Returns:
+ Dict with success status and tx_hash.
+ """
+ # Get withdrawal details
+ result = self.db.execute_sql(
+ """SELECT user_id, amount_satoshis, destination_address, status
+ FROM withdrawal_requests WHERE id = ?""",
+ (withdrawal_id,), fetch_one=True
+ )
+
+ if not result:
+ return {'success': False, 'error': 'Withdrawal not found'}
+
+ user_id, amount_satoshis, destination_address, status = result
+
+ if status != 'reserved':
+ return {'success': False, 'error': f'Invalid status: {status}'}
+
+ # Get wallet and keys
+ wallet = self.get_wallet(user_id)
+ if not wallet:
+ self._fail_withdrawal(withdrawal_id, 'No wallet found')
+ return {'success': False, 'error': 'No wallet'}
+
+ keys = self.get_wallet_keys(user_id)
+ if not keys:
+ self._fail_withdrawal(withdrawal_id, 'Could not retrieve wallet keys')
+ return {'success': False, 'error': 'Could not retrieve keys'}
+
+ # Mark as processing
+ self.db.execute_sql(
+ "UPDATE withdrawal_requests SET status = 'processing' WHERE id = ?",
+ (withdrawal_id,)
+ )
+
+ bitcoin = BitcoinService(testnet=(wallet['network'] == 'testnet'))
+
+ try:
+ # Execute the withdrawal transaction
+ tx_hash = bitcoin.send_transaction(
+ private_key_wif=keys['fee_private_key'],
+ to_address=destination_address,
+ amount_satoshis=amount_satoshis
+ )
+ if not tx_hash:
+ raise RuntimeError("Withdrawal transaction failed: no tx hash returned")
+
+ # Mark as completed
+ self.db.execute_sql(
+ """UPDATE withdrawal_requests
+ SET status = 'completed', btc_txhash = ?, processed_at = CURRENT_TIMESTAMP
+ WHERE id = ?""",
+ (tx_hash, withdrawal_id)
+ )
+
+ logger.info(f"Withdrawal {withdrawal_id} completed: {tx_hash}")
+
+ return {'success': True, 'tx_hash': tx_hash}
+
+ except Exception as e:
+ logger.error(f"Withdrawal {withdrawal_id} failed: {e}")
+ self._fail_withdrawal(withdrawal_id, str(e))
+ return {'success': False, 'error': str(e)}
+
+ def _fail_withdrawal(self, withdrawal_id: int, error_message: str):
+ """
+ Mark a withdrawal as failed and reverse the ledger debit.
+
+ Args:
+ withdrawal_id: ID of the failed withdrawal.
+ error_message: Error message to record.
+ """
+ # Get withdrawal details
+ result = self.db.execute_sql(
+ """SELECT user_id, amount_satoshis, destination_address
+ FROM withdrawal_requests WHERE id = ?""",
+ (withdrawal_id,), fetch_one=True
+ )
+
+ if not result:
+ return
+
+ user_id, amount_satoshis, destination_address = result
+
+ # Mark as failed
+ self.db.execute_sql(
+ """UPDATE withdrawal_requests
+ SET status = 'failed', error_message = ?, processed_at = CURRENT_TIMESTAMP
+ WHERE id = ?""",
+ (error_message, withdrawal_id)
+ )
+
+ # Reverse the ledger debit (credit the user back)
+ idempotency_key = f"withdrawal_reversal_{withdrawal_id}"
+
+ # Check if already reversed
+ existing = self.db.execute_sql(
+ "SELECT id FROM credits_ledger WHERE idempotency_key = ?",
+ (idempotency_key,), fetch_one=True
+ )
+
+ if not existing:
+ self.db.execute_sql(
+ """INSERT INTO credits_ledger
+ (user_id, amount_satoshis, tx_type, reference_id, idempotency_key)
+ VALUES (?, ?, 'withdrawal_reversal', ?, ?)""",
+ (user_id, amount_satoshis, f"failed_withdrawal_{withdrawal_id}", idempotency_key)
+ )
+ logger.info(f"Reversed withdrawal {withdrawal_id} - credited {amount_satoshis} sats back to user {user_id}")
diff --git a/tests/test_strategy_execution.py b/tests/test_strategy_execution.py
index 43e3028..1bc516c 100644
--- a/tests/test_strategy_execution.py
+++ b/tests/test_strategy_execution.py
@@ -334,6 +334,9 @@ class TestStopStrategy:
bt = BrighterTrades(MagicMock())
bt.strategies = MagicMock()
bt.strategies.active_instances = {}
+ # Mock wallet_manager for fee settlement
+ bt.wallet_manager = MagicMock()
+ bt.wallet_manager.settle_accumulated_fees.return_value = {'success': True, 'settled': 0}
return bt
def test_stop_strategy_not_running(self, mock_brighter_trades):
diff --git a/tests/test_wallet.py b/tests/test_wallet.py
new file mode 100644
index 0000000..56acb33
--- /dev/null
+++ b/tests/test_wallet.py
@@ -0,0 +1,827 @@
+"""
+Tests for the wallet module.
+Tests encryption, Bitcoin service, and wallet manager functionality.
+"""
+import pytest
+import tempfile
+import os
+import sys
+from unittest.mock import patch
+
+# Add src to path for imports
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+
+from wallet.encryption import KeyEncryption
+from wallet.bitcoin_service import BitcoinService
+from wallet.wallet_manager import WalletManager, BALANCE_CAP_SATOSHIS
+
+
+class TestKeyEncryption:
+ """Tests for KeyEncryption."""
+
+ def test_encrypt_decrypt(self):
+ """Test basic encryption and decryption."""
+ keys = {1: 'test_master_key_v1'}
+ encryption = KeyEncryption(keys)
+
+ original = "my_secret_private_key"
+ encrypted = encryption.encrypt(original)
+ decrypted = encryption.decrypt(encrypted, version=1)
+
+ assert decrypted == original
+ assert encrypted != original
+
+ def test_versioned_keys(self):
+ """Test encryption with multiple key versions."""
+ keys = {1: 'key_v1', 2: 'key_v2'}
+ encryption = KeyEncryption(keys)
+
+ data = "secret_data"
+ encrypted_v1 = encryption.encrypt(data, version=1)
+ encrypted_v2 = encryption.encrypt(data, version=2)
+
+ # Both should decrypt correctly with their respective versions
+ assert encryption.decrypt(encrypted_v1, version=1) == data
+ assert encryption.decrypt(encrypted_v2, version=2) == data
+
+ # Encrypted data should be different for different versions
+ assert encrypted_v1 != encrypted_v2
+
+ def test_re_encrypt(self):
+ """Test re-encryption from old version to new version."""
+ keys = {1: 'old_key', 2: 'new_key'}
+ encryption = KeyEncryption(keys)
+
+ data = "secret_data"
+ encrypted_v1 = encryption.encrypt(data, version=1)
+
+ # Re-encrypt to version 2
+ encrypted_v2, new_version = encryption.re_encrypt(encrypted_v1, old_version=1, new_version=2)
+
+ assert new_version == 2
+ assert encryption.decrypt(encrypted_v2, version=2) == data
+
+ def test_unknown_version_raises(self):
+ """Test that decrypting with unknown version raises."""
+ keys = {1: 'test_key'}
+ encryption = KeyEncryption(keys)
+
+ with pytest.raises(ValueError, match="Unknown encryption key version"):
+ encryption.decrypt("some_data", version=99)
+
+
+class TestBitcoinService:
+ """Tests for BitcoinService."""
+
+ def test_generate_keypair_testnet(self):
+ """Test generating testnet keypair."""
+ service = BitcoinService(testnet=True)
+ keypair = service.generate_keypair()
+
+ assert 'address' in keypair
+ assert 'public_key' in keypair
+ assert 'private_key' in keypair
+
+ # Testnet addresses start with m, n, or tb1
+ address = keypair['address']
+ assert address.startswith(('m', 'n', 'tb1'))
+
+ def test_validate_testnet_address(self):
+ """Test validating testnet addresses."""
+ service = BitcoinService(testnet=True)
+
+ # Generate a valid testnet address
+ keypair = service.generate_keypair()
+ assert service.validate_address(keypair['address'])
+
+ # Invalid addresses
+ assert not service.validate_address('')
+ assert not service.validate_address(None)
+ assert not service.validate_address('invalid')
+ assert not service.validate_address('1BvBMSEYstWetqTFn5Au4m4GFg7xJaNVN2') # Mainnet
+
+ def test_validate_mainnet_address(self):
+ """Test validating mainnet addresses."""
+ service = BitcoinService(testnet=False)
+
+ # Valid mainnet addresses
+ assert service.validate_address('1BvBMSEYstWetqTFn5Au4m4GFg7xJaNVN2')
+ assert service.validate_address('3J98t1WpEZ73CNmQviecrnyiWrnqRhWNLy')
+ assert service.validate_address('bc1qar0srrr7xfkvy5l643lydnw9re59gtzzwf5mdq')
+
+ # Testnet addresses should be invalid on mainnet
+ assert not service.validate_address('mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt')
+
+ def test_estimate_fee(self):
+ """Test fee estimation."""
+ service = BitcoinService(testnet=True)
+
+ fee = service.estimate_fee(num_inputs=1, num_outputs=2)
+ assert fee > 0
+ assert isinstance(fee, int)
+
+ # More inputs/outputs should cost more
+ fee_larger = service.estimate_fee(num_inputs=3, num_outputs=5)
+ assert fee_larger > fee
+
+
+class MockDatabase:
+ """Mock database for testing WalletManager."""
+
+ def __init__(self):
+ self.tables = {}
+ self.data = {}
+
+ def execute_sql(self, sql, params=None, fetch_one=False, fetch_all=False):
+ """Simple mock SQL execution."""
+ sql_lower = sql.lower().strip()
+
+ if sql_lower.startswith('create table') or sql_lower.startswith('create index'):
+ # Table creation - just ignore
+ return None
+
+ if sql_lower.startswith('alter table'):
+ # Migration - add sweep_address column
+ return None
+
+ if sql_lower.startswith('insert'):
+ # Store insert data
+ if 'wallets' in sql_lower:
+ # Two-address model: params are (user_id, btc_address, public_key, private_key, sweep_address, version, network)
+ self.data.setdefault('wallets', {})[params[0]] = {
+ 'user_id': params[0],
+ 'btc_address': params[1],
+ 'public_key_encrypted': params[2],
+ 'private_key_encrypted': params[3],
+ 'sweep_address': params[4] if len(params) > 6 else None,
+ 'encryption_key_version': params[5] if len(params) > 6 else params[4],
+ 'network': params[6] if len(params) > 6 else params[5],
+ 'is_disabled': 0,
+ 'created_at': 'now'
+ }
+ elif 'credits_ledger' in sql_lower:
+ self.data.setdefault('credits_ledger', []).append({
+ 'user_id': params[0],
+ 'amount_satoshis': params[1],
+ 'tx_type': params[2],
+ 'reference_id': params[3],
+ 'idempotency_key': params[-1],
+ })
+ elif 'pending_strategy_fees' in sql_lower:
+ self.data.setdefault('pending_fees', {})[params[0]] = {
+ 'strategy_run_id': params[0],
+ 'user_id': params[1],
+ 'creator_user_id': params[2],
+ 'fee_percent': params[3] if len(params) > 3 else 10,
+ 'accumulated_satoshis': 0,
+ 'trade_count': 0,
+ }
+ elif 'withdrawal_requests' in sql_lower:
+ self.data.setdefault('withdrawals', []).append({
+ 'user_id': params[0],
+ 'amount_satoshis': params[1],
+ 'destination_address': params[2],
+ 'status': 'reserved', # status is hardcoded in SQL
+ })
+ return None
+
+ if sql_lower.startswith('select'):
+ if fetch_one:
+ if 'wallets' in sql_lower and 'user_id' in sql_lower:
+ wallet = self.data.get('wallets', {}).get(params[0])
+ if wallet:
+ # Check if it's the get_wallet_keys query (includes encrypted keys)
+ if 'private_key_encrypted' in sql_lower:
+ return (wallet['btc_address'], wallet['public_key_encrypted'],
+ wallet['private_key_encrypted'], wallet['encryption_key_version'],
+ wallet['network'])
+ # Regular get_wallet query (5 columns including sweep_address)
+ return (wallet['btc_address'], wallet['network'],
+ wallet['is_disabled'], wallet['created_at'],
+ wallet.get('sweep_address'))
+ elif 'credits_ledger' in sql_lower and 'sum' in sql_lower:
+ ledger = self.data.get('credits_ledger', [])
+ user_total = sum(e['amount_satoshis'] for e in ledger if e['user_id'] == params[0])
+ return (user_total,)
+ elif 'credits_ledger' in sql_lower and 'idempotency' in sql_lower:
+ ledger = self.data.get('credits_ledger', [])
+ for entry in ledger:
+ if entry.get('idempotency_key') == params[0]:
+ return (1,)
+ return None
+ elif 'pending_strategy_fees' in sql_lower:
+ fees = self.data.get('pending_fees', {}).get(params[0])
+ if fees:
+ # Different queries select different columns
+ if 'fee_percent' in sql_lower and 'accumulated' not in sql_lower:
+ # accumulate_trade_fee: SELECT fee_percent
+ return (fees.get('fee_percent', 10),)
+ elif 'accumulated_satoshis, trade_count' in sql_lower and 'user_id' not in sql_lower.split('select')[1].split('from')[0]:
+ # get_pending_fees: SELECT accumulated_satoshis, trade_count
+ return (fees['accumulated_satoshis'], fees['trade_count'])
+ else:
+ # settle_accumulated_fees: SELECT user_id, creator_user_id, accumulated_satoshis, trade_count
+ return (fees['user_id'], fees['creator_user_id'],
+ fees['accumulated_satoshis'], fees['trade_count'])
+ return None
+
+ if fetch_all:
+ if 'credits_ledger' in sql_lower:
+ ledger = self.data.get('credits_ledger', [])
+ user_entries = [e for e in ledger if e['user_id'] == params[0]]
+ return [(e['tx_type'], e['amount_satoshis'], e['reference_id'], 'now')
+ for e in user_entries[-params[1]:]]
+ return []
+
+ if sql_lower.startswith('update'):
+ if 'wallets' in sql_lower and 'is_disabled' in sql_lower:
+ user_id = params[0] # Only param is user_id
+ if user_id in self.data.get('wallets', {}):
+ # Parse the value from the SQL itself (is_disabled = 0 or is_disabled = 1)
+ if 'is_disabled = 1' in sql_lower:
+ self.data['wallets'][user_id]['is_disabled'] = 1
+ else:
+ self.data['wallets'][user_id]['is_disabled'] = 0
+ elif 'pending_strategy_fees' in sql_lower and 'accumulated_satoshis' in sql_lower:
+ # Update: accumulated_satoshis = accumulated_satoshis + ?, trade_count = trade_count + 1
+ fee_amount = params[0]
+ run_id = params[1]
+ if run_id in self.data.get('pending_fees', {}):
+ self.data['pending_fees'][run_id]['accumulated_satoshis'] += fee_amount
+ self.data['pending_fees'][run_id]['trade_count'] += 1
+ return None
+
+ if sql_lower.startswith('delete'):
+ if 'pending_strategy_fees' in sql_lower:
+ run_id = params[0]
+ if run_id in self.data.get('pending_fees', {}):
+ del self.data['pending_fees'][run_id]
+ return None
+
+ return None
+
+ def execute_in_transaction(self, statements):
+ """Execute multiple statements atomically (mock version)."""
+ # In mock, just execute each statement in order
+ for sql, params in statements:
+ self.execute_sql(sql, params)
+ return True
+
+
+class TestWalletManager:
+ """Tests for WalletManager."""
+
+ @pytest.fixture
+ def wallet_manager(self):
+ """Create a wallet manager with mock database."""
+ db = MockDatabase()
+ keys = {1: 'test_encryption_key'}
+ return WalletManager(db, keys, default_network='testnet')
+
+ def test_create_wallet(self, wallet_manager):
+ """Test creating a new wallet (two-address model)."""
+ result = wallet_manager.create_wallet(user_id=1)
+
+ assert result['success']
+ # Fee Address (we store keys)
+ assert 'fee_address' in result
+ assert 'fee_private_key' in result
+ # Sweep Address (we don't store private key)
+ assert 'sweep_address' in result
+ assert 'sweep_private_key' in result
+ assert result['network'] == 'testnet'
+
+ def test_create_wallet_duplicate(self, wallet_manager):
+ """Test creating duplicate wallet fails."""
+ wallet_manager.create_wallet(user_id=1)
+ result = wallet_manager.create_wallet(user_id=1)
+
+ assert not result['success']
+ assert 'already exists' in result['error'].lower()
+
+ def test_get_wallet(self, wallet_manager):
+ """Test getting wallet info."""
+ wallet_manager.create_wallet(user_id=1)
+ wallet = wallet_manager.get_wallet(user_id=1)
+
+ assert wallet is not None
+ assert 'address' in wallet
+ assert wallet['network'] == 'testnet'
+ assert 'private_key' not in wallet # Private key should not be exposed
+
+ def test_credits_balance_empty(self, wallet_manager):
+ """Test credits balance starts at zero."""
+ balance = wallet_manager.get_credits_balance(user_id=1)
+ assert balance == 0
+
+ def test_admin_credit(self, wallet_manager):
+ """Test admin credit adds to balance."""
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='test')
+ balance = wallet_manager.get_credits_balance(user_id=1)
+ assert balance == 10000
+
+ def test_credit_deposit(self, wallet_manager):
+ """Test deposit crediting."""
+ result = wallet_manager.credit_deposit(
+ user_id=1,
+ amount_satoshis=50000,
+ network='testnet',
+ tx_hash='abc123',
+ vout=0
+ )
+
+ assert result['success']
+ assert wallet_manager.get_credits_balance(user_id=1) == 50000
+
+ def test_deposit_idempotency(self, wallet_manager):
+ """Test deposit is idempotent."""
+ wallet_manager.credit_deposit(
+ user_id=1, amount_satoshis=50000,
+ network='testnet', tx_hash='abc123', vout=0
+ )
+ wallet_manager.credit_deposit(
+ user_id=1, amount_satoshis=50000,
+ network='testnet', tx_hash='abc123', vout=0
+ )
+
+ # Should only be credited once
+ balance = wallet_manager.get_credits_balance(user_id=1)
+ assert balance == 50000
+
+ def test_fee_accumulation(self, wallet_manager):
+ """Test fee accumulation during strategy run with default 10% fee."""
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
+
+ # Start accumulation with default 10% fee
+ result = wallet_manager.start_fee_accumulation(
+ strategy_run_id='run_001',
+ user_id=1,
+ creator_user_id=2,
+ fee_percent=10, # 10% of exchange fee
+ estimated_trades=5
+ )
+ assert result['success']
+
+ # Accumulate some fees
+ wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=10000, is_profitable=True)
+ wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=5000, is_profitable=True)
+ wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=8000, is_profitable=False) # Should not accumulate
+
+ # Check pending fees
+ pending = wallet_manager.get_pending_fees('run_001')
+ assert pending['accumulated_satoshis'] == 1500 # 10% of 10000 + 10% of 5000
+ assert pending['trade_count'] == 2 # Only profitable trades counted
+
+ def test_fee_accumulation_custom_percent(self, wallet_manager):
+ """Test fee accumulation with custom fee percentage (50%)."""
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
+
+ # Start accumulation with 50% fee
+ result = wallet_manager.start_fee_accumulation(
+ strategy_run_id='run_002',
+ user_id=1,
+ creator_user_id=2,
+ fee_percent=50, # 50% of exchange fee
+ estimated_trades=5
+ )
+ assert result['success']
+
+ # Accumulate fee: 50% of 10000 = 5000
+ wallet_manager.accumulate_trade_fee('run_002', exchange_fee_satoshis=10000, is_profitable=True)
+
+ pending = wallet_manager.get_pending_fees('run_002')
+ assert pending['accumulated_satoshis'] == 5000 # 50% of 10000
+ assert pending['trade_count'] == 1
+
+ def test_fee_accumulation_100_percent(self, wallet_manager):
+ """Test fee accumulation at 100% (same as exchange commission)."""
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
+
+ # Start accumulation with 100% fee
+ result = wallet_manager.start_fee_accumulation(
+ strategy_run_id='run_003',
+ user_id=1,
+ creator_user_id=2,
+ fee_percent=100, # 100% of exchange fee
+ estimated_trades=5
+ )
+ assert result['success']
+
+ # Accumulate fee: 100% of 10000 = 10000
+ wallet_manager.accumulate_trade_fee('run_003', exchange_fee_satoshis=10000, is_profitable=True)
+
+ pending = wallet_manager.get_pending_fees('run_003')
+ assert pending['accumulated_satoshis'] == 10000 # 100% of 10000
+ assert pending['trade_count'] == 1
+
+ def test_fee_settlement(self, wallet_manager):
+ """Test fee settlement when strategy stops."""
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
+ wallet_manager.admin_credit(user_id=2, amount_satoshis=0, reason='creator setup')
+
+ wallet_manager.start_fee_accumulation('run_001', user_id=1, creator_user_id=2)
+ wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=10000, is_profitable=True)
+
+ # Settle fees
+ result = wallet_manager.settle_accumulated_fees('run_001')
+
+ assert result['success']
+ assert result['settled'] == 1000 # 10% of 10000
+ assert result['trades'] == 1
+
+ # Check balances
+ user_balance = wallet_manager.get_credits_balance(user_id=1)
+ creator_balance = wallet_manager.get_credits_balance(user_id=2)
+
+ assert user_balance == 99000 # 100000 - 1000
+ assert creator_balance == 1000 # Received fee
+
+ def test_insufficient_credits_for_strategy(self, wallet_manager):
+ """Test that strategy start fails with insufficient credits."""
+ # User has no credits
+ result = wallet_manager.start_fee_accumulation(
+ strategy_run_id='run_001',
+ user_id=1,
+ creator_user_id=2,
+ estimated_trades=10
+ )
+
+ assert not result['success']
+ assert 'insufficient' in result['error'].lower()
+
+ def test_withdrawal_request(self, wallet_manager):
+ """Test withdrawal request."""
+ wallet_manager.create_wallet(user_id=1)
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test')
+
+ result = wallet_manager.request_withdrawal(
+ user_id=1,
+ amount_satoshis=10000,
+ destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
+ )
+
+ assert result['success']
+
+ # Balance should be reduced immediately
+ balance = wallet_manager.get_credits_balance(user_id=1)
+ assert balance == 40000
+
+ def test_withdrawal_insufficient_balance(self, wallet_manager):
+ """Test withdrawal fails with insufficient balance."""
+ wallet_manager.create_wallet(user_id=1)
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=5000, reason='test')
+
+ result = wallet_manager.request_withdrawal(
+ user_id=1,
+ amount_satoshis=10000,
+ destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
+ )
+
+ assert not result['success']
+ assert 'insufficient' in result['error'].lower()
+
+ def test_withdrawal_invalid_address(self, wallet_manager):
+ """Test withdrawal fails with invalid address."""
+ wallet_manager.create_wallet(user_id=1)
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test')
+
+ result = wallet_manager.request_withdrawal(
+ user_id=1,
+ amount_satoshis=10000,
+ destination_address='invalid_address'
+ )
+
+ assert not result['success']
+ assert 'invalid' in result['error'].lower()
+
+ def test_withdrawal_reservation_transaction_failure(self, wallet_manager):
+ """Test withdrawal reservation fails cleanly when transaction fails."""
+ wallet_manager.create_wallet(user_id=1)
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test')
+
+ def fail_transaction(_statements):
+ raise RuntimeError('db transaction failed')
+
+ wallet_manager.db.execute_in_transaction = fail_transaction
+
+ result = wallet_manager.request_withdrawal(
+ user_id=1,
+ amount_satoshis=10000,
+ destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
+ )
+
+ assert not result['success']
+ assert 'failed to queue' in result['error'].lower()
+ # Balance should remain unchanged because reservation wasn't committed.
+ assert wallet_manager.get_credits_balance(user_id=1) == 50000
+
+ def test_own_strategy_no_fees(self, wallet_manager):
+ """Test that running own strategy doesn't require fees."""
+ result = wallet_manager.start_fee_accumulation(
+ strategy_run_id='run_001',
+ user_id=1,
+ creator_user_id=1, # Same user
+ estimated_trades=10
+ )
+
+ assert result['success']
+ assert 'no fees' in result.get('message', '').lower()
+
+ def test_transaction_history(self, wallet_manager):
+ """Test getting transaction history."""
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='credit1')
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=5000, reason='credit2')
+
+ history = wallet_manager.get_transaction_history(user_id=1, limit=10)
+
+ assert len(history) == 2
+ assert all('type' in tx for tx in history)
+ assert all('amount' in tx for tx in history)
+
+
+class TestBalanceCap:
+ """Tests for balance cap functionality."""
+
+ @pytest.fixture
+ def wallet_manager(self):
+ db = MockDatabase()
+ keys = {1: 'test_key'}
+ return WalletManager(db, keys, default_network='testnet')
+
+ def test_balance_cap_disables_wallet(self, wallet_manager):
+ """Test that exceeding balance cap disables wallet."""
+ wallet_manager.create_wallet(user_id=1)
+
+ # Credit above cap
+ wallet_manager.admin_credit(
+ user_id=1,
+ amount_satoshis=BALANCE_CAP_SATOSHIS + 10000,
+ reason='over cap'
+ )
+
+ result = wallet_manager.check_balance_cap(user_id=1)
+ assert result['over_cap']
+
+ # Wallet should show as disabled
+ wallet = wallet_manager.get_wallet(user_id=1)
+ assert wallet['is_disabled']
+
+ def test_balance_under_cap_enables_wallet(self, wallet_manager):
+ """Test that being under balance cap enables wallet."""
+ wallet_manager.create_wallet(user_id=1)
+
+ # Credit under cap
+ wallet_manager.admin_credit(
+ user_id=1,
+ amount_satoshis=BALANCE_CAP_SATOSHIS - 10000,
+ reason='under cap'
+ )
+
+ result = wallet_manager.check_balance_cap(user_id=1)
+ assert not result['over_cap']
+
+ wallet = wallet_manager.get_wallet(user_id=1)
+ assert not wallet['is_disabled']
+
+
+class MockDatabaseForBackgroundJobs(MockDatabase):
+ """Extended mock database for testing background job methods."""
+
+ def execute_sql(self, sql, params=None, fetch_one=False, fetch_all=False):
+ sql_lower = sql.lower().strip()
+
+ # Handle get_wallets_over_cap query
+ if 'left join credits_ledger' in sql_lower and 'having balance' in sql_lower:
+ results = []
+ for user_id, wallet in self.data.get('wallets', {}).items():
+ if not wallet.get('sweep_address'):
+ continue
+ # Calculate balance from ledger
+ ledger = self.data.get('credits_ledger', [])
+ balance = sum(e['amount_satoshis'] for e in ledger if e['user_id'] == user_id)
+ if balance > params[0]: # params[0] is BALANCE_CAP_SATOSHIS
+ results.append((user_id, wallet.get('sweep_address'), balance))
+ if fetch_all:
+ return results
+ return None
+
+ # Handle get_all_active_wallets query
+ if 'select user_id, btc_address, network' in sql_lower and 'is_disabled = 0' in sql_lower:
+ results = []
+ for user_id, wallet in self.data.get('wallets', {}).items():
+ if not wallet.get('is_disabled'):
+ results.append((user_id, wallet['btc_address'], wallet['network']))
+ if fetch_all:
+ return results
+ return None
+
+ # Handle get_wallets_for_deposit_monitoring query (includes disabled wallets)
+ if 'select user_id, btc_address, network' in sql_lower and 'from wallets' in sql_lower and 'is_disabled' not in sql_lower:
+ results = []
+ for user_id, wallet in self.data.get('wallets', {}).items():
+ results.append((user_id, wallet['btc_address'], wallet['network']))
+ if fetch_all:
+ return results
+ return None
+
+ # Handle deposits table queries
+ if 'insert' in sql_lower and 'deposits' in sql_lower:
+ self.data.setdefault('deposits', []).append({
+ 'user_id': params[0],
+ 'network': params[1],
+ 'tx_hash': params[2],
+ 'vout': params[3],
+ 'amount_satoshis': params[4],
+ 'credited': params[5] if len(params) > 5 else 0
+ })
+ return None
+
+ if 'select' in sql_lower and 'deposits' in sql_lower and 'network' in sql_lower and fetch_one:
+ deposits = self.data.get('deposits', [])
+ for d in deposits:
+ if d['network'] == params[0] and d['tx_hash'] == params[1] and d['vout'] == params[2]:
+ return (d.get('id', 1),)
+ return None
+
+ # Handle get_pending_withdrawals query
+ if 'withdrawal_requests' in sql_lower and "status = 'reserved'" in sql_lower and fetch_all:
+ withdrawals = self.data.get('withdrawals', [])
+ results = []
+ for i, w in enumerate(withdrawals):
+ if w.get('status') == 'reserved':
+ results.append((i + 1, w['user_id'], w['amount_satoshis'], w['destination_address']))
+ return results
+
+ # Handle withdrawal_requests SELECT for process_withdrawal and _fail_withdrawal
+ if 'withdrawal_requests' in sql_lower and 'select' in sql_lower and 'where id' in sql_lower and fetch_one:
+ withdrawals = self.data.get('withdrawals', [])
+ withdrawal_id = params[0]
+ if withdrawal_id > 0 and withdrawal_id <= len(withdrawals):
+ w = withdrawals[withdrawal_id - 1]
+ # process_withdrawal query includes status, _fail_withdrawal does not
+ if 'status' in sql_lower:
+ return (w['user_id'], w['amount_satoshis'], w['destination_address'], w.get('status', 'reserved'))
+ else:
+ return (w['user_id'], w['amount_satoshis'], w['destination_address'])
+ return None
+
+ # Handle withdrawal_requests UPDATE
+ if 'update' in sql_lower and 'withdrawal_requests' in sql_lower:
+ withdrawal_id = params[-1] # Last param is the ID in WHERE clause
+ withdrawals = self.data.get('withdrawals', [])
+ if withdrawal_id > 0 and withdrawal_id <= len(withdrawals):
+ if 'processing' in sql_lower:
+ withdrawals[withdrawal_id - 1]['status'] = 'processing'
+ elif 'completed' in sql_lower:
+ withdrawals[withdrawal_id - 1]['status'] = 'completed'
+ withdrawals[withdrawal_id - 1]['btc_txhash'] = params[0]
+ elif 'failed' in sql_lower:
+ withdrawals[withdrawal_id - 1]['status'] = 'failed'
+ withdrawals[withdrawal_id - 1]['error_message'] = params[0]
+ return None
+
+ # Fall back to parent implementation
+ return super().execute_sql(sql, params, fetch_one, fetch_all)
+
+
+class TestBackgroundJobSupport:
+ """Tests for background job support methods in WalletManager."""
+
+ @pytest.fixture
+ def wallet_manager(self):
+ db = MockDatabaseForBackgroundJobs()
+ keys = {1: 'test_key'}
+ return WalletManager(db, keys, default_network='testnet')
+
+ def test_get_wallets_over_cap(self, wallet_manager):
+ """Test getting wallets with balance over cap."""
+ # Create wallet with sweep address
+ wallet_manager.create_wallet(user_id=1)
+
+ # Credit above cap
+ wallet_manager.admin_credit(
+ user_id=1,
+ amount_satoshis=BALANCE_CAP_SATOSHIS + 50000,
+ reason='over cap'
+ )
+
+ wallets = wallet_manager.get_wallets_over_cap()
+
+ assert len(wallets) == 1
+ assert wallets[0]['user_id'] == 1
+ assert wallets[0]['balance'] == BALANCE_CAP_SATOSHIS + 50000
+ assert wallets[0]['sweep_address'] is not None
+
+ def test_get_wallets_over_cap_empty(self, wallet_manager):
+ """Test getting wallets when none are over cap."""
+ wallet_manager.create_wallet(user_id=1)
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='under cap')
+
+ wallets = wallet_manager.get_wallets_over_cap()
+ assert len(wallets) == 0
+
+ def test_get_all_active_wallets(self, wallet_manager):
+ """Test getting all active (non-disabled) wallets."""
+ wallet_manager.create_wallet(user_id=1)
+ wallet_manager.create_wallet(user_id=2)
+
+ # Disable user 2's wallet
+ wallet_manager.db.data['wallets'][2]['is_disabled'] = 1
+
+ wallets = wallet_manager.get_all_active_wallets()
+
+ assert len(wallets) == 1
+ assert wallets[0]['user_id'] == 1
+
+ def test_get_wallets_for_deposit_monitoring_includes_disabled(self, wallet_manager):
+ """Deposit monitoring should include disabled wallets."""
+ wallet_manager.create_wallet(user_id=1)
+ wallet_manager.create_wallet(user_id=2)
+ wallet_manager.db.data['wallets'][2]['is_disabled'] = 1
+
+ wallets = wallet_manager.get_wallets_for_deposit_monitoring()
+ user_ids = {w['user_id'] for w in wallets}
+
+ assert user_ids == {1, 2}
+
+ def test_get_pending_withdrawals(self, wallet_manager):
+ """Test getting pending withdrawal requests."""
+ wallet_manager.create_wallet(user_id=1)
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test')
+
+ # Request withdrawal
+ wallet_manager.request_withdrawal(
+ user_id=1,
+ amount_satoshis=10000,
+ destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
+ )
+
+ pending = wallet_manager.get_pending_withdrawals()
+
+ assert len(pending) == 1
+ assert pending[0]['user_id'] == 1
+ assert pending[0]['amount_satoshis'] == 10000
+
+ def test_auto_sweep_no_wallet(self, wallet_manager):
+ """Test auto_sweep fails when no wallet exists."""
+ result = wallet_manager.auto_sweep(user_id=999, amount_satoshis=10000)
+
+ assert not result['success']
+ assert 'no wallet' in result['error'].lower()
+
+ def test_auto_sweep_no_sweep_address(self, wallet_manager):
+ """Test auto_sweep fails when no sweep address configured."""
+ wallet_manager.create_wallet(user_id=1)
+ # Remove sweep address
+ wallet_manager.db.data['wallets'][1]['sweep_address'] = None
+
+ result = wallet_manager.auto_sweep(user_id=1, amount_satoshis=10000)
+
+ assert not result['success']
+ assert 'sweep address' in result['error'].lower()
+
+ def test_process_withdrawal_not_found(self, wallet_manager):
+ """Test processing non-existent withdrawal."""
+ result = wallet_manager.process_withdrawal(withdrawal_id=999)
+
+ assert not result['success']
+ assert 'not found' in result['error'].lower()
+
+ def test_fail_withdrawal_reverses_credit(self, wallet_manager):
+ """Test that failing a withdrawal reverses the ledger debit."""
+ wallet_manager.create_wallet(user_id=1)
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test')
+
+ # Request withdrawal (debits 10000)
+ wallet_manager.request_withdrawal(
+ user_id=1,
+ amount_satoshis=10000,
+ destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
+ )
+
+ # Balance should be reduced
+ assert wallet_manager.get_credits_balance(user_id=1) == 90000
+
+ # Fail the withdrawal
+ wallet_manager._fail_withdrawal(withdrawal_id=1, error_message='Test failure')
+
+ # Balance should be restored
+ assert wallet_manager.get_credits_balance(user_id=1) == 100000
+
+ def test_process_withdrawal_none_txhash_reverses_credit(self, wallet_manager):
+ """None tx hash from send_transaction should fail and reverse reservation."""
+ wallet_manager.create_wallet(user_id=1)
+ wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test')
+ wallet_manager.request_withdrawal(
+ user_id=1,
+ amount_satoshis=10000,
+ destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
+ )
+ assert wallet_manager.get_credits_balance(user_id=1) == 90000
+
+ with patch('wallet.wallet_manager.BitcoinService.send_transaction', return_value=None):
+ result = wallet_manager.process_withdrawal(withdrawal_id=1)
+
+ assert not result['success']
+ assert 'no tx hash' in result['error'].lower()
+ assert wallet_manager.get_credits_balance(user_id=1) == 100000