brighter-trading/tests/test_wallet.py

828 lines
34 KiB
Python

"""
Tests for the wallet module.
Tests encryption, Bitcoin service, and wallet manager functionality.
"""
import pytest
import tempfile
import os
import sys
from unittest.mock import patch
# Add src to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from wallet.encryption import KeyEncryption
from wallet.bitcoin_service import BitcoinService
from wallet.wallet_manager import WalletManager, BALANCE_CAP_SATOSHIS
class TestKeyEncryption:
"""Tests for KeyEncryption."""
def test_encrypt_decrypt(self):
"""Test basic encryption and decryption."""
keys = {1: 'test_master_key_v1'}
encryption = KeyEncryption(keys)
original = "my_secret_private_key"
encrypted = encryption.encrypt(original)
decrypted = encryption.decrypt(encrypted, version=1)
assert decrypted == original
assert encrypted != original
def test_versioned_keys(self):
"""Test encryption with multiple key versions."""
keys = {1: 'key_v1', 2: 'key_v2'}
encryption = KeyEncryption(keys)
data = "secret_data"
encrypted_v1 = encryption.encrypt(data, version=1)
encrypted_v2 = encryption.encrypt(data, version=2)
# Both should decrypt correctly with their respective versions
assert encryption.decrypt(encrypted_v1, version=1) == data
assert encryption.decrypt(encrypted_v2, version=2) == data
# Encrypted data should be different for different versions
assert encrypted_v1 != encrypted_v2
def test_re_encrypt(self):
"""Test re-encryption from old version to new version."""
keys = {1: 'old_key', 2: 'new_key'}
encryption = KeyEncryption(keys)
data = "secret_data"
encrypted_v1 = encryption.encrypt(data, version=1)
# Re-encrypt to version 2
encrypted_v2, new_version = encryption.re_encrypt(encrypted_v1, old_version=1, new_version=2)
assert new_version == 2
assert encryption.decrypt(encrypted_v2, version=2) == data
def test_unknown_version_raises(self):
"""Test that decrypting with unknown version raises."""
keys = {1: 'test_key'}
encryption = KeyEncryption(keys)
with pytest.raises(ValueError, match="Unknown encryption key version"):
encryption.decrypt("some_data", version=99)
class TestBitcoinService:
"""Tests for BitcoinService."""
def test_generate_keypair_testnet(self):
"""Test generating testnet keypair."""
service = BitcoinService(testnet=True)
keypair = service.generate_keypair()
assert 'address' in keypair
assert 'public_key' in keypair
assert 'private_key' in keypair
# Testnet addresses start with m, n, or tb1
address = keypair['address']
assert address.startswith(('m', 'n', 'tb1'))
def test_validate_testnet_address(self):
"""Test validating testnet addresses."""
service = BitcoinService(testnet=True)
# Generate a valid testnet address
keypair = service.generate_keypair()
assert service.validate_address(keypair['address'])
# Invalid addresses
assert not service.validate_address('')
assert not service.validate_address(None)
assert not service.validate_address('invalid')
assert not service.validate_address('1BvBMSEYstWetqTFn5Au4m4GFg7xJaNVN2') # Mainnet
def test_validate_mainnet_address(self):
"""Test validating mainnet addresses."""
service = BitcoinService(testnet=False)
# Valid mainnet addresses
assert service.validate_address('1BvBMSEYstWetqTFn5Au4m4GFg7xJaNVN2')
assert service.validate_address('3J98t1WpEZ73CNmQviecrnyiWrnqRhWNLy')
assert service.validate_address('bc1qar0srrr7xfkvy5l643lydnw9re59gtzzwf5mdq')
# Testnet addresses should be invalid on mainnet
assert not service.validate_address('mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt')
def test_estimate_fee(self):
"""Test fee estimation."""
service = BitcoinService(testnet=True)
fee = service.estimate_fee(num_inputs=1, num_outputs=2)
assert fee > 0
assert isinstance(fee, int)
# More inputs/outputs should cost more
fee_larger = service.estimate_fee(num_inputs=3, num_outputs=5)
assert fee_larger > fee
class MockDatabase:
"""Mock database for testing WalletManager."""
def __init__(self):
self.tables = {}
self.data = {}
def execute_sql(self, sql, params=None, fetch_one=False, fetch_all=False):
"""Simple mock SQL execution."""
sql_lower = sql.lower().strip()
if sql_lower.startswith('create table') or sql_lower.startswith('create index'):
# Table creation - just ignore
return None
if sql_lower.startswith('alter table'):
# Migration - add sweep_address column
return None
if sql_lower.startswith('insert'):
# Store insert data
if 'wallets' in sql_lower:
# Two-address model: params are (user_id, btc_address, public_key, private_key, sweep_address, version, network)
self.data.setdefault('wallets', {})[params[0]] = {
'user_id': params[0],
'btc_address': params[1],
'public_key_encrypted': params[2],
'private_key_encrypted': params[3],
'sweep_address': params[4] if len(params) > 6 else None,
'encryption_key_version': params[5] if len(params) > 6 else params[4],
'network': params[6] if len(params) > 6 else params[5],
'is_disabled': 0,
'created_at': 'now'
}
elif 'credits_ledger' in sql_lower:
self.data.setdefault('credits_ledger', []).append({
'user_id': params[0],
'amount_satoshis': params[1],
'tx_type': params[2],
'reference_id': params[3],
'idempotency_key': params[-1],
})
elif 'pending_strategy_fees' in sql_lower:
self.data.setdefault('pending_fees', {})[params[0]] = {
'strategy_run_id': params[0],
'user_id': params[1],
'creator_user_id': params[2],
'fee_percent': params[3] if len(params) > 3 else 10,
'accumulated_satoshis': 0,
'trade_count': 0,
}
elif 'withdrawal_requests' in sql_lower:
self.data.setdefault('withdrawals', []).append({
'user_id': params[0],
'amount_satoshis': params[1],
'destination_address': params[2],
'status': 'reserved', # status is hardcoded in SQL
})
return None
if sql_lower.startswith('select'):
if fetch_one:
if 'wallets' in sql_lower and 'user_id' in sql_lower:
wallet = self.data.get('wallets', {}).get(params[0])
if wallet:
# Check if it's the get_wallet_keys query (includes encrypted keys)
if 'private_key_encrypted' in sql_lower:
return (wallet['btc_address'], wallet['public_key_encrypted'],
wallet['private_key_encrypted'], wallet['encryption_key_version'],
wallet['network'])
# Regular get_wallet query (5 columns including sweep_address)
return (wallet['btc_address'], wallet['network'],
wallet['is_disabled'], wallet['created_at'],
wallet.get('sweep_address'))
elif 'credits_ledger' in sql_lower and 'sum' in sql_lower:
ledger = self.data.get('credits_ledger', [])
user_total = sum(e['amount_satoshis'] for e in ledger if e['user_id'] == params[0])
return (user_total,)
elif 'credits_ledger' in sql_lower and 'idempotency' in sql_lower:
ledger = self.data.get('credits_ledger', [])
for entry in ledger:
if entry.get('idempotency_key') == params[0]:
return (1,)
return None
elif 'pending_strategy_fees' in sql_lower:
fees = self.data.get('pending_fees', {}).get(params[0])
if fees:
# Different queries select different columns
if 'fee_percent' in sql_lower and 'accumulated' not in sql_lower:
# accumulate_trade_fee: SELECT fee_percent
return (fees.get('fee_percent', 10),)
elif 'accumulated_satoshis, trade_count' in sql_lower and 'user_id' not in sql_lower.split('select')[1].split('from')[0]:
# get_pending_fees: SELECT accumulated_satoshis, trade_count
return (fees['accumulated_satoshis'], fees['trade_count'])
else:
# settle_accumulated_fees: SELECT user_id, creator_user_id, accumulated_satoshis, trade_count
return (fees['user_id'], fees['creator_user_id'],
fees['accumulated_satoshis'], fees['trade_count'])
return None
if fetch_all:
if 'credits_ledger' in sql_lower:
ledger = self.data.get('credits_ledger', [])
user_entries = [e for e in ledger if e['user_id'] == params[0]]
return [(e['tx_type'], e['amount_satoshis'], e['reference_id'], 'now')
for e in user_entries[-params[1]:]]
return []
if sql_lower.startswith('update'):
if 'wallets' in sql_lower and 'is_disabled' in sql_lower:
user_id = params[0] # Only param is user_id
if user_id in self.data.get('wallets', {}):
# Parse the value from the SQL itself (is_disabled = 0 or is_disabled = 1)
if 'is_disabled = 1' in sql_lower:
self.data['wallets'][user_id]['is_disabled'] = 1
else:
self.data['wallets'][user_id]['is_disabled'] = 0
elif 'pending_strategy_fees' in sql_lower and 'accumulated_satoshis' in sql_lower:
# Update: accumulated_satoshis = accumulated_satoshis + ?, trade_count = trade_count + 1
fee_amount = params[0]
run_id = params[1]
if run_id in self.data.get('pending_fees', {}):
self.data['pending_fees'][run_id]['accumulated_satoshis'] += fee_amount
self.data['pending_fees'][run_id]['trade_count'] += 1
return None
if sql_lower.startswith('delete'):
if 'pending_strategy_fees' in sql_lower:
run_id = params[0]
if run_id in self.data.get('pending_fees', {}):
del self.data['pending_fees'][run_id]
return None
return None
def execute_in_transaction(self, statements):
"""Execute multiple statements atomically (mock version)."""
# In mock, just execute each statement in order
for sql, params in statements:
self.execute_sql(sql, params)
return True
class TestWalletManager:
"""Tests for WalletManager."""
@pytest.fixture
def wallet_manager(self):
"""Create a wallet manager with mock database."""
db = MockDatabase()
keys = {1: 'test_encryption_key'}
return WalletManager(db, keys, default_network='testnet')
def test_create_wallet(self, wallet_manager):
"""Test creating a new wallet (two-address model)."""
result = wallet_manager.create_wallet(user_id=1)
assert result['success']
# Fee Address (we store keys)
assert 'fee_address' in result
assert 'fee_private_key' in result
# Sweep Address (we don't store private key)
assert 'sweep_address' in result
assert 'sweep_private_key' in result
assert result['network'] == 'testnet'
def test_create_wallet_duplicate(self, wallet_manager):
"""Test creating duplicate wallet fails."""
wallet_manager.create_wallet(user_id=1)
result = wallet_manager.create_wallet(user_id=1)
assert not result['success']
assert 'already exists' in result['error'].lower()
def test_get_wallet(self, wallet_manager):
"""Test getting wallet info."""
wallet_manager.create_wallet(user_id=1)
wallet = wallet_manager.get_wallet(user_id=1)
assert wallet is not None
assert 'address' in wallet
assert wallet['network'] == 'testnet'
assert 'private_key' not in wallet # Private key should not be exposed
def test_credits_balance_empty(self, wallet_manager):
"""Test credits balance starts at zero."""
balance = wallet_manager.get_credits_balance(user_id=1)
assert balance == 0
def test_admin_credit(self, wallet_manager):
"""Test admin credit adds to balance."""
wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='test')
balance = wallet_manager.get_credits_balance(user_id=1)
assert balance == 10000
def test_credit_deposit(self, wallet_manager):
"""Test deposit crediting."""
result = wallet_manager.credit_deposit(
user_id=1,
amount_satoshis=50000,
network='testnet',
tx_hash='abc123',
vout=0
)
assert result['success']
assert wallet_manager.get_credits_balance(user_id=1) == 50000
def test_deposit_idempotency(self, wallet_manager):
"""Test deposit is idempotent."""
wallet_manager.credit_deposit(
user_id=1, amount_satoshis=50000,
network='testnet', tx_hash='abc123', vout=0
)
wallet_manager.credit_deposit(
user_id=1, amount_satoshis=50000,
network='testnet', tx_hash='abc123', vout=0
)
# Should only be credited once
balance = wallet_manager.get_credits_balance(user_id=1)
assert balance == 50000
def test_fee_accumulation(self, wallet_manager):
"""Test fee accumulation during strategy run with default 10% fee."""
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
# Start accumulation with default 10% fee
result = wallet_manager.start_fee_accumulation(
strategy_run_id='run_001',
user_id=1,
creator_user_id=2,
fee_percent=10, # 10% of exchange fee
estimated_trades=5
)
assert result['success']
# Accumulate some fees
wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=10000, is_profitable=True)
wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=5000, is_profitable=True)
wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=8000, is_profitable=False) # Should not accumulate
# Check pending fees
pending = wallet_manager.get_pending_fees('run_001')
assert pending['accumulated_satoshis'] == 1500 # 10% of 10000 + 10% of 5000
assert pending['trade_count'] == 2 # Only profitable trades counted
def test_fee_accumulation_custom_percent(self, wallet_manager):
"""Test fee accumulation with custom fee percentage (50%)."""
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
# Start accumulation with 50% fee
result = wallet_manager.start_fee_accumulation(
strategy_run_id='run_002',
user_id=1,
creator_user_id=2,
fee_percent=50, # 50% of exchange fee
estimated_trades=5
)
assert result['success']
# Accumulate fee: 50% of 10000 = 5000
wallet_manager.accumulate_trade_fee('run_002', exchange_fee_satoshis=10000, is_profitable=True)
pending = wallet_manager.get_pending_fees('run_002')
assert pending['accumulated_satoshis'] == 5000 # 50% of 10000
assert pending['trade_count'] == 1
def test_fee_accumulation_100_percent(self, wallet_manager):
"""Test fee accumulation at 100% (same as exchange commission)."""
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
# Start accumulation with 100% fee
result = wallet_manager.start_fee_accumulation(
strategy_run_id='run_003',
user_id=1,
creator_user_id=2,
fee_percent=100, # 100% of exchange fee
estimated_trades=5
)
assert result['success']
# Accumulate fee: 100% of 10000 = 10000
wallet_manager.accumulate_trade_fee('run_003', exchange_fee_satoshis=10000, is_profitable=True)
pending = wallet_manager.get_pending_fees('run_003')
assert pending['accumulated_satoshis'] == 10000 # 100% of 10000
assert pending['trade_count'] == 1
def test_fee_settlement(self, wallet_manager):
"""Test fee settlement when strategy stops."""
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
wallet_manager.admin_credit(user_id=2, amount_satoshis=0, reason='creator setup')
wallet_manager.start_fee_accumulation('run_001', user_id=1, creator_user_id=2)
wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=10000, is_profitable=True)
# Settle fees
result = wallet_manager.settle_accumulated_fees('run_001')
assert result['success']
assert result['settled'] == 1000 # 10% of 10000
assert result['trades'] == 1
# Check balances
user_balance = wallet_manager.get_credits_balance(user_id=1)
creator_balance = wallet_manager.get_credits_balance(user_id=2)
assert user_balance == 99000 # 100000 - 1000
assert creator_balance == 1000 # Received fee
def test_insufficient_credits_for_strategy(self, wallet_manager):
"""Test that strategy start fails with insufficient credits."""
# User has no credits
result = wallet_manager.start_fee_accumulation(
strategy_run_id='run_001',
user_id=1,
creator_user_id=2,
estimated_trades=10
)
assert not result['success']
assert 'insufficient' in result['error'].lower()
def test_withdrawal_request(self, wallet_manager):
"""Test withdrawal request."""
wallet_manager.create_wallet(user_id=1)
wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test')
result = wallet_manager.request_withdrawal(
user_id=1,
amount_satoshis=10000,
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
)
assert result['success']
# Balance should be reduced immediately
balance = wallet_manager.get_credits_balance(user_id=1)
assert balance == 40000
def test_withdrawal_insufficient_balance(self, wallet_manager):
"""Test withdrawal fails with insufficient balance."""
wallet_manager.create_wallet(user_id=1)
wallet_manager.admin_credit(user_id=1, amount_satoshis=5000, reason='test')
result = wallet_manager.request_withdrawal(
user_id=1,
amount_satoshis=10000,
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
)
assert not result['success']
assert 'insufficient' in result['error'].lower()
def test_withdrawal_invalid_address(self, wallet_manager):
"""Test withdrawal fails with invalid address."""
wallet_manager.create_wallet(user_id=1)
wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test')
result = wallet_manager.request_withdrawal(
user_id=1,
amount_satoshis=10000,
destination_address='invalid_address'
)
assert not result['success']
assert 'invalid' in result['error'].lower()
def test_withdrawal_reservation_transaction_failure(self, wallet_manager):
"""Test withdrawal reservation fails cleanly when transaction fails."""
wallet_manager.create_wallet(user_id=1)
wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test')
def fail_transaction(_statements):
raise RuntimeError('db transaction failed')
wallet_manager.db.execute_in_transaction = fail_transaction
result = wallet_manager.request_withdrawal(
user_id=1,
amount_satoshis=10000,
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
)
assert not result['success']
assert 'failed to queue' in result['error'].lower()
# Balance should remain unchanged because reservation wasn't committed.
assert wallet_manager.get_credits_balance(user_id=1) == 50000
def test_own_strategy_no_fees(self, wallet_manager):
"""Test that running own strategy doesn't require fees."""
result = wallet_manager.start_fee_accumulation(
strategy_run_id='run_001',
user_id=1,
creator_user_id=1, # Same user
estimated_trades=10
)
assert result['success']
assert 'no fees' in result.get('message', '').lower()
def test_transaction_history(self, wallet_manager):
"""Test getting transaction history."""
wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='credit1')
wallet_manager.admin_credit(user_id=1, amount_satoshis=5000, reason='credit2')
history = wallet_manager.get_transaction_history(user_id=1, limit=10)
assert len(history) == 2
assert all('type' in tx for tx in history)
assert all('amount' in tx for tx in history)
class TestBalanceCap:
"""Tests for balance cap functionality."""
@pytest.fixture
def wallet_manager(self):
db = MockDatabase()
keys = {1: 'test_key'}
return WalletManager(db, keys, default_network='testnet')
def test_balance_cap_disables_wallet(self, wallet_manager):
"""Test that exceeding balance cap disables wallet."""
wallet_manager.create_wallet(user_id=1)
# Credit above cap
wallet_manager.admin_credit(
user_id=1,
amount_satoshis=BALANCE_CAP_SATOSHIS + 10000,
reason='over cap'
)
result = wallet_manager.check_balance_cap(user_id=1)
assert result['over_cap']
# Wallet should show as disabled
wallet = wallet_manager.get_wallet(user_id=1)
assert wallet['is_disabled']
def test_balance_under_cap_enables_wallet(self, wallet_manager):
"""Test that being under balance cap enables wallet."""
wallet_manager.create_wallet(user_id=1)
# Credit under cap
wallet_manager.admin_credit(
user_id=1,
amount_satoshis=BALANCE_CAP_SATOSHIS - 10000,
reason='under cap'
)
result = wallet_manager.check_balance_cap(user_id=1)
assert not result['over_cap']
wallet = wallet_manager.get_wallet(user_id=1)
assert not wallet['is_disabled']
class MockDatabaseForBackgroundJobs(MockDatabase):
"""Extended mock database for testing background job methods."""
def execute_sql(self, sql, params=None, fetch_one=False, fetch_all=False):
sql_lower = sql.lower().strip()
# Handle get_wallets_over_cap query
if 'left join credits_ledger' in sql_lower and 'having balance' in sql_lower:
results = []
for user_id, wallet in self.data.get('wallets', {}).items():
if not wallet.get('sweep_address'):
continue
# Calculate balance from ledger
ledger = self.data.get('credits_ledger', [])
balance = sum(e['amount_satoshis'] for e in ledger if e['user_id'] == user_id)
if balance > params[0]: # params[0] is BALANCE_CAP_SATOSHIS
results.append((user_id, wallet.get('sweep_address'), balance))
if fetch_all:
return results
return None
# Handle get_all_active_wallets query
if 'select user_id, btc_address, network' in sql_lower and 'is_disabled = 0' in sql_lower:
results = []
for user_id, wallet in self.data.get('wallets', {}).items():
if not wallet.get('is_disabled'):
results.append((user_id, wallet['btc_address'], wallet['network']))
if fetch_all:
return results
return None
# Handle get_wallets_for_deposit_monitoring query (includes disabled wallets)
if 'select user_id, btc_address, network' in sql_lower and 'from wallets' in sql_lower and 'is_disabled' not in sql_lower:
results = []
for user_id, wallet in self.data.get('wallets', {}).items():
results.append((user_id, wallet['btc_address'], wallet['network']))
if fetch_all:
return results
return None
# Handle deposits table queries
if 'insert' in sql_lower and 'deposits' in sql_lower:
self.data.setdefault('deposits', []).append({
'user_id': params[0],
'network': params[1],
'tx_hash': params[2],
'vout': params[3],
'amount_satoshis': params[4],
'credited': params[5] if len(params) > 5 else 0
})
return None
if 'select' in sql_lower and 'deposits' in sql_lower and 'network' in sql_lower and fetch_one:
deposits = self.data.get('deposits', [])
for d in deposits:
if d['network'] == params[0] and d['tx_hash'] == params[1] and d['vout'] == params[2]:
return (d.get('id', 1),)
return None
# Handle get_pending_withdrawals query
if 'withdrawal_requests' in sql_lower and "status = 'reserved'" in sql_lower and fetch_all:
withdrawals = self.data.get('withdrawals', [])
results = []
for i, w in enumerate(withdrawals):
if w.get('status') == 'reserved':
results.append((i + 1, w['user_id'], w['amount_satoshis'], w['destination_address']))
return results
# Handle withdrawal_requests SELECT for process_withdrawal and _fail_withdrawal
if 'withdrawal_requests' in sql_lower and 'select' in sql_lower and 'where id' in sql_lower and fetch_one:
withdrawals = self.data.get('withdrawals', [])
withdrawal_id = params[0]
if withdrawal_id > 0 and withdrawal_id <= len(withdrawals):
w = withdrawals[withdrawal_id - 1]
# process_withdrawal query includes status, _fail_withdrawal does not
if 'status' in sql_lower:
return (w['user_id'], w['amount_satoshis'], w['destination_address'], w.get('status', 'reserved'))
else:
return (w['user_id'], w['amount_satoshis'], w['destination_address'])
return None
# Handle withdrawal_requests UPDATE
if 'update' in sql_lower and 'withdrawal_requests' in sql_lower:
withdrawal_id = params[-1] # Last param is the ID in WHERE clause
withdrawals = self.data.get('withdrawals', [])
if withdrawal_id > 0 and withdrawal_id <= len(withdrawals):
if 'processing' in sql_lower:
withdrawals[withdrawal_id - 1]['status'] = 'processing'
elif 'completed' in sql_lower:
withdrawals[withdrawal_id - 1]['status'] = 'completed'
withdrawals[withdrawal_id - 1]['btc_txhash'] = params[0]
elif 'failed' in sql_lower:
withdrawals[withdrawal_id - 1]['status'] = 'failed'
withdrawals[withdrawal_id - 1]['error_message'] = params[0]
return None
# Fall back to parent implementation
return super().execute_sql(sql, params, fetch_one, fetch_all)
class TestBackgroundJobSupport:
"""Tests for background job support methods in WalletManager."""
@pytest.fixture
def wallet_manager(self):
db = MockDatabaseForBackgroundJobs()
keys = {1: 'test_key'}
return WalletManager(db, keys, default_network='testnet')
def test_get_wallets_over_cap(self, wallet_manager):
"""Test getting wallets with balance over cap."""
# Create wallet with sweep address
wallet_manager.create_wallet(user_id=1)
# Credit above cap
wallet_manager.admin_credit(
user_id=1,
amount_satoshis=BALANCE_CAP_SATOSHIS + 50000,
reason='over cap'
)
wallets = wallet_manager.get_wallets_over_cap()
assert len(wallets) == 1
assert wallets[0]['user_id'] == 1
assert wallets[0]['balance'] == BALANCE_CAP_SATOSHIS + 50000
assert wallets[0]['sweep_address'] is not None
def test_get_wallets_over_cap_empty(self, wallet_manager):
"""Test getting wallets when none are over cap."""
wallet_manager.create_wallet(user_id=1)
wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='under cap')
wallets = wallet_manager.get_wallets_over_cap()
assert len(wallets) == 0
def test_get_all_active_wallets(self, wallet_manager):
"""Test getting all active (non-disabled) wallets."""
wallet_manager.create_wallet(user_id=1)
wallet_manager.create_wallet(user_id=2)
# Disable user 2's wallet
wallet_manager.db.data['wallets'][2]['is_disabled'] = 1
wallets = wallet_manager.get_all_active_wallets()
assert len(wallets) == 1
assert wallets[0]['user_id'] == 1
def test_get_wallets_for_deposit_monitoring_includes_disabled(self, wallet_manager):
"""Deposit monitoring should include disabled wallets."""
wallet_manager.create_wallet(user_id=1)
wallet_manager.create_wallet(user_id=2)
wallet_manager.db.data['wallets'][2]['is_disabled'] = 1
wallets = wallet_manager.get_wallets_for_deposit_monitoring()
user_ids = {w['user_id'] for w in wallets}
assert user_ids == {1, 2}
def test_get_pending_withdrawals(self, wallet_manager):
"""Test getting pending withdrawal requests."""
wallet_manager.create_wallet(user_id=1)
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test')
# Request withdrawal
wallet_manager.request_withdrawal(
user_id=1,
amount_satoshis=10000,
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
)
pending = wallet_manager.get_pending_withdrawals()
assert len(pending) == 1
assert pending[0]['user_id'] == 1
assert pending[0]['amount_satoshis'] == 10000
def test_auto_sweep_no_wallet(self, wallet_manager):
"""Test auto_sweep fails when no wallet exists."""
result = wallet_manager.auto_sweep(user_id=999, amount_satoshis=10000)
assert not result['success']
assert 'no wallet' in result['error'].lower()
def test_auto_sweep_no_sweep_address(self, wallet_manager):
"""Test auto_sweep fails when no sweep address configured."""
wallet_manager.create_wallet(user_id=1)
# Remove sweep address
wallet_manager.db.data['wallets'][1]['sweep_address'] = None
result = wallet_manager.auto_sweep(user_id=1, amount_satoshis=10000)
assert not result['success']
assert 'sweep address' in result['error'].lower()
def test_process_withdrawal_not_found(self, wallet_manager):
"""Test processing non-existent withdrawal."""
result = wallet_manager.process_withdrawal(withdrawal_id=999)
assert not result['success']
assert 'not found' in result['error'].lower()
def test_fail_withdrawal_reverses_credit(self, wallet_manager):
"""Test that failing a withdrawal reverses the ledger debit."""
wallet_manager.create_wallet(user_id=1)
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test')
# Request withdrawal (debits 10000)
wallet_manager.request_withdrawal(
user_id=1,
amount_satoshis=10000,
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
)
# Balance should be reduced
assert wallet_manager.get_credits_balance(user_id=1) == 90000
# Fail the withdrawal
wallet_manager._fail_withdrawal(withdrawal_id=1, error_message='Test failure')
# Balance should be restored
assert wallet_manager.get_credits_balance(user_id=1) == 100000
def test_process_withdrawal_none_txhash_reverses_credit(self, wallet_manager):
"""None tx hash from send_transaction should fail and reverse reservation."""
wallet_manager.create_wallet(user_id=1)
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test')
wallet_manager.request_withdrawal(
user_id=1,
amount_satoshis=10000,
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
)
assert wallet_manager.get_credits_balance(user_id=1) == 90000
with patch('wallet.wallet_manager.BitcoinService.send_transaction', return_value=None):
result = wallet_manager.process_withdrawal(withdrawal_id=1)
assert not result['success']
assert 'no tx hash' in result['error'].lower()
assert wallet_manager.get_credits_balance(user_id=1) == 100000