Add the ability to specify commit ID /tag or branch when installing nodes

This commit is contained in:
chenyijian 2024-12-24 19:07:27 +08:00
parent 50fc1389b0
commit ae8ac2ea9f
2 changed files with 48 additions and 15 deletions

View File

@ -183,16 +183,27 @@ class Ctx:
cmd_ctx = Ctx() cmd_ctx = Ctx()
def parse_node(node: str):
if '@@' in node:
name, commit = node.split('@@', 1)
else:
name, commit = node, None
return name, commit
def install_node(node_spec_str, is_all=False, cnt_msg=''): def install_node(node_spec_str, is_all=False, cnt_msg='', **kwargs):
_, commit_id = parse_node(node_spec_str)
exit_on_fail = kwargs.get('exit_on_fail', False)
print(f"node_spec_str:{node_spec_str}, exit_on_fail:{exit_on_fail}...")
if core.is_valid_url(node_spec_str): if core.is_valid_url(node_spec_str):
# install via urls # install via urls
res = asyncio.run(core.gitclone_install(node_spec_str, no_deps=cmd_ctx.no_deps)) res = asyncio.run(core.gitclone_install(node_spec_str, no_deps=cmd_ctx.no_deps, commit_id=commit_id))
if not res.result: if not res.result:
print(res.msg) print(res.msg)
print(f"[bold red]ERROR: An error occurred while installing '{node_spec_str}'.[/bold red]") print(f"[bold red]ERROR: An error occurred while installing '{node_spec_str}'.[/bold red]")
if exit_on_fail:
sys.exit(1)
else: else:
print(f"{cnt_msg} [INSTALLED] {node_spec_str:50}") print(f"{cnt_msg} [INSTALLED] {node_spec_str:50} => {commit_id}")
else: else:
node_spec = unified_manager.resolve_node_spec(node_spec_str) node_spec = unified_manager.resolve_node_spec(node_spec_str)
@ -205,7 +216,7 @@ def install_node(node_spec_str, is_all=False, cnt_msg=''):
if not is_specified: if not is_specified:
version_spec = None version_spec = None
res = asyncio.run(unified_manager.install_by_id(node_name, version_spec, cmd_ctx.channel, cmd_ctx.mode, instant_execution=True, no_deps=cmd_ctx.no_deps)) res = asyncio.run(unified_manager.install_by_id(node_name, version_spec, cmd_ctx.channel, cmd_ctx.mode, instant_execution=True, no_deps=cmd_ctx.no_deps, commit_id=commit_id))
if res.action == 'skip': if res.action == 'skip':
print(f"{cnt_msg} [ SKIP ] {node_name:50} => Already installed") print(f"{cnt_msg} [ SKIP ] {node_name:50} => Already installed")
@ -225,6 +236,8 @@ def install_node(node_spec_str, is_all=False, cnt_msg=''):
print("") print("")
else: else:
print(f"[bold red]ERROR: An error occurred while installing '{node_name}'.\n{res.msg}[/bold red]") print(f"[bold red]ERROR: An error occurred while installing '{node_name}'.\n{res.msg}[/bold red]")
if exit_on_fail:
sys.exit(1)
def reinstall_node(node_spec_str, is_all=False, cnt_msg=''): def reinstall_node(node_spec_str, is_all=False, cnt_msg=''):
@ -586,7 +599,7 @@ def get_all_installed_node_specs():
return res return res
def for_each_nodes(nodes, act, allow_all=True): def for_each_nodes(nodes, act, allow_all=True, **kwargs):
is_all = False is_all = False
if allow_all and 'all' in nodes: if allow_all and 'all' in nodes:
is_all = True is_all = True
@ -598,7 +611,7 @@ def for_each_nodes(nodes, act, allow_all=True):
i = 1 i = 1
for x in nodes: for x in nodes:
try: try:
act(x, is_all=is_all, cnt_msg=f'{i}/{total}') act(x, is_all=is_all, cnt_msg=f'{i}/{total}', **kwargs)
except Exception as e: except Exception as e:
print(f"ERROR: {e}") print(f"ERROR: {e}")
traceback.print_exc() traceback.print_exc()
@ -642,13 +655,17 @@ def install(
None, None,
help="user directory" help="user directory"
), ),
exit_on_fail: bool = typer.Option(
False,
help="Exit on failure"
)
): ):
cmd_ctx.set_user_directory(user_directory) cmd_ctx.set_user_directory(user_directory)
cmd_ctx.set_channel_mode(channel, mode) cmd_ctx.set_channel_mode(channel, mode)
cmd_ctx.set_no_deps(no_deps) cmd_ctx.set_no_deps(no_deps)
pip_fixer = manager_util.PIPFixer(manager_util.get_installed_packages(), comfy_path, core.manager_files_path) pip_fixer = manager_util.PIPFixer(manager_util.get_installed_packages(), comfy_path, core.manager_files_path)
for_each_nodes(nodes, act=install_node) for_each_nodes(nodes, act=install_node, exit_on_fail=exit_on_fail)
pip_fixer.fix_broken() pip_fixer.fix_broken()
@ -679,7 +696,7 @@ def reinstall(
user_directory: str = typer.Option( user_directory: str = typer.Option(
None, None,
help="user directory" help="user directory"
), )
): ):
cmd_ctx.set_user_directory(user_directory) cmd_ctx.set_user_directory(user_directory)
cmd_ctx.set_channel_mode(channel, mode) cmd_ctx.set_channel_mode(channel, mode)

