mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Merge 9789de07b8 into 96e0e3585b
This commit is contained in:
commit
9b3df13732
@ -57,19 +57,62 @@ def get_alembic_config():
|
|||||||
|
|
||||||
config = Config(config_path)
|
config = Config(config_path)
|
||||||
config.set_main_option("script_location", scripts_path)
|
config.set_main_option("script_location", scripts_path)
|
||||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
config.set_main_option("sqlalchemy.url", get_database_url())
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_db_path():
|
def get_database_url():
|
||||||
|
if getattr(args, "database_url_explicit", False):
|
||||||
|
return args.database_url
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
db_path = os.path.join(folder_paths.get_user_directory(), "comfyui.db")
|
||||||
|
return f"sqlite:///{db_path}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_legacy_default_db_path():
|
||||||
url = args.database_url
|
url = args.database_url
|
||||||
if url.startswith("sqlite:///"):
|
if url.startswith("sqlite:///"):
|
||||||
return url.split("///")[1]
|
return url.split("///", 1)[1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_path():
|
||||||
|
url = get_database_url()
|
||||||
|
if url.startswith("sqlite:///"):
|
||||||
|
return url.split("///", 1)[1]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||||
|
|
||||||
|
|
||||||
|
def copy_legacy_default_db(db_path):
|
||||||
|
if getattr(args, "database_url_explicit", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
legacy_db_path = get_legacy_default_db_path()
|
||||||
|
if legacy_db_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if os.path.abspath(legacy_db_path) == os.path.abspath(db_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
if os.path.exists(db_path) or not os.path.exists(legacy_db_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
shutil.copy(legacy_db_path, db_path)
|
||||||
|
logging.info(f"Copied legacy database from '{legacy_db_path}' to '{db_path}'")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_file_db_path(db_path):
|
||||||
|
db_dir = os.path.dirname(db_path)
|
||||||
|
if db_dir:
|
||||||
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
|
|
||||||
|
copy_legacy_default_db(db_path)
|
||||||
|
|
||||||
|
|
||||||
_db_lock = None
|
_db_lock = None
|
||||||
|
|
||||||
def _acquire_file_lock(db_path):
|
def _acquire_file_lock(db_path):
|
||||||
@ -97,7 +140,7 @@ def _is_memory_db(db_url):
|
|||||||
|
|
||||||
|
|
||||||
def init_db():
|
def init_db():
|
||||||
db_url = args.database_url
|
db_url = get_database_url()
|
||||||
logging.debug(f"Database URL: {db_url}")
|
logging.debug(f"Database URL: {db_url}")
|
||||||
|
|
||||||
if _is_memory_db(db_url):
|
if _is_memory_db(db_url):
|
||||||
@ -134,6 +177,7 @@ def _init_memory_db(db_url):
|
|||||||
def _init_file_db(db_url):
|
def _init_file_db(db_url):
|
||||||
"""Initialize a file-backed SQLite database using Alembic migrations."""
|
"""Initialize a file-backed SQLite database using Alembic migrations."""
|
||||||
db_path = get_db_path()
|
db_path = get_db_path()
|
||||||
|
prepare_file_db_path(db_path)
|
||||||
db_exists = os.path.exists(db_path)
|
db_exists = os.path.exists(db_path)
|
||||||
|
|
||||||
config = get_alembic_config()
|
config = get_alembic_config()
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import comfy.options
|
import comfy.options
|
||||||
|
|
||||||
|
|
||||||
@ -246,8 +247,13 @@ parser.add_argument("--list-feature-flags", action="store_true", help="Print the
|
|||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
args.database_url_explicit = any(
|
||||||
|
arg == "--database-url" or arg.startswith("--database-url=")
|
||||||
|
for arg in sys.argv[1:]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
|
args.database_url_explicit = False
|
||||||
|
|
||||||
if args.cache_ram is not None and len(args.cache_ram) > 2:
|
if args.cache_ram is not None and len(args.cache_ram) > 2:
|
||||||
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
|
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
|
||||||
|
|||||||
78
tests-unit/app_test/database_path_test.py
Normal file
78
tests-unit/app_test/database_path_test.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from app.database import db
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_database_url_uses_effective_user_directory(monkeypatch, tmp_path):
|
||||||
|
user_dir = tmp_path / "custom_user"
|
||||||
|
user_dir.mkdir()
|
||||||
|
|
||||||
|
monkeypatch.setattr(db.args, "database_url_explicit", False, raising=False)
|
||||||
|
monkeypatch.setattr("folder_paths.get_user_directory", lambda: str(user_dir))
|
||||||
|
|
||||||
|
assert db.get_database_url() == f"sqlite:///{user_dir / 'comfyui.db'}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_database_url_is_preserved(monkeypatch):
|
||||||
|
database_url = "sqlite:///:memory:"
|
||||||
|
|
||||||
|
monkeypatch.setattr(db.args, "database_url", database_url)
|
||||||
|
monkeypatch.setattr(db.args, "database_url_explicit", True, raising=False)
|
||||||
|
|
||||||
|
assert db.get_database_url() == database_url
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_default_database_is_copied_to_effective_user_directory(monkeypatch, tmp_path):
|
||||||
|
legacy_db = tmp_path / "install" / "user" / "comfyui.db"
|
||||||
|
user_dir = tmp_path / "custom_user"
|
||||||
|
legacy_db.parent.mkdir(parents=True)
|
||||||
|
user_dir.mkdir()
|
||||||
|
legacy_db.write_bytes(b"legacy db")
|
||||||
|
|
||||||
|
monkeypatch.setattr(db.args, "database_url_explicit", False, raising=False)
|
||||||
|
monkeypatch.setattr("folder_paths.get_user_directory", lambda: str(user_dir))
|
||||||
|
monkeypatch.setattr(db, "get_legacy_default_db_path", lambda: str(legacy_db))
|
||||||
|
|
||||||
|
db.copy_legacy_default_db(str(user_dir / "comfyui.db"))
|
||||||
|
|
||||||
|
assert (user_dir / "comfyui.db").read_bytes() == b"legacy db"
|
||||||
|
assert legacy_db.read_bytes() == b"legacy db"
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_default_database_does_not_overwrite_existing_effective_db(monkeypatch, tmp_path):
|
||||||
|
legacy_db = tmp_path / "install" / "user" / "comfyui.db"
|
||||||
|
user_db = tmp_path / "custom_user" / "comfyui.db"
|
||||||
|
legacy_db.parent.mkdir(parents=True)
|
||||||
|
user_db.parent.mkdir(parents=True)
|
||||||
|
legacy_db.write_bytes(b"legacy db")
|
||||||
|
user_db.write_bytes(b"user db")
|
||||||
|
|
||||||
|
monkeypatch.setattr(db.args, "database_url_explicit", False, raising=False)
|
||||||
|
monkeypatch.setattr(db, "get_legacy_default_db_path", lambda: str(legacy_db))
|
||||||
|
|
||||||
|
db.copy_legacy_default_db(str(user_db))
|
||||||
|
|
||||||
|
assert user_db.read_bytes() == b"user db"
|
||||||
|
assert legacy_db.read_bytes() == b"legacy db"
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_file_database_creates_parent_directory(monkeypatch, tmp_path):
|
||||||
|
db_path = tmp_path / "nested" / "comfyui.db"
|
||||||
|
|
||||||
|
monkeypatch.setattr(db.args, "database_url_explicit", False, raising=False)
|
||||||
|
monkeypatch.setattr(db, "copy_legacy_default_db", lambda path: None)
|
||||||
|
|
||||||
|
db.prepare_file_db_path(str(db_path))
|
||||||
|
|
||||||
|
assert db_path.parent.is_dir()
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_file_database_accepts_relative_database_path(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
monkeypatch.setattr(db.args, "database_url_explicit", True, raising=False)
|
||||||
|
monkeypatch.setattr(db, "copy_legacy_default_db", lambda path: None)
|
||||||
|
|
||||||
|
db.prepare_file_db_path("relative.db")
|
||||||
|
|
||||||
|
assert os.getcwd() == str(tmp_path)
|
||||||
|
assert not list(tmp_path.iterdir())
|
||||||
Loading…
Reference in New Issue
Block a user