"""SQLite-backed server-side sessions for the web UI.""" from __future__ import annotations import json import uuid from datetime import datetime, timedelta from typing import Optional from flask.sessions import SessionInterface, SessionMixin from werkzeug.datastructures import CallbackDict from cmdforge.registry.db import connect_db SESSION_TTL = timedelta(days=7) class SQLiteSession(CallbackDict, SessionMixin): def __init__(self, initial=None, session_id: Optional[str] = None): super().__init__(initial or {}) self.session_id = session_id class SQLiteSessionInterface(SessionInterface): def __init__(self, cookie_name: str = "cmdforge_session"): self.cookie_name = cookie_name def open_session(self, app, request): session_id = request.cookies.get(self.cookie_name) if not session_id: return SQLiteSession(session_id=self._new_session_id()) conn = connect_db() try: row = conn.execute( "SELECT data, expires_at FROM web_sessions WHERE session_id = ?", [session_id], ).fetchone() if not row: return SQLiteSession(session_id=self._new_session_id()) expires_at = self._parse_dt(row["expires_at"]) if expires_at and expires_at < datetime.utcnow(): conn.execute("DELETE FROM web_sessions WHERE session_id = ?", [session_id]) conn.commit() return SQLiteSession(session_id=self._new_session_id()) data = json.loads(row["data"] or "{}") return SQLiteSession(initial=data, session_id=session_id) finally: conn.close() def save_session(self, app, session, response): if session is None: return session_id = session.session_id or self._new_session_id() expires_at = datetime.utcnow() + SESSION_TTL data = json.dumps(dict(session)) conn = connect_db() try: conn.execute( """ INSERT INTO web_sessions (session_id, data, created_at, expires_at) VALUES (?, ?, ?, ?) ON CONFLICT(session_id) DO UPDATE SET data = excluded.data, expires_at = excluded.expires_at """, [session_id, data, datetime.utcnow().isoformat(), expires_at.isoformat()], ) conn.commit() finally: conn.close() response.set_cookie( self.cookie_name, session_id, httponly=True, samesite="Lax", secure=app.config.get("SESSION_COOKIE_SECURE", False), max_age=int(SESSION_TTL.total_seconds()), ) def rotate_session(self, session) -> None: session.session_id = self._new_session_id() @staticmethod def _new_session_id() -> str: return uuid.uuid4().hex @staticmethod def _parse_dt(value: str | None) -> Optional[datetime]: if not value: return None try: return datetime.fromisoformat(value) except ValueError: return None def cleanup_expired_sessions() -> int: """Remove expired sessions from the database.""" conn = connect_db() try: now = datetime.utcnow().isoformat() cursor = conn.execute("DELETE FROM web_sessions WHERE expires_at < ?", [now]) conn.commit() return cursor.rowcount or 0 finally: conn.close()