mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 03:23:00 +08:00
feat: support for in-memory SQLite databases
This commit is contained in:
parent
3fa0fc496c
commit
e3311c9229
@ -36,14 +36,39 @@ def _absolutize_sqlite_url(db_url: str) -> str:
|
|||||||
if not u.drivername.startswith("sqlite"):
|
if not u.drivername.startswith("sqlite"):
|
||||||
return db_url
|
return db_url
|
||||||
|
|
||||||
# Make path absolute if relative
|
db_path: str = u.database or ""
|
||||||
db_path = u.database or ""
|
if isinstance(db_path, str) and db_path.startswith("file:"):
|
||||||
|
return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared"
|
||||||
if not os.path.isabs(db_path):
|
if not os.path.isabs(db_path):
|
||||||
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
|
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
|
||||||
u = u.set(database=db_path)
|
u = u.set(database=db_path)
|
||||||
return str(u)
|
return str(u)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]:
|
||||||
|
"""
|
||||||
|
If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory),
|
||||||
|
rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present.
|
||||||
|
Returns: (normalized_url, is_memory)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
u = make_url(db_url)
|
||||||
|
except Exception:
|
||||||
|
return db_url, False
|
||||||
|
if not u.drivername.startswith("sqlite"):
|
||||||
|
return db_url, False
|
||||||
|
|
||||||
|
db = u.database or ""
|
||||||
|
if db == ":memory:":
|
||||||
|
u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true")
|
||||||
|
return str(u), True
|
||||||
|
if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db:
|
||||||
|
if "uri=true" not in db:
|
||||||
|
u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true"))
|
||||||
|
return str(u), True
|
||||||
|
return str(u), False
|
||||||
|
|
||||||
|
|
||||||
def _to_sync_driver_url(async_url: str) -> str:
|
def _to_sync_driver_url(async_url: str) -> str:
|
||||||
"""Convert an async SQLAlchemy URL to a sync URL for Alembic."""
|
"""Convert an async SQLAlchemy URL to a sync URL for Alembic."""
|
||||||
u = make_url(async_url)
|
u = make_url(async_url)
|
||||||
@ -70,7 +95,10 @@ def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
|
|||||||
|
|
||||||
if not u.drivername.startswith("sqlite"):
|
if not u.drivername.startswith("sqlite"):
|
||||||
return None
|
return None
|
||||||
return u.database
|
db_path = u.database
|
||||||
|
if isinstance(db_path, str) and db_path.startswith("file:"):
|
||||||
|
return None # Not a real file if it is a URI like "file:...?"
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
def _get_alembic_config(sync_url: str) -> Config:
|
def _get_alembic_config(sync_url: str) -> Config:
|
||||||
@ -96,8 +124,8 @@ async def init_db_engine() -> None:
|
|||||||
if not raw_url:
|
if not raw_url:
|
||||||
raise RuntimeError("Database URL is not configured.")
|
raise RuntimeError("Database URL is not configured.")
|
||||||
|
|
||||||
# Absolutize SQLite path for async engine
|
db_url, is_mem = _normalize_sqlite_memory_url(raw_url)
|
||||||
db_url = _absolutize_sqlite_url(raw_url)
|
db_url = _absolutize_sqlite_url(db_url)
|
||||||
|
|
||||||
# Prepare async engine
|
# Prepare async engine
|
||||||
connect_args = {}
|
connect_args = {}
|
||||||
@ -106,6 +134,8 @@ async def init_db_engine() -> None:
|
|||||||
"check_same_thread": False,
|
"check_same_thread": False,
|
||||||
"timeout": 12,
|
"timeout": 12,
|
||||||
}
|
}
|
||||||
|
if is_mem:
|
||||||
|
connect_args["uri"] = True
|
||||||
|
|
||||||
ENGINE = create_async_engine(
|
ENGINE = create_async_engine(
|
||||||
db_url,
|
db_url,
|
||||||
@ -117,6 +147,7 @@ async def init_db_engine() -> None:
|
|||||||
# Enforce SQLite pragmas on the async engine
|
# Enforce SQLite pragmas on the async engine
|
||||||
if db_url.startswith("sqlite"):
|
if db_url.startswith("sqlite"):
|
||||||
async with ENGINE.begin() as conn:
|
async with ENGINE.begin() as conn:
|
||||||
|
if not is_mem:
|
||||||
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
||||||
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
||||||
if str(current_mode).lower() != "wal":
|
if str(current_mode).lower() != "wal":
|
||||||
@ -128,7 +159,7 @@ async def init_db_engine() -> None:
|
|||||||
await conn.execute(text("PRAGMA foreign_keys = ON;"))
|
await conn.execute(text("PRAGMA foreign_keys = ON;"))
|
||||||
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
|
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
|
||||||
|
|
||||||
await _run_migrations(raw_url=db_url)
|
await _run_migrations(raw_url=db_url, connect_args=connect_args)
|
||||||
|
|
||||||
SESSION = async_sessionmaker(
|
SESSION = async_sessionmaker(
|
||||||
bind=ENGINE,
|
bind=ENGINE,
|
||||||
@ -139,7 +170,7 @@ async def init_db_engine() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _run_migrations(raw_url: str) -> None:
|
async def _run_migrations(raw_url: str, connect_args: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Run Alembic migrations up to head.
|
Run Alembic migrations up to head.
|
||||||
|
|
||||||
@ -148,12 +179,11 @@ async def _run_migrations(raw_url: str) -> None:
|
|||||||
"""
|
"""
|
||||||
# Convert to sync URL and make SQLite URL an absolute one
|
# Convert to sync URL and make SQLite URL an absolute one
|
||||||
sync_url = _to_sync_driver_url(raw_url)
|
sync_url = _to_sync_driver_url(raw_url)
|
||||||
|
sync_url, is_mem = _normalize_sqlite_memory_url(sync_url)
|
||||||
sync_url = _absolutize_sqlite_url(sync_url)
|
sync_url = _absolutize_sqlite_url(sync_url)
|
||||||
|
|
||||||
cfg = _get_alembic_config(sync_url)
|
cfg = _get_alembic_config(sync_url)
|
||||||
|
engine = create_engine(sync_url, future=True, connect_args=connect_args)
|
||||||
# Inspect current and target heads
|
|
||||||
engine = create_engine(sync_url, future=True)
|
|
||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
context = MigrationContext.configure(conn)
|
context = MigrationContext.configure(conn)
|
||||||
current_rev = context.get_current_revision()
|
current_rev = context.get_current_revision()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user