diff --git a/app/database/db.py b/app/database/db.py index 0aab09a49..ab2198814 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -57,19 +57,62 @@ def get_alembic_config(): config = Config(config_path) config.set_main_option("script_location", scripts_path) - config.set_main_option("sqlalchemy.url", args.database_url) + config.set_main_option("sqlalchemy.url", get_database_url()) return config -def get_db_path(): +def get_database_url(): + if getattr(args, "database_url_explicit", False): + return args.database_url + + import folder_paths + + db_path = os.path.join(folder_paths.get_user_directory(), "comfyui.db") + return f"sqlite:///{db_path}" + + +def get_legacy_default_db_path(): url = args.database_url if url.startswith("sqlite:///"): - return url.split("///")[1] + return url.split("///", 1)[1] + return None + + +def get_db_path(): + url = get_database_url() + if url.startswith("sqlite:///"): + return url.split("///", 1)[1] else: raise ValueError(f"Unsupported database URL '{url}'.") +def copy_legacy_default_db(db_path): + if getattr(args, "database_url_explicit", False): + return + + legacy_db_path = get_legacy_default_db_path() + if legacy_db_path is None: + return + + if os.path.abspath(legacy_db_path) == os.path.abspath(db_path): + return + + if os.path.exists(db_path) or not os.path.exists(legacy_db_path): + return + + shutil.copy(legacy_db_path, db_path) + logging.info(f"Copied legacy database from '{legacy_db_path}' to '{db_path}'") + + +def prepare_file_db_path(db_path): + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + + copy_legacy_default_db(db_path) + + _db_lock = None def _acquire_file_lock(db_path): @@ -97,7 +140,7 @@ def _is_memory_db(db_url): def init_db(): - db_url = args.database_url + db_url = get_database_url() logging.debug(f"Database URL: {db_url}") if _is_memory_db(db_url): @@ -134,6 +177,7 @@ def _init_memory_db(db_url): def _init_file_db(db_url): """Initialize a file-backed SQLite database using Alembic migrations.""" db_path = get_db_path() + prepare_file_db_path(db_path) db_exists = os.path.exists(db_path) config = get_alembic_config() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 4bef096fb..9e4c5c0bd 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,6 +1,7 @@ import argparse import enum import os +import sys import comfy.options @@ -246,8 +247,13 @@ parser.add_argument("--list-feature-flags", action="store_true", help="Print the if comfy.options.args_parsing: args = parser.parse_args() + args.database_url_explicit = any( + arg == "--database-url" or arg.startswith("--database-url=") + for arg in sys.argv[1:] + ) else: args = parser.parse_args([]) + args.database_url_explicit = False if args.cache_ram is not None and len(args.cache_ram) > 2: parser.error("--cache-ram accepts at most two values: active GB and inactive GB") diff --git a/tests-unit/app_test/database_path_test.py b/tests-unit/app_test/database_path_test.py new file mode 100644 index 000000000..52b1b6b99 --- /dev/null +++ b/tests-unit/app_test/database_path_test.py @@ -0,0 +1,78 @@ +import os + +from app.database import db + + +def test_default_database_url_uses_effective_user_directory(monkeypatch, tmp_path): + user_dir = tmp_path / "custom_user" + user_dir.mkdir() + + monkeypatch.setattr(db.args, "database_url_explicit", False, raising=False) + monkeypatch.setattr("folder_paths.get_user_directory", lambda: str(user_dir)) + + assert db.get_database_url() == f"sqlite:///{user_dir / 'comfyui.db'}" + + +def test_explicit_database_url_is_preserved(monkeypatch): + database_url = "sqlite:///:memory:" + + monkeypatch.setattr(db.args, "database_url", database_url) + monkeypatch.setattr(db.args, "database_url_explicit", True, raising=False) + + assert db.get_database_url() == database_url + + +def test_legacy_default_database_is_copied_to_effective_user_directory(monkeypatch, tmp_path): + legacy_db = tmp_path / "install" / "user" / "comfyui.db" + user_dir = tmp_path / "custom_user" + legacy_db.parent.mkdir(parents=True) + user_dir.mkdir() + legacy_db.write_bytes(b"legacy db") + + monkeypatch.setattr(db.args, "database_url_explicit", False, raising=False) + monkeypatch.setattr("folder_paths.get_user_directory", lambda: str(user_dir)) + monkeypatch.setattr(db, "get_legacy_default_db_path", lambda: str(legacy_db)) + + db.copy_legacy_default_db(str(user_dir / "comfyui.db")) + + assert (user_dir / "comfyui.db").read_bytes() == b"legacy db" + assert legacy_db.read_bytes() == b"legacy db" + + +def test_legacy_default_database_does_not_overwrite_existing_effective_db(monkeypatch, tmp_path): + legacy_db = tmp_path / "install" / "user" / "comfyui.db" + user_db = tmp_path / "custom_user" / "comfyui.db" + legacy_db.parent.mkdir(parents=True) + user_db.parent.mkdir(parents=True) + legacy_db.write_bytes(b"legacy db") + user_db.write_bytes(b"user db") + + monkeypatch.setattr(db.args, "database_url_explicit", False, raising=False) + monkeypatch.setattr(db, "get_legacy_default_db_path", lambda: str(legacy_db)) + + db.copy_legacy_default_db(str(user_db)) + + assert user_db.read_bytes() == b"user db" + assert legacy_db.read_bytes() == b"legacy db" + + +def test_prepare_file_database_creates_parent_directory(monkeypatch, tmp_path): + db_path = tmp_path / "nested" / "comfyui.db" + + monkeypatch.setattr(db.args, "database_url_explicit", False, raising=False) + monkeypatch.setattr(db, "copy_legacy_default_db", lambda path: None) + + db.prepare_file_db_path(str(db_path)) + + assert db_path.parent.is_dir() + + +def test_prepare_file_database_accepts_relative_database_path(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + monkeypatch.setattr(db.args, "database_url_explicit", True, raising=False) + monkeypatch.setattr(db, "copy_legacy_default_db", lambda path: None) + + db.prepare_file_db_path("relative.db") + + assert os.getcwd() == str(tmp_path) + assert not list(tmp_path.iterdir())