mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 19:13:02 +08:00
feat: support for in-memory SQLite databases
This commit is contained in:
parent
3fa0fc496c
commit
e3311c9229
@ -36,14 +36,39 @@ def _absolutize_sqlite_url(db_url: str) -> str:
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return db_url
|
||||
|
||||
# Make path absolute if relative
|
||||
db_path = u.database or ""
|
||||
db_path: str = u.database or ""
|
||||
if isinstance(db_path, str) and db_path.startswith("file:"):
|
||||
return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared"
|
||||
if not os.path.isabs(db_path):
|
||||
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
|
||||
u = u.set(database=db_path)
|
||||
return str(u)
|
||||
|
||||
|
||||
def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]:
|
||||
"""
|
||||
If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory),
|
||||
rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present.
|
||||
Returns: (normalized_url, is_memory)
|
||||
"""
|
||||
try:
|
||||
u = make_url(db_url)
|
||||
except Exception:
|
||||
return db_url, False
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return db_url, False
|
||||
|
||||
db = u.database or ""
|
||||
if db == ":memory:":
|
||||
u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true")
|
||||
return str(u), True
|
||||
if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db:
|
||||
if "uri=true" not in db:
|
||||
u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true"))
|
||||
return str(u), True
|
||||
return str(u), False
|
||||
|
||||
|
||||
def _to_sync_driver_url(async_url: str) -> str:
|
||||
"""Convert an async SQLAlchemy URL to a sync URL for Alembic."""
|
||||
u = make_url(async_url)
|
||||
@ -70,7 +95,10 @@ def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
|
||||
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return None
|
||||
return u.database
|
||||
db_path = u.database
|
||||
if isinstance(db_path, str) and db_path.startswith("file:"):
|
||||
return None # Not a real file if it is a URI like "file:...?"
|
||||
return db_path
|
||||
|
||||
|
||||
def _get_alembic_config(sync_url: str) -> Config:
|
||||
@ -96,8 +124,8 @@ async def init_db_engine() -> None:
|
||||
if not raw_url:
|
||||
raise RuntimeError("Database URL is not configured.")
|
||||
|
||||
# Absolutize SQLite path for async engine
|
||||
db_url = _absolutize_sqlite_url(raw_url)
|
||||
db_url, is_mem = _normalize_sqlite_memory_url(raw_url)
|
||||
db_url = _absolutize_sqlite_url(db_url)
|
||||
|
||||
# Prepare async engine
|
||||
connect_args = {}
|
||||
@ -106,6 +134,8 @@ async def init_db_engine() -> None:
|
||||
"check_same_thread": False,
|
||||
"timeout": 12,
|
||||
}
|
||||
if is_mem:
|
||||
connect_args["uri"] = True
|
||||
|
||||
ENGINE = create_async_engine(
|
||||
db_url,
|
||||
@ -117,18 +147,19 @@ async def init_db_engine() -> None:
|
||||
# Enforce SQLite pragmas on the async engine
|
||||
if db_url.startswith("sqlite"):
|
||||
async with ENGINE.begin() as conn:
|
||||
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
||||
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
||||
if str(current_mode).lower() != "wal":
|
||||
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
|
||||
if str(new_mode).lower() != "wal":
|
||||
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
|
||||
LOGGER.info("SQLite journal mode set to WAL.")
|
||||
if not is_mem:
|
||||
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
||||
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
||||
if str(current_mode).lower() != "wal":
|
||||
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
|
||||
if str(new_mode).lower() != "wal":
|
||||
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
|
||||
LOGGER.info("SQLite journal mode set to WAL.")
|
||||
|
||||
await conn.execute(text("PRAGMA foreign_keys = ON;"))
|
||||
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
|
||||
|
||||
await _run_migrations(raw_url=db_url)
|
||||
await _run_migrations(raw_url=db_url, connect_args=connect_args)
|
||||
|
||||
SESSION = async_sessionmaker(
|
||||
bind=ENGINE,
|
||||
@ -139,7 +170,7 @@ async def init_db_engine() -> None:
|
||||
)
|
||||
|
||||
|
||||
async def _run_migrations(raw_url: str) -> None:
|
||||
async def _run_migrations(raw_url: str, connect_args: dict) -> None:
|
||||
"""
|
||||
Run Alembic migrations up to head.
|
||||
|
||||
@ -148,12 +179,11 @@ async def _run_migrations(raw_url: str) -> None:
|
||||
"""
|
||||
# Convert to sync URL and make SQLite URL an absolute one
|
||||
sync_url = _to_sync_driver_url(raw_url)
|
||||
sync_url, is_mem = _normalize_sqlite_memory_url(sync_url)
|
||||
sync_url = _absolutize_sqlite_url(sync_url)
|
||||
|
||||
cfg = _get_alembic_config(sync_url)
|
||||
|
||||
# Inspect current and target heads
|
||||
engine = create_engine(sync_url, future=True)
|
||||
engine = create_engine(sync_url, future=True, connect_args=connect_args)
|
||||
with engine.connect() as conn:
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user