110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
"""SQLite-backed server-side sessions for the web UI."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
|
|
from flask.sessions import SessionInterface, SessionMixin
|
|
from werkzeug.datastructures import CallbackDict
|
|
|
|
from cmdforge.registry.db import connect_db
|
|
|
|
SESSION_TTL = timedelta(days=7)
|
|
|
|
|
|
class SQLiteSession(CallbackDict, SessionMixin):
|
|
def __init__(self, initial=None, session_id: Optional[str] = None):
|
|
super().__init__(initial or {})
|
|
self.session_id = session_id
|
|
|
|
|
|
class SQLiteSessionInterface(SessionInterface):
|
|
def __init__(self, cookie_name: str = "cmdforge_session"):
|
|
self.cookie_name = cookie_name
|
|
|
|
def open_session(self, app, request):
|
|
session_id = request.cookies.get(self.cookie_name)
|
|
if not session_id:
|
|
return SQLiteSession(session_id=self._new_session_id())
|
|
|
|
conn = connect_db()
|
|
try:
|
|
row = conn.execute(
|
|
"SELECT data, expires_at FROM web_sessions WHERE session_id = ?",
|
|
[session_id],
|
|
).fetchone()
|
|
if not row:
|
|
return SQLiteSession(session_id=self._new_session_id())
|
|
|
|
expires_at = self._parse_dt(row["expires_at"])
|
|
if expires_at and expires_at < datetime.utcnow():
|
|
conn.execute("DELETE FROM web_sessions WHERE session_id = ?", [session_id])
|
|
conn.commit()
|
|
return SQLiteSession(session_id=self._new_session_id())
|
|
|
|
data = json.loads(row["data"] or "{}")
|
|
return SQLiteSession(initial=data, session_id=session_id)
|
|
finally:
|
|
conn.close()
|
|
|
|
def save_session(self, app, session, response):
|
|
if session is None:
|
|
return
|
|
session_id = session.session_id or self._new_session_id()
|
|
expires_at = datetime.utcnow() + SESSION_TTL
|
|
data = json.dumps(dict(session))
|
|
|
|
conn = connect_db()
|
|
try:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO web_sessions (session_id, data, created_at, expires_at)
|
|
VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(session_id) DO UPDATE SET data = excluded.data, expires_at = excluded.expires_at
|
|
""",
|
|
[session_id, data, datetime.utcnow().isoformat(), expires_at.isoformat()],
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
response.set_cookie(
|
|
self.cookie_name,
|
|
session_id,
|
|
httponly=True,
|
|
samesite="Lax",
|
|
secure=app.config.get("SESSION_COOKIE_SECURE", False),
|
|
max_age=int(SESSION_TTL.total_seconds()),
|
|
)
|
|
|
|
def rotate_session(self, session) -> None:
|
|
session.session_id = self._new_session_id()
|
|
|
|
@staticmethod
|
|
def _new_session_id() -> str:
|
|
return uuid.uuid4().hex
|
|
|
|
@staticmethod
|
|
def _parse_dt(value: str | None) -> Optional[datetime]:
|
|
if not value:
|
|
return None
|
|
try:
|
|
return datetime.fromisoformat(value)
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
def cleanup_expired_sessions() -> int:
|
|
"""Remove expired sessions from the database."""
|
|
conn = connect_db()
|
|
try:
|
|
now = datetime.utcnow().isoformat()
|
|
cursor = conn.execute("DELETE FROM web_sessions WHERE expires_at < ?", [now])
|
|
conn.commit()
|
|
return cursor.rowcount or 0
|
|
finally:
|
|
conn.close()
|