View File

@ -1271,7 +1271,7 @@ class UnifiedManager:
return result return result
def repo_install(self, url: str, repo_path: str, instant_execution=False, no_deps=False, return_postinstall=False): def repo_install(self, url: str, repo_path: str, instant_execution=False, no_deps=False, return_postinstall=False, commit_id=None):
result = ManagedResult('install-git') result = ManagedResult('install-git')
result.append(url) result.append(url)
@ -1294,6 +1294,14 @@ class UnifiedManager:
return result.fail(f"Failed to clone repo: {clone_url}") return result.fail(f"Failed to clone repo: {clone_url}")
else: else:
repo = git.Repo.clone_from(clone_url, repo_path, recursive=True, progress=GitProgress()) repo = git.Repo.clone_from(clone_url, repo_path, recursive=True, progress=GitProgress())
if commit_id:
print(f"Checkout commit: {commit_id}")
try:
# Try checking out as a commit, branch, or tag
repo.git.checkout(commit_id)
except Exception as checkout_error:
print(f"Error checking out {commit_id}: {checkout_error}")
return False
repo.git.clear_cache() repo.git.clear_cache()
repo.close() repo.close()
@ -1398,7 +1406,7 @@ class UnifiedManager:
else: else:
return self.cnr_switch_version(node_id, instant_execution=instant_execution, no_deps=no_deps, return_postinstall=return_postinstall).with_ver('cnr') return self.cnr_switch_version(node_id, instant_execution=instant_execution, no_deps=no_deps, return_postinstall=return_postinstall).with_ver('cnr')
async def install_by_id(self, node_id: str, version_spec=None, channel=None, mode=None, instant_execution=False, no_deps=False, return_postinstall=False): async def install_by_id(self, node_id: str, version_spec=None, channel=None, mode=None, instant_execution=False, no_deps=False, return_postinstall=False, commit_id=None):
""" """
priority if version_spec == None priority if version_spec == None
1. CNR latest 1. CNR latest
@ -1448,7 +1456,7 @@ class UnifiedManager:
self.unified_disable(node_id, False) self.unified_disable(node_id, False)
to_path = os.path.abspath(os.path.join(get_default_custom_nodes_path(), node_id)) to_path = os.path.abspath(os.path.join(get_default_custom_nodes_path(), node_id))
res = self.repo_install(repo_url, to_path, instant_execution=instant_execution, no_deps=no_deps, return_postinstall=return_postinstall) res = self.repo_install(repo_url, to_path, instant_execution=instant_execution, no_deps=no_deps, return_postinstall=return_postinstall, commit_id=commit_id)
if res.result: if res.result:
if version_spec == 'unknown': if version_spec == 'unknown':
self.unknown_active_nodes[node_id] = repo_url, to_path self.unknown_active_nodes[node_id] = repo_url, to_path
@ -2073,12 +2081,11 @@ def is_valid_url(url):
return False return False
async def gitclone_install(url, instant_execution=False, msg_prefix='', no_deps=False): async def gitclone_install(url, instant_execution=False, msg_prefix='', no_deps=False, commit_id=None):
await unified_manager.reload('cache') await unified_manager.reload('cache')
await unified_manager.get_custom_nodes('default', 'cache') await unified_manager.get_custom_nodes('default', 'cache')
print(f"{msg_prefix}Install: {url}") print(f"{msg_prefix}Install: {url}:{commit_id}")
result = ManagedResult('install-git') result = ManagedResult('install-git')
if not is_valid_url(url): if not is_valid_url(url):
@ -2123,7 +2130,16 @@ async def gitclone_install(url, instant_execution=False, msg_prefix='', no_deps=
if res != 0: if res != 0:
return result.fail(f"Failed to clone '{clone_url}' into '{repo_path}'") return result.fail(f"Failed to clone '{clone_url}' into '{repo_path}'")
else: else:
repo = git.Repo.clone_from(clone_url, repo_path, recursive=True, progress=GitProgress()) repo = git.Repo.clone_from(url, repo_path, recursive=True, progress=GitProgress())
if commit_id:
print(f"Checkout commit: {commit_id}")
try:
# Try checking out as a commit, branch, or tag
repo.git.checkout(commit_id)
except Exception as checkout_error:
print(f"Error checking out {commit_id}: {checkout_error}")
return False
repo.git.clear_cache() repo.git.clear_cache()
repo.close() repo.close()