828 lines
34 KiB
Python
828 lines
34 KiB
Python
"""
|
|
Tests for the wallet module.
|
|
Tests encryption, Bitcoin service, and wallet manager functionality.
|
|
"""
|
|
import pytest
|
|
import tempfile
|
|
import os
|
|
import sys
|
|
from unittest.mock import patch
|
|
|
|
# Add src to path for imports
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
|
|
|
from wallet.encryption import KeyEncryption
|
|
from wallet.bitcoin_service import BitcoinService
|
|
from wallet.wallet_manager import WalletManager, BALANCE_CAP_SATOSHIS
|
|
|
|
|
|
class TestKeyEncryption:
|
|
"""Tests for KeyEncryption."""
|
|
|
|
def test_encrypt_decrypt(self):
|
|
"""Test basic encryption and decryption."""
|
|
keys = {1: 'test_master_key_v1'}
|
|
encryption = KeyEncryption(keys)
|
|
|
|
original = "my_secret_private_key"
|
|
encrypted = encryption.encrypt(original)
|
|
decrypted = encryption.decrypt(encrypted, version=1)
|
|
|
|
assert decrypted == original
|
|
assert encrypted != original
|
|
|
|
def test_versioned_keys(self):
|
|
"""Test encryption with multiple key versions."""
|
|
keys = {1: 'key_v1', 2: 'key_v2'}
|
|
encryption = KeyEncryption(keys)
|
|
|
|
data = "secret_data"
|
|
encrypted_v1 = encryption.encrypt(data, version=1)
|
|
encrypted_v2 = encryption.encrypt(data, version=2)
|
|
|
|
# Both should decrypt correctly with their respective versions
|
|
assert encryption.decrypt(encrypted_v1, version=1) == data
|
|
assert encryption.decrypt(encrypted_v2, version=2) == data
|
|
|
|
# Encrypted data should be different for different versions
|
|
assert encrypted_v1 != encrypted_v2
|
|
|
|
def test_re_encrypt(self):
|
|
"""Test re-encryption from old version to new version."""
|
|
keys = {1: 'old_key', 2: 'new_key'}
|
|
encryption = KeyEncryption(keys)
|
|
|
|
data = "secret_data"
|
|
encrypted_v1 = encryption.encrypt(data, version=1)
|
|
|
|
# Re-encrypt to version 2
|
|
encrypted_v2, new_version = encryption.re_encrypt(encrypted_v1, old_version=1, new_version=2)
|
|
|
|
assert new_version == 2
|
|
assert encryption.decrypt(encrypted_v2, version=2) == data
|
|
|
|
def test_unknown_version_raises(self):
|
|
"""Test that decrypting with unknown version raises."""
|
|
keys = {1: 'test_key'}
|
|
encryption = KeyEncryption(keys)
|
|
|
|
with pytest.raises(ValueError, match="Unknown encryption key version"):
|
|
encryption.decrypt("some_data", version=99)
|
|
|
|
|
|
class TestBitcoinService:
|
|
"""Tests for BitcoinService."""
|
|
|
|
def test_generate_keypair_testnet(self):
|
|
"""Test generating testnet keypair."""
|
|
service = BitcoinService(testnet=True)
|
|
keypair = service.generate_keypair()
|
|
|
|
assert 'address' in keypair
|
|
assert 'public_key' in keypair
|
|
assert 'private_key' in keypair
|
|
|
|
# Testnet addresses start with m, n, or tb1
|
|
address = keypair['address']
|
|
assert address.startswith(('m', 'n', 'tb1'))
|
|
|
|
def test_validate_testnet_address(self):
|
|
"""Test validating testnet addresses."""
|
|
service = BitcoinService(testnet=True)
|
|
|
|
# Generate a valid testnet address
|
|
keypair = service.generate_keypair()
|
|
assert service.validate_address(keypair['address'])
|
|
|
|
# Invalid addresses
|
|
assert not service.validate_address('')
|
|
assert not service.validate_address(None)
|
|
assert not service.validate_address('invalid')
|
|
assert not service.validate_address('1BvBMSEYstWetqTFn5Au4m4GFg7xJaNVN2') # Mainnet
|
|
|
|
def test_validate_mainnet_address(self):
|
|
"""Test validating mainnet addresses."""
|
|
service = BitcoinService(testnet=False)
|
|
|
|
# Valid mainnet addresses
|
|
assert service.validate_address('1BvBMSEYstWetqTFn5Au4m4GFg7xJaNVN2')
|
|
assert service.validate_address('3J98t1WpEZ73CNmQviecrnyiWrnqRhWNLy')
|
|
assert service.validate_address('bc1qar0srrr7xfkvy5l643lydnw9re59gtzzwf5mdq')
|
|
|
|
# Testnet addresses should be invalid on mainnet
|
|
assert not service.validate_address('mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt')
|
|
|
|
def test_estimate_fee(self):
|
|
"""Test fee estimation."""
|
|
service = BitcoinService(testnet=True)
|
|
|
|
fee = service.estimate_fee(num_inputs=1, num_outputs=2)
|
|
assert fee > 0
|
|
assert isinstance(fee, int)
|
|
|
|
# More inputs/outputs should cost more
|
|
fee_larger = service.estimate_fee(num_inputs=3, num_outputs=5)
|
|
assert fee_larger > fee
|
|
|
|
|
|
class MockDatabase:
|
|
"""Mock database for testing WalletManager."""
|
|
|
|
def __init__(self):
|
|
self.tables = {}
|
|
self.data = {}
|
|
|
|
def execute_sql(self, sql, params=None, fetch_one=False, fetch_all=False):
|
|
"""Simple mock SQL execution."""
|
|
sql_lower = sql.lower().strip()
|
|
|
|
if sql_lower.startswith('create table') or sql_lower.startswith('create index'):
|
|
# Table creation - just ignore
|
|
return None
|
|
|
|
if sql_lower.startswith('alter table'):
|
|
# Migration - add sweep_address column
|
|
return None
|
|
|
|
if sql_lower.startswith('insert'):
|
|
# Store insert data
|
|
if 'wallets' in sql_lower:
|
|
# Two-address model: params are (user_id, btc_address, public_key, private_key, sweep_address, version, network)
|
|
self.data.setdefault('wallets', {})[params[0]] = {
|
|
'user_id': params[0],
|
|
'btc_address': params[1],
|
|
'public_key_encrypted': params[2],
|
|
'private_key_encrypted': params[3],
|
|
'sweep_address': params[4] if len(params) > 6 else None,
|
|
'encryption_key_version': params[5] if len(params) > 6 else params[4],
|
|
'network': params[6] if len(params) > 6 else params[5],
|
|
'is_disabled': 0,
|
|
'created_at': 'now'
|
|
}
|
|
elif 'credits_ledger' in sql_lower:
|
|
self.data.setdefault('credits_ledger', []).append({
|
|
'user_id': params[0],
|
|
'amount_satoshis': params[1],
|
|
'tx_type': params[2],
|
|
'reference_id': params[3],
|
|
'idempotency_key': params[-1],
|
|
})
|
|
elif 'pending_strategy_fees' in sql_lower:
|
|
self.data.setdefault('pending_fees', {})[params[0]] = {
|
|
'strategy_run_id': params[0],
|
|
'user_id': params[1],
|
|
'creator_user_id': params[2],
|
|
'fee_percent': params[3] if len(params) > 3 else 10,
|
|
'accumulated_satoshis': 0,
|
|
'trade_count': 0,
|
|
}
|
|
elif 'withdrawal_requests' in sql_lower:
|
|
self.data.setdefault('withdrawals', []).append({
|
|
'user_id': params[0],
|
|
'amount_satoshis': params[1],
|
|
'destination_address': params[2],
|
|
'status': 'reserved', # status is hardcoded in SQL
|
|
})
|
|
return None
|
|
|
|
if sql_lower.startswith('select'):
|
|
if fetch_one:
|
|
if 'wallets' in sql_lower and 'user_id' in sql_lower:
|
|
wallet = self.data.get('wallets', {}).get(params[0])
|
|
if wallet:
|
|
# Check if it's the get_wallet_keys query (includes encrypted keys)
|
|
if 'private_key_encrypted' in sql_lower:
|
|
return (wallet['btc_address'], wallet['public_key_encrypted'],
|
|
wallet['private_key_encrypted'], wallet['encryption_key_version'],
|
|
wallet['network'])
|
|
# Regular get_wallet query (5 columns including sweep_address)
|
|
return (wallet['btc_address'], wallet['network'],
|
|
wallet['is_disabled'], wallet['created_at'],
|
|
wallet.get('sweep_address'))
|
|
elif 'credits_ledger' in sql_lower and 'sum' in sql_lower:
|
|
ledger = self.data.get('credits_ledger', [])
|
|
user_total = sum(e['amount_satoshis'] for e in ledger if e['user_id'] == params[0])
|
|
return (user_total,)
|
|
elif 'credits_ledger' in sql_lower and 'idempotency' in sql_lower:
|
|
ledger = self.data.get('credits_ledger', [])
|
|
for entry in ledger:
|
|
if entry.get('idempotency_key') == params[0]:
|
|
return (1,)
|
|
return None
|
|
elif 'pending_strategy_fees' in sql_lower:
|
|
fees = self.data.get('pending_fees', {}).get(params[0])
|
|
if fees:
|
|
# Different queries select different columns
|
|
if 'fee_percent' in sql_lower and 'accumulated' not in sql_lower:
|
|
# accumulate_trade_fee: SELECT fee_percent
|
|
return (fees.get('fee_percent', 10),)
|
|
elif 'accumulated_satoshis, trade_count' in sql_lower and 'user_id' not in sql_lower.split('select')[1].split('from')[0]:
|
|
# get_pending_fees: SELECT accumulated_satoshis, trade_count
|
|
return (fees['accumulated_satoshis'], fees['trade_count'])
|
|
else:
|
|
# settle_accumulated_fees: SELECT user_id, creator_user_id, accumulated_satoshis, trade_count
|
|
return (fees['user_id'], fees['creator_user_id'],
|
|
fees['accumulated_satoshis'], fees['trade_count'])
|
|
return None
|
|
|
|
if fetch_all:
|
|
if 'credits_ledger' in sql_lower:
|
|
ledger = self.data.get('credits_ledger', [])
|
|
user_entries = [e for e in ledger if e['user_id'] == params[0]]
|
|
return [(e['tx_type'], e['amount_satoshis'], e['reference_id'], 'now')
|
|
for e in user_entries[-params[1]:]]
|
|
return []
|
|
|
|
if sql_lower.startswith('update'):
|
|
if 'wallets' in sql_lower and 'is_disabled' in sql_lower:
|
|
user_id = params[0] # Only param is user_id
|
|
if user_id in self.data.get('wallets', {}):
|
|
# Parse the value from the SQL itself (is_disabled = 0 or is_disabled = 1)
|
|
if 'is_disabled = 1' in sql_lower:
|
|
self.data['wallets'][user_id]['is_disabled'] = 1
|
|
else:
|
|
self.data['wallets'][user_id]['is_disabled'] = 0
|
|
elif 'pending_strategy_fees' in sql_lower and 'accumulated_satoshis' in sql_lower:
|
|
# Update: accumulated_satoshis = accumulated_satoshis + ?, trade_count = trade_count + 1
|
|
fee_amount = params[0]
|
|
run_id = params[1]
|
|
if run_id in self.data.get('pending_fees', {}):
|
|
self.data['pending_fees'][run_id]['accumulated_satoshis'] += fee_amount
|
|
self.data['pending_fees'][run_id]['trade_count'] += 1
|
|
return None
|
|
|
|
if sql_lower.startswith('delete'):
|
|
if 'pending_strategy_fees' in sql_lower:
|
|
run_id = params[0]
|
|
if run_id in self.data.get('pending_fees', {}):
|
|
del self.data['pending_fees'][run_id]
|
|
return None
|
|
|
|
return None
|
|
|
|
def execute_in_transaction(self, statements):
|
|
"""Execute multiple statements atomically (mock version)."""
|
|
# In mock, just execute each statement in order
|
|
for sql, params in statements:
|
|
self.execute_sql(sql, params)
|
|
return True
|
|
|
|
|
|
class TestWalletManager:
|
|
"""Tests for WalletManager."""
|
|
|
|
@pytest.fixture
|
|
def wallet_manager(self):
|
|
"""Create a wallet manager with mock database."""
|
|
db = MockDatabase()
|
|
keys = {1: 'test_encryption_key'}
|
|
return WalletManager(db, keys, default_network='testnet')
|
|
|
|
def test_create_wallet(self, wallet_manager):
|
|
"""Test creating a new wallet (two-address model)."""
|
|
result = wallet_manager.create_wallet(user_id=1)
|
|
|
|
assert result['success']
|
|
# Fee Address (we store keys)
|
|
assert 'fee_address' in result
|
|
assert 'fee_private_key' in result
|
|
# Sweep Address (we don't store private key)
|
|
assert 'sweep_address' in result
|
|
assert 'sweep_private_key' in result
|
|
assert result['network'] == 'testnet'
|
|
|
|
def test_create_wallet_duplicate(self, wallet_manager):
|
|
"""Test creating duplicate wallet fails."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
result = wallet_manager.create_wallet(user_id=1)
|
|
|
|
assert not result['success']
|
|
assert 'already exists' in result['error'].lower()
|
|
|
|
def test_get_wallet(self, wallet_manager):
|
|
"""Test getting wallet info."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
wallet = wallet_manager.get_wallet(user_id=1)
|
|
|
|
assert wallet is not None
|
|
assert 'address' in wallet
|
|
assert wallet['network'] == 'testnet'
|
|
assert 'private_key' not in wallet # Private key should not be exposed
|
|
|
|
def test_credits_balance_empty(self, wallet_manager):
|
|
"""Test credits balance starts at zero."""
|
|
balance = wallet_manager.get_credits_balance(user_id=1)
|
|
assert balance == 0
|
|
|
|
def test_admin_credit(self, wallet_manager):
|
|
"""Test admin credit adds to balance."""
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='test')
|
|
balance = wallet_manager.get_credits_balance(user_id=1)
|
|
assert balance == 10000
|
|
|
|
def test_credit_deposit(self, wallet_manager):
|
|
"""Test deposit crediting."""
|
|
result = wallet_manager.credit_deposit(
|
|
user_id=1,
|
|
amount_satoshis=50000,
|
|
network='testnet',
|
|
tx_hash='abc123',
|
|
vout=0
|
|
)
|
|
|
|
assert result['success']
|
|
assert wallet_manager.get_credits_balance(user_id=1) == 50000
|
|
|
|
def test_deposit_idempotency(self, wallet_manager):
|
|
"""Test deposit is idempotent."""
|
|
wallet_manager.credit_deposit(
|
|
user_id=1, amount_satoshis=50000,
|
|
network='testnet', tx_hash='abc123', vout=0
|
|
)
|
|
wallet_manager.credit_deposit(
|
|
user_id=1, amount_satoshis=50000,
|
|
network='testnet', tx_hash='abc123', vout=0
|
|
)
|
|
|
|
# Should only be credited once
|
|
balance = wallet_manager.get_credits_balance(user_id=1)
|
|
assert balance == 50000
|
|
|
|
def test_fee_accumulation(self, wallet_manager):
|
|
"""Test fee accumulation during strategy run with default 10% fee."""
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
|
|
|
|
# Start accumulation with default 10% fee
|
|
result = wallet_manager.start_fee_accumulation(
|
|
strategy_run_id='run_001',
|
|
user_id=1,
|
|
creator_user_id=2,
|
|
fee_percent=10, # 10% of exchange fee
|
|
estimated_trades=5
|
|
)
|
|
assert result['success']
|
|
|
|
# Accumulate some fees
|
|
wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=10000, is_profitable=True)
|
|
wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=5000, is_profitable=True)
|
|
wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=8000, is_profitable=False) # Should not accumulate
|
|
|
|
# Check pending fees
|
|
pending = wallet_manager.get_pending_fees('run_001')
|
|
assert pending['accumulated_satoshis'] == 1500 # 10% of 10000 + 10% of 5000
|
|
assert pending['trade_count'] == 2 # Only profitable trades counted
|
|
|
|
def test_fee_accumulation_custom_percent(self, wallet_manager):
|
|
"""Test fee accumulation with custom fee percentage (50%)."""
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
|
|
|
|
# Start accumulation with 50% fee
|
|
result = wallet_manager.start_fee_accumulation(
|
|
strategy_run_id='run_002',
|
|
user_id=1,
|
|
creator_user_id=2,
|
|
fee_percent=50, # 50% of exchange fee
|
|
estimated_trades=5
|
|
)
|
|
assert result['success']
|
|
|
|
# Accumulate fee: 50% of 10000 = 5000
|
|
wallet_manager.accumulate_trade_fee('run_002', exchange_fee_satoshis=10000, is_profitable=True)
|
|
|
|
pending = wallet_manager.get_pending_fees('run_002')
|
|
assert pending['accumulated_satoshis'] == 5000 # 50% of 10000
|
|
assert pending['trade_count'] == 1
|
|
|
|
def test_fee_accumulation_100_percent(self, wallet_manager):
|
|
"""Test fee accumulation at 100% (same as exchange commission)."""
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
|
|
|
|
# Start accumulation with 100% fee
|
|
result = wallet_manager.start_fee_accumulation(
|
|
strategy_run_id='run_003',
|
|
user_id=1,
|
|
creator_user_id=2,
|
|
fee_percent=100, # 100% of exchange fee
|
|
estimated_trades=5
|
|
)
|
|
assert result['success']
|
|
|
|
# Accumulate fee: 100% of 10000 = 10000
|
|
wallet_manager.accumulate_trade_fee('run_003', exchange_fee_satoshis=10000, is_profitable=True)
|
|
|
|
pending = wallet_manager.get_pending_fees('run_003')
|
|
assert pending['accumulated_satoshis'] == 10000 # 100% of 10000
|
|
assert pending['trade_count'] == 1
|
|
|
|
def test_fee_settlement(self, wallet_manager):
|
|
"""Test fee settlement when strategy stops."""
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='initial')
|
|
wallet_manager.admin_credit(user_id=2, amount_satoshis=0, reason='creator setup')
|
|
|
|
wallet_manager.start_fee_accumulation('run_001', user_id=1, creator_user_id=2)
|
|
wallet_manager.accumulate_trade_fee('run_001', exchange_fee_satoshis=10000, is_profitable=True)
|
|
|
|
# Settle fees
|
|
result = wallet_manager.settle_accumulated_fees('run_001')
|
|
|
|
assert result['success']
|
|
assert result['settled'] == 1000 # 10% of 10000
|
|
assert result['trades'] == 1
|
|
|
|
# Check balances
|
|
user_balance = wallet_manager.get_credits_balance(user_id=1)
|
|
creator_balance = wallet_manager.get_credits_balance(user_id=2)
|
|
|
|
assert user_balance == 99000 # 100000 - 1000
|
|
assert creator_balance == 1000 # Received fee
|
|
|
|
def test_insufficient_credits_for_strategy(self, wallet_manager):
|
|
"""Test that strategy start fails with insufficient credits."""
|
|
# User has no credits
|
|
result = wallet_manager.start_fee_accumulation(
|
|
strategy_run_id='run_001',
|
|
user_id=1,
|
|
creator_user_id=2,
|
|
estimated_trades=10
|
|
)
|
|
|
|
assert not result['success']
|
|
assert 'insufficient' in result['error'].lower()
|
|
|
|
def test_withdrawal_request(self, wallet_manager):
|
|
"""Test withdrawal request."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test')
|
|
|
|
result = wallet_manager.request_withdrawal(
|
|
user_id=1,
|
|
amount_satoshis=10000,
|
|
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
|
|
)
|
|
|
|
assert result['success']
|
|
|
|
# Balance should be reduced immediately
|
|
balance = wallet_manager.get_credits_balance(user_id=1)
|
|
assert balance == 40000
|
|
|
|
def test_withdrawal_insufficient_balance(self, wallet_manager):
|
|
"""Test withdrawal fails with insufficient balance."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=5000, reason='test')
|
|
|
|
result = wallet_manager.request_withdrawal(
|
|
user_id=1,
|
|
amount_satoshis=10000,
|
|
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
|
|
)
|
|
|
|
assert not result['success']
|
|
assert 'insufficient' in result['error'].lower()
|
|
|
|
def test_withdrawal_invalid_address(self, wallet_manager):
|
|
"""Test withdrawal fails with invalid address."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test')
|
|
|
|
result = wallet_manager.request_withdrawal(
|
|
user_id=1,
|
|
amount_satoshis=10000,
|
|
destination_address='invalid_address'
|
|
)
|
|
|
|
assert not result['success']
|
|
assert 'invalid' in result['error'].lower()
|
|
|
|
def test_withdrawal_reservation_transaction_failure(self, wallet_manager):
|
|
"""Test withdrawal reservation fails cleanly when transaction fails."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=50000, reason='test')
|
|
|
|
def fail_transaction(_statements):
|
|
raise RuntimeError('db transaction failed')
|
|
|
|
wallet_manager.db.execute_in_transaction = fail_transaction
|
|
|
|
result = wallet_manager.request_withdrawal(
|
|
user_id=1,
|
|
amount_satoshis=10000,
|
|
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
|
|
)
|
|
|
|
assert not result['success']
|
|
assert 'failed to queue' in result['error'].lower()
|
|
# Balance should remain unchanged because reservation wasn't committed.
|
|
assert wallet_manager.get_credits_balance(user_id=1) == 50000
|
|
|
|
def test_own_strategy_no_fees(self, wallet_manager):
|
|
"""Test that running own strategy doesn't require fees."""
|
|
result = wallet_manager.start_fee_accumulation(
|
|
strategy_run_id='run_001',
|
|
user_id=1,
|
|
creator_user_id=1, # Same user
|
|
estimated_trades=10
|
|
)
|
|
|
|
assert result['success']
|
|
assert 'no fees' in result.get('message', '').lower()
|
|
|
|
def test_transaction_history(self, wallet_manager):
|
|
"""Test getting transaction history."""
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='credit1')
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=5000, reason='credit2')
|
|
|
|
history = wallet_manager.get_transaction_history(user_id=1, limit=10)
|
|
|
|
assert len(history) == 2
|
|
assert all('type' in tx for tx in history)
|
|
assert all('amount' in tx for tx in history)
|
|
|
|
|
|
class TestBalanceCap:
|
|
"""Tests for balance cap functionality."""
|
|
|
|
@pytest.fixture
|
|
def wallet_manager(self):
|
|
db = MockDatabase()
|
|
keys = {1: 'test_key'}
|
|
return WalletManager(db, keys, default_network='testnet')
|
|
|
|
def test_balance_cap_disables_wallet(self, wallet_manager):
|
|
"""Test that exceeding balance cap disables wallet."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
|
|
# Credit above cap
|
|
wallet_manager.admin_credit(
|
|
user_id=1,
|
|
amount_satoshis=BALANCE_CAP_SATOSHIS + 10000,
|
|
reason='over cap'
|
|
)
|
|
|
|
result = wallet_manager.check_balance_cap(user_id=1)
|
|
assert result['over_cap']
|
|
|
|
# Wallet should show as disabled
|
|
wallet = wallet_manager.get_wallet(user_id=1)
|
|
assert wallet['is_disabled']
|
|
|
|
def test_balance_under_cap_enables_wallet(self, wallet_manager):
|
|
"""Test that being under balance cap enables wallet."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
|
|
# Credit under cap
|
|
wallet_manager.admin_credit(
|
|
user_id=1,
|
|
amount_satoshis=BALANCE_CAP_SATOSHIS - 10000,
|
|
reason='under cap'
|
|
)
|
|
|
|
result = wallet_manager.check_balance_cap(user_id=1)
|
|
assert not result['over_cap']
|
|
|
|
wallet = wallet_manager.get_wallet(user_id=1)
|
|
assert not wallet['is_disabled']
|
|
|
|
|
|
class MockDatabaseForBackgroundJobs(MockDatabase):
|
|
"""Extended mock database for testing background job methods."""
|
|
|
|
def execute_sql(self, sql, params=None, fetch_one=False, fetch_all=False):
|
|
sql_lower = sql.lower().strip()
|
|
|
|
# Handle get_wallets_over_cap query
|
|
if 'left join credits_ledger' in sql_lower and 'having balance' in sql_lower:
|
|
results = []
|
|
for user_id, wallet in self.data.get('wallets', {}).items():
|
|
if not wallet.get('sweep_address'):
|
|
continue
|
|
# Calculate balance from ledger
|
|
ledger = self.data.get('credits_ledger', [])
|
|
balance = sum(e['amount_satoshis'] for e in ledger if e['user_id'] == user_id)
|
|
if balance > params[0]: # params[0] is BALANCE_CAP_SATOSHIS
|
|
results.append((user_id, wallet.get('sweep_address'), balance))
|
|
if fetch_all:
|
|
return results
|
|
return None
|
|
|
|
# Handle get_all_active_wallets query
|
|
if 'select user_id, btc_address, network' in sql_lower and 'is_disabled = 0' in sql_lower:
|
|
results = []
|
|
for user_id, wallet in self.data.get('wallets', {}).items():
|
|
if not wallet.get('is_disabled'):
|
|
results.append((user_id, wallet['btc_address'], wallet['network']))
|
|
if fetch_all:
|
|
return results
|
|
return None
|
|
|
|
# Handle get_wallets_for_deposit_monitoring query (includes disabled wallets)
|
|
if 'select user_id, btc_address, network' in sql_lower and 'from wallets' in sql_lower and 'is_disabled' not in sql_lower:
|
|
results = []
|
|
for user_id, wallet in self.data.get('wallets', {}).items():
|
|
results.append((user_id, wallet['btc_address'], wallet['network']))
|
|
if fetch_all:
|
|
return results
|
|
return None
|
|
|
|
# Handle deposits table queries
|
|
if 'insert' in sql_lower and 'deposits' in sql_lower:
|
|
self.data.setdefault('deposits', []).append({
|
|
'user_id': params[0],
|
|
'network': params[1],
|
|
'tx_hash': params[2],
|
|
'vout': params[3],
|
|
'amount_satoshis': params[4],
|
|
'credited': params[5] if len(params) > 5 else 0
|
|
})
|
|
return None
|
|
|
|
if 'select' in sql_lower and 'deposits' in sql_lower and 'network' in sql_lower and fetch_one:
|
|
deposits = self.data.get('deposits', [])
|
|
for d in deposits:
|
|
if d['network'] == params[0] and d['tx_hash'] == params[1] and d['vout'] == params[2]:
|
|
return (d.get('id', 1),)
|
|
return None
|
|
|
|
# Handle get_pending_withdrawals query
|
|
if 'withdrawal_requests' in sql_lower and "status = 'reserved'" in sql_lower and fetch_all:
|
|
withdrawals = self.data.get('withdrawals', [])
|
|
results = []
|
|
for i, w in enumerate(withdrawals):
|
|
if w.get('status') == 'reserved':
|
|
results.append((i + 1, w['user_id'], w['amount_satoshis'], w['destination_address']))
|
|
return results
|
|
|
|
# Handle withdrawal_requests SELECT for process_withdrawal and _fail_withdrawal
|
|
if 'withdrawal_requests' in sql_lower and 'select' in sql_lower and 'where id' in sql_lower and fetch_one:
|
|
withdrawals = self.data.get('withdrawals', [])
|
|
withdrawal_id = params[0]
|
|
if withdrawal_id > 0 and withdrawal_id <= len(withdrawals):
|
|
w = withdrawals[withdrawal_id - 1]
|
|
# process_withdrawal query includes status, _fail_withdrawal does not
|
|
if 'status' in sql_lower:
|
|
return (w['user_id'], w['amount_satoshis'], w['destination_address'], w.get('status', 'reserved'))
|
|
else:
|
|
return (w['user_id'], w['amount_satoshis'], w['destination_address'])
|
|
return None
|
|
|
|
# Handle withdrawal_requests UPDATE
|
|
if 'update' in sql_lower and 'withdrawal_requests' in sql_lower:
|
|
withdrawal_id = params[-1] # Last param is the ID in WHERE clause
|
|
withdrawals = self.data.get('withdrawals', [])
|
|
if withdrawal_id > 0 and withdrawal_id <= len(withdrawals):
|
|
if 'processing' in sql_lower:
|
|
withdrawals[withdrawal_id - 1]['status'] = 'processing'
|
|
elif 'completed' in sql_lower:
|
|
withdrawals[withdrawal_id - 1]['status'] = 'completed'
|
|
withdrawals[withdrawal_id - 1]['btc_txhash'] = params[0]
|
|
elif 'failed' in sql_lower:
|
|
withdrawals[withdrawal_id - 1]['status'] = 'failed'
|
|
withdrawals[withdrawal_id - 1]['error_message'] = params[0]
|
|
return None
|
|
|
|
# Fall back to parent implementation
|
|
return super().execute_sql(sql, params, fetch_one, fetch_all)
|
|
|
|
|
|
class TestBackgroundJobSupport:
|
|
"""Tests for background job support methods in WalletManager."""
|
|
|
|
@pytest.fixture
|
|
def wallet_manager(self):
|
|
db = MockDatabaseForBackgroundJobs()
|
|
keys = {1: 'test_key'}
|
|
return WalletManager(db, keys, default_network='testnet')
|
|
|
|
def test_get_wallets_over_cap(self, wallet_manager):
|
|
"""Test getting wallets with balance over cap."""
|
|
# Create wallet with sweep address
|
|
wallet_manager.create_wallet(user_id=1)
|
|
|
|
# Credit above cap
|
|
wallet_manager.admin_credit(
|
|
user_id=1,
|
|
amount_satoshis=BALANCE_CAP_SATOSHIS + 50000,
|
|
reason='over cap'
|
|
)
|
|
|
|
wallets = wallet_manager.get_wallets_over_cap()
|
|
|
|
assert len(wallets) == 1
|
|
assert wallets[0]['user_id'] == 1
|
|
assert wallets[0]['balance'] == BALANCE_CAP_SATOSHIS + 50000
|
|
assert wallets[0]['sweep_address'] is not None
|
|
|
|
def test_get_wallets_over_cap_empty(self, wallet_manager):
|
|
"""Test getting wallets when none are over cap."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=10000, reason='under cap')
|
|
|
|
wallets = wallet_manager.get_wallets_over_cap()
|
|
assert len(wallets) == 0
|
|
|
|
def test_get_all_active_wallets(self, wallet_manager):
|
|
"""Test getting all active (non-disabled) wallets."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
wallet_manager.create_wallet(user_id=2)
|
|
|
|
# Disable user 2's wallet
|
|
wallet_manager.db.data['wallets'][2]['is_disabled'] = 1
|
|
|
|
wallets = wallet_manager.get_all_active_wallets()
|
|
|
|
assert len(wallets) == 1
|
|
assert wallets[0]['user_id'] == 1
|
|
|
|
def test_get_wallets_for_deposit_monitoring_includes_disabled(self, wallet_manager):
|
|
"""Deposit monitoring should include disabled wallets."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
wallet_manager.create_wallet(user_id=2)
|
|
wallet_manager.db.data['wallets'][2]['is_disabled'] = 1
|
|
|
|
wallets = wallet_manager.get_wallets_for_deposit_monitoring()
|
|
user_ids = {w['user_id'] for w in wallets}
|
|
|
|
assert user_ids == {1, 2}
|
|
|
|
def test_get_pending_withdrawals(self, wallet_manager):
|
|
"""Test getting pending withdrawal requests."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test')
|
|
|
|
# Request withdrawal
|
|
wallet_manager.request_withdrawal(
|
|
user_id=1,
|
|
amount_satoshis=10000,
|
|
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
|
|
)
|
|
|
|
pending = wallet_manager.get_pending_withdrawals()
|
|
|
|
assert len(pending) == 1
|
|
assert pending[0]['user_id'] == 1
|
|
assert pending[0]['amount_satoshis'] == 10000
|
|
|
|
def test_auto_sweep_no_wallet(self, wallet_manager):
|
|
"""Test auto_sweep fails when no wallet exists."""
|
|
result = wallet_manager.auto_sweep(user_id=999, amount_satoshis=10000)
|
|
|
|
assert not result['success']
|
|
assert 'no wallet' in result['error'].lower()
|
|
|
|
def test_auto_sweep_no_sweep_address(self, wallet_manager):
|
|
"""Test auto_sweep fails when no sweep address configured."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
# Remove sweep address
|
|
wallet_manager.db.data['wallets'][1]['sweep_address'] = None
|
|
|
|
result = wallet_manager.auto_sweep(user_id=1, amount_satoshis=10000)
|
|
|
|
assert not result['success']
|
|
assert 'sweep address' in result['error'].lower()
|
|
|
|
def test_process_withdrawal_not_found(self, wallet_manager):
|
|
"""Test processing non-existent withdrawal."""
|
|
result = wallet_manager.process_withdrawal(withdrawal_id=999)
|
|
|
|
assert not result['success']
|
|
assert 'not found' in result['error'].lower()
|
|
|
|
def test_fail_withdrawal_reverses_credit(self, wallet_manager):
|
|
"""Test that failing a withdrawal reverses the ledger debit."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test')
|
|
|
|
# Request withdrawal (debits 10000)
|
|
wallet_manager.request_withdrawal(
|
|
user_id=1,
|
|
amount_satoshis=10000,
|
|
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
|
|
)
|
|
|
|
# Balance should be reduced
|
|
assert wallet_manager.get_credits_balance(user_id=1) == 90000
|
|
|
|
# Fail the withdrawal
|
|
wallet_manager._fail_withdrawal(withdrawal_id=1, error_message='Test failure')
|
|
|
|
# Balance should be restored
|
|
assert wallet_manager.get_credits_balance(user_id=1) == 100000
|
|
|
|
def test_process_withdrawal_none_txhash_reverses_credit(self, wallet_manager):
|
|
"""None tx hash from send_transaction should fail and reverse reservation."""
|
|
wallet_manager.create_wallet(user_id=1)
|
|
wallet_manager.admin_credit(user_id=1, amount_satoshis=100000, reason='test')
|
|
wallet_manager.request_withdrawal(
|
|
user_id=1,
|
|
amount_satoshis=10000,
|
|
destination_address='mkHS9ne12qx9pS9VojpwU5xtRd4T7X7ZUt'
|
|
)
|
|
assert wallet_manager.get_credits_balance(user_id=1) == 90000
|
|
|
|
with patch('wallet.wallet_manager.BitcoinService.send_transaction', return_value=None):
|
|
result = wallet_manager.process_withdrawal(withdrawal_id=1)
|
|
|
|
assert not result['success']
|
|
assert 'no tx hash' in result['error'].lower()
|
|
assert wallet_manager.get_credits_balance(user_id=1) == 100000
|