From 7e77f55837e073d94bb757d8b2ff6f8b5ab3e0d7 Mon Sep 17 00:00:00 2001 From: rob Date: Mon, 9 Mar 2026 12:45:25 -0300 Subject: [PATCH] Add Bitcoin wallet system with configurable strategy fees Features: - Two-address custody model (Fee Address + Sweep Address) - Credits ledger for instant, reliable strategy fee transactions - Configurable fee percentage (1-100% of exchange commission) - Background jobs for auto-sweep, deposit detection, withdrawal processing - Account settings dialog accessible via username click - $50 balance cap with auto-sweep to user's sweep address Security improvements: - Atomic withdrawal reservation prevents partial state - Fee accumulation cleanup on strategy startup failure - Deposit monitoring includes disabled wallets for recovery - Null tx hash checks prevent silent failures - Key export disabled by default Co-Authored-By: Claude Opus 4.5 --- pytest.ini | 2 + requirements.txt | 5 +- src/BrighterTrades.py | 149 ++- src/Database.py | 39 +- src/StrategyInstance.py | 42 + src/Users.py | 77 +- src/app.py | 275 ++++++ src/brokers/live_broker.py | 15 +- src/brokers/paper_broker.py | 15 +- src/live_strategy_instance.py | 15 + src/paper_strategy_instance.py | 13 + src/static/Account.js | 695 ++++++++++++++ src/static/Strategies.js | 6 + src/static/brighterStyles.css | 2 +- src/static/general.js | 10 + src/templates/account_settings_dialog.html | 520 ++++++++++ src/templates/index.html | 2 + src/templates/login.html | 2 +- src/templates/new_strategy_popup.html | 36 +- src/wallet/__init__.py | 8 + src/wallet/background_jobs.py | 194 ++++ src/wallet/bitcoin_service.py | 160 ++++ src/wallet/encryption.py | 87 ++ src/wallet/wallet_manager.py | 1005 ++++++++++++++++++++ tests/test_strategy_execution.py | 3 + tests/test_wallet.py | 827 ++++++++++++++++ 26 files changed, 4188 insertions(+), 16 deletions(-) create mode 100644 src/static/Account.js create mode 100644 src/templates/account_settings_dialog.html create mode 100644 src/wallet/__init__.py create mode 100644 src/wallet/background_jobs.py create mode 100644 src/wallet/bitcoin_service.py create mode 100644 src/wallet/encryption.py create mode 100644 src/wallet/wallet_manager.py create mode 100644 tests/test_wallet.py diff --git a/pytest.ini b/pytest.ini index fe048b6..0488659 100644 --- a/pytest.ini +++ b/pytest.ini @@ -5,6 +5,8 @@ python_classes = Test* python_functions = test_* # Default: exclude integration tests (run with: pytest -m integration) addopts = -v --tb=short -m "not integration" +filterwarnings = + ignore:'crypt' is deprecated and slated for removal in Python 3\.13:DeprecationWarning markers = live_testnet: marks tests as requiring live testnet API keys (deselect with '-m "not live_testnet"') diff --git a/requirements.txt b/requirements.txt index d834492..70c4415 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,7 @@ Flask-Cors~=3.0.10 email_validator~=2.2.0 aiohttp>=3.9.0 websockets>=12.0 -requests>=2.31.0 \ No newline at end of file +requests>=2.31.0 +# Bitcoin wallet and encryption +bit>=0.8.0 +cryptography>=41.0.0 \ No newline at end of file diff --git a/src/BrighterTrades.py b/src/BrighterTrades.py index 7eeacc8..79bd709 100644 --- a/src/BrighterTrades.py +++ b/src/BrighterTrades.py @@ -13,6 +13,7 @@ from indicators import Indicators from Signals import Signals from trade import Trades from edm_client import EdmClient, EdmWebSocketClient +from wallet import WalletManager # Configure logging logger = logging.getLogger(__name__) @@ -71,6 +72,18 @@ class BrighterTrades: edm_client=self.edm_client) self.backtests = {} # In-memory storage for backtests (replace with DB access in production) + # Wallet manager for Bitcoin wallets and credits ledger + wallet_config = self.config.get_setting('wallet') or {} + wallet_keys = wallet_config.get('encryption_keys', {1: 'default_dev_key_change_in_production'}) + # Ensure keys are int -> str mapping + wallet_keys = {int(k): v for k, v in wallet_keys.items()} + self.wallet_manager = WalletManager( + database=self.data.db, + encryption_keys=wallet_keys, + default_network=wallet_config.get('bitcoin_network', 'testnet') + ) + logger.info(f"Wallet manager initialized (network: {wallet_config.get('bitcoin_network', 'testnet')})") + @staticmethod def _coerce_user_id(user_id: Any) -> int | None: if user_id is None or user_id == '': @@ -752,7 +765,12 @@ class BrighterTrades: # This ensures subscribers run with the creator's indicator definitions indicator_owner_id = creator_id if is_subscribed and not is_owner else None - # Early exchange requirements validation + # Check for strategy fees (only for non-owners running in paper/live mode) + strategy_fee = float(strategy_row.get('fee', 0.0)) + has_fee = strategy_fee > 0 and not is_owner and mode in ['paper', 'live'] + strategy_run_id = None # Will be set if fee accumulation is started + + # Early exchange requirements validation (BEFORE fee accumulation to avoid orphaned fees) from exchange_validation import extract_required_exchanges, validate_exchange_requirements strategy_full = self.strategies.get_strategy_by_tbl_key(strategy_id) required_exchanges = extract_required_exchanges(strategy_full) @@ -933,6 +951,28 @@ class BrighterTrades: # Paper mode: random UUID since paper state is ephemeral strategy_instance_id = str(uuid.uuid4()) + # Start fee accumulation only after all startup validation has passed. + if has_fee and creator_id is not None: + strategy_run_id = f"{strategy_id}_{user_id}_{uuid.uuid4().hex[:8]}" + fee_result = self.wallet_manager.start_fee_accumulation( + strategy_run_id=strategy_run_id, + user_id=user_id, + creator_user_id=creator_id, + fee_percent=int(strategy_fee), # 1-100% of exchange commission + estimated_trades=10 # Check for ~10 trades worth of credits + ) + + if not fee_result['success']: + return { + 'success': False, + 'message': fee_result.get('error', 'Failed to start fee accumulation'), + 'balance_available': fee_result.get('available', 0), + 'recommended_minimum': fee_result.get('recommended_minimum', 10000), + 'need_deposit': True + } + + logger.info(f"Started fee accumulation for strategy {strategy_id}: run_id={strategy_run_id}") + instance = self.strategies.create_strategy_instance( mode=mode, strategy_instance_id=strategy_instance_id, @@ -950,6 +990,15 @@ class BrighterTrades: indicator_owner_id=indicator_owner_id, # For subscribed strategies, use creator's indicators ) + # Store fee tracking info on the instance + if strategy_run_id: + instance.strategy_run_id = strategy_run_id + instance.has_fee = True + instance.wallet_manager = self.wallet_manager + else: + instance.strategy_run_id = None + instance.has_fee = False + # Store the active instance self.strategies.active_instances[instance_key] = instance @@ -980,6 +1029,8 @@ class BrighterTrades: return result except Exception as e: + if strategy_run_id: + self.wallet_manager.cancel_fee_accumulation(strategy_run_id) logger.error(f"Failed to create strategy instance: {e}", exc_info=True) return {"success": False, "message": f"Failed to start strategy: {str(e)}"} @@ -1028,9 +1079,20 @@ class BrighterTrades: if hasattr(instance, 'trade_history'): final_stats['total_trades'] = len(instance.trade_history) + # Settle accumulated fees if this was a paid strategy + fee_settlement = None + if hasattr(instance, 'strategy_run_id') and instance.strategy_run_id: + settle_result = self.wallet_manager.settle_accumulated_fees(instance.strategy_run_id) + fee_settlement = { + 'fees_settled': settle_result.get('settled', 0), + 'trades_charged': settle_result.get('trades', 0) + } + logger.info(f"Settled {settle_result.get('settled', 0)} sats for " + f"{settle_result.get('trades', 0)} trades on strategy {strategy_id}") + logger.info(f"Stopped strategy '{strategy_name}' for user {user_id} in {mode} mode") - return { + result = { "success": True, "message": f"Strategy '{strategy_name}' stopped.", "strategy_id": strategy_id, @@ -1040,6 +1102,11 @@ class BrighterTrades: "final_stats": final_stats, } + if fee_settlement: + result["fee_settlement"] = fee_settlement + + return result + def get_strategy_status( self, user_id: int, @@ -1809,3 +1876,81 @@ class BrighterTrades: if msg_type == 'reply': # If the message is a reply log the response to the terminal. print(f"\napp.py:Received reply: {msg_data}") + + # ===== Wallet Methods ===== + + def create_user_wallet(self, user_id: int, user_sweep_address: str = None) -> dict: + """ + Create BTC wallet for registered user (two-address model). + + Args: + user_id: User ID to create wallet for. + user_sweep_address: Optional user-provided sweep address. + + Returns: + Dict with success status and wallet details. + """ + user = self.users.get_user_by_id(user_id) + if not user: + return {'success': False, 'error': 'User not found'} + + # Check if user is a guest (guests can't have wallets) + username = user.get('user_name', '') + if username == 'guest' or user.get('is_guest', False): + return {'success': False, 'error': 'Only registered users can create wallets'} + + return self.wallet_manager.create_wallet(user_id, user_sweep_address=user_sweep_address) + + def get_user_wallet(self, user_id: int) -> dict: + """ + Get user's wallet info (without private key). + + Args: + user_id: User ID to look up. + + Returns: + Dict with wallet info or None. + """ + return self.wallet_manager.get_wallet(user_id) + + def get_credits_balance(self, user_id: int) -> int: + """ + Get user's spendable credits balance. + + Args: + user_id: User ID to look up. + + Returns: + Balance in satoshis. + """ + return self.wallet_manager.get_credits_balance(user_id) + + def request_withdrawal(self, user_id: int, amount_satoshis: int, + destination_address: str) -> dict: + """ + Request BTC withdrawal. + + Args: + user_id: User ID requesting withdrawal. + amount_satoshis: Amount to withdraw. + destination_address: Bitcoin address to send to. + + Returns: + Dict with success status. + """ + return self.wallet_manager.request_withdrawal( + user_id, amount_satoshis, destination_address + ) + + def get_transaction_history(self, user_id: int, limit: int = 20) -> list: + """ + Get user's recent ledger transactions. + + Args: + user_id: User ID to look up. + limit: Maximum number of transactions. + + Returns: + List of transaction dicts. + """ + return self.wallet_manager.get_transaction_history(user_id, limit) diff --git a/src/Database.py b/src/Database.py index 9008c27..41099cc 100644 --- a/src/Database.py +++ b/src/Database.py @@ -89,12 +89,16 @@ class Database: def __init__(self, db_file: str = None): self.db_file = db_file - def execute_sql(self, sql: str, params: list = None) -> None: + def execute_sql(self, sql: str, params: tuple = None, + fetch_one: bool = False, fetch_all: bool = False) -> Any: """ - Executes a raw SQL statement with optional parameters. + Executes a raw SQL statement with optional parameters and fetch modes. :param sql: SQL statement to execute. :param params: Optional tuple of parameters to pass with the SQL statement. + :param fetch_one: If True, returns a single row (or None). + :param fetch_all: If True, returns all rows as a list (or empty list). + :return: None, single row, or list of rows depending on fetch parameters. """ with SQLite(self.db_file) as con: cur = con.cursor() @@ -103,6 +107,37 @@ class Database: else: cur.execute(sql, params) + if fetch_one: + return cur.fetchone() + elif fetch_all: + return cur.fetchall() + return None + + def execute_in_transaction(self, statements: list) -> bool: + """ + Executes multiple SQL statements within a single transaction. + All statements succeed or all fail (atomic). + + :param statements: List of (sql, params) tuples to execute. + :return: True if all statements succeeded. + :raises: Exception if any statement fails (transaction rolled back). + """ + conn = sqlite3.connect(self.db_file) + try: + cur = conn.cursor() + for sql, params in statements: + if params is None: + cur.execute(sql) + else: + cur.execute(sql, params) + conn.commit() + return True + except Exception: + conn.rollback() + raise + finally: + conn.close() + def get_all_rows(self, table_name: str) -> pd.DataFrame: """ Retrieves all rows from a table. diff --git a/src/StrategyInstance.py b/src/StrategyInstance.py index 16be176..af0d381 100644 --- a/src/StrategyInstance.py +++ b/src/StrategyInstance.py @@ -839,3 +839,45 @@ class StrategyInstance: Retrieves the available balance for the strategy. """ return self.trades.get_available_balance(self.strategy_id) + + def accumulate_trade_fee(self, trade_value_usd: float, commission_rate: float, + is_profitable: bool) -> dict: + """ + Accumulate fee for a completed trade (only called for paid strategies). + + This method is called when a trade fills to accumulate fees that will + be settled when the strategy stops. + + :param trade_value_usd: The USD value of the trade. + :param commission_rate: The exchange commission rate (e.g., 0.001 for 0.1%). + :param is_profitable: Whether the trade was profitable. + :return: Dict with fee charged amount. + """ + # Check if this strategy has fee tracking enabled + if not getattr(self, 'has_fee', False) or not getattr(self, 'strategy_run_id', None): + return {'success': True, 'fee_charged': 0, 'reason': 'no_fees_enabled'} + + wallet_manager = getattr(self, 'wallet_manager', None) + if not wallet_manager: + return {'success': False, 'error': 'No wallet manager available'} + + # Calculate exchange fee in USD, then convert to satoshis + # Assume 1 BTC = $50,000 for conversion (rough estimate, could be made dynamic) + btc_price_usd = 50000 # This could be fetched from exchange + exchange_fee_usd = trade_value_usd * commission_rate + + # Convert to satoshis: 1 BTC = 100,000,000 satoshis + exchange_fee_btc = exchange_fee_usd / btc_price_usd + exchange_fee_satoshis = int(exchange_fee_btc * 100_000_000) + + # Accumulate the fee + result = wallet_manager.accumulate_trade_fee( + strategy_run_id=self.strategy_run_id, + exchange_fee_satoshis=exchange_fee_satoshis, + is_profitable=is_profitable + ) + + if result.get('fee_charged', 0) > 0: + logger.debug(f"Accumulated fee: {result['fee_charged']} sats for trade worth ${trade_value_usd:.2f}") + + return result diff --git a/src/Users.py b/src/Users.py index 94db5d2..331c3ca 100644 --- a/src/Users.py +++ b/src/Users.py @@ -60,6 +60,25 @@ class BaseUser: filter_vals=('id', user_id) ) + def get_user_by_id(self, user_id: int) -> dict | None: + """ + Retrieves user data as a dict based on the user ID. + + :param user_id: The ID of the user. + :return: A dict containing the user's data, or None if not found. + """ + try: + user_df = self.data.get_rows_from_datacache( + cache_name='users', filter_vals=[('id', user_id)]) + + if user_df is None or user_df.empty: + return None + + # Convert first row to dict + return user_df.iloc[0].to_dict() + except Exception: + return None + def _remove_user_from_memory(self, user_name: str) -> None: """ Private method to remove a user's data from the cache (memory). @@ -202,8 +221,11 @@ class UserAccountManagement(BaseUser): """ if self.validate_password(username=username, password=password): self.modify_user_data(username=username, field_name="status", new_data="logged_in") - self.modify_user_data(username=username, field_name="signin_time", - new_data=dt.datetime.utcnow().timestamp()) + self.modify_user_data( + username=username, + field_name="signin_time", + new_data=dt.datetime.now(dt.timezone.utc).timestamp() + ) return True return False @@ -216,13 +238,60 @@ class UserAccountManagement(BaseUser): """ # Update the user's status and sign-in time in both cache and database self.modify_user_data(username=username, field_name='status', new_data='logged_out') - self.modify_user_data(username=username, field_name='signin_time', new_data=dt.datetime.utcnow().timestamp()) + self.modify_user_data( + username=username, + field_name='signin_time', + new_data=dt.datetime.now(dt.timezone.utc).timestamp() + ) # Remove the user's data from the cache self._remove_user_from_memory(user_name=username) return True + def update_email(self, user_id: int, email: str) -> bool: + """ + Updates the user's email address. + + :param user_id: The ID of the user. + :param email: The new email address. + :return: True on success, False on failure. + """ + try: + username = self.get_username(user_id) + if not username: + return False + + self.modify_user_data(username=username, field_name='email', new_data=email) + return True + except Exception: + return False + + def update_password(self, user_id: int, current_password: str, new_password: str) -> dict: + """ + Updates the user's password after validating the current password. + + :param user_id: The ID of the user. + :param current_password: The current password for verification. + :param new_password: The new password to set. + :return: Dict with 'success' key and optionally 'error' key. + """ + try: + username = self.get_username(user_id) + if not username: + return {'success': False, 'error': 'User not found'} + + # Validate current password + if not self.validate_password(username=username, password=current_password): + return {'success': False, 'error': 'Current password is incorrect'} + + # Hash and store new password + encrypted_password = self.scramble_text(new_password) + self.modify_user_data(username=username, field_name='password', new_data=encrypted_password) + return {'success': True} + except Exception as e: + return {'success': False, 'error': str(e)} + def log_out_all_users(self, enforcement: str = 'hard') -> None: """ Logs out all users by updating their status in the database and clearing the data. @@ -378,7 +447,7 @@ class UserAccountManagement(BaseUser): :param some_text: The text to encrypt. :return: The hashed text. """ - return bcrypt.hash(some_text, rounds=13) + return bcrypt.using(rounds=13).hash(some_text) class UserExchangeManagement(UserAccountManagement): diff --git a/src/app.py b/src/app.py index 374d343..1bd8f19 100644 --- a/src/app.py +++ b/src/app.py @@ -20,6 +20,7 @@ from email_validator import validate_email, EmailNotValidError # noqa: E402 # Local application imports from BrighterTrades import BrighterTrades # noqa: E402 from utils import sanitize_for_json # noqa: E402 +from wallet import WalletBackgroundJobs # noqa: E402 # Set up logging log_level_name = os.getenv('BRIGHTER_LOG_LEVEL', 'INFO').upper() @@ -72,6 +73,8 @@ CORS_HEADERS = 'Content-Type' # Socket ID to authenticated user_id mapping # This is the source of truth for WebSocket authentication - never trust client payloads socket_user_mapping = {} # request.sid -> user_id +wallet_jobs = None +_wallet_jobs_started = False # Set the app directly with the globals. app.config.from_object(__name__) @@ -247,6 +250,41 @@ def start_strategy_loop(): # Start the loop when the app starts (will be called from main block) +def start_wallet_background_jobs(): + """ + Start wallet background jobs once per process. + + This supports both `python src/app.py` and WSGI imports. + Jobs are skipped in test contexts. + """ + global wallet_jobs, _wallet_jobs_started + + if _wallet_jobs_started: + return + + if app.config.get('TESTING') or os.getenv('PYTEST_CURRENT_TEST'): + return + + if os.getenv('BRIGHTER_DISABLE_WALLET_JOBS', '').lower() in ('1', 'true', 'yes'): + logging.info("Wallet background jobs disabled by BRIGHTER_DISABLE_WALLET_JOBS") + return + + if not getattr(brighter_trades, 'wallet_manager', None): + logging.warning("Wallet manager not initialized - background jobs disabled") + return + + wallet_jobs = WalletBackgroundJobs(brighter_trades.wallet_manager, socketio) + wallet_jobs.start_all_jobs() + _wallet_jobs_started = True + logging.info("Wallet background jobs started") + + +@app.before_request +def ensure_background_jobs_started(): + """Ensure wallet background jobs are running in non-`__main__` deployments.""" + start_wallet_background_jobs() + + def _coerce_user_id(user_id): if user_id is None or user_id == '': return None @@ -764,6 +802,240 @@ def _validate_blockly_xml(xml_string: str) -> bool: return False +# ============================================================================= +# Wallet API Routes +# ============================================================================= + +def _get_current_user_id(): + """Get user_id from session. Returns None if not logged in.""" + user_name = session.get('user') + if not user_name: + return None + return brighter_trades.users.get_id(user_name) + + +# === User Profile APIs === + +@app.route('/api/user/profile', methods=['GET']) +def get_user_profile(): + """Get current user's profile info.""" + user_id = _get_current_user_id() + if not user_id: + return jsonify({'success': False, 'error': 'Not logged in'}), 401 + + user = brighter_trades.users.get_user_by_id(user_id) + if user: + return jsonify({ + 'success': True, + 'profile': { + 'username': user.get('user_name'), + 'email': user.get('email', '') + } + }) + return jsonify({'success': False, 'error': 'User not found'}), 404 + + +@app.route('/api/user/email', methods=['POST']) +def update_user_email(): + """Update current user's email address.""" + user_id = _get_current_user_id() + if not user_id: + return jsonify({'success': False, 'error': 'Not logged in'}), 401 + + data = request.get_json() or {} + email = data.get('email', '').strip() + + if not email: + return jsonify({'success': False, 'error': 'Email is required'}), 400 + + # Basic email validation + if '@' not in email or '.' not in email: + return jsonify({'success': False, 'error': 'Invalid email format'}), 400 + + result = brighter_trades.users.update_email(user_id, email) + if result: + return jsonify({'success': True}) + return jsonify({'success': False, 'error': 'Failed to update email'}), 500 + + +@app.route('/api/user/password', methods=['POST']) +def update_user_password(): + """Update current user's password.""" + user_id = _get_current_user_id() + if not user_id: + return jsonify({'success': False, 'error': 'Not logged in'}), 401 + + data = request.get_json() or {} + current_password = data.get('current_password', '') + new_password = data.get('new_password', '') + + if not current_password or not new_password: + return jsonify({'success': False, 'error': 'Both current and new password are required'}), 400 + + if len(new_password) < 6: + return jsonify({'success': False, 'error': 'Password must be at least 6 characters'}), 400 + + result = brighter_trades.users.update_password(user_id, current_password, new_password) + if result.get('success'): + return jsonify({'success': True}) + return jsonify({'success': False, 'error': result.get('error', 'Failed to update password')}), 400 + + +# === Wallet APIs === + +@app.route('/api/wallet', methods=['GET']) +def get_wallet(): + """Get current user's wallet info (without private key).""" + user_id = _get_current_user_id() + if not user_id: + return jsonify({'success': False, 'error': 'Not logged in'}), 401 + + wallet = brighter_trades.get_user_wallet(user_id) + if wallet: + return jsonify({'success': True, 'wallet': wallet}) + return jsonify({'success': True, 'wallet': None}) + + +@app.route('/api/wallet/create', methods=['POST']) +def create_wallet(): + """Create a new wallet for the current user (two-address model).""" + user_id = _get_current_user_id() + if not user_id: + return jsonify({'success': False, 'error': 'Not logged in'}), 401 + + data = request.get_json() or {} + user_sweep_address = data.get('user_sweep_address') # Optional user-provided sweep address + + result = brighter_trades.create_user_wallet(user_id, user_sweep_address=user_sweep_address) + return jsonify(result) + + +@app.route('/api/wallet/credits', methods=['GET']) +def get_credits(): + """Get user's spendable credits balance (from ledger).""" + user_id = _get_current_user_id() + if not user_id: + return jsonify({'success': False, 'error': 'Not logged in'}), 401 + + balance = brighter_trades.get_credits_balance(user_id) + return jsonify({'success': True, 'balance': balance}) + + +@app.route('/api/wallet/withdraw', methods=['POST']) +def request_withdrawal(): + """Request BTC withdrawal (processed async).""" + user_id = _get_current_user_id() + if not user_id: + return jsonify({'success': False, 'error': 'Not logged in'}), 401 + + data = request.get_json() or {} + amount = data.get('amount_satoshis') + destination = data.get('destination_address') + + if not amount or not destination: + return jsonify({'success': False, 'error': 'Missing amount or destination'}), 400 + + try: + amount = int(amount) + except (TypeError, ValueError): + return jsonify({'success': False, 'error': 'Invalid amount'}), 400 + + # SECURITY: Prevent negative withdrawal (would credit balance) + if amount <= 0: + return jsonify({'success': False, 'error': 'Amount must be positive'}), 400 + + result = brighter_trades.request_withdrawal(user_id, amount, destination) + return jsonify(result) + + +@app.route('/api/wallet/transactions', methods=['GET']) +def get_transactions(): + """Get user's recent transaction history.""" + user_id = _get_current_user_id() + if not user_id: + return jsonify({'success': False, 'error': 'Not logged in'}), 401 + + limit = request.args.get('limit', 20, type=int) + transactions = brighter_trades.get_transaction_history(user_id, limit) + return jsonify({'success': True, 'transactions': transactions}) + + +@app.route('/api/wallet/keys', methods=['GET']) +def get_wallet_keys(): + """ + Get fee address keys (for viewing after wallet creation). + Note: Sweep address private key is NOT stored and cannot be retrieved. + """ + user_id = _get_current_user_id() + if not user_id: + return jsonify({'success': False, 'error': 'Not logged in'}), 401 + + # Keep key export disabled by default; can be opt-in for controlled POC sessions. + wallet_cfg = brighter_trades.config.get_setting('wallet') or {} + if not wallet_cfg.get('allow_key_export', False): + return jsonify({'success': False, 'error': 'Key export is disabled'}), 403 + + keys = brighter_trades.wallet_manager.get_wallet_keys(user_id) + if keys: + return jsonify({'success': True, **keys}) + return jsonify({'success': False, 'error': 'Wallet not found or keys unavailable'}) + + +@app.route('/api/admin/credit', methods=['POST']) +def admin_credit_user(): + """ + Admin endpoint to manually credit a user's wallet. + For POC testing until deposit detection is implemented. + + SECURITY: In POC mode, users can only credit themselves. + In production, add proper admin role checking. + """ + user_id = _get_current_user_id() + if not user_id: + return jsonify({'success': False, 'error': 'Not logged in'}), 401 + + data = request.get_json() or {} + target_user_id = data.get('user_id') + amount = data.get('amount_satoshis', 10000) # Default 10k sats + reason = data.get('reason', 'POC test credit') + + # SECURITY: In POC mode, only allow crediting yourself + # Remove this restriction when proper admin auth is implemented + if target_user_id is not None: + try: + if int(target_user_id) != user_id: + return jsonify({'success': False, 'error': 'Can only credit your own account in POC mode'}), 403 + except (TypeError, ValueError): + return jsonify({'success': False, 'error': 'Invalid user_id'}), 400 + + target_user_id = user_id # Force to self + + try: + amount = int(amount) + except (TypeError, ValueError): + return jsonify({'success': False, 'error': 'Invalid amount'}), 400 + + # SECURITY: Prevent negative credits + if amount <= 0: + return jsonify({'success': False, 'error': 'Amount must be positive'}), 400 + + # SECURITY: Limit POC credits to prevent abuse + MAX_POC_CREDIT = 100000 # 0.001 BTC max per credit + if amount > MAX_POC_CREDIT: + return jsonify({'success': False, 'error': f'POC credit limited to {MAX_POC_CREDIT} satoshis'}), 400 + + result = brighter_trades.wallet_manager.admin_credit( + user_id=target_user_id, + amount_satoshis=amount, + reason=reason + ) + return jsonify(result) + + +# ============================================================================= +# Health Check Routes +# ============================================================================= + @app.route('/health/edm', methods=['GET']) def edm_health(): """ @@ -787,4 +1059,7 @@ if __name__ == '__main__': start_strategy_loop() logging.info("Strategy execution loop started in background") + # Start wallet background jobs (auto-sweep, deposit detection, withdrawal processing) + start_wallet_background_jobs() + socketio.run(app, host='127.0.0.1', port=5002, debug=False, use_reloader=False) diff --git a/src/brokers/live_broker.py b/src/brokers/live_broker.py index 50a4a13..74dd306 100644 --- a/src/brokers/live_broker.py +++ b/src/brokers/live_broker.py @@ -955,6 +955,16 @@ class LiveBroker(BaseBroker): # Emit fill event if order was filled if order.status == OrderStatus.FILLED and old_status != OrderStatus.FILLED: + # Calculate profitability for sell orders (before position update) + is_profitable = False + realized_pnl = 0.0 + entry_price = 0.0 + if order.side == OrderSide.SELL and order.symbol in self._positions: + pos = self._positions[order.symbol] + entry_price = pos.entry_price + realized_pnl = (order.filled_price - pos.entry_price) * order.filled_qty - order.commission + is_profitable = realized_pnl > 0 + events.append({ 'type': 'fill', 'order_id': order_id, @@ -965,7 +975,10 @@ class LiveBroker(BaseBroker): 'filled_qty': order.filled_qty, 'price': order.filled_price, 'filled_price': order.filled_price, - 'commission': order.commission + 'commission': order.commission, + 'is_profitable': is_profitable, + 'realized_pnl': realized_pnl, + 'entry_price': entry_price }) logger.info(f"Order filled: {order_id} - {order.side.value} {order.filled_qty} {order.symbol} @ {order.filled_price}") diff --git a/src/brokers/paper_broker.py b/src/brokers/paper_broker.py index f0c45c3..165cd05 100644 --- a/src/brokers/paper_broker.py +++ b/src/brokers/paper_broker.py @@ -230,6 +230,11 @@ class PaperBroker(BaseBroker): order.status = OrderStatus.FILLED order.filled_at = datetime.now(timezone.utc) + # Calculate profitability for sell orders (for fee tracking) + order.is_profitable = False + order.realized_pnl = 0.0 + order.entry_price = 0.0 + # Update balances and positions order_value = order.size * fill_price @@ -261,10 +266,15 @@ class PaperBroker(BaseBroker): # Update position if order.symbol in self._positions: position = self._positions[order.symbol] + order.entry_price = position.entry_price # Store for fee calculation realized_pnl = (fill_price - position.entry_price) * order.size - order.commission position.realized_pnl += realized_pnl position.size -= order.size + # Track profitability for fee calculation + order.realized_pnl = realized_pnl + order.is_profitable = realized_pnl > 0 + # Remove position if fully closed if position.size <= 0: del self._positions[order.symbol] @@ -405,7 +415,10 @@ class PaperBroker(BaseBroker): 'filled_qty': order.filled_qty, 'price': order.filled_price, 'filled_price': order.filled_price, - 'commission': order.commission + 'commission': order.commission, + 'is_profitable': getattr(order, 'is_profitable', False), + 'realized_pnl': getattr(order, 'realized_pnl', 0.0), + 'entry_price': getattr(order, 'entry_price', 0.0) }) logger.info(f"PaperBroker: Limit order filled: {order.side.value} {order.size} {order.symbol} @ {fill_price:.4f}") diff --git a/src/live_strategy_instance.py b/src/live_strategy_instance.py index f3b1100..aedf4a1 100644 --- a/src/live_strategy_instance.py +++ b/src/live_strategy_instance.py @@ -375,6 +375,21 @@ class LiveStrategyInstance(StrategyInstance): }) logger.info(f"Order filled: {event}") + # Accumulate strategy fees for paid strategies + # Fee is charged on profitable sell orders (closing positions) + if event.get('side') == 'sell': + filled_qty = event.get('filled_qty', 0) + filled_price = event.get('filled_price', 0) + trade_value = filled_qty * filled_price + commission_rate = self.live_broker._commission + # Use actual profitability from broker (based on realized PnL) + is_profitable = event.get('is_profitable', False) + self.accumulate_trade_fee( + trade_value_usd=trade_value, + commission_rate=commission_rate, + is_profitable=is_profitable + ) + # Update balance attributes self._update_balances() diff --git a/src/paper_strategy_instance.py b/src/paper_strategy_instance.py index 48b7a28..ea70741 100644 --- a/src/paper_strategy_instance.py +++ b/src/paper_strategy_instance.py @@ -240,6 +240,19 @@ class PaperStrategyInstance(StrategyInstance): 'filled_price': event.get('filled_price', event.get('price')), }) + # Accumulate strategy fees for paid strategies + # Fee is charged on profitable sells (closing a position) + if event.get('side') == 'sell': + trade_value = event.get('filled_qty', 0) * event.get('filled_price', 0) + commission_rate = self.paper_broker._commission + # Use actual profitability from broker (based on realized PnL) + is_profitable = event.get('is_profitable', False) + self.accumulate_trade_fee( + trade_value_usd=trade_value, + commission_rate=commission_rate, + is_profitable=is_profitable + ) + # Update exec context with current data self.exec_context['current_candle'] = candle_data self.exec_context['current_price'] = price diff --git a/src/static/Account.js b/src/static/Account.js new file mode 100644 index 0000000..2cb281a --- /dev/null +++ b/src/static/Account.js @@ -0,0 +1,695 @@ +/** + * Account - Manages user wallet and credits with two-address custody model + * + * Two Address Model: + * - Fee Address: We store private key, used for strategy fees (max ~$50) + * - Sweep Address: We don't store private key, user controls, receives excess + */ +class Account { + constructor() { + this.walletInfo = null; + this.creditsBalance = 0; + this.transactions = []; + this.userProfile = null; + this.initialized = false; + } + + /** + * Initialize the account panel + */ + async initialize() { + if (this.initialized) { + return; + } + this.setupEventListeners(); + this.initialized = true; + } + + /** + * Show the account settings dialog + */ + async showAccountSettings() { + // Check if user is logged in (not a guest) + const username = document.getElementById('username_display')?.textContent || ''; + if (username.startsWith('guest') || !username) { + alert('Please sign in to access account settings.'); + return; + } + + // Show dialog + document.getElementById('account_settings_form').style.display = 'block'; + + // Load profile and wallet data + await Promise.all([ + this.loadProfile(), + this.refresh() + ]); + + // Default to profile tab + this.switchTab('profile'); + } + + /** + * Close the account settings dialog + */ + closeAccountSettings() { + document.getElementById('account_settings_form').style.display = 'none'; + } + + /** + * Switch between tabs in the account settings + */ + switchTab(tabName) { + // Update tab buttons + document.querySelectorAll('.account-tab').forEach(tab => { + tab.classList.toggle('active', tab.dataset.tab === tabName); + }); + + // Update tab panels + document.querySelectorAll('.tab-panel').forEach(panel => { + panel.classList.toggle('active', panel.id === `${tabName}_tab`); + }); + } + + /** + * Load user profile data + */ + async loadProfile() { + try { + const response = await fetch('/api/user/profile'); + const data = await response.json(); + + if (data.success) { + this.userProfile = data.profile; + const emailEl = document.getElementById('current_email'); + if (emailEl) { + emailEl.textContent = data.profile.email || 'Not set'; + } + } + } catch (error) { + console.error('Error loading profile:', error); + } + } + + /** + * Update user email + */ + async updateEmail() { + const newEmail = document.getElementById('new_email').value.trim(); + if (!newEmail) { + alert('Please enter a new email address.'); + return; + } + + // Basic email validation + if (!newEmail.includes('@') || !newEmail.includes('.')) { + alert('Please enter a valid email address.'); + return; + } + + try { + const response = await fetch('/api/user/email', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ email: newEmail }) + }); + + const data = await response.json(); + + if (data.success) { + alert('Email updated successfully.'); + document.getElementById('current_email').textContent = newEmail; + document.getElementById('new_email').value = ''; + } else { + alert('Failed to update email: ' + (data.error || 'Unknown error')); + } + } catch (error) { + console.error('Error updating email:', error); + alert('Failed to update email: ' + error.message); + } + } + + /** + * Update user password + */ + async updatePassword() { + const currentPassword = document.getElementById('current_password').value; + const newPassword = document.getElementById('new_password').value; + const confirmPassword = document.getElementById('confirm_password').value; + + if (!currentPassword || !newPassword || !confirmPassword) { + alert('Please fill in all password fields.'); + return; + } + + if (newPassword !== confirmPassword) { + alert('New passwords do not match.'); + return; + } + + if (newPassword.length < 6) { + alert('Password must be at least 6 characters.'); + return; + } + + try { + const response = await fetch('/api/user/password', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + current_password: currentPassword, + new_password: newPassword + }) + }); + + const data = await response.json(); + + if (data.success) { + alert('Password updated successfully.'); + document.getElementById('current_password').value = ''; + document.getElementById('new_password').value = ''; + document.getElementById('confirm_password').value = ''; + } else { + alert('Failed to update password: ' + (data.error || 'Unknown error')); + } + } catch (error) { + console.error('Error updating password:', error); + alert('Failed to update password: ' + error.message); + } + } + + /** + * Setup checkbox event listeners for modals + */ + setupEventListeners() { + // Terms checkbox enables create button + const termsCheckbox = document.getElementById('terms_checkbox'); + const createBtn = document.getElementById('create_wallet_btn'); + if (termsCheckbox && createBtn) { + termsCheckbox.addEventListener('change', () => { + createBtn.disabled = !termsCheckbox.checked; + }); + } + + // Keys saved checkbox enables close button + const keysSavedCheckbox = document.getElementById('keys_saved_checkbox'); + const closeKeysBtn = document.getElementById('close_keys_btn'); + if (keysSavedCheckbox && closeKeysBtn) { + keysSavedCheckbox.addEventListener('change', () => { + closeKeysBtn.disabled = !keysSavedCheckbox.checked; + }); + } + } + + /** + * Refresh wallet info, credits balance, and transactions + */ + async refresh() { + try { + // Load wallet and credits in parallel + const [walletRes, creditsRes, txRes] = await Promise.all([ + fetch('/api/wallet'), + fetch('/api/wallet/credits'), + fetch('/api/wallet/transactions?limit=20') + ]); + + const walletData = await walletRes.json(); + const creditsData = await creditsRes.json(); + const txData = await txRes.json(); + + if (walletData.success && walletData.wallet) { + this.walletInfo = walletData.wallet; + this.displayWallet(); + } else { + this.walletInfo = null; + this.displayNoWallet(); + } + + if (creditsData.success) { + this.creditsBalance = creditsData.balance || 0; + this.updateCreditsDisplay(); + } + + if (txData.success) { + this.transactions = txData.transactions || []; + this.displayTransactions(); + } + } catch (error) { + console.error('Error refreshing account:', error); + } + } + + /** + * Display wallet info + */ + displayWallet() { + const noWallet = document.getElementById('no_wallet'); + const walletInfo = document.getElementById('wallet_info'); + + if (noWallet) noWallet.style.display = 'none'; + if (walletInfo) walletInfo.style.display = 'block'; + + const addressEl = document.getElementById('btc_address'); + const sweepAddressEl = document.getElementById('sweep_address'); + const networkEl = document.getElementById('wallet_network'); + const warningEl = document.getElementById('wallet_disabled_warning'); + + if (addressEl) { + addressEl.textContent = this.walletInfo.fee_address || this.walletInfo.address; + } + if (sweepAddressEl) { + sweepAddressEl.textContent = this.walletInfo.sweep_address || 'Not configured'; + } + if (networkEl) { + networkEl.textContent = this.walletInfo.network === 'testnet' ? '(Testnet)' : '(Mainnet)'; + } + if (warningEl) { + warningEl.style.display = this.walletInfo.is_disabled ? 'block' : 'none'; + } + } + + /** + * Display no wallet state + */ + displayNoWallet() { + const noWallet = document.getElementById('no_wallet'); + const walletInfo = document.getElementById('wallet_info'); + + if (noWallet) noWallet.style.display = 'block'; + if (walletInfo) walletInfo.style.display = 'none'; + } + + /** + * Update credits balance display + */ + updateCreditsDisplay() { + const balanceEl = document.getElementById('credits_balance'); + if (balanceEl) { + balanceEl.textContent = this.formatSatoshis(this.creditsBalance); + } + } + + /** + * Format satoshis as BTC with appropriate precision + */ + formatSatoshis(satoshis) { + const btc = satoshis / 100000000; + if (btc === 0) return '0'; + if (btc < 0.0001) return btc.toFixed(8); + if (btc < 1) return btc.toFixed(6); + return btc.toFixed(4); + } + + /** + * Display transaction history + */ + displayTransactions() { + const tbody = document.getElementById('transactions_body'); + if (!tbody) return; + + if (!this.transactions || this.transactions.length === 0) { + tbody.innerHTML = 'No transactions yet'; + return; + } + + tbody.innerHTML = this.transactions.map(tx => { + const date = this.escapeHtml(new Date(tx.date).toLocaleString()); + const amountClass = tx.amount >= 0 ? 'tx-positive' : 'tx-negative'; + const amountStr = (tx.amount >= 0 ? '+' : '') + this.formatSatoshis(tx.amount); + const typeDisplay = this.formatTxType(tx.type); + const refFull = this.escapeHtml(tx.reference || ''); + const refShort = tx.reference ? this.escapeHtml(tx.reference.substring(0, 16)) + '...' : '-'; + + return ` + + ${date} + ${typeDisplay} + ${amountStr} + ${refShort} + + `; + }).join(''); + } + + /** + * Format transaction type for display + */ + formatTxType(type) { + const typeMap = { + 'deposit': 'Deposit', + 'withdrawal': 'Withdrawal', + 'fee_paid': 'Fee Paid', + 'fee_received': 'Fee Received', + 'admin_credit': 'Credit', + 'withdrawal_reversal': 'Reversal', + 'auto_sweep': 'Auto-Sweep' + }; + return typeMap[type] || this.escapeHtml(type); + } + + /** + * Escape HTML to prevent XSS attacks + */ + escapeHtml(text) { + if (text === null || text === undefined) return ''; + const div = document.createElement('div'); + div.textContent = String(text); + return div.innerHTML; + } + + /** + * Show wallet setup dialog with terms + */ + showSetupDialog() { + // Reset checkbox + const termsCheckbox = document.getElementById('terms_checkbox'); + const createBtn = document.getElementById('create_wallet_btn'); + if (termsCheckbox) termsCheckbox.checked = false; + if (createBtn) createBtn.disabled = true; + + // Reset sweep option to "generate" + const generateRadio = document.querySelector('input[name="sweep_option"][value="generate"]'); + if (generateRadio) generateRadio.checked = true; + this.toggleSweepOption(); + + // Clear any previous user input + const userAddressInput = document.getElementById('user_sweep_address'); + if (userAddressInput) userAddressInput.value = ''; + + document.getElementById('wallet_setup_modal').style.display = 'block'; + } + + /** + * Toggle sweep address input visibility based on radio selection + */ + toggleSweepOption() { + const ownRadio = document.querySelector('input[name="sweep_option"][value="own"]'); + const ownAddressSection = document.getElementById('own_address_input'); + + if (ownRadio && ownAddressSection) { + ownAddressSection.style.display = ownRadio.checked ? 'block' : 'none'; + } + } + + /** + * Close wallet setup dialog + */ + closeSetupDialog() { + document.getElementById('wallet_setup_modal').style.display = 'none'; + } + + /** + * Create a new wallet (two-address model) + */ + async createWallet() { + const termsCheckbox = document.getElementById('terms_checkbox'); + if (!termsCheckbox || !termsCheckbox.checked) { + alert('Please agree to the terms before creating a wallet.'); + return; + } + + // Check if user is providing their own sweep address + const ownRadio = document.querySelector('input[name="sweep_option"][value="own"]'); + const userSweepAddress = document.getElementById('user_sweep_address'); + let sweepAddress = null; + + if (ownRadio && ownRadio.checked) { + sweepAddress = userSweepAddress ? userSweepAddress.value.trim() : ''; + if (!sweepAddress) { + alert('Please enter your Bitcoin address for the sweep address.'); + return; + } + } + + const btn = document.getElementById('create_wallet_btn'); + btn.disabled = true; + btn.textContent = 'Generating...'; + + try { + const requestBody = {}; + if (sweepAddress) { + requestBody.user_sweep_address = sweepAddress; + } + + const response = await fetch('/api/wallet/create', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(requestBody) + }); + const data = await response.json(); + + if (data.success) { + // Close setup dialog + this.closeSetupDialog(); + + // Show keys modal - requires acknowledgment before closing + document.getElementById('modal_network').textContent = data.network; + document.getElementById('modal_fee_address').textContent = data.fee_address; + document.getElementById('modal_fee_private_key').textContent = data.fee_private_key; + document.getElementById('modal_sweep_address').textContent = data.sweep_address; + + // Handle sweep private key - only shown if we generated it + const sweepPrivateKeyEl = document.getElementById('modal_sweep_private_key'); + const sweepPrivateKeyRow = sweepPrivateKeyEl ? sweepPrivateKeyEl.closest('.key-row') : null; + + if (data.sweep_private_key) { + // We generated the sweep address - show private key + sweepPrivateKeyEl.textContent = data.sweep_private_key; + if (sweepPrivateKeyRow) sweepPrivateKeyRow.style.display = 'flex'; + } else { + // User provided their own address - no private key to show + if (sweepPrivateKeyRow) sweepPrivateKeyRow.style.display = 'none'; + } + + // Update the sweep section header based on whether we generated it + const sweepSection = document.querySelector('.key-section.sweep-section h5'); + if (sweepSection) { + if (data.sweep_private_key) { + sweepSection.textContent = 'Sweep Address (You Control - SAVE THIS KEY!)'; + } else { + sweepSection.textContent = 'Sweep Address (Your Provided Address)'; + } + } + + // Reset checkbox and button + const keysSavedCheckbox = document.getElementById('keys_saved_checkbox'); + const closeKeysBtn = document.getElementById('close_keys_btn'); + if (keysSavedCheckbox) keysSavedCheckbox.checked = false; + if (closeKeysBtn) closeKeysBtn.disabled = true; + + document.getElementById('wallet_keys_modal').style.display = 'block'; + + // Update wallet info + this.walletInfo = { + fee_address: data.fee_address, + sweep_address: data.sweep_address, + network: data.network, + is_disabled: false + }; + this.creditsBalance = 0; + this.displayWallet(); + this.updateCreditsDisplay(); + } else { + alert('Failed to create wallet: ' + (data.error || 'Unknown error')); + } + } catch (error) { + console.error('Error creating wallet:', error); + alert('Failed to create wallet: ' + error.message); + } finally { + btn.disabled = false; + btn.textContent = 'Create Wallet'; + } + } + + /** + * Close the keys modal (only if acknowledged) + */ + closeKeysModal() { + const keysSavedCheckbox = document.getElementById('keys_saved_checkbox'); + if (!keysSavedCheckbox || !keysSavedCheckbox.checked) { + alert('Please confirm that you have saved your keys before closing.'); + return; + } + document.getElementById('wallet_keys_modal').style.display = 'none'; + } + + /** + * Show view keys dialog (for fee address only - sweep key not stored) + */ + async showViewKeysDialog() { + try { + const response = await fetch('/api/wallet/keys'); + const data = await response.json(); + + if (data.success) { + document.getElementById('view_network').textContent = data.network; + document.getElementById('view_fee_address').textContent = data.fee_address; + document.getElementById('view_fee_private_key').textContent = data.fee_private_key; + document.getElementById('view_keys_modal').style.display = 'block'; + } else { + alert('Failed to retrieve keys: ' + (data.error || 'Unknown error')); + } + } catch (error) { + console.error('Error retrieving keys:', error); + alert('Failed to retrieve keys: ' + error.message); + } + } + + /** + * Close view keys dialog + */ + closeViewKeysDialog() { + document.getElementById('view_keys_modal').style.display = 'none'; + } + + /** + * Copy wallet address to clipboard + */ + copyAddress(type) { + if (!this.walletInfo) return; + + const address = type === 'sweep' + ? (this.walletInfo.sweep_address || '') + : (this.walletInfo.fee_address || this.walletInfo.address); + + if (!address) return; + + navigator.clipboard.writeText(address).then(() => { + // Brief visual feedback would be nice here + }).catch(err => { + console.error('Failed to copy:', err); + }); + } + + /** + * Copy a key value to clipboard + */ + copyKey(elementId) { + const el = document.getElementById(elementId); + if (!el) return; + + navigator.clipboard.writeText(el.textContent).then(() => { + // Could add visual feedback + }).catch(err => { + console.error('Failed to copy:', err); + }); + } + + /** + * Show withdrawal dialog + */ + showWithdrawDialog() { + document.getElementById('withdraw_amount').value = ''; + document.getElementById('withdraw_address').value = ''; + document.getElementById('withdraw_modal').style.display = 'block'; + } + + /** + * Close withdrawal dialog + */ + closeWithdrawDialog() { + document.getElementById('withdraw_modal').style.display = 'none'; + } + + /** + * Submit withdrawal request + */ + async submitWithdrawal() { + const amountStr = document.getElementById('withdraw_amount').value.trim(); + const address = document.getElementById('withdraw_address').value.trim(); + + if (!amountStr || !address) { + alert('Please enter amount and destination address'); + return; + } + + const btcAmount = parseFloat(amountStr); + if (isNaN(btcAmount) || btcAmount <= 0) { + alert('Invalid amount'); + return; + } + + const satoshis = Math.floor(btcAmount * 100000000); + + try { + const response = await fetch('/api/wallet/withdraw', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + amount_satoshis: satoshis, + destination_address: address + }) + }); + + const data = await response.json(); + + if (data.success) { + alert('Withdrawal request submitted. It will be processed shortly.'); + this.closeWithdrawDialog(); + this.refresh(); + } else { + alert('Withdrawal failed: ' + (data.error || 'Unknown error')); + } + } catch (error) { + console.error('Error requesting withdrawal:', error); + alert('Withdrawal failed: ' + error.message); + } + } + + /** + * Show test credit dialog + */ + showCreditDialog() { + document.getElementById('credit_amount').value = '10000'; + document.getElementById('credit_modal').style.display = 'block'; + } + + /** + * Close test credit dialog + */ + closeCreditDialog() { + document.getElementById('credit_modal').style.display = 'none'; + } + + /** + * Submit test credit + */ + async submitCredit() { + const amountStr = document.getElementById('credit_amount').value.trim(); + const amount = parseInt(amountStr); + + if (isNaN(amount) || amount <= 0) { + alert('Invalid amount'); + return; + } + + try { + const response = await fetch('/api/admin/credit', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + amount_satoshis: amount, + reason: 'POC test credit from UI' + }) + }); + + const data = await response.json(); + + if (data.success) { + alert(`Added ${amount} satoshis to your credits`); + this.closeCreditDialog(); + this.refresh(); + } else { + alert('Credit failed: ' + (data.error || 'Unknown error')); + } + } catch (error) { + console.error('Error adding credit:', error); + alert('Credit failed: ' + error.message); + } + } +} + +// Export for use in UI namespace +if (typeof window !== 'undefined') { + window.Account = Account; +} diff --git a/src/static/Strategies.js b/src/static/Strategies.js index f5ce6a3..5350030 100644 --- a/src/static/Strategies.js +++ b/src/static/Strategies.js @@ -1681,6 +1681,12 @@ class Strategies { return; } + // Validate fee is 1-100 if public + if (strategyData.public === 1 && (strategyData.fee < 1 || strategyData.fee > 100)) { + alert("Fee must be between 1 and 100 (percentage of exchange commission)."); + return; + } + if (!strategyData.name) { alert("Please provide a name for the strategy."); return; diff --git a/src/static/brighterStyles.css b/src/static/brighterStyles.css index 39a939f..ca9e43a 100644 --- a/src/static/brighterStyles.css +++ b/src/static/brighterStyles.css @@ -390,7 +390,7 @@ height: 500px; font-size: 15px; } -.active, .collapsible:hover { +.collapsible.active, .collapsible:hover { background-color: #0A07DF; } diff --git a/src/static/general.js b/src/static/general.js index 8849de8..344fecf 100644 --- a/src/static/general.js +++ b/src/static/general.js @@ -12,6 +12,7 @@ class User_Interface { this.signals = new Signals(this); this.backtesting = new Backtesting(this); this.statistics = new Statistics(this.data.comms); + this.account = new Account(); // Register a callback function for when indicator updates are received from the data object this.data.registerCallback_i_updates(this.indicators.update); @@ -35,6 +36,14 @@ class User_Interface { this.initializeResizablePopup("new_trade_form", null, "trade_draggable_header", "resize-trade"); this.initializeResizablePopup("ai_strategy_form", null, "ai_strategy_header", "resize-ai-strategy"); + // Account settings popups + this.initializeResizablePopup("account_settings_form", null, "account_settings_header", "resize-account"); + this.initializeResizablePopup("wallet_setup_modal", null, "wallet_setup_header", "resize-wallet-setup"); + this.initializeResizablePopup("wallet_keys_modal", null, "wallet_keys_header", "resize-wallet-keys"); + this.initializeResizablePopup("view_keys_modal", null, "view_keys_header", "resize-view-keys"); + this.initializeResizablePopup("withdraw_modal", null, "withdraw_header", "resize-withdraw"); + this.initializeResizablePopup("credit_modal", null, "credit_header", "resize-credit"); + // Initialize Backtesting's DOM elements this.backtesting.initialize(); } catch (error) { @@ -69,6 +78,7 @@ class User_Interface { this.exchanges.initialize(); this.strats.initialize('strats_display', 'new_strat_form', this.data); this.backtesting.fetchSavedTests(); + this.account.initialize(); } /** diff --git a/src/templates/account_settings_dialog.html b/src/templates/account_settings_dialog.html new file mode 100644 index 0000000..40d693f --- /dev/null +++ b/src/templates/account_settings_dialog.html @@ -0,0 +1,520 @@ + + + + + + + + + + + + + + + + + + + diff --git a/src/templates/index.html b/src/templates/index.html index 8d1e0e5..7efdc7b 100644 --- a/src/templates/index.html +++ b/src/templates/index.html @@ -32,6 +32,7 @@ + @@ -46,6 +47,7 @@ {% include "trade_details_popup.html" %} {% include "indicator_popup.html" %} {% include "exchange_config_popup.html" %} + {% include "account_settings_dialog.html" %}
- {{user_name}} + {{user_name}} diff --git a/src/templates/new_strategy_popup.html b/src/templates/new_strategy_popup.html index 3abc544..190ec16 100644 --- a/src/templates/new_strategy_popup.html +++ b/src/templates/new_strategy_popup.html @@ -35,9 +35,14 @@
-
- - +
+ + + % + ⓘ
@@ -154,6 +159,31 @@ z-index: 10; border-radius: 3px; } + +/* Fee info icon tooltip */ +.fee-info-icon { + display: inline-block; + width: 16px; + height: 16px; + background: #3E3AF2; + color: white; + border-radius: 50%; + text-align: center; + font-size: 11px; + line-height: 16px; + cursor: help; + margin-left: 5px; + vertical-align: middle; +} + +.fee-info-icon:hover { + background: #0A07DF; +} + +/* Native tooltip styling via title attribute - multiline support */ +.fee-info-icon[title] { + position: relative; +} diff --git a/src/wallet/__init__.py b/src/wallet/__init__.py new file mode 100644 index 0000000..45df518 --- /dev/null +++ b/src/wallet/__init__.py @@ -0,0 +1,8 @@ +"""Bitcoin wallet and credits ledger module for strategy fees.""" + +from .wallet_manager import WalletManager +from .bitcoin_service import BitcoinService +from .encryption import KeyEncryption +from .background_jobs import WalletBackgroundJobs + +__all__ = ['WalletManager', 'BitcoinService', 'KeyEncryption', 'WalletBackgroundJobs'] diff --git a/src/wallet/background_jobs.py b/src/wallet/background_jobs.py new file mode 100644 index 0000000..020787d --- /dev/null +++ b/src/wallet/background_jobs.py @@ -0,0 +1,194 @@ +""" +Background jobs for wallet operations. + +- Auto-sweep: Transfer excess funds (over $50 cap) to user's sweep address +- Deposit detection: Monitor for incoming deposits and credit user ledgers +- Withdrawal processing: Process pending withdrawal requests +""" +import logging +import time +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .wallet_manager import WalletManager + +logger = logging.getLogger(__name__) + +# Job intervals in seconds +AUTO_SWEEP_INTERVAL = 300 # 5 minutes +DEPOSIT_CHECK_INTERVAL = 60 # 1 minute +WITHDRAWAL_PROCESS_INTERVAL = 30 # 30 seconds + + +class WalletBackgroundJobs: + """ + Manages background jobs for wallet operations. + Uses eventlet for async execution compatible with Flask-SocketIO. + """ + + def __init__(self, wallet_manager: 'WalletManager', socketio=None): + self.wallet_manager = wallet_manager + self.socketio = socketio + self._running = False + self._jobs_started = False + + def start_all_jobs(self): + """Start all background jobs.""" + if self._jobs_started: + logger.warning("Wallet background jobs already started") + return + + self._running = True + self._jobs_started = True + + if self.socketio: + # Use SocketIO's background task mechanism + self.socketio.start_background_task(self._auto_sweep_loop) + self.socketio.start_background_task(self._deposit_detection_loop) + self.socketio.start_background_task(self._withdrawal_processing_loop) + logger.info("Wallet background jobs started via SocketIO") + else: + # Fallback to eventlet directly + import eventlet + eventlet.spawn(self._auto_sweep_loop) + eventlet.spawn(self._deposit_detection_loop) + eventlet.spawn(self._withdrawal_processing_loop) + logger.info("Wallet background jobs started via eventlet") + + def stop_all_jobs(self): + """Stop all background jobs.""" + self._running = False + logger.info("Wallet background jobs stopping...") + + # === Auto-Sweep Job === + + def _auto_sweep_loop(self): + """Background loop for auto-sweep functionality.""" + logger.info("Auto-sweep job started") + while self._running: + try: + self._process_auto_sweeps() + except Exception as e: + logger.error(f"Auto-sweep error: {e}") + self._sleep(AUTO_SWEEP_INTERVAL) + + def _process_auto_sweeps(self): + """Check all wallets and sweep excess funds.""" + try: + # Get all wallets that might need sweeping + wallets = self.wallet_manager.get_wallets_over_cap() + + for wallet in wallets: + user_id = wallet['user_id'] + balance = wallet['balance'] + sweep_address = wallet['sweep_address'] + cap = self.wallet_manager.BALANCE_CAP_SATOSHIS + + if not sweep_address: + logger.warning(f"User {user_id} has no sweep address configured") + continue + + excess = balance - cap + if excess <= 0: + continue + + # Keep a small buffer (1000 sats) to avoid sweeping tiny amounts + if excess < 1000: + continue + + logger.info(f"Auto-sweeping {excess} satoshis for user {user_id}") + + # Perform the sweep + result = self.wallet_manager.auto_sweep(user_id, excess) + if result.get('success'): + logger.info(f"Auto-sweep successful for user {user_id}: {result.get('tx_hash', 'simulated')}") + else: + logger.error(f"Auto-sweep failed for user {user_id}: {result.get('error')}") + + except Exception as e: + logger.error(f"Error processing auto-sweeps: {e}") + + # === Deposit Detection Job === + + def _deposit_detection_loop(self): + """Background loop for deposit detection.""" + logger.info("Deposit detection job started") + while self._running: + try: + self._check_for_deposits() + except Exception as e: + logger.error(f"Deposit detection error: {e}") + self._sleep(DEPOSIT_CHECK_INTERVAL) + + def _check_for_deposits(self): + """Check all wallet addresses for new deposits.""" + try: + # Include disabled wallets so incoming deposits are still credited. + wallets = self.wallet_manager.get_wallets_for_deposit_monitoring() + + for wallet in wallets: + user_id = wallet['user_id'] + fee_address = wallet['fee_address'] + network = wallet['network'] + + # Check for new deposits at the fee address + new_deposits = self.wallet_manager.check_address_for_deposits( + user_id=user_id, + address=fee_address, + network=network + ) + + for deposit in new_deposits: + logger.info(f"New deposit detected for user {user_id}: " + f"{deposit['amount_satoshis']} sats, tx: {deposit['tx_hash']}") + + except Exception as e: + logger.error(f"Error checking for deposits: {e}") + + # === Withdrawal Processing Job === + + def _withdrawal_processing_loop(self): + """Background loop for processing pending withdrawals.""" + logger.info("Withdrawal processing job started") + while self._running: + try: + self._process_pending_withdrawals() + except Exception as e: + logger.error(f"Withdrawal processing error: {e}") + self._sleep(WITHDRAWAL_PROCESS_INTERVAL) + + def _process_pending_withdrawals(self): + """Process all pending withdrawal requests.""" + try: + # Get pending withdrawals (status = 'reserved') + pending = self.wallet_manager.get_pending_withdrawals() + + for withdrawal in pending: + withdrawal_id = withdrawal['id'] + user_id = withdrawal['user_id'] + amount = withdrawal['amount_satoshis'] + destination = withdrawal['destination_address'] + + logger.info(f"Processing withdrawal {withdrawal_id} for user {user_id}: " + f"{amount} sats to {destination}") + + # Process the withdrawal + result = self.wallet_manager.process_withdrawal(withdrawal_id) + + if result.get('success'): + logger.info(f"Withdrawal {withdrawal_id} completed: {result.get('tx_hash', 'simulated')}") + else: + logger.error(f"Withdrawal {withdrawal_id} failed: {result.get('error')}") + + except Exception as e: + logger.error(f"Error processing withdrawals: {e}") + + def _sleep(self, seconds): + """Sleep that respects the running flag and works with eventlet.""" + import eventlet + # Sleep in small increments to allow quick shutdown + remaining = seconds + while remaining > 0 and self._running: + sleep_time = min(remaining, 5) + eventlet.sleep(sleep_time) + remaining -= sleep_time diff --git a/src/wallet/bitcoin_service.py b/src/wallet/bitcoin_service.py new file mode 100644 index 0000000..2977e5c --- /dev/null +++ b/src/wallet/bitcoin_service.py @@ -0,0 +1,160 @@ +"""Bitcoin network operations using bit library.""" +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class BitcoinService: + """ + Service for Bitcoin network operations. + Handles address generation, balance checking, and transactions. + """ + + def __init__(self, testnet: bool = True): + """ + Initialize Bitcoin service. + + Args: + testnet: If True, use testnet; otherwise use mainnet. + """ + self.testnet = testnet + + def _get_key_class(self): + """Get the appropriate key class based on network.""" + from bit import Key, PrivateKeyTestnet + return PrivateKeyTestnet if self.testnet else Key + + def generate_keypair(self) -> dict: + """ + Generate new Bitcoin keypair. + + Returns: + Dict with 'address', 'public_key', and 'private_key' (WIF format). + """ + key_class = self._get_key_class() + key = key_class() + return { + 'address': key.address, + 'public_key': key.public_key.hex(), + 'private_key': key.to_wif() # Wallet Import Format + } + + def get_balance(self, address: str) -> int: + """ + Get balance in satoshis from blockchain. + + Args: + address: Bitcoin address to check. + + Returns: + Balance in satoshis, or 0 on error. + """ + try: + from bit.network import NetworkAPI + if self.testnet: + balance = NetworkAPI.get_balance_testnet(address) + else: + balance = NetworkAPI.get_balance(address) + return balance + except Exception as e: + logger.error(f"Failed to fetch balance for {address}: {e}") + return 0 + + def get_unspent(self, address: str) -> list: + """ + Get unspent transaction outputs for address. + + Args: + address: Bitcoin address to check. + + Returns: + List of UTXOs, or empty list on error. + """ + try: + from bit.network import NetworkAPI + if self.testnet: + return NetworkAPI.get_unspent_testnet(address) + return NetworkAPI.get_unspent(address) + except Exception as e: + logger.error(f"Failed to fetch UTXOs for {address}: {e}") + return [] + + def send_transaction(self, private_key_wif: str, to_address: str, + amount_satoshis: int) -> Optional[str]: + """ + Send BTC transaction. + + Args: + private_key_wif: Sender's private key in WIF format. + to_address: Recipient's Bitcoin address. + amount_satoshis: Amount to send in satoshis. + + Returns: + Transaction hash on success, None on failure. + """ + try: + key_class = self._get_key_class() + key = key_class(private_key_wif) + tx_hash = key.send([(to_address, amount_satoshis, 'satoshi')]) + logger.info(f"Sent {amount_satoshis} satoshis to {to_address}: {tx_hash}") + return tx_hash + except Exception as e: + logger.error(f"Failed to send transaction: {e}") + return None + + def validate_address(self, address: str) -> bool: + """ + Validate Bitcoin address format and checksum. + + Args: + address: Address string to validate. + + Returns: + True if valid for this network, False otherwise. + """ + if not address or not isinstance(address, str): + return False + + try: + address = address.strip() + + if self.testnet: + # Testnet addresses: m/n (legacy), 2 (P2SH), tb1 (bech32) + if address.startswith(('m', 'n', '2')): + # Base58 testnet address + return 26 <= len(address) <= 35 + elif address.startswith('tb1'): + # Bech32 testnet address + return 42 <= len(address) <= 62 + return False + else: + # Mainnet addresses: 1 (legacy), 3 (P2SH), bc1 (bech32) + if address.startswith(('1', '3')): + # Base58 mainnet address + return 26 <= len(address) <= 35 + elif address.startswith('bc1'): + # Bech32 mainnet address + return 42 <= len(address) <= 62 + return False + + except Exception: + return False + + def estimate_fee(self, num_inputs: int = 1, num_outputs: int = 2) -> int: + """ + Estimate transaction fee in satoshis. + + Args: + num_inputs: Number of transaction inputs. + num_outputs: Number of transaction outputs. + + Returns: + Estimated fee in satoshis. + """ + # Rough estimate: ~148 bytes per input, ~34 bytes per output, ~10 overhead + tx_size = (num_inputs * 148) + (num_outputs * 34) + 10 + # Use a reasonable fee rate (satoshis per byte) + # Could be made dynamic based on network conditions + fee_rate = 10 + return tx_size * fee_rate diff --git a/src/wallet/encryption.py b/src/wallet/encryption.py new file mode 100644 index 0000000..d755bb1 --- /dev/null +++ b/src/wallet/encryption.py @@ -0,0 +1,87 @@ +"""Symmetric encryption for wallet keys with versioning.""" +import base64 +import hashlib +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + +class KeyEncryption: + """ + Versioned encryption for wallet keys. + Supports key rotation by maintaining multiple key versions. + """ + CURRENT_VERSION = 1 + + def __init__(self, master_keys: dict): + """ + Initialize with versioned master keys. + + Args: + master_keys: Dict mapping version numbers to master key strings. + Example: {1: 'key_v1', 2: 'key_v2'} + """ + self.fernets = {} + for version, key in master_keys.items(): + # Derive unique salt per version (deterministic but version-specific) + salt = hashlib.sha256(f"brighter_wallet_v{version}".encode()).digest()[:16] + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + derived_key = base64.urlsafe_b64encode(kdf.derive(key.encode())) + self.fernets[int(version)] = Fernet(derived_key) + + def encrypt(self, data: str, version: int = None) -> str: + """ + Encrypt data with current (or specified) key version. + + Args: + data: Plaintext string to encrypt. + version: Key version to use (defaults to CURRENT_VERSION). + + Returns: + Encrypted data as string. + """ + version = version or self.CURRENT_VERSION + if version not in self.fernets: + raise ValueError(f"Unknown encryption key version: {version}") + return self.fernets[version].encrypt(data.encode()).decode() + + def decrypt(self, encrypted_data: str, version: int) -> str: + """ + Decrypt data using the specified key version. + + Args: + encrypted_data: Encrypted string to decrypt. + version: Key version that was used to encrypt. + + Returns: + Decrypted plaintext string. + + Raises: + ValueError: If the key version is unknown. + """ + if version not in self.fernets: + raise ValueError(f"Unknown encryption key version: {version}") + return self.fernets[version].decrypt(encrypted_data.encode()).decode() + + def re_encrypt(self, encrypted_data: str, old_version: int, + new_version: int = None) -> tuple[str, int]: + """ + Re-encrypt data from old version to new version. + Used during key rotation. + + Args: + encrypted_data: Data encrypted with old_version. + old_version: Version the data is currently encrypted with. + new_version: Target version (defaults to CURRENT_VERSION). + + Returns: + Tuple of (new_encrypted_data, new_version). + """ + new_version = new_version or self.CURRENT_VERSION + plaintext = self.decrypt(encrypted_data, old_version) + return self.encrypt(plaintext, new_version), new_version diff --git a/src/wallet/wallet_manager.py b/src/wallet/wallet_manager.py new file mode 100644 index 0000000..3000250 --- /dev/null +++ b/src/wallet/wallet_manager.py @@ -0,0 +1,1005 @@ +"""Main wallet management operations using credits ledger with balance cap.""" +import logging +import uuid +from typing import Optional + +from .bitcoin_service import BitcoinService +from .encryption import KeyEncryption + +logger = logging.getLogger(__name__) + +# Balance cap in satoshis (~$50 at typical BTC prices) +# 0.001 BTC ~ $50-100 range depending on BTC price +BALANCE_CAP_SATOSHIS = 100000 # 0.001 BTC + + +class WalletManager: + """ + Manages user Bitcoin wallets and internal credits ledger. + + Two-Address Custody Model: + 1. Fee Address: We store private key, used for strategy fees (max ~$50) + 2. Sweep Address: We only store public address, user controls private key + + Credits Ledger: Internal balance for instant, reliable strategy fees + """ + + # Balance cap as class attribute so background jobs can access it + BALANCE_CAP_SATOSHIS = BALANCE_CAP_SATOSHIS + + def __init__(self, database, encryption_keys: dict, default_network: str = 'testnet'): + """ + Initialize wallet manager. + + Args: + database: Database instance for persistence. + encryption_keys: Dict mapping version numbers to encryption keys. + default_network: 'testnet' or 'mainnet' for new wallets. + """ + self.db = database + self.encryption = KeyEncryption(encryption_keys) + self.default_network = default_network + + # Ensure wallet tables exist + self._ensure_tables() + + def _ensure_tables(self): + """Create wallet-related tables if they don't exist.""" + # Wallets table - stores Fee Address (with keys) and Sweep Address (public only) + self.db.execute_sql(""" + CREATE TABLE IF NOT EXISTS wallets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL UNIQUE, + btc_address TEXT NOT NULL, + public_key_encrypted TEXT NOT NULL, + private_key_encrypted TEXT NOT NULL, + sweep_address TEXT, + encryption_key_version INTEGER NOT NULL DEFAULT 1, + network TEXT NOT NULL DEFAULT 'testnet' CHECK (network IN ('testnet', 'mainnet')), + is_disabled INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) + ) + """) + + # Migration: Add sweep_address column if missing (for existing wallets) + try: + self.db.execute_sql("ALTER TABLE wallets ADD COLUMN sweep_address TEXT") + except Exception: + pass # Column already exists + + # Credits ledger - source of truth for spendable balance + self.db.execute_sql(""" + CREATE TABLE IF NOT EXISTS credits_ledger ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + amount_satoshis INTEGER NOT NULL CHECK (amount_satoshis != 0), + tx_type TEXT NOT NULL CHECK (tx_type IN ('deposit', 'withdrawal', 'withdrawal_reversal', 'fee_paid', 'fee_received', 'admin_credit')), + reference_id TEXT, + counterparty_user_id INTEGER, + idempotency_key TEXT NOT NULL UNIQUE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id), + FOREIGN KEY (counterparty_user_id) REFERENCES users(id) + ) + """) + + # Pending fees during strategy run + self.db.execute_sql(""" + CREATE TABLE IF NOT EXISTS pending_strategy_fees ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + strategy_run_id TEXT NOT NULL UNIQUE, + user_id INTEGER NOT NULL, + creator_user_id INTEGER NOT NULL, + fee_percent INTEGER NOT NULL DEFAULT 10, + accumulated_satoshis INTEGER NOT NULL DEFAULT 0, + trade_count INTEGER NOT NULL DEFAULT 0, + started_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id), + FOREIGN KEY (creator_user_id) REFERENCES users(id) + ) + """) + + # Migration: Add fee_percent column if missing + try: + self.db.execute_sql("ALTER TABLE pending_strategy_fees ADD COLUMN fee_percent INTEGER NOT NULL DEFAULT 10") + except Exception: + pass # Column already exists + + # Withdrawal requests (async processing) + self.db.execute_sql(""" + CREATE TABLE IF NOT EXISTS withdrawal_requests ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + amount_satoshis INTEGER NOT NULL CHECK (amount_satoshis > 0), + destination_address TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending' CHECK (status IN ('pending', 'reserved', 'processing', 'completed', 'failed', 'reversed')), + btc_txhash TEXT, + error_message TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + processed_at TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) + ) + """) + + # Deposit tracking (for dedupe) + self.db.execute_sql(""" + CREATE TABLE IF NOT EXISTS deposits ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + network TEXT NOT NULL, + tx_hash TEXT NOT NULL, + vout INTEGER NOT NULL, + amount_satoshis INTEGER NOT NULL, + confirmations INTEGER NOT NULL DEFAULT 0, + credited INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) + ) + """) + + # Create indexes for performance + self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_wallets_user_id ON wallets(user_id)") + self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_ledger_user_id ON credits_ledger(user_id)") + self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_ledger_idempotency ON credits_ledger(idempotency_key)") + self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_pending_fees_run ON pending_strategy_fees(strategy_run_id)") + self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_withdrawals_status ON withdrawal_requests(status)") + self.db.execute_sql("CREATE INDEX IF NOT EXISTS idx_deposits_credited ON deposits(credited)") + # Unique constraint on deposits (network, tx_hash, vout) + self.db.execute_sql("CREATE UNIQUE INDEX IF NOT EXISTS idx_deposits_unique ON deposits(network, tx_hash, vout)") + + logger.info("Wallet tables ensured") + + def create_wallet(self, user_id: int, user_sweep_address: str = None) -> dict: + """ + Generate and store new wallet for user (two-address model). + + Creates: + 1. Fee Address - We store private key, used for strategy fees + 2. Sweep Address - Either user-provided or we generate one (we don't store private key) + + Args: + user_id: User ID to create wallet for. + user_sweep_address: Optional user-provided sweep address (e.g., from their own wallet). + + Returns: + Dict with success status and both addresses' details. + The sweep address private key is only returned if we generated it. + """ + existing = self.get_wallet(user_id) + if existing: + return {'success': False, 'error': 'Wallet already exists'} + + bitcoin = BitcoinService(testnet=(self.default_network == 'testnet')) + + # Validate user-provided sweep address if given + if user_sweep_address: + if not bitcoin.validate_address(user_sweep_address): + return {'success': False, 'error': 'Invalid Bitcoin address for sweep address'} + + # Generate Fee Address (we store and manage) + fee_keypair = bitcoin.generate_keypair() + + # Sweep Address: use user-provided or generate new + if user_sweep_address: + sweep_address = user_sweep_address + sweep_keypair = None # User controls this, we don't have the keys + else: + # Generate Sweep Address (user controls - we DON'T store private key) + sweep_keypair = bitcoin.generate_keypair() + sweep_address = sweep_keypair['address'] + + version = KeyEncryption.CURRENT_VERSION + encrypted_public = self.encryption.encrypt(fee_keypair['public_key'], version) + encrypted_private = self.encryption.encrypt(fee_keypair['private_key'], version) + + self.db.execute_sql( + """INSERT INTO wallets (user_id, btc_address, public_key_encrypted, + private_key_encrypted, sweep_address, encryption_key_version, network) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + (user_id, fee_keypair['address'], encrypted_public, encrypted_private, + sweep_address, version, self.default_network) + ) + + logger.info(f"Created two-address wallet for user {user_id}: " + f"fee={fee_keypair['address']}, sweep={sweep_address} " + f"(user_provided={bool(user_sweep_address)})") + + result = { + 'success': True, + # Fee Address (we manage) + 'fee_address': fee_keypair['address'], + 'fee_public_key': fee_keypair['public_key'], + 'fee_private_key': fee_keypair['private_key'], + # Sweep Address + 'sweep_address': sweep_address, + 'network': self.default_network, + } + + # Only include sweep private key if we generated it + if sweep_keypair: + result['sweep_public_key'] = sweep_keypair['public_key'] + result['sweep_private_key'] = sweep_keypair['private_key'] + result['warning'] = 'SAVE YOUR SWEEP ADDRESS PRIVATE KEY. It is NOT stored and CANNOT be recovered.' + else: + result['message'] = 'Using your provided sweep address. Excess funds will be sent there.' + + return result + + def get_wallet(self, user_id: int) -> Optional[dict]: + """ + Get wallet info (without decrypted private keys). + + Args: + user_id: User ID to look up. + + Returns: + Dict with wallet info, or None if no wallet exists. + """ + result = self.db.execute_sql( + "SELECT btc_address, network, is_disabled, created_at, sweep_address FROM wallets WHERE user_id = ?", + (user_id,), fetch_one=True + ) + if not result: + return None + return { + 'fee_address': result[0], + 'address': result[0], # Backwards compatibility + 'network': result[1], + 'is_disabled': bool(result[2]), + 'created_at': result[3], + 'sweep_address': result[4] + } + + def get_wallet_keys(self, user_id: int) -> Optional[dict]: + """ + Get fee address keys (for "View Keys" feature). + Note: Sweep address private key is NOT stored. + + Args: + user_id: User ID to look up. + + Returns: + Dict with fee address and decrypted private key, or None. + """ + result = self.db.execute_sql( + """SELECT btc_address, public_key_encrypted, private_key_encrypted, + encryption_key_version, network FROM wallets WHERE user_id = ?""", + (user_id,), fetch_one=True + ) + if not result: + return None + + address, encrypted_public, encrypted_private, version, network = result + + try: + private_key = self.encryption.decrypt(encrypted_private, version) + return { + 'fee_address': address, + 'fee_private_key': private_key, + 'network': network + } + except Exception as e: + logger.error(f"Failed to decrypt keys for user {user_id}: {e}") + return None + + def get_credits_balance(self, user_id: int) -> int: + """ + Get user's spendable credits balance (sum of ledger entries). + + Args: + user_id: User ID to look up. + + Returns: + Balance in satoshis. + """ + result = self.db.execute_sql( + "SELECT COALESCE(SUM(amount_satoshis), 0) FROM credits_ledger WHERE user_id = ?", + (user_id,), fetch_one=True + ) + return result[0] if result else 0 + + def check_balance_cap(self, user_id: int) -> dict: + """ + Check if balance exceeds cap and disable/enable wallet accordingly. + + Args: + user_id: User ID to check. + + Returns: + Dict with 'over_cap' status and current balance. + """ + balance = self.get_credits_balance(user_id) + if balance > BALANCE_CAP_SATOSHIS: + self.db.execute_sql( + "UPDATE wallets SET is_disabled = 1 WHERE user_id = ?", (user_id,) + ) + logger.warning(f"User {user_id} wallet disabled - over balance cap " + f"({balance} > {BALANCE_CAP_SATOSHIS})") + return {'over_cap': True, 'balance': balance, 'cap': BALANCE_CAP_SATOSHIS} + else: + # Re-enable if under cap + self.db.execute_sql( + "UPDATE wallets SET is_disabled = 0 WHERE user_id = ?", (user_id,) + ) + return {'over_cap': False, 'balance': balance} + + # === Strategy Fee Accumulation === + + def start_fee_accumulation(self, strategy_run_id: str, user_id: int, + creator_user_id: int, fee_percent: int = 10, + estimated_trades: int = 10) -> dict: + """ + Start fee accumulation for a strategy run. + Checks if user has enough credits for estimated trades. + + Args: + strategy_run_id: Unique identifier for this strategy run. + user_id: User running the strategy. + creator_user_id: Strategy creator who will receive fees. + fee_percent: Percentage of exchange commission to charge (1-100). + estimated_trades: Estimated number of trades for credit check. + + Returns: + Dict with success status. + """ + # If user is running their own strategy, no fees apply + if user_id == creator_user_id: + return {'success': True, 'message': 'No fees for own strategy'} + + # Validate fee_percent + fee_percent = max(1, min(100, int(fee_percent))) + + # Check wallet is enabled + wallet = self.get_wallet(user_id) + if wallet and wallet['is_disabled']: + return {'success': False, 'error': 'Wallet disabled - withdraw funds to re-enable'} + + # Estimate fee per trade (rough estimate, actual depends on trade size) + # Scale estimate by fee_percent (higher fee = more credits needed) + base_estimate = 1000 # ~1000 sats per trade at 100% + min_credits_needed = (base_estimate * fee_percent // 100) * estimated_trades + + balance = self.get_credits_balance(user_id) + if balance < min_credits_needed: + return { + 'success': False, + 'error': 'Insufficient credits for strategy fees', + 'available': balance, + 'recommended_minimum': min_credits_needed + } + + # Create pending fees record with fee_percent + self.db.execute_sql( + """INSERT OR REPLACE INTO pending_strategy_fees + (strategy_run_id, user_id, creator_user_id, fee_percent, accumulated_satoshis, trade_count) + VALUES (?, ?, ?, ?, 0, 0)""", + (strategy_run_id, user_id, creator_user_id, fee_percent) + ) + + logger.info(f"Started fee accumulation for run {strategy_run_id}: " + f"user={user_id}, creator={creator_user_id}, fee_percent={fee_percent}%") + + return {'success': True, 'strategy_run_id': strategy_run_id} + + def accumulate_trade_fee(self, strategy_run_id: str, exchange_fee_satoshis: int, + is_profitable: bool) -> dict: + """ + Accumulate fee for a single trade (called after each trade). + Only charges on profitable trades. + + Args: + strategy_run_id: Strategy run identifier. + exchange_fee_satoshis: Exchange commission in satoshis. + is_profitable: Whether the trade was profitable. + + Returns: + Dict with fee charged amount. + """ + if not is_profitable: + return {'success': True, 'fee_charged': 0, 'reason': 'unprofitable_trade'} + + # Get the fee_percent for this strategy run + result = self.db.execute_sql( + "SELECT fee_percent FROM pending_strategy_fees WHERE strategy_run_id = ?", + (strategy_run_id,), fetch_one=True + ) + + if not result: + return {'success': False, 'fee_charged': 0, 'reason': 'no_pending_fees_record'} + + fee_percent = result[0] + + # Fee is fee_percent% of exchange commission + strategy_fee = (exchange_fee_satoshis * fee_percent) // 100 + + if strategy_fee <= 0: + return {'success': True, 'fee_charged': 0, 'reason': 'fee_too_small'} + + # Add to accumulated fees + self.db.execute_sql( + """UPDATE pending_strategy_fees + SET accumulated_satoshis = accumulated_satoshis + ?, + trade_count = trade_count + 1 + WHERE strategy_run_id = ?""", + (strategy_fee, strategy_run_id) + ) + + logger.debug(f"Accumulated {strategy_fee} sats fee ({fee_percent}% of {exchange_fee_satoshis}) for run {strategy_run_id}") + + return {'success': True, 'fee_charged': strategy_fee} + + def settle_accumulated_fees(self, strategy_run_id: str) -> dict: + """ + Settle all accumulated fees when strategy stops. + Transfers total from user to creator. + + Args: + strategy_run_id: Strategy run identifier. + + Returns: + Dict with settlement details. + """ + # Get pending fees + pending = self.db.execute_sql( + """SELECT user_id, creator_user_id, accumulated_satoshis, trade_count + FROM pending_strategy_fees WHERE strategy_run_id = ?""", + (strategy_run_id,), fetch_one=True + ) + + if not pending: + return {'success': True, 'message': 'No pending fees found', 'settled': 0} + + user_id, creator_user_id, total_fees, trade_count = pending + + if total_fees <= 0: + # Clean up record, no fees to settle + self.db.execute_sql( + "DELETE FROM pending_strategy_fees WHERE strategy_run_id = ?", + (strategy_run_id,) + ) + return {'success': True, 'settled': 0, 'trades': trade_count} + + # Check user still has sufficient balance + balance = self.get_credits_balance(user_id) + if balance < total_fees: + # Settle what we can, log the shortfall + logger.warning(f"User {user_id} has insufficient balance for full fee settlement. " + f"Owed: {total_fees}, Available: {balance}") + total_fees = max(0, balance) # Settle partial + + if total_fees > 0: + idempotency_key = f"fee_settle_{strategy_run_id}" + + # Check if already settled (idempotency) + existing = self.db.execute_sql( + "SELECT id FROM credits_ledger WHERE idempotency_key = ?", + (idempotency_key,), fetch_one=True + ) + if existing: + logger.info(f"Fees already settled for run {strategy_run_id}") + else: + # Atomic fee settlement - both debit and credit in one transaction + # Prevents partial state if one insert fails + try: + self.db.execute_in_transaction([ + # Debit user + ("""INSERT INTO credits_ledger + (user_id, amount_satoshis, tx_type, reference_id, + counterparty_user_id, idempotency_key) + VALUES (?, ?, 'fee_paid', ?, ?, ?)""", + (user_id, -total_fees, strategy_run_id, creator_user_id, idempotency_key)), + # Credit creator + ("""INSERT INTO credits_ledger + (user_id, amount_satoshis, tx_type, reference_id, + counterparty_user_id, idempotency_key) + VALUES (?, ?, 'fee_received', ?, ?, ?)""", + (creator_user_id, total_fees, strategy_run_id, user_id, + f"{idempotency_key}_creator")) + ]) + logger.info(f"Settled {total_fees} sats from user {user_id} to creator {creator_user_id}") + except Exception as e: + logger.error(f"Failed to settle fees for run {strategy_run_id}: {e}") + # Don't delete pending record - leave for retry + return {'success': False, 'error': f'Fee settlement failed: {e}'} + + # Clean up pending record + self.db.execute_sql( + "DELETE FROM pending_strategy_fees WHERE strategy_run_id = ?", + (strategy_run_id,) + ) + + # Check balance cap after settlement + self.check_balance_cap(user_id) + + return {'success': True, 'settled': total_fees, 'trades': trade_count} + + def get_pending_fees(self, strategy_run_id: str) -> dict: + """ + Get current accumulated fees for a running strategy. + + Args: + strategy_run_id: Strategy run identifier. + + Returns: + Dict with accumulated fees and trade count. + """ + result = self.db.execute_sql( + """SELECT accumulated_satoshis, trade_count + FROM pending_strategy_fees WHERE strategy_run_id = ?""", + (strategy_run_id,), fetch_one=True + ) + if not result: + return {'accumulated_satoshis': 0, 'trade_count': 0} + return {'accumulated_satoshis': result[0], 'trade_count': result[1]} + + def cancel_fee_accumulation(self, strategy_run_id: str) -> dict: + """ + Cancel fee accumulation for a strategy run without charging fees. + + Used when strategy startup fails after a pending fee record has been created. + + Args: + strategy_run_id: Strategy run identifier. + + Returns: + Dict with cancellation status. + """ + try: + self.db.execute_sql( + "DELETE FROM pending_strategy_fees WHERE strategy_run_id = ?", + (strategy_run_id,) + ) + return {'success': True} + except Exception as e: + logger.error(f"Failed to cancel fee accumulation for run {strategy_run_id}: {e}") + return {'success': False, 'error': str(e)} + + # === Deposits and Withdrawals === + + def credit_deposit(self, user_id: int, amount_satoshis: int, + network: str, tx_hash: str, vout: int) -> dict: + """ + Credit user's ledger when BTC deposit is confirmed. + + Args: + user_id: User ID to credit. + amount_satoshis: Amount to credit. + network: Bitcoin network ('testnet' or 'mainnet'). + tx_hash: Transaction hash. + vout: Output index within transaction. + + Returns: + Dict with success status. + """ + idempotency_key = f"deposit_{network}_{tx_hash}_{vout}" + + existing = self.db.execute_sql( + "SELECT id FROM credits_ledger WHERE idempotency_key = ?", + (idempotency_key,), fetch_one=True + ) + if existing: + return {'success': True, 'already_processed': True} + + self.db.execute_sql( + """INSERT INTO credits_ledger + (user_id, amount_satoshis, tx_type, reference_id, idempotency_key) + VALUES (?, ?, 'deposit', ?, ?)""", + (user_id, amount_satoshis, tx_hash, idempotency_key) + ) + + logger.info(f"Credited {amount_satoshis} sats to user {user_id} from deposit {tx_hash}") + + # Check balance cap after deposit + cap_check = self.check_balance_cap(user_id) + if cap_check['over_cap']: + logger.warning(f"User {user_id} wallet disabled - over balance cap after deposit") + + return {'success': True, 'over_cap': cap_check.get('over_cap', False)} + + def request_withdrawal(self, user_id: int, amount_satoshis: int, + destination_address: str) -> dict: + """ + Queue a withdrawal request with immediate reservation. + + Args: + user_id: User ID requesting withdrawal. + amount_satoshis: Amount to withdraw. + destination_address: Bitcoin address to send to. + + Returns: + Dict with success status. + """ + wallet = self.get_wallet(user_id) + if not wallet: + return {'success': False, 'error': 'No wallet'} + + # Validate address for the wallet's network + bitcoin = BitcoinService(testnet=(wallet['network'] == 'testnet')) + if not bitcoin.validate_address(destination_address): + return {'success': False, 'error': 'Invalid Bitcoin address for this network'} + + balance = self.get_credits_balance(user_id) + if balance < amount_satoshis: + return {'success': False, 'error': 'Insufficient balance', 'available': balance} + + # Reserve immediately (debit from ledger) + idempotency_key = f"withdrawal_reserve_{uuid.uuid4().hex}" + try: + self.db.execute_in_transaction([ + ("""INSERT INTO credits_ledger + (user_id, amount_satoshis, tx_type, reference_id, idempotency_key) + VALUES (?, ?, 'withdrawal', ?, ?)""", + (user_id, -amount_satoshis, destination_address, idempotency_key)), + ("""INSERT INTO withdrawal_requests + (user_id, amount_satoshis, destination_address, status) + VALUES (?, ?, ?, 'reserved')""", + (user_id, amount_satoshis, destination_address)) + ]) + except Exception as e: + logger.error(f"Failed to reserve withdrawal for user {user_id}: {e}") + return {'success': False, 'error': 'Failed to queue withdrawal request'} + + logger.info(f"Withdrawal requested: {amount_satoshis} sats to {destination_address} for user {user_id}") + + # Re-check balance cap (may re-enable wallet after withdrawal) + self.check_balance_cap(user_id) + + return {'success': True, 'message': 'Withdrawal reserved and queued'} + + def admin_credit(self, user_id: int, amount_satoshis: int, reason: str) -> dict: + """ + Admin function to manually credit a user (for POC testing). + + Args: + user_id: User ID to credit. + amount_satoshis: Amount to credit. + reason: Reason for the credit. + + Returns: + Dict with success status. + """ + idempotency_key = f"admin_{user_id}_{uuid.uuid4().hex}" + self.db.execute_sql( + """INSERT INTO credits_ledger + (user_id, amount_satoshis, tx_type, reference_id, idempotency_key) + VALUES (?, ?, 'admin_credit', ?, ?)""", + (user_id, amount_satoshis, reason, idempotency_key) + ) + + logger.info(f"Admin credited {amount_satoshis} sats to user {user_id}: {reason}") + + self.check_balance_cap(user_id) + return {'success': True} + + def get_transaction_history(self, user_id: int, limit: int = 20) -> list: + """ + Get user's recent ledger transactions. + + Args: + user_id: User ID to look up. + limit: Maximum number of transactions to return. + + Returns: + List of transaction dicts. + """ + results = self.db.execute_sql( + """SELECT tx_type, amount_satoshis, reference_id, created_at + FROM credits_ledger WHERE user_id = ? + ORDER BY created_at DESC LIMIT ?""", + (user_id, limit), fetch_all=True + ) + return [{'type': r[0], 'amount': r[1], 'reference': r[2], 'date': r[3]} + for r in (results or [])] + + # === Background Job Support Methods === + + def get_wallets_over_cap(self) -> list: + """ + Get all wallets with balance exceeding the cap (for auto-sweep). + + Returns: + List of dicts with user_id, balance, and sweep_address. + """ + # Get all users with wallets and compute their balances + results = self.db.execute_sql( + """SELECT w.user_id, w.sweep_address, + COALESCE(SUM(cl.amount_satoshis), 0) as balance + FROM wallets w + LEFT JOIN credits_ledger cl ON w.user_id = cl.user_id + WHERE w.sweep_address IS NOT NULL + GROUP BY w.user_id + HAVING balance > ?""", + (BALANCE_CAP_SATOSHIS,), fetch_all=True + ) + return [{'user_id': r[0], 'sweep_address': r[1], 'balance': r[2]} + for r in (results or [])] + + def auto_sweep(self, user_id: int, amount_satoshis: int) -> dict: + """ + Sweep excess funds to user's sweep address. + + Args: + user_id: User ID to sweep for. + amount_satoshis: Amount to sweep. + + Returns: + Dict with success status and tx_hash. + """ + wallet = self.get_wallet(user_id) + if not wallet: + return {'success': False, 'error': 'No wallet'} + + sweep_address = wallet.get('sweep_address') + if not sweep_address: + return {'success': False, 'error': 'No sweep address configured'} + + # Get fee address private key + keys = self.get_wallet_keys(user_id) + if not keys: + return {'success': False, 'error': 'Could not retrieve wallet keys'} + + bitcoin = BitcoinService(testnet=(wallet['network'] == 'testnet')) + + try: + # Execute the sweep transaction + tx_hash = bitcoin.send_transaction( + private_key_wif=keys['fee_private_key'], + to_address=sweep_address, + amount_satoshis=amount_satoshis + ) + if not tx_hash: + return {'success': False, 'error': 'Sweep transaction failed: no tx hash returned'} + + # Record the sweep as a withdrawal in the ledger + idempotency_key = f"auto_sweep_{user_id}_{tx_hash}" + self.db.execute_sql( + """INSERT INTO credits_ledger + (user_id, amount_satoshis, tx_type, reference_id, idempotency_key) + VALUES (?, ?, 'withdrawal', ?, ?)""", + (user_id, -amount_satoshis, f"auto_sweep:{tx_hash}", idempotency_key) + ) + + logger.info(f"Auto-swept {amount_satoshis} sats for user {user_id} to {sweep_address}: {tx_hash}") + + # Re-check balance cap + self.check_balance_cap(user_id) + + return {'success': True, 'tx_hash': tx_hash} + + except Exception as e: + logger.error(f"Auto-sweep failed for user {user_id}: {e}") + return {'success': False, 'error': str(e)} + + def get_all_active_wallets(self) -> list: + """ + Get all active (non-disabled) wallets for deposit checking. + + Returns: + List of dicts with user_id, fee_address, and network. + """ + results = self.db.execute_sql( + """SELECT user_id, btc_address, network + FROM wallets WHERE is_disabled = 0""", + fetch_all=True + ) + return [{'user_id': r[0], 'fee_address': r[1], 'network': r[2]} + for r in (results or [])] + + def get_wallets_for_deposit_monitoring(self) -> list: + """ + Get all wallets for deposit checking, including disabled wallets. + + Disabled wallets must still receive deposit credits so users can move + back under the cap via sweeping/withdrawal flows. + + Returns: + List of dicts with user_id, fee_address, and network. + """ + results = self.db.execute_sql( + """SELECT user_id, btc_address, network + FROM wallets""", + fetch_all=True + ) + return [{'user_id': r[0], 'fee_address': r[1], 'network': r[2]} + for r in (results or [])] + + def check_address_for_deposits(self, user_id: int, address: str, network: str) -> list: + """ + Check an address for new deposits and credit the user's ledger. + + Args: + user_id: User ID who owns the address. + address: Bitcoin address to check. + network: Bitcoin network ('testnet' or 'mainnet'). + + Returns: + List of new deposits found and credited. + """ + bitcoin = BitcoinService(testnet=(network == 'testnet')) + new_deposits = [] + + try: + # Get unspent outputs (UTXOs) for the address + utxos = bitcoin.get_unspent(address) + + for utxo in utxos: + tx_hash = utxo.txid + vout = utxo.txindex + amount_satoshis = utxo.amount # bit library returns in satoshis + + # Check if we've already processed this deposit + existing = self.db.execute_sql( + """SELECT id, credited FROM deposits + WHERE network = ? AND tx_hash = ? AND vout = ?""", + (network, tx_hash, vout), fetch_one=True + ) + + if existing and int(existing[1]) == 1: + continue # Already processed + + # Ensure tracking row exists. Keep credited=0 until ledger credit succeeds. + if not existing: + self.db.execute_sql( + """INSERT OR IGNORE INTO deposits + (user_id, network, tx_hash, vout, amount_satoshis, credited) + VALUES (?, ?, ?, ?, ?, 0)""", + (user_id, network, tx_hash, vout, amount_satoshis) + ) + + # Credit the user's ledger + result = self.credit_deposit(user_id, amount_satoshis, network, tx_hash, vout) + + if result.get('success'): + self.db.execute_sql( + """UPDATE deposits + SET credited = 1 + WHERE network = ? AND tx_hash = ? AND vout = ?""", + (network, tx_hash, vout) + ) + + if not result.get('already_processed'): + new_deposits.append({ + 'tx_hash': tx_hash, + 'vout': vout, + 'amount_satoshis': amount_satoshis + }) + + except Exception as e: + logger.error(f"Error checking deposits for {address}: {e}") + + return new_deposits + + def get_pending_withdrawals(self) -> list: + """ + Get all pending withdrawal requests (status = 'reserved'). + + Returns: + List of pending withdrawal dicts. + """ + results = self.db.execute_sql( + """SELECT id, user_id, amount_satoshis, destination_address + FROM withdrawal_requests + WHERE status = 'reserved' + ORDER BY created_at ASC""", + fetch_all=True + ) + return [{'id': r[0], 'user_id': r[1], 'amount_satoshis': r[2], 'destination_address': r[3]} + for r in (results or [])] + + def process_withdrawal(self, withdrawal_id: int) -> dict: + """ + Process a pending withdrawal request. + + Args: + withdrawal_id: ID of the withdrawal request. + + Returns: + Dict with success status and tx_hash. + """ + # Get withdrawal details + result = self.db.execute_sql( + """SELECT user_id, amount_satoshis, destination_address, status + FROM withdrawal_requests WHERE id = ?""", + (withdrawal_id,), fetch_one=True + ) + + if not result: + return {'success': False, 'error': 'Withdrawal not found'} + + user_id, amount_satoshis, destination_address, status = result + + if status != 'reserved': + return {'success': False, 'error': f'Invalid status: {status}'} + + # Get wallet and keys + wallet = self.get_wallet(user_id) + if not wallet: + self._fail_withdrawal(withdrawal_id, 'No wallet found') + return {'success': False, 'error': 'No wallet'} + + keys = self.get_wallet_keys(user_id) + if not keys: + self._fail_withdrawal(withdrawal_id, 'Could not retrieve wallet keys') + return {'success': False, 'error': 'Could not retrieve keys'} + + # Mark as processing + self.db.execute_sql( + "UPDATE withdrawal_requests SET status = 'processing' WHERE id = ?", + (withdrawal_id,) + ) + + bitcoin = BitcoinService(testnet=(wallet['network'] == 'testnet')) + + try: + # Execute the withdrawal transaction + tx_hash = bitcoin.send_transaction( + private_key_wif=keys['fee_private_key'], + to_address=destination_address, + amount_satoshis=amount_satoshis + ) + if not tx_hash: + raise RuntimeError("Withdrawal transaction failed: no tx hash returned") + + # Mark as completed + self.db.execute_sql( + """UPDATE withdrawal_requests + SET status = 'completed', btc_txhash = ?, processed_at = CURRENT_TIMESTAMP + WHERE id = ?""", + (tx_hash, withdrawal_id) + ) + + logger.info(f"Withdrawal {withdrawal_id} completed: {tx_hash}") + + return {'success': True, 'tx_hash': tx_hash} + + except Exception as e: + logger.error(f"Withdrawal {withdrawal_id} failed: {e}") + self._fail_withdrawal(withdrawal_id, str(e)) + return {'success': False, 'error': str(e)} + + def _fail_withdrawal(self, withdrawal_id: int, error_message: str): + """ + Mark a withdrawal as failed and reverse the ledger debit. + + Args: + withdrawal_id: ID of the failed withdrawal. + error_message: Error message to record. + """ + # Get withdrawal details + result = self.db.execute_sql( + """SELECT user_id, amount_satoshis, destination_address + FROM withdrawal_requests WHERE id = ?""", + (withdrawal_id,), fetch_one=True + ) + + if not result: + return + + user_id, amount_satoshis, destination_address = result + + # Mark as failed + self.db.execute_sql( + """UPDATE withdrawal_requests + SET status = 'failed', error_message = ?, processed_at = CURRENT_TIMESTAMP + WHERE id = ?""", + (error_message, withdrawal_id) + ) + + # Reverse the ledger debit (credit the user back) + idempotency_key = f"withdrawal_reversal_{withdrawal_id}" + + # Check if already reversed + existing = self.db.execute_sql( + "SELECT id FROM credits_ledger WHERE idempotency_key = ?", + (idempotency_key,), fetch_one=True + ) + + if not existing: + self.db.execute_sql( + """INSERT INTO credits_ledger + (user_id, amount_satoshis, tx_type, reference_id, idempotency_key) + VALUES (?, ?, 'withdrawal_reversal', ?, ?)""", + (user_id, amount_satoshis, f"failed_withdrawal_{withdrawal_id}", idempotency_key) + ) + logger.info(f"Reversed withdrawal {withdrawal_id} - credited {amount_satoshis} sats back to user {user_id}") diff --git a/tests/test_strategy_execution.py b/tests/test_strategy_execution.py index 43e3028..1bc516c 100644 --- a/tests/test_strategy_execution.py +++ b/tests/test_strategy_execution.py @@ -334,6 +334,9 @@ class TestStopStrategy: bt = BrighterTrades(MagicMock()) bt.strategies = MagicMock() bt.strategies.active_instances = {} + # Mock wallet_manager for fee settlement + bt.wallet_manager = MagicMock() + bt.wallet_manager.settle_accumulated_fees.return_value = {'success': True, 'settled': 0} return bt def test_stop_strategy_not_running(self, mock_brighter_trades): diff --git a/tests/test_wallet.py b/tests/test_wallet.py new file mode 100644 index 0000000..56acb33 --- /dev/null +++ b/tests/test_wallet.py @@ -0,0 +1,827 @@ +""" +Tests for the wallet module. +Tests encryption, Bitcoin service, and wallet manager functionality. +""" +import pytest +import tempfile +import os +import sys +from unittest.mock import patch + +# Add src to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from wallet.encryption import KeyEncryption +from wallet.bitcoin_service import BitcoinService +from wallet.wallet_manager import WalletManager, BALANCE_CAP_SATOSHIS + + +class TestKeyEncryption: + """Tests for KeyEncryption.""" + + def test_encrypt_decrypt(self): + """Test basic encryption and decryption.""" + keys = {1: 'test_master_key_v1'} + encryption = KeyEncryption(keys) + + original = "my_secret_private_key" + encrypted = encryption.encrypt(original) + decrypted = encryption.decrypt(encrypted, version=1) + + assert decrypted == original + assert encrypted != original + + def test_versioned_keys(self): + """Test encryption with multiple key versions.""" + keys = {1: 'key_v1', 2: 'key_v2'} + encryption = KeyEncryption(keys) + + data = "secret_data" + encrypted_v1 = encryption.encrypt(data, version=1) + encrypted_v2 = encryption.encrypt(data, version=2) + + # Both should decrypt correctly with their respective versions + assert encryption.decrypt(encrypted_v1, version=1) == data + assert encryption.decrypt(encrypted_v2, version=2) == data + + # Encrypted data should be different for different versions + assert encrypted_v1 != encrypted_v2 + + def test_re_encrypt(self): + """Test re-encryption from old version to new version.""" + keys = {1: 'old_key', 2: 'new_key'} + encryption = KeyEncryption(keys) + + data = "secret_data" + encrypted_v1 = encryption.encrypt(data, version=1) + + # Re-encrypt to version 2 + encrypted_v2, new_version = encryption.re_encrypt(encrypted_v1, old_version=1, new_version=2) + + assert new_version == 2 + assert encryption.decrypt(encrypted_v2, version=2) == data + + def test_unknown_version_raises(self): + """Test that decrypting with unknown version raises.""" + keys = {1: 'test_key'} + encryption = KeyEncryption(keys) + + with pytest.raises(ValueError, match="Unknown encryption key version"): + encryption.decrypt("some_data", version=99) + + +class TestBitcoinService: + """Tests for BitcoinService.""" + + def test_generate_keypair_testnet(self): + """Test generating testnet keypair.""" + service = BitcoinService(testnet=True) + keypair = service.generate_keypair() + + assert 'address' in keypair + assert 'public_key' in keypair + assert 'private_key' in keypair + + # Testnet addresses start with m, n, or tb1 + address = keypair['address'] + assert address.startswith(('m', 'n', 'tb1')) + + def test_validate_testnet_address(self): + """Test validating testnet addresses.""" + service = BitcoinService(testnet=True) + + # Generate a valid testnet address + keypair = service.generate_keypair() + assert service.validate_address(keypair['address']) + + # Invalid addresses + assert not service.validate_address('') + assert not service.validate_address(None) + assert not service.validate_address('invalid') + assert not service.validate_address('1BvBMSEYstWetqTFn5Au4m4GFg7xJaNVN2') # Mainnet + + def test_validate_mainnet_address(self): + """Test validating mainnet addresses.""" + service = BitcoinService(testnet=False) + + # Valid mainnet addresses + assert service.validate_address('1BvBMSEYstWetqTFn5Au4m4GFg7xJaNVN2') + assert service.validate_address('3J98t1WpEZ73CNmQviecrnyiWrnqRhWNLy') + assert service.validate_address('bc1qar0srrr7xfkvy5l643lydnw9re59gtzzwf5mdq') + + # Testnet addresses should be invalid on mainnet + assert not service.validate_address('mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt') + + def test_estimate_fee(self): + """Test fee estimation.""" + service = BitcoinService(testnet=True) + + fee = service.estimate_fee(num_inputs=1, num_outputs=2) + assert fee > 0 + assert isinstance(fee, int) + + # More inputs/outputs should cost more + fee_larger = service.estimate_fee(num_inputs=3, num_outputs=5) + assert fee_larger > fee + + +class MockDatabase: + """Mock database for testing WalletManager.""" + + def __init__(self): + self.tables = {} + self.data = {} + + def execute_sql(self, sql, params=None, fetch_one=False, fetch_all=False): + """Simple mock SQL execution.""" + sql_lower = sql.lower().strip() + + if sql_lower.startswith('create table') or sql_lower.startswith('create index'): + # Table creation - just ignore + return None + + if sql_lower.startswith('alter table'): + # Migration - add sweep_address column + return None + + if sql_lower.startswith('insert'): + # Store insert data + if 'wallets' in sql_lower: + # Two-address model: params are (user_id, btc_address, public_key, private_key, sweep_address, version, network) + self.data.setdefault('wallets', {})[params[0]] = { + 'user_id': params[0], + 'btc_address': params[1], + 'public_key_encrypted': params[2], + 'private_key_encrypted': params[3], + 'sweep_address': params[4] if len(params) > 6 else None, + 'encryption_key_version': params[5] if len(params) > 6 else params[4], + 'network': params[6] if len(params) > 6 else params[5], + 'is_disabled': 0, + 'created_at': 'now' + } + elif 'credits_ledger' in sql_lower: + self.data.setdefault('credits_ledger', []).append({ + 'user_id': params[0], + 'amount_satoshis': params[1], + 'tx_type': params[2], + 'reference_id': params[3], + 'idempotency_key': params[-1], + }) + elif 'pending_strategy_fees' in sql_lower: + self.data.setdefault('pending_fees', {})[params[0]] = { + 'strategy_run_id': params[0], + 'user_id': params[1], + 'creator_user_id': params[2], + 'fee_percent': params[3] if len(params) > 3 else 10, + 'accumulated_satoshis': 0, + 'trade_count': 0, + } + elif 'withdrawal_requests' in sql_lower: + self.data.setdefault('withdrawals', []).append({ + 'user_id': params[0], + 'amount_satoshis': params[1], + 'destination_address': params[2], + 'status': 'reserved', # status is hardcoded in SQL + }) + return None + + if sql_lower.startswith('select'): + if fetch_one: + if 'wallets' in sql_lower and 'user_id' in sql_lower: + wallet = self.data.get('wallets', {}).get(params[0]) + if wallet: + # Check if it's the get_wallet_keys query (includes encrypted keys) + if 'private_key_encrypted' in sql_lower: + return (wallet['btc_address'], wallet['public_key_encrypted'], + wallet['private_key_encrypted'], wallet['encryption_key_version'], + wallet['network']) + # Regular get_wallet query (5 columns including sweep_address) + return (wallet['btc_address'], wallet['network'], + wallet['is_disabled'], wallet['created_at'], + wallet.get('sweep_address')) + elif 'credits_ledger' in sql_lower and 'sum' in sql_lower: + ledger = self.data.get('credits_ledger', []) + user_total = sum(e['amount_satoshis'] for e in ledger if e['user_id'] == params[0]) + return (user_total,) + elif 'credits_ledger' in sql_lower and 'idempotency' in sql_lower: + ledger = self.data.get('credits_ledger', []) + for entry in ledger: + if entry.get('idempotency_key') == params[0]: + return (1,) + return None + elif 'pending_strategy_fees' in sql_lower: + fees = self.data.get('pending_fees', {}).get(params[0]) + if fees: + # Different queries select different columns + if 'fee_percent' in sql_lower and 'accumulated' not in sql_lower: + # accumulate_trade_fee: SELECT fee_percent + return (fees.get('fee_percent', 10),) + elif 'accumulated_satoshis, trade_count' in sql_lower and 'user_id' not in sql_lower.split('select')[1].split('from')[0]: + # get_pending_fees: SELECT accumulated_satoshis, trade_count + return (fees['accumulated_satoshis'], fees['trade_count']) + else: + # settle_accumulated_fees: SELECT user_id, creator_user_id, accumulated_satoshis, trade_count + return (fees['user_id'], fees['creator_user_id'], + fees['accumulated_satoshis'], fees['trade_count']) + return None + + if fetch_all: + if 'credits_ledger' in sql_lower: + ledger = self.data.get('credits_ledger', []) + user_entries = [e for e in ledger if e['user_id'] == params[0]] + return [(e['tx_type'], e['amount_satoshis'], e['reference_id'], 'now') + for e in user_entries[-params[1]:]] + return [] + + if sql_lower.startswith('update'): + if 'wallets' in sql_lower and 'is_disabled' in sql_lower: + user_id = params[0] # Only param is user_id + if user_id in self.data.get('wallets', {}): + # Parse the value from the SQL itself (is_disabled = 0 or is_disabled = 1) + if 'is_disabled = 1' in sql_lower: + self.data['wallets'][user_id]['is_disabled'] = 1 + else: + self.data['wallets'][user_id]['is_disabled'] = 0 + elif 'pending_strategy_fees' in sql_lower and 'accumulated_satoshis' in sql_lower: + # Update: accumulated_satoshis = accumulated_satoshis + ?, trade_count = trade_count + 1 + fee_amount = params[0] + run_id = params[1] + if run_id in self.data.get('pending_fees', {}): + self.data['pending_fees'][run_id]['accumulated_satoshis'] += fee_amount + self.data['pending_fees'][run_id]['trade_count'] += 1 + return None + + if sql_lower.startswith('delete'): + if 'pending_strategy_fees' in sql_lower: + run_id = params[0] + if run_id in self.data.get('pending_fees', {}): + del self.data['pending_fees'][run_id] + return None + + return None + + def execute_in_transaction(self, statements): + """Execute multiple statements atomically (mock version).""" + # In mock, just execute each statement in order + for sql, params in statements: + self.execute_sql(sql, params) + return True + + +class TestWalletManager: + """Tests for WalletManager.""" + + @pytest.fixture + def wallet_manager(self): + """Create a wallet manager with mock database.""" + db = MockDatabase() + keys = {1: 'test_encryption_key'} + return WalletManager(db, keys, default_network='testnet') + + def test_create_wallet(self, wallet_manager): + """Test creating a new wallet (two-address model).""" + result = wallet_manager.create_wallet(user_id=1) + + assert result['success'] + # Fee Address (we store keys) + assert 'fee_address' in result + assert 'fee_private_key' in result + # Sweep Address (we don't store private key) + assert 'sweep_address' in result + assert 'sweep_private_key' in result + assert result['network'] == 'testnet' + + def test_create_wallet_duplicate(self, wallet_manager): + """Test creating duplicate wallet fails.""" + wallet_manager.create_wallet(user_id=1) + result = wallet_manager.create_wallet(user_id=1) + + assert not result['success'] + assert 'already exists' in result['error'].lower() + + def test_get_wallet(self, wallet_manager): + """Test getting wallet info.""" + wallet_manager.create_wallet(user_id=1) + wallet = wallet_manager.get_wallet(user_id=1) + + assert wallet is not None + assert 'address' in wallet + assert wallet['network'] == 'testnet' + assert 'private_key' not in wallet # Private key should not be exposed + + def test_credits_balance_empty(self, wallet_manager): + """Test credits balance starts at zero.""" + balance = wallet_manager.get_credits_balance(user_id=1) + assert balance == 0 + + def test_admin_credit(self, wallet_manager): + """Test admin credit adds to balance.""" + wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='test') + balance = wallet_manager.get_credits_balance(user_id=1) + assert balance == 10000 + + def test_credit_deposit(self, wallet_manager): + """Test deposit crediting.""" + result = wallet_manager.credit_deposit( + user_id=1, + amount_satoshis=50000, + network='testnet', + tx_hash='abc123', + vout=0 + ) + + assert result['success'] + assert wallet_manager.get_credits_balance(user_id=1) == 50000 + + def test_deposit_idempotency(self, wallet_manager): + """Test deposit is idempotent.""" + wallet_manager.credit_deposit( + user_id=1, amount_satoshis=50000, + network='testnet', tx_hash='abc123', vout=0 + ) + wallet_manager.credit_deposit( + user_id=1, amount_satoshis=50000, + network='testnet', tx_hash='abc123', vout=0 + ) + + # Should only be credited once + balance = wallet_manager.get_credits_balance(user_id=1) + assert balance == 50000 + + def test_fee_accumulation(self, wallet_manager): + """Test fee accumulation during strategy run with default 10% fee.""" + wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial') + + # Start accumulation with default 10% fee + result = wallet_manager.start_fee_accumulation( + strategy_run_id='run_001', + user_id=1, + creator_user_id=2, + fee_percent=10, # 10% of exchange fee + estimated_trades=5 + ) + assert result['success'] + + # Accumulate some fees + wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=10000, is_profitable=True) + wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=5000, is_profitable=True) + wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=8000, is_profitable=False) # Should not accumulate + + # Check pending fees + pending = wallet_manager.get_pending_fees('run_001') + assert pending['accumulated_satoshis'] == 1500 # 10% of 10000 + 10% of 5000 + assert pending['trade_count'] == 2 # Only profitable trades counted + + def test_fee_accumulation_custom_percent(self, wallet_manager): + """Test fee accumulation with custom fee percentage (50%).""" + wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial') + + # Start accumulation with 50% fee + result = wallet_manager.start_fee_accumulation( + strategy_run_id='run_002', + user_id=1, + creator_user_id=2, + fee_percent=50, # 50% of exchange fee + estimated_trades=5 + ) + assert result['success'] + + # Accumulate fee: 50% of 10000 = 5000 + wallet_manager.accumulate_trade_fee('run_002', exchange_fee_satoshis=10000, is_profitable=True) + + pending = wallet_manager.get_pending_fees('run_002') + assert pending['accumulated_satoshis'] == 5000 # 50% of 10000 + assert pending['trade_count'] == 1 + + def test_fee_accumulation_100_percent(self, wallet_manager): + """Test fee accumulation at 100% (same as exchange commission).""" + wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial') + + # Start accumulation with 100% fee + result = wallet_manager.start_fee_accumulation( + strategy_run_id='run_003', + user_id=1, + creator_user_id=2, + fee_percent=100, # 100% of exchange fee + estimated_trades=5 + ) + assert result['success'] + + # Accumulate fee: 100% of 10000 = 10000 + wallet_manager.accumulate_trade_fee('run_003', exchange_fee_satoshis=10000, is_profitable=True) + + pending = wallet_manager.get_pending_fees('run_003') + assert pending['accumulated_satoshis'] == 10000 # 100% of 10000 + assert pending['trade_count'] == 1 + + def test_fee_settlement(self, wallet_manager): + """Test fee settlement when strategy stops.""" + wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial') + wallet_manager.admin_credit(user_id=2, amount_satoshis=0, reason='creator setup') + + wallet_manager.start_fee_accumulation('run_001', user_id=1, creator_user_id=2) + wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=10000, is_profitable=True) + + # Settle fees + result = wallet_manager.settle_accumulated_fees('run_001') + + assert result['success'] + assert result['settled'] == 1000 # 10% of 10000 + assert result['trades'] == 1 + + # Check balances + user_balance = wallet_manager.get_credits_balance(user_id=1) + creator_balance = wallet_manager.get_credits_balance(user_id=2) + + assert user_balance == 99000 # 100000 - 1000 + assert creator_balance == 1000 # Received fee + + def test_insufficient_credits_for_strategy(self, wallet_manager): + """Test that strategy start fails with insufficient credits.""" + # User has no credits + result = wallet_manager.start_fee_accumulation( + strategy_run_id='run_001', + user_id=1, + creator_user_id=2, + estimated_trades=10 + ) + + assert not result['success'] + assert 'insufficient' in result['error'].lower() + + def test_withdrawal_request(self, wallet_manager): + """Test withdrawal request.""" + wallet_manager.create_wallet(user_id=1) + wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test') + + result = wallet_manager.request_withdrawal( + user_id=1, + amount_satoshis=10000, + destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt' + ) + + assert result['success'] + + # Balance should be reduced immediately + balance = wallet_manager.get_credits_balance(user_id=1) + assert balance == 40000 + + def test_withdrawal_insufficient_balance(self, wallet_manager): + """Test withdrawal fails with insufficient balance.""" + wallet_manager.create_wallet(user_id=1) + wallet_manager.admin_credit(user_id=1, amount_satoshis=5000, reason='test') + + result = wallet_manager.request_withdrawal( + user_id=1, + amount_satoshis=10000, + destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt' + ) + + assert not result['success'] + assert 'insufficient' in result['error'].lower() + + def test_withdrawal_invalid_address(self, wallet_manager): + """Test withdrawal fails with invalid address.""" + wallet_manager.create_wallet(user_id=1) + wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test') + + result = wallet_manager.request_withdrawal( + user_id=1, + amount_satoshis=10000, + destination_address='invalid_address' + ) + + assert not result['success'] + assert 'invalid' in result['error'].lower() + + def test_withdrawal_reservation_transaction_failure(self, wallet_manager): + """Test withdrawal reservation fails cleanly when transaction fails.""" + wallet_manager.create_wallet(user_id=1) + wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test') + + def fail_transaction(_statements): + raise RuntimeError('db transaction failed') + + wallet_manager.db.execute_in_transaction = fail_transaction + + result = wallet_manager.request_withdrawal( + user_id=1, + amount_satoshis=10000, + destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt' + ) + + assert not result['success'] + assert 'failed to queue' in result['error'].lower() + # Balance should remain unchanged because reservation wasn't committed. + assert wallet_manager.get_credits_balance(user_id=1) == 50000 + + def test_own_strategy_no_fees(self, wallet_manager): + """Test that running own strategy doesn't require fees.""" + result = wallet_manager.start_fee_accumulation( + strategy_run_id='run_001', + user_id=1, + creator_user_id=1, # Same user + estimated_trades=10 + ) + + assert result['success'] + assert 'no fees' in result.get('message', '').lower() + + def test_transaction_history(self, wallet_manager): + """Test getting transaction history.""" + wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='credit1') + wallet_manager.admin_credit(user_id=1, amount_satoshis=5000, reason='credit2') + + history = wallet_manager.get_transaction_history(user_id=1, limit=10) + + assert len(history) == 2 + assert all('type' in tx for tx in history) + assert all('amount' in tx for tx in history) + + +class TestBalanceCap: + """Tests for balance cap functionality.""" + + @pytest.fixture + def wallet_manager(self): + db = MockDatabase() + keys = {1: 'test_key'} + return WalletManager(db, keys, default_network='testnet') + + def test_balance_cap_disables_wallet(self, wallet_manager): + """Test that exceeding balance cap disables wallet.""" + wallet_manager.create_wallet(user_id=1) + + # Credit above cap + wallet_manager.admin_credit( + user_id=1, + amount_satoshis=BALANCE_CAP_SATOSHIS + 10000, + reason='over cap' + ) + + result = wallet_manager.check_balance_cap(user_id=1) + assert result['over_cap'] + + # Wallet should show as disabled + wallet = wallet_manager.get_wallet(user_id=1) + assert wallet['is_disabled'] + + def test_balance_under_cap_enables_wallet(self, wallet_manager): + """Test that being under balance cap enables wallet.""" + wallet_manager.create_wallet(user_id=1) + + # Credit under cap + wallet_manager.admin_credit( + user_id=1, + amount_satoshis=BALANCE_CAP_SATOSHIS - 10000, + reason='under cap' + ) + + result = wallet_manager.check_balance_cap(user_id=1) + assert not result['over_cap'] + + wallet = wallet_manager.get_wallet(user_id=1) + assert not wallet['is_disabled'] + + +class MockDatabaseForBackgroundJobs(MockDatabase): + """Extended mock database for testing background job methods.""" + + def execute_sql(self, sql, params=None, fetch_one=False, fetch_all=False): + sql_lower = sql.lower().strip() + + # Handle get_wallets_over_cap query + if 'left join credits_ledger' in sql_lower and 'having balance' in sql_lower: + results = [] + for user_id, wallet in self.data.get('wallets', {}).items(): + if not wallet.get('sweep_address'): + continue + # Calculate balance from ledger + ledger = self.data.get('credits_ledger', []) + balance = sum(e['amount_satoshis'] for e in ledger if e['user_id'] == user_id) + if balance > params[0]: # params[0] is BALANCE_CAP_SATOSHIS + results.append((user_id, wallet.get('sweep_address'), balance)) + if fetch_all: + return results + return None + + # Handle get_all_active_wallets query + if 'select user_id, btc_address, network' in sql_lower and 'is_disabled = 0' in sql_lower: + results = [] + for user_id, wallet in self.data.get('wallets', {}).items(): + if not wallet.get('is_disabled'): + results.append((user_id, wallet['btc_address'], wallet['network'])) + if fetch_all: + return results + return None + + # Handle get_wallets_for_deposit_monitoring query (includes disabled wallets) + if 'select user_id, btc_address, network' in sql_lower and 'from wallets' in sql_lower and 'is_disabled' not in sql_lower: + results = [] + for user_id, wallet in self.data.get('wallets', {}).items(): + results.append((user_id, wallet['btc_address'], wallet['network'])) + if fetch_all: + return results + return None + + # Handle deposits table queries + if 'insert' in sql_lower and 'deposits' in sql_lower: + self.data.setdefault('deposits', []).append({ + 'user_id': params[0], + 'network': params[1], + 'tx_hash': params[2], + 'vout': params[3], + 'amount_satoshis': params[4], + 'credited': params[5] if len(params) > 5 else 0 + }) + return None + + if 'select' in sql_lower and 'deposits' in sql_lower and 'network' in sql_lower and fetch_one: + deposits = self.data.get('deposits', []) + for d in deposits: + if d['network'] == params[0] and d['tx_hash'] == params[1] and d['vout'] == params[2]: + return (d.get('id', 1),) + return None + + # Handle get_pending_withdrawals query + if 'withdrawal_requests' in sql_lower and "status = 'reserved'" in sql_lower and fetch_all: + withdrawals = self.data.get('withdrawals', []) + results = [] + for i, w in enumerate(withdrawals): + if w.get('status') == 'reserved': + results.append((i + 1, w['user_id'], w['amount_satoshis'], w['destination_address'])) + return results + + # Handle withdrawal_requests SELECT for process_withdrawal and _fail_withdrawal + if 'withdrawal_requests' in sql_lower and 'select' in sql_lower and 'where id' in sql_lower and fetch_one: + withdrawals = self.data.get('withdrawals', []) + withdrawal_id = params[0] + if withdrawal_id > 0 and withdrawal_id <= len(withdrawals): + w = withdrawals[withdrawal_id - 1] + # process_withdrawal query includes status, _fail_withdrawal does not + if 'status' in sql_lower: + return (w['user_id'], w['amount_satoshis'], w['destination_address'], w.get('status', 'reserved')) + else: + return (w['user_id'], w['amount_satoshis'], w['destination_address']) + return None + + # Handle withdrawal_requests UPDATE + if 'update' in sql_lower and 'withdrawal_requests' in sql_lower: + withdrawal_id = params[-1] # Last param is the ID in WHERE clause + withdrawals = self.data.get('withdrawals', []) + if withdrawal_id > 0 and withdrawal_id <= len(withdrawals): + if 'processing' in sql_lower: + withdrawals[withdrawal_id - 1]['status'] = 'processing' + elif 'completed' in sql_lower: + withdrawals[withdrawal_id - 1]['status'] = 'completed' + withdrawals[withdrawal_id - 1]['btc_txhash'] = params[0] + elif 'failed' in sql_lower: + withdrawals[withdrawal_id - 1]['status'] = 'failed' + withdrawals[withdrawal_id - 1]['error_message'] = params[0] + return None + + # Fall back to parent implementation + return super().execute_sql(sql, params, fetch_one, fetch_all) + + +class TestBackgroundJobSupport: + """Tests for background job support methods in WalletManager.""" + + @pytest.fixture + def wallet_manager(self): + db = MockDatabaseForBackgroundJobs() + keys = {1: 'test_key'} + return WalletManager(db, keys, default_network='testnet') + + def test_get_wallets_over_cap(self, wallet_manager): + """Test getting wallets with balance over cap.""" + # Create wallet with sweep address + wallet_manager.create_wallet(user_id=1) + + # Credit above cap + wallet_manager.admin_credit( + user_id=1, + amount_satoshis=BALANCE_CAP_SATOSHIS + 50000, + reason='over cap' + ) + + wallets = wallet_manager.get_wallets_over_cap() + + assert len(wallets) == 1 + assert wallets[0]['user_id'] == 1 + assert wallets[0]['balance'] == BALANCE_CAP_SATOSHIS + 50000 + assert wallets[0]['sweep_address'] is not None + + def test_get_wallets_over_cap_empty(self, wallet_manager): + """Test getting wallets when none are over cap.""" + wallet_manager.create_wallet(user_id=1) + wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='under cap') + + wallets = wallet_manager.get_wallets_over_cap() + assert len(wallets) == 0 + + def test_get_all_active_wallets(self, wallet_manager): + """Test getting all active (non-disabled) wallets.""" + wallet_manager.create_wallet(user_id=1) + wallet_manager.create_wallet(user_id=2) + + # Disable user 2's wallet + wallet_manager.db.data['wallets'][2]['is_disabled'] = 1 + + wallets = wallet_manager.get_all_active_wallets() + + assert len(wallets) == 1 + assert wallets[0]['user_id'] == 1 + + def test_get_wallets_for_deposit_monitoring_includes_disabled(self, wallet_manager): + """Deposit monitoring should include disabled wallets.""" + wallet_manager.create_wallet(user_id=1) + wallet_manager.create_wallet(user_id=2) + wallet_manager.db.data['wallets'][2]['is_disabled'] = 1 + + wallets = wallet_manager.get_wallets_for_deposit_monitoring() + user_ids = {w['user_id'] for w in wallets} + + assert user_ids == {1, 2} + + def test_get_pending_withdrawals(self, wallet_manager): + """Test getting pending withdrawal requests.""" + wallet_manager.create_wallet(user_id=1) + wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test') + + # Request withdrawal + wallet_manager.request_withdrawal( + user_id=1, + amount_satoshis=10000, + destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt' + ) + + pending = wallet_manager.get_pending_withdrawals() + + assert len(pending) == 1 + assert pending[0]['user_id'] == 1 + assert pending[0]['amount_satoshis'] == 10000 + + def test_auto_sweep_no_wallet(self, wallet_manager): + """Test auto_sweep fails when no wallet exists.""" + result = wallet_manager.auto_sweep(user_id=999, amount_satoshis=10000) + + assert not result['success'] + assert 'no wallet' in result['error'].lower() + + def test_auto_sweep_no_sweep_address(self, wallet_manager): + """Test auto_sweep fails when no sweep address configured.""" + wallet_manager.create_wallet(user_id=1) + # Remove sweep address + wallet_manager.db.data['wallets'][1]['sweep_address'] = None + + result = wallet_manager.auto_sweep(user_id=1, amount_satoshis=10000) + + assert not result['success'] + assert 'sweep address' in result['error'].lower() + + def test_process_withdrawal_not_found(self, wallet_manager): + """Test processing non-existent withdrawal.""" + result = wallet_manager.process_withdrawal(withdrawal_id=999) + + assert not result['success'] + assert 'not found' in result['error'].lower() + + def test_fail_withdrawal_reverses_credit(self, wallet_manager): + """Test that failing a withdrawal reverses the ledger debit.""" + wallet_manager.create_wallet(user_id=1) + wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test') + + # Request withdrawal (debits 10000) + wallet_manager.request_withdrawal( + user_id=1, + amount_satoshis=10000, + destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt' + ) + + # Balance should be reduced + assert wallet_manager.get_credits_balance(user_id=1) == 90000 + + # Fail the withdrawal + wallet_manager._fail_withdrawal(withdrawal_id=1, error_message='Test failure') + + # Balance should be restored + assert wallet_manager.get_credits_balance(user_id=1) == 100000 + + def test_process_withdrawal_none_txhash_reverses_credit(self, wallet_manager): + """None tx hash from send_transaction should fail and reverse reservation.""" + wallet_manager.create_wallet(user_id=1) + wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test') + wallet_manager.request_withdrawal( + user_id=1, + amount_satoshis=10000, + destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt' + ) + assert wallet_manager.get_credits_balance(user_id=1) == 90000 + + with patch('wallet.wallet_manager.BitcoinService.send_transaction', return_value=None): + result = wallet_manager.process_withdrawal(withdrawal_id=1) + + assert not result['success'] + assert 'no tx hash' in result['error'].lower() + assert wallet_manager.get_credits_balance(user_id=1) == 100000