mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
268 lines
8.3 KiB
Python
268 lines
8.3 KiB
Python
import logging
|
|
import os
|
|
import shutil
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional
|
|
|
|
from app.logger import log_startup_warning
|
|
from utils.install_util import get_missing_requirements_message
|
|
from comfy.cli_args import args
|
|
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
# Attempt imports which may not exist in some environments
|
|
try:
|
|
from alembic import command
|
|
from alembic.config import Config
|
|
from alembic.runtime.migration import MigrationContext
|
|
from alembic.script import ScriptDirectory
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.engine import make_url
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
_DB_AVAILABLE = True
|
|
ENGINE: AsyncEngine | None = None
|
|
SESSION: async_sessionmaker | None = None
|
|
except ImportError as e:
|
|
log_startup_warning(
|
|
(
|
|
"------------------------------------------------------------------------\n"
|
|
f"Error importing DB dependencies: {e}\n"
|
|
f"{get_missing_requirements_message()}\n"
|
|
"This error is happening because ComfyUI now uses a local database.\n"
|
|
"------------------------------------------------------------------------"
|
|
).strip()
|
|
)
|
|
_DB_AVAILABLE = False
|
|
ENGINE = None
|
|
SESSION = None
|
|
|
|
|
|
def dependencies_available() -> bool:
|
|
"""Check if DB dependencies are importable."""
|
|
return _DB_AVAILABLE
|
|
|
|
|
|
def _root_paths():
|
|
"""Resolve alembic.ini and migrations script folder."""
|
|
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
|
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
|
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
|
return config_path, scripts_path
|
|
|
|
|
|
def _absolutize_sqlite_url(db_url: str) -> str:
|
|
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
|
|
try:
|
|
u = make_url(db_url)
|
|
except Exception:
|
|
return db_url
|
|
|
|
if not u.drivername.startswith("sqlite"):
|
|
return db_url
|
|
|
|
# Make path absolute if relative
|
|
db_path = u.database or ""
|
|
if not os.path.isabs(db_path):
|
|
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
|
|
u = u.set(database=db_path)
|
|
return str(u)
|
|
|
|
|
|
def _to_sync_driver_url(async_url: str) -> str:
|
|
"""Convert an async SQLAlchemy URL to a sync URL for Alembic."""
|
|
u = make_url(async_url)
|
|
driver = u.drivername
|
|
|
|
if driver.startswith("sqlite+aiosqlite"):
|
|
u = u.set(drivername="sqlite")
|
|
elif driver.startswith("postgresql+asyncpg"):
|
|
u = u.set(drivername="postgresql")
|
|
else:
|
|
# Generic: strip the async driver part if present
|
|
if "+" in driver:
|
|
u = u.set(drivername=driver.split("+", 1)[0])
|
|
|
|
return str(u)
|
|
|
|
|
|
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
|
|
"""Return the on-disk path for a SQLite URL, else None."""
|
|
try:
|
|
u = make_url(sync_url)
|
|
except Exception:
|
|
return None
|
|
|
|
if not u.drivername.startswith("sqlite"):
|
|
return None
|
|
return u.database
|
|
|
|
|
|
def _get_alembic_config(sync_url: str) -> Config:
|
|
"""Prepare Alembic Config with script location and DB URL."""
|
|
config_path, scripts_path = _root_paths()
|
|
cfg = Config(config_path)
|
|
cfg.set_main_option("script_location", scripts_path)
|
|
cfg.set_main_option("sqlalchemy.url", sync_url)
|
|
return cfg
|
|
|
|
|
|
async def init_db_engine() -> None:
|
|
"""Initialize async engine + sessionmaker and run migrations to head.
|
|
|
|
This must be called once on application startup before any DB usage.
|
|
"""
|
|
global ENGINE, SESSION
|
|
|
|
if not dependencies_available():
|
|
raise RuntimeError("Database dependencies are not available.")
|
|
|
|
if ENGINE is not None:
|
|
return
|
|
|
|
raw_url = args.database_url
|
|
if not raw_url:
|
|
raise RuntimeError("Database URL is not configured.")
|
|
|
|
# Absolutize SQLite path for async engine
|
|
db_url = _absolutize_sqlite_url(raw_url)
|
|
|
|
# Prepare async engine
|
|
connect_args = {}
|
|
if db_url.startswith("sqlite"):
|
|
connect_args = {
|
|
"check_same_thread": False,
|
|
"timeout": 12,
|
|
}
|
|
|
|
ENGINE = create_async_engine(
|
|
db_url,
|
|
connect_args=connect_args,
|
|
pool_pre_ping=True,
|
|
future=True,
|
|
)
|
|
|
|
# Enforce SQLite pragmas on the async engine
|
|
if db_url.startswith("sqlite"):
|
|
async with ENGINE.begin() as conn:
|
|
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
|
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
|
if str(current_mode).lower() != "wal":
|
|
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
|
|
if str(new_mode).lower() != "wal":
|
|
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
|
|
LOGGER.info("SQLite journal mode set to WAL.")
|
|
|
|
await conn.execute(text("PRAGMA foreign_keys = ON;"))
|
|
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
|
|
|
|
await _run_migrations(raw_url=db_url)
|
|
|
|
SESSION = async_sessionmaker(
|
|
bind=ENGINE,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
|
|
|
|
async def _run_migrations(raw_url: str) -> None:
|
|
"""
|
|
Run Alembic migrations up to head.
|
|
|
|
We deliberately use a synchronous engine for migrations because Alembic's
|
|
programmatic API is synchronous by default and this path is robust.
|
|
"""
|
|
# Convert to sync URL and make SQLite URL an absolute one
|
|
sync_url = _to_sync_driver_url(raw_url)
|
|
sync_url = _absolutize_sqlite_url(sync_url)
|
|
|
|
cfg = _get_alembic_config(sync_url)
|
|
|
|
# Inspect current and target heads
|
|
engine = create_engine(sync_url, future=True)
|
|
with engine.connect() as conn:
|
|
context = MigrationContext.configure(conn)
|
|
current_rev = context.get_current_revision()
|
|
|
|
script = ScriptDirectory.from_config(cfg)
|
|
target_rev = script.get_current_head()
|
|
|
|
if target_rev is None:
|
|
LOGGER.warning("Alembic: no target revision found.")
|
|
return
|
|
|
|
if current_rev == target_rev:
|
|
LOGGER.debug("Alembic: database already at head %s", target_rev)
|
|
return
|
|
|
|
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
|
|
|
|
# Optional backup for SQLite file DBs
|
|
backup_path = None
|
|
sqlite_path = _get_sqlite_file_path(sync_url)
|
|
if sqlite_path and os.path.exists(sqlite_path):
|
|
backup_path = sqlite_path + ".bkp"
|
|
try:
|
|
shutil.copy(sqlite_path, backup_path)
|
|
except Exception as exc:
|
|
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
|
|
|
|
try:
|
|
command.upgrade(cfg, target_rev)
|
|
except Exception:
|
|
if backup_path and os.path.exists(backup_path):
|
|
LOGGER.exception("Error upgrading database, attempting restore from backup.")
|
|
try:
|
|
shutil.copy(backup_path, sqlite_path) # restore
|
|
os.remove(backup_path)
|
|
except Exception as re:
|
|
LOGGER.error("Failed to restore SQLite backup: %s", re)
|
|
else:
|
|
LOGGER.exception("Error upgrading database, backup is not available.")
|
|
raise
|
|
|
|
|
|
def get_engine():
|
|
"""Return the global async engine (initialized after init_db_engine())."""
|
|
if ENGINE is None:
|
|
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
|
|
return ENGINE
|
|
|
|
|
|
def get_session_maker():
|
|
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
|
|
if SESSION is None:
|
|
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
|
|
return SESSION
|
|
|
|
|
|
@asynccontextmanager
|
|
async def session_scope():
|
|
"""Async context manager for a unit of work:
|
|
|
|
async with session_scope() as sess:
|
|
... use sess ...
|
|
"""
|
|
maker = get_session_maker()
|
|
async with maker() as sess:
|
|
try:
|
|
yield sess
|
|
await sess.commit()
|
|
except Exception:
|
|
await sess.rollback()
|
|
raise
|
|
|
|
|
|
async def create_session():
|
|
"""Convenience helper to acquire a single AsyncSession instance.
|
|
|
|
Typical usage:
|
|
async with (await create_session()) as sess:
|
|
...
|
|
"""
|
|
maker = get_session_maker()
|
|
return maker()
|