remove suffix for enabled custom nodes

This commit is contained in:
Dr.Lt.Data 2024-12-22 01:20:40 +09:00
parent cbcd2e14ce
commit 5e5235f5d1
7 changed files with 162 additions and 127 deletions

View File

@ -481,8 +481,8 @@ def show_list(kind, simple=False):
print(f"{prefix} {title:50} {short_id:30} (author: {author:20}) [UNKNOWN]")
def show_snapshot(simple_mode=False):
json_obj = core.get_current_snapshot()
async def show_snapshot(simple_mode=False):
json_obj = await core.get_current_snapshot()
if simple_mode:
print(f"[{json_obj['comfyui']}] comfyui")
@ -513,8 +513,8 @@ def cancel():
os.remove(cmd_ctx.get_restore_snapshot_path())
def auto_save_snapshot():
path = core.save_snapshot_with_postfix('cli-autosave')
async def auto_save_snapshot():
path = await core.save_snapshot_with_postfix('cli-autosave')
print(f"Current snapshot is saved as `{path}`")
@ -698,7 +698,7 @@ def update(
cmd_ctx.set_channel_mode(channel, mode)
if 'all' in nodes:
auto_save_snapshot()
asyncio.run(auto_save_snapshot())
for x in nodes:
if x.lower() in ['comfyui', 'comfy', 'all']:
@ -734,7 +734,7 @@ def disable(
cmd_ctx.set_channel_mode(channel, mode)
if 'all' in nodes:
auto_save_snapshot()
asyncio.run(auto_save_snapshot())
for_each_nodes(nodes, disable_node, allow_all=True)
@ -765,7 +765,7 @@ def enable(
cmd_ctx.set_channel_mode(channel, mode)
if 'all' in nodes:
auto_save_snapshot()
asyncio.run(auto_save_snapshot())
for_each_nodes(nodes, enable_node, allow_all=True)
@ -796,7 +796,7 @@ def fix(
cmd_ctx.set_channel_mode(channel, mode)
if 'all' in nodes:
auto_save_snapshot()
asyncio.run(auto_save_snapshot())
for_each_nodes(nodes, fix_node, allow_all=True)
@ -997,7 +997,7 @@ def save_snapshot(
):
cmd_ctx.set_user_directory(user_directory)
path = core.save_snapshot_with_postfix('snapshot', output)
path = asyncio.run(core.save_snapshot_with_postfix('snapshot', output))
print(f"Current snapshot is saved as `{path}`")
@ -1123,7 +1123,7 @@ def install_deps(
):
cmd_ctx.set_user_directory(user_directory)
cmd_ctx.set_channel_mode(channel, mode)
auto_save_snapshot()
asyncio.run(auto_save_snapshot())
if not os.path.exists(deps):
print(f"[bold red]File not found: {deps}[/bold red]")

View File

@ -2,6 +2,8 @@ import requests
from dataclasses import dataclass
from typing import List
import manager_util
import toml
import os
base_url = "https://api.comfy.org"
@ -98,3 +100,32 @@ def all_versions_of_node(node_id):
else:
return None
def read_cnr_info(fullpath):
try:
toml_path = os.path.join(fullpath, 'pyproject.toml')
tracking_path = os.path.join(fullpath, '.tracking')
if not os.path.exists(toml_path) or not os.path.exists(tracking_path):
return None # not valid CNR node pack
with open(toml_path, "r", encoding="utf-8") as f:
data = toml.load(f)
project = data.get('project', {})
name = project.get('name')
version = project.get('version')
urls = project.get('urls', {})
repository = urls.get('Repository')
if name and version: # repository is optional
return {
"id": name,
"version": version,
"url": repository
}
return None
except Exception:
return None # not valid CNR node pack

View File

@ -1,16 +1,12 @@
import os
import git
import configparser
def is_git_repo(path: str) -> bool:
""" Check if the path is a git repository. """
try:
# Try to create a Repo object from the path
_ = git.Repo(path).git_dir
return True
except git.exc.InvalidGitRepositoryError:
return False
# NOTE: Checking it through `git.Repo` must be avoided.
# It locks the file, causing issues on Windows.
return os.path.exists(os.path.join(path, '.git'))
def get_commit_hash(fullpath):
@ -30,3 +26,36 @@ def get_commit_hash(fullpath):
return line
return "unknown"
def git_url(fullpath):
"""
resolve version of unclassified custom node based on remote url in .git/config
"""
git_config_path = os.path.join(fullpath, '.git', 'config')
if not os.path.exists(git_config_path):
return None
config = configparser.ConfigParser()
config.read(git_config_path)
for k, v in config.items():
if k.startswith('remote ') and 'url' in v:
return v['url']
return None
def normalize_url(url) -> str:
url = url.replace("git@github.com:", "https://github.com/")
if url.endswith('.git'):
url = url[:-4]
return url
def normalize_url_http(url) -> str:
url = url.replace("https://github.com/", "git@github.com:")
if url.endswith('.git'):
url = url[:-4]
return url

View File

@ -31,6 +31,7 @@ sys.path.append(glob_path)
import cm_global
import cnr_utils
import manager_util
import git_utils
import manager_downloader
from node_package import InstalledNodePackage
@ -323,9 +324,9 @@ class UnifiedManager:
self.custom_node_map_cache = {} # (channel, mode) -> augmented custom node list json
self.processed_install = set()
def get_cnr_by_repo(self, url):
normalized_url = url.replace("git@github.com:", "https://github.com/")
return self.repo_cnr_map.get(normalized_url)
return self.repo_cnr_map.get(git_utils.normalize_url(url))
def resolve_unspecified_version(self, node_name, guess_mode=None):
if guess_mode == 'active':
@ -426,32 +427,36 @@ class UnifiedManager:
return node_name, version_spec, len(spec) > 1
def resolve_ver(self, fullpath):
"""
resolve version of unclassified custom node based on remote url in .git/config
"""
git_config_path = os.path.join(fullpath, '.git', 'config')
def resolve_from_path(self, fullpath):
url = git_utils.git_url(fullpath)
if url:
cnr = self.get_cnr_by_repo(url)
commit_hash = git_utils.get_commit_hash(fullpath)
if cnr:
return {'id': cnr['id'], 'cnr': cnr, 'ver': 'nightly', 'hash': commit_hash}
else:
url = os.path.basename(url)
if url.endswith('.git'):
url = url[:-4]
return {'id': url, 'ver': 'unknown', 'hash': commit_hash}
else:
info = cnr_utils.read_cnr_info(fullpath)
if not os.path.exists(git_config_path):
return "unknown"
config = configparser.ConfigParser()
config.read(git_config_path)
for k, v in config.items():
if k.startswith('remote ') and 'url' in v:
cnr = self.get_cnr_by_repo(v['url'])
if info:
cnr = self.cnr_map.get(info['id'])
if cnr:
return "nightly"
return {'id': cnr['id'], 'cnr': cnr, 'ver': info['version']}
else:
return "unknown"
return None
else:
return None
def update_cache_at_path(self, fullpath):
node_package = InstalledNodePackage.from_fullpath(fullpath)
node_package = InstalledNodePackage.from_fullpath(fullpath, self.resolve_from_path)
self.installed_node_packages[node_package.id] = node_package
if node_package.is_disabled and node_package.is_unknown:
# NOTE: unknown package does not have a url.
# NOTE: unknown package does not have an url.
self.unknown_inactive_nodes[node_package.id] = ('', node_package.fullpath)
if node_package.is_disabled and node_package.is_nightly:
@ -461,7 +466,7 @@ class UnifiedManager:
self.active_nodes[node_package.id] = node_package.version, node_package.fullpath
if node_package.is_enabled and node_package.is_unknown:
# NOTE: unknown package does not have a url.
# NOTE: unknown package does not have an url.
self.unknown_active_nodes[node_package.id] = ('', node_package.fullpath)
if node_package.is_from_cnr and node_package.is_disabled:
@ -629,7 +634,7 @@ class UnifiedManager:
self.cnr_map[x['id']] = x
if 'repository' in x:
normalized_url = x['repository'].replace("git@github.com:", "https://github.com/")
normalized_url = git_utils.normalize_url(x['repository'])
self.repo_cnr_map[normalized_url] = x
# reload node status info from custom_nodes/*
@ -806,8 +811,7 @@ class UnifiedManager:
zip_url = node_info.download_url
from_path = self.active_nodes[node_id][1]
# PTAL(@ltdrdata): how to redesign and drop version_spec here?
target = f"{node_id}@{version_spec.replace('.', '_')}"
target = node_id
to_path = os.path.join(get_default_custom_nodes_path(), target)
def postinstall():
@ -842,7 +846,7 @@ class UnifiedManager:
download_path = os.path.join(get_default_custom_nodes_path(), archive_name)
manager_downloader.download_url(node_info.download_url, get_default_custom_nodes_path(), archive_name)
# 2. extract files into <node_id>@<cur_ver>
# 2. extract files into <node_id>
install_path = self.active_nodes[node_id][1]
extracted = manager_util.extract_package_as_zip(download_path, install_path)
os.remove(download_path)
@ -873,22 +877,16 @@ class UnifiedManager:
if not os.listdir(x):
os.rmdir(x)
# 5. rename dir name <node_id>@<prev_ver> ==> <node_id>@<cur_ver>
# PTAL(@ltdrdata): how to redesign and drop version_spec here
new_install_path = os.path.join(get_default_custom_nodes_path(), f"{node_id}@{version_spec.replace('.', '_')}")
print(f"'{install_path}' is moved to '{new_install_path}'")
shutil.move(install_path, new_install_path)
# 6. create .tracking file
tracking_info_file = os.path.join(new_install_path, '.tracking')
# 5. create .tracking file
tracking_info_file = os.path.join(install_path, '.tracking')
with open(tracking_info_file, "w", encoding='utf-8') as file:
file.write('\n'.join(list(extracted)))
# 7. post install
# 6. post install
result.target = version_spec
def postinstall():
res = self.execute_install_script(f"{node_id}@{version_spec}", new_install_path, instant_execution=instant_execution, no_deps=no_deps)
res = self.execute_install_script(f"{node_id}@{version_spec}", install_path, instant_execution=instant_execution, no_deps=no_deps)
return res
if return_postinstall:
@ -930,8 +928,7 @@ class UnifiedManager:
if repo_and_path is None:
return result.fail(f'Specified inactive node not exists: {node_id}@unknown')
from_path = repo_and_path[1]
# NOTE: Keep original name as possible if unknown node
# to_path = os.path.join(get_default_custom_nodes_path(), f"{node_id}@unknown")
base_path = extract_base_custom_nodes_dir(from_path)
to_path = os.path.join(base_path, node_id)
elif version_spec == 'nightly':
@ -940,7 +937,7 @@ class UnifiedManager:
if from_path is None:
return result.fail(f'Specified inactive node not exists: {node_id}@nightly')
base_path = extract_base_custom_nodes_dir(from_path)
to_path = os.path.join(base_path, f"{node_id}@nightly")
to_path = os.path.join(base_path, node_id)
elif version_spec is not None:
self.unified_disable(node_id, False)
cnr_info = self.cnr_inactive_nodes.get(node_id)
@ -956,8 +953,7 @@ class UnifiedManager:
from_path = cnr_info[version_spec]
base_path = extract_base_custom_nodes_dir(from_path)
# PTAL(@ltdrdata): how to redesign and drop version_spec here
to_path = os.path.join(base_path, f"{node_id}@{version_spec.replace('.', '_')}")
to_path = os.path.join(base_path, node_id)
if from_path is None or not os.path.exists(from_path):
return result.fail(f'Specified inactive node path not exists: {from_path}')
@ -999,9 +995,6 @@ class UnifiedManager:
return result.fail(f'Specified active node not exists: {node_id}')
base_path = extract_base_custom_nodes_dir(repo_and_path[1])
# NOTE: Keep original name as possible if unknown node
# to_path = os.path.join(get_default_custom_nodes_path(), '.disabled', f"{node_id}@unknown")
to_path = os.path.join(base_path, '.disabled', node_id)
shutil.move(repo_and_path[1], to_path)
@ -1018,7 +1011,8 @@ class UnifiedManager:
return result.fail(f'Specified active node not exists: {node_id}')
base_path = extract_base_custom_nodes_dir(ver_and_path[1])
# PTAL(@ltdrdata): how to redesign and drop version_spec here
# NOTE: A disabled node may have multiple versions, so preserve it using the `@ suffix`.
to_path = os.path.join(base_path, '.disabled', f"{node_id}@{ver_and_path[0].replace('.', '_')}")
shutil.move(ver_and_path[1], to_path)
result.append((ver_and_path[1], to_path))
@ -1109,8 +1103,7 @@ class UnifiedManager:
os.remove(download_path)
# install_path
# PTAL(@ltdrdata): how to redesign and drop version_spec here
install_path = os.path.join(get_default_custom_nodes_path(), f"{node_id}@{version_spec.replace('.', '_')}")
install_path = os.path.join(get_default_custom_nodes_path(), node_id)
if os.path.exists(install_path):
return result.fail(f'Install path already exists: {install_path}')
@ -1293,11 +1286,7 @@ class UnifiedManager:
if self.is_enabled(node_id, 'cnr'):
self.unified_disable(node_id, False)
if version_spec == 'unknown':
to_path = os.path.abspath(os.path.join(get_default_custom_nodes_path(), node_id)) # don't attach @unknown
else:
# PTAL(@ltdrdata): how to redesign and drop version_spec here
to_path = os.path.abspath(os.path.join(get_default_custom_nodes_path(), f"{node_id}@{version_spec.replace('.', '_')}"))
to_path = os.path.abspath(os.path.join(get_default_custom_nodes_path(), node_id))
res = self.repo_install(repo_url, to_path, instant_execution=instant_execution, no_deps=no_deps, return_postinstall=return_postinstall)
if res.result:
if version_spec == 'unknown':
@ -1347,18 +1336,6 @@ class UnifiedManager:
new_path = os.path.join(get_default_custom_nodes_path(), '.disabled', f"{x}@nightly")
moves.append((v, new_path))
print("Migration: STAGE 2")
# migrate active nodes
for x, v in self.active_nodes.items():
if v[0] not in ['nightly']:
continue
if v[1].endswith('@nightly'):
continue
new_path = os.path.join(get_default_custom_nodes_path(), f"{x}@nightly")
moves.append((v[1], new_path))
self.reserve_migration(moves)
print("DONE (Migration reserved)")
@ -2265,7 +2242,10 @@ def get_installed_pip_packages():
return res
def get_current_snapshot():
async def get_current_snapshot():
await unified_manager.reload('cache')
await unified_manager.get_custom_nodes('default', 'cache')
# Get ComfyUI hash
repo_path = comfy_path
@ -2282,36 +2262,33 @@ def get_current_snapshot():
# Get custom nodes hash
for custom_nodes_dir in get_custom_nodes_paths():
for path in os.listdir(custom_nodes_dir):
paths = os.listdir(custom_nodes_dir)
disabled_path = os.path.join(custom_nodes_dir, '.disabled')
if os.path.exists(disabled_path):
for x in os.listdir(disabled_path):
paths.append(os.path.join(disabled_path, x))
for path in paths:
if path in ['.disabled', '__pycache__']:
continue
fullpath = os.path.join(custom_nodes_dir, path)
if os.path.isdir(fullpath):
is_disabled = path.endswith(".disabled")
is_disabled = path.endswith(".disabled") or os.path.basename(os.path.dirname(fullpath)) == ".disabled"
try:
git_dir = os.path.join(fullpath, '.git')
info = unified_manager.resolve_from_path(fullpath)
parsed_spec = path.split('@')
if info is None:
continue
if len(parsed_spec) == 1:
node_id = parsed_spec[0]
ver_spec = 'unknown'
else:
node_id, ver_spec = parsed_spec
ver_spec = ver_spec.replace('_', '.')
if len(ver_spec) > 1 and ver_spec not in ['nightly', 'latest', 'unknown']:
if info['ver'] not in ['nightly', 'latest', 'unknown']:
if is_disabled:
continue # don't restore disabled state of CNR node.
cnr_custom_nodes[node_id] = ver_spec
elif not os.path.exists(git_dir):
continue
cnr_custom_nodes[info['id']] = info['ver']
else:
repo = git.Repo(fullpath)
commit_hash = repo.head.commit.hexsha
@ -2341,7 +2318,7 @@ def get_current_snapshot():
}
def save_snapshot_with_postfix(postfix, path=None):
async def save_snapshot_with_postfix(postfix, path=None):
if path is None:
now = datetime.now()
@ -2353,7 +2330,7 @@ def save_snapshot_with_postfix(postfix, path=None):
file_name = path.replace('\\', '/').split('/')[-1]
file_name = file_name.split('.')[-2]
snapshot = get_current_snapshot()
snapshot = await get_current_snapshot()
if path.endswith('.json'):
with open(path, "w") as json_file:
json.dump(snapshot, json_file, indent=4)
@ -2753,8 +2730,8 @@ async def restore_snapshot(snapshot_path, git_helper_extras=None):
if v[0] == 'nightly' and cnr_repo_map.get(k):
repo_url = cnr_repo_map.get(k)
normalized_url1 = repo_url.replace("git@github.com:", "https://github.com/")
normalized_url2 = repo_url.replace("https://github.com/", "git@github.com:")
normalized_url1 = git_utils.normalize_url(repo_url)
normalized_url2 = git_utils.normalize_url_http(repo_url)
if normalized_url1 not in git_info and normalized_url2 not in git_info:
todo_disable.append(k)
@ -2773,8 +2750,8 @@ async def restore_snapshot(snapshot_path, git_helper_extras=None):
if cnr_repo_map.get(k):
repo_url = cnr_repo_map.get(k)
normalized_url1 = repo_url.replace("git@github.com:", "https://github.com/")
normalized_url2 = repo_url.replace("https://github.com/", "git@github.com:")
normalized_url1 = git_utils.normalize_url(repo_url)
normalized_url2 = git_utils.normalize_url_http(repo_url)
if normalized_url1 in git_info:
commit_hash = git_info[normalized_url1]['hash']
@ -2811,7 +2788,7 @@ async def restore_snapshot(snapshot_path, git_helper_extras=None):
skip_node_packs.append(x[0])
for x in git_info.keys():
normalized_url = x.replace("git@github.com:", "https://github.com/")
normalized_url = git_utils.normalize_url(x)
cnr = unified_manager.repo_cnr_map.get(normalized_url)
if cnr is not None:
pack_id = cnr['id']
@ -2837,8 +2814,8 @@ async def restore_snapshot(snapshot_path, git_helper_extras=None):
if repo_url is None:
continue
normalized_url1 = repo_url.replace("git@github.com:", "https://github.com/")
normalized_url2 = repo_url.replace("https://github.com/", "git@github.com:")
normalized_url1 = git_utils.normalize_url(repo_url)
normalized_url2 = git_utils.normalize_url_http(repo_url)
if normalized_url1 not in git_info and normalized_url2 not in git_info:
todo_disable.append(k2)
@ -2859,8 +2836,8 @@ async def restore_snapshot(snapshot_path, git_helper_extras=None):
if repo_url is None:
continue
normalized_url1 = repo_url.replace("git@github.com:", "https://github.com/")
normalized_url2 = repo_url.replace("https://github.com/", "git@github.com:")
normalized_url1 = git_utils.normalize_url(repo_url)
normalized_url2 = git_utils.normalize_url_http(repo_url)
if normalized_url1 in git_info:
commit_hash = git_info[normalized_url1]['hash']

View File

@ -460,7 +460,7 @@ async def update_all(request):
return web.Response(status=403)
try:
core.save_snapshot_with_postfix('autosave')
await core.save_snapshot_with_postfix('autosave')
if request.rel_url.query["mode"] == "local":
channel = 'local'
@ -546,9 +546,12 @@ def populate_markdown(x):
async def installed_list(request):
unified_manager = core.unified_manager
await unified_manager.reload('cache')
await unified_manager.get_custom_nodes('default', 'cache')
return web.json_response({
node_id: package.version if package.is_from_cnr else package.get_commit_hash()
for node_id, package in unified_manager.installed_node_packages.items()
for node_id, package in unified_manager.installed_node_packages.items() if not package.disabled
}, content_type='application/json')
@ -696,7 +699,7 @@ async def restore_snapshot(request):
@routes.get("/snapshot/get_current")
async def get_current_snapshot_api(request):
try:
return web.json_response(core.get_current_snapshot(), content_type='application/json')
return web.json_response(await core.get_current_snapshot(), content_type='application/json')
except:
return web.Response(status=400)
@ -704,7 +707,7 @@ async def get_current_snapshot_api(request):
@routes.get("/snapshot/save")
async def save_snapshot(request):
try:
core.save_snapshot_with_postfix('snapshot')
await core.save_snapshot_with_postfix('snapshot')
return web.Response(status=200)
except:
return web.Response(status=400)

View File

@ -47,10 +47,9 @@ class InstalledNodePackage:
return True
@staticmethod
def from_fullpath(fullpath: str) -> InstalledNodePackage:
parent_folder_name = os.path.split(fullpath)[-2]
def from_fullpath(fullpath: str, resolve_from_path) -> InstalledNodePackage:
parent_folder_name = os.path.basename(os.path.dirname(fullpath))
module_name = os.path.basename(fullpath)
pyproject_toml_path = os.path.join(fullpath, "pyproject.toml")
if module_name.endswith(".disabled"):
node_id = module_name[:-9]
@ -63,16 +62,12 @@ class InstalledNodePackage:
node_id = module_name
disabled = False
if is_git_repo(fullpath):
version = "nightly"
elif os.path.exists(pyproject_toml_path):
# Read project.toml to get the version
with open(pyproject_toml_path, "r", encoding="utf-8") as f:
pyproject_toml = toml.load(f)
# Fallback to 'unknown' if project.version doesn't exist
version = pyproject_toml.get("project", {}).get("version", "unknown")
info = resolve_from_path(fullpath)
if info is None:
version = 'unknown'
else:
version = "unknown"
node_id = info['id'] # robust module guessing
version = info['ver']
return InstalledNodePackage(
id=node_id, fullpath=fullpath, disabled=disabled, version=version

View File

@ -317,7 +317,7 @@ async def share_art(request):
form.add_field("shareWorkflowTitle", title)
form.add_field("shareWorkflowDescription", description)
form.add_field("shareWorkflowIsNSFW", str(is_nsfw).lower())
form.add_field("currentSnapshot", json.dumps(core.get_current_snapshot()))
form.add_field("currentSnapshot", json.dumps(await core.get_current_snapshot()))
form.add_field("modelsInfo", json.dumps(models_info))
async with session.post(