316 lines
9.9 KiB
Python
316 lines
9.9 KiB
Python
"""Tests for session management."""
|
|
|
|
import pytest
|
|
from datetime import datetime, timedelta, timezone
|
|
import asyncio
|
|
|
|
from exchange_data_manager.sessions import Session, ExchangeCredentials, SessionManager
|
|
|
|
|
|
class TestExchangeCredentials:
|
|
"""Tests for ExchangeCredentials dataclass."""
|
|
|
|
def test_create_credentials(self):
|
|
"""Test creating credentials."""
|
|
creds = ExchangeCredentials(
|
|
api_key="test_key",
|
|
api_secret="test_secret",
|
|
)
|
|
|
|
assert creds.api_key == "test_key"
|
|
assert creds.api_secret == "test_secret"
|
|
assert creds.passphrase is None
|
|
assert creds.testnet is False
|
|
|
|
def test_credentials_with_passphrase(self):
|
|
"""Test credentials with passphrase (e.g., KuCoin)."""
|
|
creds = ExchangeCredentials(
|
|
api_key="key",
|
|
api_secret="secret",
|
|
passphrase="pass",
|
|
testnet=True,
|
|
)
|
|
|
|
assert creds.passphrase == "pass"
|
|
assert creds.testnet is True
|
|
|
|
def test_repr_hides_secrets(self):
|
|
"""Test that repr doesn't expose sensitive data."""
|
|
creds = ExchangeCredentials(
|
|
api_key="super_secret_key",
|
|
api_secret="super_secret_secret",
|
|
)
|
|
|
|
repr_str = repr(creds)
|
|
assert "super_secret" not in repr_str
|
|
assert "***" in repr_str
|
|
|
|
|
|
class TestSession:
|
|
"""Tests for Session dataclass."""
|
|
|
|
def test_create_session(self):
|
|
"""Test creating a session."""
|
|
session = Session()
|
|
|
|
assert session.id is not None
|
|
assert len(session.id) == 36 # UUID format
|
|
assert session.created_at is not None
|
|
assert session.expires_at is None
|
|
assert len(session.credentials) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_credentials(self):
|
|
"""Test adding credentials to a session."""
|
|
session = Session()
|
|
creds = ExchangeCredentials(api_key="key", api_secret="secret")
|
|
|
|
await session.add_credentials("binance", creds)
|
|
|
|
assert session.has_credentials("binance")
|
|
assert session.has_credentials("BINANCE") # Case insensitive
|
|
assert not session.has_credentials("kucoin")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_remove_credentials(self):
|
|
"""Test removing credentials from a session."""
|
|
session = Session()
|
|
creds = ExchangeCredentials(api_key="key", api_secret="secret")
|
|
|
|
await session.add_credentials("binance", creds)
|
|
assert session.has_credentials("binance")
|
|
|
|
result = await session.remove_credentials("binance")
|
|
assert result is True
|
|
assert not session.has_credentials("binance")
|
|
|
|
# Removing non-existent returns False
|
|
result = await session.remove_credentials("binance")
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exchanges_property(self):
|
|
"""Test exchanges property."""
|
|
session = Session()
|
|
await session.add_credentials("binance", ExchangeCredentials("k", "s"))
|
|
await session.add_credentials("kucoin", ExchangeCredentials("k", "s"))
|
|
|
|
assert "binance" in session.exchanges
|
|
assert "kucoin" in session.exchanges
|
|
assert len(session.exchanges) == 2
|
|
|
|
def test_is_expired_no_expiry(self):
|
|
"""Test is_expired when no expiry set."""
|
|
session = Session()
|
|
assert session.is_expired() is False
|
|
|
|
def test_is_expired_future(self):
|
|
"""Test is_expired with future expiry."""
|
|
session = Session(
|
|
expires_at=datetime.now(timezone.utc) + timedelta(hours=1)
|
|
)
|
|
assert session.is_expired() is False
|
|
|
|
def test_is_expired_past(self):
|
|
"""Test is_expired with past expiry."""
|
|
session = Session(
|
|
expires_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
|
)
|
|
assert session.is_expired() is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup(self):
|
|
"""Test session cleanup."""
|
|
session = Session()
|
|
await session.add_credentials("binance", ExchangeCredentials("k", "s"))
|
|
await session.add_credentials("kucoin", ExchangeCredentials("k", "s"))
|
|
|
|
await session.cleanup()
|
|
|
|
assert len(session.credentials) == 0
|
|
assert len(session._connectors) == 0
|
|
|
|
|
|
class TestSessionManager:
|
|
"""Tests for SessionManager class."""
|
|
|
|
@pytest.fixture
|
|
def manager(self):
|
|
"""Create a session manager for testing."""
|
|
return SessionManager(
|
|
session_timeout_minutes=60,
|
|
cleanup_interval_seconds=300,
|
|
)
|
|
|
|
def test_create_session(self, manager):
|
|
"""Test creating a session."""
|
|
session = manager.create_session()
|
|
|
|
assert session is not None
|
|
assert session.id in manager._sessions
|
|
assert session.expires_at is not None
|
|
|
|
def test_get_session(self, manager):
|
|
"""Test getting a session by ID."""
|
|
session = manager.create_session()
|
|
|
|
retrieved = manager.get_session(session.id)
|
|
assert retrieved is session
|
|
|
|
def test_get_session_not_found(self, manager):
|
|
"""Test getting a non-existent session."""
|
|
result = manager.get_session("nonexistent-id")
|
|
assert result is None
|
|
|
|
def test_get_session_expired(self, manager):
|
|
"""Test getting an expired session returns None."""
|
|
session = manager.create_session()
|
|
session.expires_at = datetime.now(timezone.utc) - timedelta(hours=1)
|
|
|
|
result = manager.get_session(session.id)
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_destroy_session(self, manager):
|
|
"""Test destroying a session."""
|
|
session = manager.create_session()
|
|
session_id = session.id
|
|
|
|
result = await manager.destroy_session(session_id)
|
|
|
|
assert result is True
|
|
assert session_id not in manager._sessions
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_destroy_session_not_found(self, manager):
|
|
"""Test destroying a non-existent session."""
|
|
result = await manager.destroy_session("nonexistent")
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_exchange_credentials(self, manager):
|
|
"""Test adding credentials to a session."""
|
|
session = manager.create_session()
|
|
|
|
result = await manager.add_exchange_credentials(
|
|
session_id=session.id,
|
|
exchange="binance",
|
|
api_key="test_key",
|
|
api_secret="test_secret",
|
|
)
|
|
|
|
assert result is True
|
|
assert session.has_credentials("binance")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_exchange_credentials_session_not_found(self, manager):
|
|
"""Test adding credentials to non-existent session."""
|
|
result = await manager.add_exchange_credentials(
|
|
session_id="nonexistent",
|
|
exchange="binance",
|
|
api_key="key",
|
|
api_secret="secret",
|
|
)
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_remove_exchange_credentials(self, manager):
|
|
"""Test removing credentials from a session."""
|
|
session = manager.create_session()
|
|
await manager.add_exchange_credentials(
|
|
session.id, "binance", "key", "secret"
|
|
)
|
|
|
|
result = await manager.remove_exchange_credentials(session.id, "binance")
|
|
assert result is True
|
|
assert not session.has_credentials("binance")
|
|
|
|
def test_refresh_session(self, manager):
|
|
"""Test refreshing a session."""
|
|
session = manager.create_session()
|
|
original_expiry = session.expires_at
|
|
|
|
# Wait a tiny bit
|
|
import time
|
|
time.sleep(0.01)
|
|
|
|
result = manager.refresh_session(session.id)
|
|
|
|
assert result is True
|
|
assert session.expires_at > original_expiry
|
|
|
|
def test_refresh_session_not_found(self, manager):
|
|
"""Test refreshing non-existent session."""
|
|
result = manager.refresh_session("nonexistent")
|
|
assert result is False
|
|
|
|
def test_active_session_count(self, manager):
|
|
"""Test active session count."""
|
|
assert manager.active_session_count == 0
|
|
|
|
manager.create_session()
|
|
assert manager.active_session_count == 1
|
|
|
|
manager.create_session()
|
|
assert manager.active_session_count == 2
|
|
|
|
def test_stats(self, manager):
|
|
"""Test session stats."""
|
|
manager.create_session()
|
|
session2 = manager.create_session()
|
|
session2.expires_at = datetime.now(timezone.utc) - timedelta(hours=1)
|
|
|
|
stats = manager.stats()
|
|
|
|
assert stats["total_sessions"] == 2
|
|
assert stats["active_sessions"] == 1
|
|
assert stats["expired_pending_cleanup"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_stop(self, manager):
|
|
"""Test starting and stopping the manager."""
|
|
await manager.start()
|
|
assert manager._running is True
|
|
assert manager._cleanup_task is not None
|
|
|
|
await manager.stop()
|
|
assert manager._running is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_cleans_up_sessions(self, manager):
|
|
"""Test that stop cleans up all sessions."""
|
|
session = manager.create_session()
|
|
await session.add_credentials(
|
|
"binance",
|
|
ExchangeCredentials("key", "secret")
|
|
)
|
|
|
|
await manager.start()
|
|
await manager.stop()
|
|
|
|
assert len(manager._sessions) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_loop(self):
|
|
"""Test that cleanup loop removes expired sessions."""
|
|
# Use short intervals for testing
|
|
manager = SessionManager(
|
|
session_timeout_minutes=1,
|
|
cleanup_interval_seconds=0.1, # 100ms
|
|
)
|
|
|
|
session = manager.create_session()
|
|
# Expire the session immediately
|
|
session.expires_at = datetime.now(timezone.utc) - timedelta(seconds=1)
|
|
|
|
await manager.start()
|
|
|
|
# Wait for cleanup
|
|
await asyncio.sleep(0.2)
|
|
|
|
# Session should be cleaned up
|
|
assert session.id not in manager._sessions
|
|
|
|
await manager.stop()
|