From e3311c9229f10c68437065af904f12ddbf68e282 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 8 Sep 2025 18:15:09 +0300 Subject: [PATCH] feat: support for in-memory SQLite databases --- app/database/db.py | 64 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/app/database/db.py b/app/database/db.py index 67ddf412b..eaf6648db 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -36,14 +36,39 @@ def _absolutize_sqlite_url(db_url: str) -> str: if not u.drivername.startswith("sqlite"): return db_url - # Make path absolute if relative - db_path = u.database or "" + db_path: str = u.database or "" + if isinstance(db_path, str) and db_path.startswith("file:"): + return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared" if not os.path.isabs(db_path): db_path = os.path.abspath(os.path.join(os.getcwd(), db_path)) u = u.set(database=db_path) return str(u) +def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]: + """ + If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory), + rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present. + Returns: (normalized_url, is_memory) + """ + try: + u = make_url(db_url) + except Exception: + return db_url, False + if not u.drivername.startswith("sqlite"): + return db_url, False + + db = u.database or "" + if db == ":memory:": + u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true") + return str(u), True + if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db: + if "uri=true" not in db: + u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true")) + return str(u), True + return str(u), False + + def _to_sync_driver_url(async_url: str) -> str: """Convert an async SQLAlchemy URL to a sync URL for Alembic.""" u = make_url(async_url) @@ -70,7 +95,10 @@ def _get_sqlite_file_path(sync_url: str) -> Optional[str]: if not u.drivername.startswith("sqlite"): return None - return u.database + db_path = u.database + if isinstance(db_path, str) and db_path.startswith("file:"): + return None # Not a real file if it is a URI like "file:...?" + return db_path def _get_alembic_config(sync_url: str) -> Config: @@ -96,8 +124,8 @@ async def init_db_engine() -> None: if not raw_url: raise RuntimeError("Database URL is not configured.") - # Absolutize SQLite path for async engine - db_url = _absolutize_sqlite_url(raw_url) + db_url, is_mem = _normalize_sqlite_memory_url(raw_url) + db_url = _absolutize_sqlite_url(db_url) # Prepare async engine connect_args = {} @@ -106,6 +134,8 @@ async def init_db_engine() -> None: "check_same_thread": False, "timeout": 12, } + if is_mem: + connect_args["uri"] = True ENGINE = create_async_engine( db_url, @@ -117,18 +147,19 @@ async def init_db_engine() -> None: # Enforce SQLite pragmas on the async engine if db_url.startswith("sqlite"): async with ENGINE.begin() as conn: - # WAL for concurrency and durability, Foreign Keys for referential integrity - current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar() - if str(current_mode).lower() != "wal": - new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar() - if str(new_mode).lower() != "wal": - raise RuntimeError("Failed to set SQLite journal mode to WAL.") - LOGGER.info("SQLite journal mode set to WAL.") + if not is_mem: + # WAL for concurrency and durability, Foreign Keys for referential integrity + current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar() + if str(current_mode).lower() != "wal": + new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar() + if str(new_mode).lower() != "wal": + raise RuntimeError("Failed to set SQLite journal mode to WAL.") + LOGGER.info("SQLite journal mode set to WAL.") await conn.execute(text("PRAGMA foreign_keys = ON;")) await conn.execute(text("PRAGMA synchronous = NORMAL;")) - await _run_migrations(raw_url=db_url) + await _run_migrations(raw_url=db_url, connect_args=connect_args) SESSION = async_sessionmaker( bind=ENGINE, @@ -139,7 +170,7 @@ async def init_db_engine() -> None: ) -async def _run_migrations(raw_url: str) -> None: +async def _run_migrations(raw_url: str, connect_args: dict) -> None: """ Run Alembic migrations up to head. @@ -148,12 +179,11 @@ async def _run_migrations(raw_url: str) -> None: """ # Convert to sync URL and make SQLite URL an absolute one sync_url = _to_sync_driver_url(raw_url) + sync_url, is_mem = _normalize_sqlite_memory_url(sync_url) sync_url = _absolutize_sqlite_url(sync_url) cfg = _get_alembic_config(sync_url) - - # Inspect current and target heads - engine = create_engine(sync_url, future=True) + engine = create_engine(sync_url, future=True, connect_args=connect_args) with engine.connect() as conn: context = MigrationContext.configure(conn) current_rev = context.get_current_revision()