Add the ability to specify commit ID /tag or branch when installing nodes

This commit is contained in:
chenyijian 2024-12-24 19:07:27 +08:00
parent 50fc1389b0
commit ae8ac2ea9f
2 changed files with 48 additions and 15 deletions

View File

@ -183,16 +183,27 @@ class Ctx:
cmd_ctx = Ctx()
def parse_node(node: str):
if '@@' in node:
name, commit = node.split('@@', 1)
else:
name, commit = node, None
return name, commit
def install_node(node_spec_str, is_all=False, cnt_msg=''):
def install_node(node_spec_str, is_all=False, cnt_msg='', **kwargs):
_, commit_id = parse_node(node_spec_str)
exit_on_fail = kwargs.get('exit_on_fail', False)
print(f"node_spec_str:{node_spec_str}, exit_on_fail:{exit_on_fail}...")
if core.is_valid_url(node_spec_str):
# install via urls
res = asyncio.run(core.gitclone_install(node_spec_str, no_deps=cmd_ctx.no_deps))
res = asyncio.run(core.gitclone_install(node_spec_str, no_deps=cmd_ctx.no_deps, commit_id=commit_id))
if not res.result:
print(res.msg)
print(f"[bold red]ERROR: An error occurred while installing '{node_spec_str}'.[/bold red]")
if exit_on_fail:
sys.exit(1)
else:
print(f"{cnt_msg} [INSTALLED] {node_spec_str:50}")
print(f"{cnt_msg} [INSTALLED] {node_spec_str:50} => {commit_id}")
else:
node_spec = unified_manager.resolve_node_spec(node_spec_str)
@ -205,7 +216,7 @@ def install_node(node_spec_str, is_all=False, cnt_msg=''):
if not is_specified:
version_spec = None
res = asyncio.run(unified_manager.install_by_id(node_name, version_spec, cmd_ctx.channel, cmd_ctx.mode, instant_execution=True, no_deps=cmd_ctx.no_deps))
res = asyncio.run(unified_manager.install_by_id(node_name, version_spec, cmd_ctx.channel, cmd_ctx.mode, instant_execution=True, no_deps=cmd_ctx.no_deps, commit_id=commit_id))
if res.action == 'skip':
print(f"{cnt_msg} [ SKIP ] {node_name:50} => Already installed")
@ -225,6 +236,8 @@ def install_node(node_spec_str, is_all=False, cnt_msg=''):
print("")
else:
print(f"[bold red]ERROR: An error occurred while installing '{node_name}'.\n{res.msg}[/bold red]")
if exit_on_fail:
sys.exit(1)
def reinstall_node(node_spec_str, is_all=False, cnt_msg=''):
@ -586,7 +599,7 @@ def get_all_installed_node_specs():
return res
def for_each_nodes(nodes, act, allow_all=True):
def for_each_nodes(nodes, act, allow_all=True, **kwargs):
is_all = False
if allow_all and 'all' in nodes:
is_all = True
@ -598,7 +611,7 @@ def for_each_nodes(nodes, act, allow_all=True):
i = 1
for x in nodes:
try:
act(x, is_all=is_all, cnt_msg=f'{i}/{total}')
act(x, is_all=is_all, cnt_msg=f'{i}/{total}', **kwargs)
except Exception as e:
print(f"ERROR: {e}")
traceback.print_exc()
@ -642,13 +655,17 @@ def install(
None,
help="user directory"
),
exit_on_fail: bool = typer.Option(
False,
help="Exit on failure"
)
):
cmd_ctx.set_user_directory(user_directory)
cmd_ctx.set_channel_mode(channel, mode)
cmd_ctx.set_no_deps(no_deps)
pip_fixer = manager_util.PIPFixer(manager_util.get_installed_packages(), comfy_path, core.manager_files_path)
for_each_nodes(nodes, act=install_node)
for_each_nodes(nodes, act=install_node, exit_on_fail=exit_on_fail)
pip_fixer.fix_broken()
@ -679,7 +696,7 @@ def reinstall(
user_directory: str = typer.Option(
None,
help="user directory"
),
)
):
cmd_ctx.set_user_directory(user_directory)
cmd_ctx.set_channel_mode(channel, mode)

View File

@ -1271,7 +1271,7 @@ class UnifiedManager:
return result
def repo_install(self, url: str, repo_path: str, instant_execution=False, no_deps=False, return_postinstall=False):
def repo_install(self, url: str, repo_path: str, instant_execution=False, no_deps=False, return_postinstall=False, commit_id=None):
result = ManagedResult('install-git')
result.append(url)
@ -1294,6 +1294,14 @@ class UnifiedManager:
return result.fail(f"Failed to clone repo: {clone_url}")
else:
repo = git.Repo.clone_from(clone_url, repo_path, recursive=True, progress=GitProgress())
if commit_id:
print(f"Checkout commit: {commit_id}")
try:
# Try checking out as a commit, branch, or tag
repo.git.checkout(commit_id)
except Exception as checkout_error:
print(f"Error checking out {commit_id}: {checkout_error}")
return False
repo.git.clear_cache()
repo.close()
@ -1398,7 +1406,7 @@ class UnifiedManager:
else:
return self.cnr_switch_version(node_id, instant_execution=instant_execution, no_deps=no_deps, return_postinstall=return_postinstall).with_ver('cnr')
async def install_by_id(self, node_id: str, version_spec=None, channel=None, mode=None, instant_execution=False, no_deps=False, return_postinstall=False):
async def install_by_id(self, node_id: str, version_spec=None, channel=None, mode=None, instant_execution=False, no_deps=False, return_postinstall=False, commit_id=None):
"""
priority if version_spec == None
1. CNR latest
@ -1448,7 +1456,7 @@ class UnifiedManager:
self.unified_disable(node_id, False)
to_path = os.path.abspath(os.path.join(get_default_custom_nodes_path(), node_id))
res = self.repo_install(repo_url, to_path, instant_execution=instant_execution, no_deps=no_deps, return_postinstall=return_postinstall)
res = self.repo_install(repo_url, to_path, instant_execution=instant_execution, no_deps=no_deps, return_postinstall=return_postinstall, commit_id=commit_id)
if res.result:
if version_spec == 'unknown':
self.unknown_active_nodes[node_id] = repo_url, to_path
@ -2073,12 +2081,11 @@ def is_valid_url(url):
return False
async def gitclone_install(url, instant_execution=False, msg_prefix='', no_deps=False):
async def gitclone_install(url, instant_execution=False, msg_prefix='', no_deps=False, commit_id=None):
await unified_manager.reload('cache')
await unified_manager.get_custom_nodes('default', 'cache')
print(f"{msg_prefix}Install: {url}")
print(f"{msg_prefix}Install: {url}:{commit_id}")
result = ManagedResult('install-git')
if not is_valid_url(url):
@ -2123,7 +2130,16 @@ async def gitclone_install(url, instant_execution=False, msg_prefix='', no_deps=
if res != 0:
return result.fail(f"Failed to clone '{clone_url}' into '{repo_path}'")
else:
repo = git.Repo.clone_from(clone_url, repo_path, recursive=True, progress=GitProgress())
repo = git.Repo.clone_from(url, repo_path, recursive=True, progress=GitProgress())
if commit_id:
print(f"Checkout commit: {commit_id}")
try:
# Try checking out as a commit, branch, or tag
repo.git.checkout(commit_id)
except Exception as checkout_error:
print(f"Error checking out {commit_id}: {checkout_error}")
return False
repo.git.clear_cache()
repo.close()