import logging import os import shutil from contextlib import asynccontextmanager from typing import Optional from alembic import command from alembic.config import Config from alembic.runtime.migration import MigrationContext from alembic.script import ScriptDirectory from sqlalchemy import create_engine, text from sqlalchemy.engine import make_url from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from comfy.cli_args import args LOGGER = logging.getLogger(__name__) ENGINE: Optional[AsyncEngine] = None SESSION: Optional[async_sessionmaker] = None def _root_paths(): """Resolve alembic.ini and migrations script folder.""" root_path = os.path.abspath(os.path.dirname(__file__)) config_path = os.path.abspath(os.path.join(root_path, "../alembic.ini")) scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) return config_path, scripts_path def _absolutize_sqlite_url(db_url: str) -> str: """Make SQLite database path absolute. No-op for non-SQLite URLs.""" try: u = make_url(db_url) except Exception: return db_url if not u.drivername.startswith("sqlite"): return db_url db_path: str = u.database or "" if isinstance(db_path, str) and db_path.startswith("file:"): return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared" if not os.path.isabs(db_path): db_path = os.path.abspath(os.path.join(os.getcwd(), db_path)) u = u.set(database=db_path) return str(u) def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]: """ If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory), rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present. Returns: (normalized_url, is_memory) """ try: u = make_url(db_url) except Exception: return db_url, False if not u.drivername.startswith("sqlite"): return db_url, False db = u.database or "" if db == ":memory:": u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true") return str(u), True if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db: if "uri=true" not in db: u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true")) return str(u), True return str(u), False def _get_sqlite_file_path(sync_url: str) -> Optional[str]: """Return the on-disk path for a SQLite URL, else None.""" try: u = make_url(sync_url) except Exception: return None if not u.drivername.startswith("sqlite"): return None db_path = u.database if isinstance(db_path, str) and db_path.startswith("file:"): return None # Not a real file if it is a URI like "file:...?" return db_path def _get_alembic_config(sync_url: str) -> Config: """Prepare Alembic Config with script location and DB URL.""" config_path, scripts_path = _root_paths() cfg = Config(config_path) cfg.set_main_option("script_location", scripts_path) cfg.set_main_option("sqlalchemy.url", sync_url) return cfg async def init_db_engine() -> None: """Initialize async engine + sessionmaker and run migrations to head. This must be called once on application startup before any DB usage. """ global ENGINE, SESSION if ENGINE is not None: return raw_url = args.database_url if not raw_url: raise RuntimeError("Database URL is not configured.") db_url, is_mem = _normalize_sqlite_memory_url(raw_url) db_url = _absolutize_sqlite_url(db_url) # Prepare async engine connect_args = {} if db_url.startswith("sqlite"): connect_args = { "check_same_thread": False, "timeout": 12, } if is_mem: connect_args["uri"] = True ENGINE = create_async_engine( db_url, connect_args=connect_args, pool_pre_ping=True, future=True, ) # Enforce SQLite pragmas on the async engine if db_url.startswith("sqlite"): async with ENGINE.begin() as conn: if not is_mem: # WAL for concurrency and durability, Foreign Keys for referential integrity current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar() if str(current_mode).lower() != "wal": new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar() if str(new_mode).lower() != "wal": raise RuntimeError("Failed to set SQLite journal mode to WAL.") LOGGER.info("SQLite journal mode set to WAL.") await conn.execute(text("PRAGMA foreign_keys = ON;")) await conn.execute(text("PRAGMA synchronous = NORMAL;")) await _run_migrations(database_url=db_url, connect_args=connect_args) SESSION = async_sessionmaker( bind=ENGINE, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async def _run_migrations(database_url: str, connect_args: dict) -> None: if database_url.find("postgresql+psycopg") == -1: """SQLite: Convert an async SQLAlchemy URL to a sync URL for Alembic.""" u = make_url(database_url) driver = u.drivername if not driver.startswith("sqlite+aiosqlite"): raise ValueError(f"Unsupported DB driver: {driver}") database_url, is_mem = _normalize_sqlite_memory_url(str(u.set(drivername="sqlite"))) database_url = _absolutize_sqlite_url(database_url) cfg = _get_alembic_config(database_url) engine = create_engine(database_url, future=True, connect_args=connect_args) with engine.connect() as conn: context = MigrationContext.configure(conn) current_rev = context.get_current_revision() script = ScriptDirectory.from_config(cfg) target_rev = script.get_current_head() if target_rev is None: LOGGER.warning("Alembic: no target revision found.") return if current_rev == target_rev: LOGGER.debug("Alembic: database already at head %s", target_rev) return LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev) # Optional backup for SQLite file DBs backup_path = None sqlite_path = _get_sqlite_file_path(database_url) if sqlite_path and os.path.exists(sqlite_path): backup_path = sqlite_path + ".bkp" try: shutil.copy(sqlite_path, backup_path) except Exception as exc: LOGGER.warning("Failed to create SQLite backup before migration: %s", exc) try: command.upgrade(cfg, target_rev) except Exception: if backup_path and os.path.exists(backup_path): LOGGER.exception("Error upgrading database, attempting restore from backup.") try: shutil.copy(backup_path, sqlite_path) # restore os.remove(backup_path) except Exception as re: LOGGER.error("Failed to restore SQLite backup: %s", re) else: LOGGER.exception("Error upgrading database, backup is not available.") raise def get_engine(): """Return the global async engine (initialized after init_db_engine()).""" if ENGINE is None: raise RuntimeError("Engine is not initialized. Call init_db_engine() first.") return ENGINE def get_session_maker(): """Return the global async_sessionmaker (initialized after init_db_engine()).""" if SESSION is None: raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.") return SESSION @asynccontextmanager async def session_scope(): """Async context manager for a unit of work: async with session_scope() as sess: ... use sess ... """ maker = get_session_maker() async with maker() as sess: try: yield sess await sess.commit() except Exception: await sess.rollback() raise async def create_session(): """Convenience helper to acquire a single AsyncSession instance. Typical usage: async with (await create_session()) as sess: ... """ maker = get_session_maker() return maker()