feat: support for in-memory SQLite databases

This commit is contained in:
bigcat88 2025-09-08 18:15:09 +03:00
parent 3fa0fc496c
commit e3311c9229
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721

View File

@ -36,14 +36,39 @@ def _absolutize_sqlite_url(db_url: str) -> str:
if not u.drivername.startswith("sqlite"): if not u.drivername.startswith("sqlite"):
return db_url return db_url
# Make path absolute if relative db_path: str = u.database or ""
db_path = u.database or "" if isinstance(db_path, str) and db_path.startswith("file:"):
return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared"
if not os.path.isabs(db_path): if not os.path.isabs(db_path):
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path)) db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
u = u.set(database=db_path) u = u.set(database=db_path)
return str(u) return str(u)
def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]:
"""
If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory),
rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present.
Returns: (normalized_url, is_memory)
"""
try:
u = make_url(db_url)
except Exception:
return db_url, False
if not u.drivername.startswith("sqlite"):
return db_url, False
db = u.database or ""
if db == ":memory:":
u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true")
return str(u), True
if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db:
if "uri=true" not in db:
u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true"))
return str(u), True
return str(u), False
def _to_sync_driver_url(async_url: str) -> str: def _to_sync_driver_url(async_url: str) -> str:
"""Convert an async SQLAlchemy URL to a sync URL for Alembic.""" """Convert an async SQLAlchemy URL to a sync URL for Alembic."""
u = make_url(async_url) u = make_url(async_url)
@ -70,7 +95,10 @@ def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
if not u.drivername.startswith("sqlite"): if not u.drivername.startswith("sqlite"):
return None return None
return u.database db_path = u.database
if isinstance(db_path, str) and db_path.startswith("file:"):
return None # Not a real file if it is a URI like "file:...?"
return db_path
def _get_alembic_config(sync_url: str) -> Config: def _get_alembic_config(sync_url: str) -> Config:
@ -96,8 +124,8 @@ async def init_db_engine() -> None:
if not raw_url: if not raw_url:
raise RuntimeError("Database URL is not configured.") raise RuntimeError("Database URL is not configured.")
# Absolutize SQLite path for async engine db_url, is_mem = _normalize_sqlite_memory_url(raw_url)
db_url = _absolutize_sqlite_url(raw_url) db_url = _absolutize_sqlite_url(db_url)
# Prepare async engine # Prepare async engine
connect_args = {} connect_args = {}
@ -106,6 +134,8 @@ async def init_db_engine() -> None:
"check_same_thread": False, "check_same_thread": False,
"timeout": 12, "timeout": 12,
} }
if is_mem:
connect_args["uri"] = True
ENGINE = create_async_engine( ENGINE = create_async_engine(
db_url, db_url,
@ -117,18 +147,19 @@ async def init_db_engine() -> None:
# Enforce SQLite pragmas on the async engine # Enforce SQLite pragmas on the async engine
if db_url.startswith("sqlite"): if db_url.startswith("sqlite"):
async with ENGINE.begin() as conn: async with ENGINE.begin() as conn:
# WAL for concurrency and durability, Foreign Keys for referential integrity if not is_mem:
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar() # WAL for concurrency and durability, Foreign Keys for referential integrity
if str(current_mode).lower() != "wal": current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar() if str(current_mode).lower() != "wal":
if str(new_mode).lower() != "wal": new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
raise RuntimeError("Failed to set SQLite journal mode to WAL.") if str(new_mode).lower() != "wal":
LOGGER.info("SQLite journal mode set to WAL.") raise RuntimeError("Failed to set SQLite journal mode to WAL.")
LOGGER.info("SQLite journal mode set to WAL.")
await conn.execute(text("PRAGMA foreign_keys = ON;")) await conn.execute(text("PRAGMA foreign_keys = ON;"))
await conn.execute(text("PRAGMA synchronous = NORMAL;")) await conn.execute(text("PRAGMA synchronous = NORMAL;"))
await _run_migrations(raw_url=db_url) await _run_migrations(raw_url=db_url, connect_args=connect_args)
SESSION = async_sessionmaker( SESSION = async_sessionmaker(
bind=ENGINE, bind=ENGINE,
@ -139,7 +170,7 @@ async def init_db_engine() -> None:
) )
async def _run_migrations(raw_url: str) -> None: async def _run_migrations(raw_url: str, connect_args: dict) -> None:
""" """
Run Alembic migrations up to head. Run Alembic migrations up to head.
@ -148,12 +179,11 @@ async def _run_migrations(raw_url: str) -> None:
""" """
# Convert to sync URL and make SQLite URL an absolute one # Convert to sync URL and make SQLite URL an absolute one
sync_url = _to_sync_driver_url(raw_url) sync_url = _to_sync_driver_url(raw_url)
sync_url, is_mem = _normalize_sqlite_memory_url(sync_url)
sync_url = _absolutize_sqlite_url(sync_url) sync_url = _absolutize_sqlite_url(sync_url)
cfg = _get_alembic_config(sync_url) cfg = _get_alembic_config(sync_url)
engine = create_engine(sync_url, future=True, connect_args=connect_args)
# Inspect current and target heads
engine = create_engine(sync_url, future=True)
with engine.connect() as conn: with engine.connect() as conn:
context = MigrationContext.configure(conn) context = MigrationContext.configure(conn)
current_rev = context.get_current_revision() current_rev = context.get_current_revision()