diff --git a/cm_cli/__main__.py b/cm_cli/__main__.py index 554a7e03..d6ad5e53 100644 --- a/cm_cli/__main__.py +++ b/cm_cli/__main__.py @@ -11,7 +11,6 @@ import typer from rich import print from typing_extensions import List, Annotated import re -import git import importlib @@ -62,9 +61,10 @@ if os.path.exists(os.path.join(manager_util.comfyui_manager_path, "pip_blacklist def check_comfyui_hash(): try: - repo = git.Repo(comfy_path) - core.comfy_ui_revision = len(list(repo.iter_commits('HEAD'))) - core.comfy_ui_commit_datetime = repo.head.commit.committed_datetime + from comfyui_manager.common.git_compat import open_repo + with open_repo(comfy_path) as repo: + core.comfy_ui_revision = repo.iter_commits_count() + core.comfy_ui_commit_datetime = repo.head_commit_datetime except Exception: print('[bold yellow]INFO: Frozen ComfyUI mode.[/bold yellow]') core.comfy_ui_revision = 0 diff --git a/comfyui_manager/common/context.py b/comfyui_manager/common/context.py index 88fa9089..774bad67 100644 --- a/comfyui_manager/common/context.py +++ b/comfyui_manager/common/context.py @@ -3,7 +3,7 @@ import os import logging from . import manager_util import toml -import git +from .git_compat import open_repo # read env vars @@ -98,8 +98,8 @@ def get_current_comfyui_ver(): def get_comfyui_tag(): try: - with git.Repo(comfy_path) as repo: - return repo.git.describe('--tags') + with open_repo(comfy_path) as repo: + return repo.describe_tags() except Exception: return None diff --git a/comfyui_manager/common/git_compat.py b/comfyui_manager/common/git_compat.py new file mode 100644 index 00000000..5fa0501d --- /dev/null +++ b/comfyui_manager/common/git_compat.py @@ -0,0 +1,844 @@ +""" +git_compat.py — Compatibility layer for git operations in ComfyUI-Manager. + +Wraps either GitPython (`git` module) or `pygit2`, depending on availability +and the CM_USE_PYGIT2 environment variable (set by Desktop 2.0 Launcher). + +Exports: + USE_PYGIT2 — bool: which backend is active + GitCommandError — exception class for git command failures + open_repo(path) — returns a repo wrapper object + clone_repo(url, dest, progress=None) — clone a repository + get_comfyui_tag(repo_path) — get describe --tags output + setup_git_environment(git_exe) — configure git executable path +""" + +import os +import sys +from abc import ABC, abstractmethod +from collections import deque +from datetime import datetime, timezone, timedelta + +# --------------------------------------------------------------------------- +# Backend selection +# --------------------------------------------------------------------------- + +_PYGIT2_REQUESTED = os.environ.get('CM_USE_PYGIT2', '').strip() == '1' +USE_PYGIT2 = _PYGIT2_REQUESTED + +if not USE_PYGIT2: + try: + import git as _git + _git.Git().execute(['git', '--version']) + except Exception: + USE_PYGIT2 = True + +if USE_PYGIT2: + try: + import pygit2 as _pygit2 + except ImportError: + # pygit2 not available either — fall back to GitPython and let it + # fail at the point of use, preserving pre-existing behavior. + USE_PYGIT2 = False + _PYGIT2_REQUESTED = False + import git as _git + else: + # Disable owner validation once at import time. + # Required for Desktop 2.0 standalone installs where repo directories + # may be owned by a different user (e.g., system-installed paths). + # See CVE-2022-24765 for context on this validation. + _pygit2.option(_pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0) + +if not USE_PYGIT2: + import git as _git + +# --------------------------------------------------------------------------- +# Shared exception type +# --------------------------------------------------------------------------- + +if USE_PYGIT2: + class GitCommandError(Exception): + """Stand-in for git.GitCommandError when using pygit2 backend.""" + pass +else: + from git import GitCommandError # noqa: F401 + +# --------------------------------------------------------------------------- +# Banner +# --------------------------------------------------------------------------- + +if USE_PYGIT2: + if _PYGIT2_REQUESTED: + print("[ComfyUI-Manager] Using pygit2 backend (CM_USE_PYGIT2=1)") + else: + print("[ComfyUI-Manager] Using pygit2 backend (system git not available)") +else: + print("[ComfyUI-Manager] Using GitPython backend") + + +# =================================================================== +# Abstract base class +# =================================================================== + +class GitRepo(ABC): + """Abstract interface for git repository operations.""" + + @property + @abstractmethod + def working_dir(self) -> str: ... + + @property + @abstractmethod + def head_commit_hexsha(self) -> str: ... + + @property + @abstractmethod + def head_is_detached(self) -> bool: ... + + @property + @abstractmethod + def head_commit_datetime(self): ... + + @property + @abstractmethod + def active_branch_name(self) -> str: ... + + @abstractmethod + def is_dirty(self) -> bool: ... + + @abstractmethod + def get_tracking_remote_name(self) -> str: ... + + @abstractmethod + def get_remote(self, name: str): ... + + @abstractmethod + def has_ref(self, ref_name: str) -> bool: ... + + @abstractmethod + def get_ref_commit_hexsha(self, ref_name: str) -> str: ... + + @abstractmethod + def get_ref_commit_datetime(self, ref_name: str): ... + + @abstractmethod + def list_remotes(self) -> list: ... + + @abstractmethod + def get_remote_url(self, index_or_name) -> str: ... + + @abstractmethod + def iter_commits_count(self) -> int: ... + + @abstractmethod + def symbolic_ref(self, ref: str) -> str: ... + + @abstractmethod + def describe_tags(self, exact_match=False): ... + + @abstractmethod + def list_tags(self) -> list: ... + + @abstractmethod + def list_heads(self) -> list: ... + + @abstractmethod + def list_branches(self) -> list: ... + + @abstractmethod + def get_head_by_name(self, name: str): ... + + @abstractmethod + def head_commit_equals(self, other_commit) -> bool: ... + + @abstractmethod + def get_ref_object(self, ref_name: str): ... + + @abstractmethod + def stash(self): ... + + @abstractmethod + def pull_ff_only(self): ... + + @abstractmethod + def reset_hard(self, ref: str): ... + + @abstractmethod + def create_backup_branch(self, name: str): ... + + @abstractmethod + def checkout(self, ref): ... + + @abstractmethod + def checkout_new_branch(self, branch_name: str, start_point: str): ... + + @abstractmethod + def submodule_update(self): ... + + @abstractmethod + def clear_cache(self): ... + + @abstractmethod + def fetch_remote_by_index(self, index): ... + + @abstractmethod + def pull_remote_by_index(self, index): ... + + @abstractmethod + def close(self): ... + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +# =================================================================== +# Helper types for tag/head/ref proxies +# =================================================================== + +class _TagProxy: + """Mimics a GitPython tag reference: .name and .commit.""" + def __init__(self, name, commit_obj): + self.name = name + self.commit = commit_obj + + +class _HeadProxy: + """Mimics a GitPython head reference: .name and .commit.""" + def __init__(self, name, commit_obj=None): + self.name = name + self.commit = commit_obj + + +class _RefProxy: + """Mimics a GitPython ref: .object.hexsha, .object.committed_datetime, .reference.commit.""" + def __init__(self, hexsha, committed_datetime, commit_obj=None): + self.object = type('obj', (), { + 'hexsha': hexsha, + 'committed_datetime': committed_datetime, + })() + self.reference = type('ref', (), {'commit': commit_obj})() if commit_obj is not None else None + + +class _RemoteProxy: + """Mimics a GitPython remote: .name, .url, .fetch(), .pull().""" + def __init__(self, name, url, fetch_fn, pull_fn=None): + self.name = name + self.url = url + self._fetch = fetch_fn + self._pull = pull_fn + + def fetch(self): + return self._fetch() + + def pull(self): + if self._pull is not None: + return self._pull() + raise GitCommandError("pull not supported on this remote") + + +# =================================================================== +# GitPython wrapper — 1:1 pass-throughs +# =================================================================== + +class _GitPythonRepo(GitRepo): + def __init__(self, path): + self._repo = _git.Repo(path) + + @property + def working_dir(self): + return self._repo.working_dir + + @property + def head_commit_hexsha(self): + return self._repo.head.commit.hexsha + + @property + def head_is_detached(self): + return self._repo.head.is_detached + + @property + def head_commit_datetime(self): + return self._repo.head.commit.committed_datetime + + @property + def active_branch_name(self): + return self._repo.active_branch.name + + def is_dirty(self): + return self._repo.is_dirty() + + def get_tracking_remote_name(self): + return self._repo.active_branch.tracking_branch().remote_name + + def get_remote(self, name): + r = self._repo.remote(name=name) + return _RemoteProxy(r.name, r.url, r.fetch, getattr(r, 'pull', None)) + + def has_ref(self, ref_name): + return ref_name in self._repo.refs + + def get_ref_commit_hexsha(self, ref_name): + return self._repo.refs[ref_name].object.hexsha + + def get_ref_commit_datetime(self, ref_name): + return self._repo.refs[ref_name].object.committed_datetime + + def list_remotes(self): + return [_RemoteProxy(r.name, r.url, r.fetch, getattr(r, 'pull', None)) + for r in self._repo.remotes] + + def get_remote_url(self, index_or_name): + if isinstance(index_or_name, int): + return self._repo.remotes[index_or_name].url + return self._repo.remote(name=index_or_name).url + + def iter_commits_count(self): + return len(list(self._repo.iter_commits('HEAD'))) + + def symbolic_ref(self, ref): + return self._repo.git.symbolic_ref(ref) + + def describe_tags(self, exact_match=False): + try: + if exact_match: + return self._repo.git.describe('--tags', '--exact-match') + else: + return self._repo.git.describe('--tags') + except Exception: + return None + + def list_tags(self): + return [_TagProxy(t.name, t.commit) for t in self._repo.tags] + + def list_heads(self): + return [_HeadProxy(h.name, h.commit) for h in self._repo.heads] + + def list_branches(self): + return [_HeadProxy(b.name, b.commit) for b in self._repo.branches] + + def get_head_by_name(self, name): + head = getattr(self._repo.heads, name) + return _HeadProxy(head.name, head.commit) + + def head_commit_equals(self, other_commit): + return self._repo.head.commit == other_commit + + def get_ref_object(self, ref_name): + ref = self._repo.refs[ref_name] + try: + ref_commit = ref.reference.commit + except (TypeError, AttributeError): + ref_commit = ref.object + return _RefProxy( + ref.object.hexsha, + ref.object.committed_datetime, + commit_obj=ref_commit, + ) + + def stash(self): + self._repo.git.stash() + + def pull_ff_only(self): + self._repo.git.pull('--ff-only') + + def reset_hard(self, ref): + self._repo.git.reset('--hard', ref) + + def create_backup_branch(self, name): + self._repo.create_head(name) + + def checkout(self, ref): + self._repo.git.checkout(ref) + + def checkout_new_branch(self, branch_name, start_point): + self._repo.git.checkout('-b', branch_name, start_point) + + def submodule_update(self): + self._repo.git.submodule('update', '--init', '--recursive') + + def clear_cache(self): + self._repo.git.clear_cache() + + def fetch_remote_by_index(self, index): + self._repo.remotes[index].fetch() + + def pull_remote_by_index(self, index): + self._repo.remotes[index].pull() + + def close(self): + self._repo.close() + + +# =================================================================== +# Pygit2 wrapper +# =================================================================== + +class _Pygit2Repo(GitRepo): + def __init__(self, path): + repo_path = os.path.abspath(path) + git_dir = os.path.join(repo_path, '.git') + for sub in ['refs/heads', 'refs/tags', 'refs/remotes']: + try: + os.makedirs(os.path.join(git_dir, sub), exist_ok=True) + except OSError: + pass + self._repo = _pygit2.Repository(git_dir) + self._working_dir = repo_path + + @property + def working_dir(self): + return self._working_dir + + @property + def head_commit_hexsha(self): + return str(self._repo.head.peel(_pygit2.Commit).id) + + @property + def head_is_detached(self): + return self._repo.head_is_detached + + @property + def head_commit_datetime(self): + commit = self._repo.head.peel(_pygit2.Commit) + ts = commit.commit_time + offset_minutes = commit.commit_time_offset + tz = timezone(timedelta(minutes=offset_minutes)) + return datetime.fromtimestamp(ts, tz=tz) + + @property + def active_branch_name(self): + ref = self._repo.head.name + if ref.startswith('refs/heads/'): + return ref[len('refs/heads/'):] + return ref + + def is_dirty(self): + st = self._repo.status() + for flags in st.values(): + if flags == _pygit2.GIT_STATUS_CURRENT: + continue + if flags == _pygit2.GIT_STATUS_IGNORED: + continue + if flags == _pygit2.GIT_STATUS_WT_NEW: + continue + return True + return False + + def get_tracking_remote_name(self): + branch = self._repo.branches.get(self.active_branch_name) + if branch is None: + raise GitCommandError("Cannot determine tracking branch: HEAD is detached or branch not found") + upstream = branch.upstream + if upstream is None: + raise GitCommandError(f"No upstream configured for branch '{self.active_branch_name}'") + # upstream.name can be "origin/master" or "refs/remotes/origin/master" + name = upstream.name + if name.startswith('refs/remotes/'): + name = name[len('refs/remotes/'):] + return name.split('/')[0] + + def get_remote(self, name): + remote = self._repo.remotes[name] + + def _pull(): + remote.fetch() + branch_name = self.active_branch_name + branch = self._repo.branches.get(branch_name) + if branch and branch.upstream: + remote_commit = branch.upstream.peel(_pygit2.Commit) + analysis, _ = self._repo.merge_analysis(remote_commit.id) + if analysis & _pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD: + self._repo.checkout_tree(self._repo.get(remote_commit.id)) + branch_ref = self._repo.references.get(f'refs/heads/{branch_name}') + if branch_ref is not None: + branch_ref.set_target(remote_commit.id) + self._repo.head.set_target(remote_commit.id) + + return _RemoteProxy(remote.name, remote.url, remote.fetch, _pull) + + def has_ref(self, ref_name): + for prefix in [f'refs/remotes/{ref_name}', f'refs/heads/{ref_name}', + f'refs/tags/{ref_name}', ref_name]: + try: + if self._repo.references.get(prefix) is not None: + return True + except Exception: + pass + return False + + def _resolve_ref(self, ref_name): + for prefix in [f'refs/remotes/{ref_name}', f'refs/heads/{ref_name}', + f'refs/tags/{ref_name}', ref_name]: + ref = self._repo.references.get(prefix) + if ref is not None: + return ref.peel(_pygit2.Commit) + raise GitCommandError(f"Reference not found: {ref_name}") + + def get_ref_commit_hexsha(self, ref_name): + return str(self._resolve_ref(ref_name).id) + + def get_ref_commit_datetime(self, ref_name): + commit = self._resolve_ref(ref_name) + ts = commit.commit_time + offset_minutes = commit.commit_time_offset + tz = timezone(timedelta(minutes=offset_minutes)) + return datetime.fromtimestamp(ts, tz=tz) + + def list_remotes(self): + result = [] + for r in self._repo.remotes: + result.append(_RemoteProxy(r.name, r.url, r.fetch)) + return result + + def get_remote_url(self, index_or_name): + if isinstance(index_or_name, int): + remotes = list(self._repo.remotes) + return remotes[index_or_name].url + return self._repo.remotes[index_or_name].url + + def iter_commits_count(self): + count = 0 + head_commit = self._repo.head.peel(_pygit2.Commit) + visited = set() + queue = deque([head_commit.id]) + while queue: + oid = queue.popleft() + if oid in visited: + continue + visited.add(oid) + count += 1 + commit = self._repo.get(oid) + if commit is not None: + for parent_id in commit.parent_ids: + if parent_id not in visited: + queue.append(parent_id) + return count + + def symbolic_ref(self, ref): + git_dir = self._repo.path + ref_file = os.path.join(git_dir, ref) + if os.path.isfile(ref_file): + with open(ref_file, 'r') as f: + content = f.read().strip() + if content.startswith('ref: '): + return content[5:] + return content + ref_obj = self._repo.references.get(ref) + if ref_obj is not None and ref_obj.type == _pygit2.GIT_REFERENCE_SYMBOLIC: + return ref_obj.target + raise GitCommandError(f"Not a symbolic reference: {ref}") + + def describe_tags(self, exact_match=False): + try: + if exact_match: + head_oid = self._repo.head.peel(_pygit2.Commit).id + for ref_name in self._repo.references: + if not ref_name.startswith('refs/tags/'): + continue + ref = self._repo.references.get(ref_name) + if ref is None: + continue + try: + if ref.peel(_pygit2.Commit).id == head_oid: + return ref_name[len('refs/tags/'):] + except Exception: + pass + return None + else: + import math + num_objects = sum(1 for _ in self._repo.odb) + abbrev = max(7, math.ceil(math.log2(max(num_objects, 1)) / 2)) if num_objects > 0 else 7 + return self._repo.describe( + describe_strategy=1, + abbreviated_size=abbrev, + ) + except Exception: + return None + + def list_tags(self): + tags = [] + for ref_name in self._repo.references: + if ref_name.startswith('refs/tags/'): + tag_name = ref_name[len('refs/tags/'):] + ref = self._repo.references.get(ref_name) + if ref is not None: + try: + commit = ref.peel(_pygit2.Commit) + commit_obj = type('commit', (), { + 'hexsha': str(commit.id), + 'committed_datetime': datetime.fromtimestamp( + commit.commit_time, + tz=timezone(timedelta(minutes=commit.commit_time_offset)) + ), + })() + tags.append(_TagProxy(tag_name, commit_obj)) + except Exception: + tags.append(_TagProxy(tag_name, None)) + return tags + + def list_heads(self): + heads = [] + for ref_name in self._repo.references: + if ref_name.startswith('refs/heads/'): + branch_name = ref_name[len('refs/heads/'):] + ref = self._repo.references.get(ref_name) + commit_obj = None + if ref is not None: + try: + commit = ref.peel(_pygit2.Commit) + commit_obj = type('commit', (), { + 'hexsha': str(commit.id), + 'committed_datetime': datetime.fromtimestamp( + commit.commit_time, + tz=timezone(timedelta(minutes=commit.commit_time_offset)) + ), + })() + except Exception: + pass + heads.append(_HeadProxy(branch_name, commit_obj)) + return heads + + def list_branches(self): + return self.list_heads() + + def get_head_by_name(self, name): + ref = self._repo.references.get(f'refs/heads/{name}') + if ref is None: + raise AttributeError(f"Head '{name}' not found") + try: + commit = ref.peel(_pygit2.Commit) + commit_obj = type('commit', (), { + 'hexsha': str(commit.id), + 'committed_datetime': datetime.fromtimestamp( + commit.commit_time, + tz=timezone(timedelta(minutes=commit.commit_time_offset)) + ), + })() + except Exception: + commit_obj = None + return _HeadProxy(name, commit_obj) + + def head_commit_equals(self, other_commit): + head_sha = str(self._repo.head.peel(_pygit2.Commit).id) + if hasattr(other_commit, 'hexsha'): + return head_sha == other_commit.hexsha + return head_sha == str(other_commit) + + def get_ref_object(self, ref_name): + commit = self._resolve_ref(ref_name) + hexsha = str(commit.id) + dt = datetime.fromtimestamp( + commit.commit_time, + tz=timezone(timedelta(minutes=commit.commit_time_offset)) + ) + commit_obj = type('commit', (), { + 'hexsha': hexsha, + 'committed_datetime': dt, + })() + return _RefProxy(hexsha, dt, commit_obj=commit_obj) + + def stash(self): + sig = _pygit2.Signature('comfyui-manager', 'manager@comfy') + self._repo.stash(sig) + + def pull_ff_only(self): + branch_name = self.active_branch_name + branch = self._repo.branches.get(branch_name) + if branch is None: + raise GitCommandError(f"Branch '{branch_name}' not found") + upstream = branch.upstream + if upstream is None: + raise GitCommandError(f"No upstream for branch '{branch_name}'") + + remote_name = upstream.remote_name + self._repo.remotes[remote_name].fetch() + + upstream = self._repo.branches.get(branch_name).upstream + if upstream is None: + raise GitCommandError(f"Upstream lost after fetch for '{branch_name}'") + + remote_commit = upstream.peel(_pygit2.Commit) + analysis, _ = self._repo.merge_analysis(remote_commit.id) + + if analysis & _pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE: + return + + if analysis & _pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD: + self._repo.checkout_tree(self._repo.get(remote_commit.id)) + branch_ref = self._repo.references.get(f'refs/heads/{branch_name}') + if branch_ref is not None: + branch_ref.set_target(remote_commit.id) + self._repo.head.set_target(remote_commit.id) + else: + raise GitCommandError("Cannot fast-forward; merge or rebase required") + + def reset_hard(self, ref): + commit = None + # Try as hex SHA first + try: + oid = _pygit2.Oid(hex=ref) + commit = self._repo.get(oid) + except (ValueError, Exception): + pass + + if commit is None: + # Try as named reference + for candidate in [ref, f'refs/remotes/{ref}', f'refs/heads/{ref}', f'refs/tags/{ref}']: + try: + ref_obj = self._repo.references.get(candidate) + if ref_obj is not None: + commit = ref_obj.peel(_pygit2.Commit) + break + except Exception: + continue + + if commit is None: + raise GitCommandError(f"Cannot resolve ref: {ref}") + + self._repo.reset(commit.id, _pygit2.GIT_RESET_HARD) + + def create_backup_branch(self, name): + head_commit = self._repo.head.peel(_pygit2.Commit) + self._repo.branches.local.create(name, head_commit) + + def checkout(self, ref): + # ref can be a _HeadProxy from get_head_by_name + if isinstance(ref, _HeadProxy): + ref = ref.name + + branch = self._repo.branches.get(ref) + if branch is not None: + branch_ref = self._repo.lookup_reference(f'refs/heads/{ref}') + self._repo.checkout(branch_ref) + self._repo.set_head(branch_ref.name) + return + + for prefix in [f'refs/remotes/{ref}', f'refs/tags/{ref}']: + ref_obj = self._repo.references.get(prefix) + if ref_obj is not None: + commit = ref_obj.peel(_pygit2.Commit) + self._repo.checkout_tree(self._repo.get(commit.id)) + self._repo.set_head(commit.id) + return + + try: + oid = _pygit2.Oid(hex=ref) + obj = self._repo.get(oid) + if obj is not None: + commit = obj.peel(_pygit2.Commit) + self._repo.checkout_tree(self._repo.get(commit.id)) + self._repo.set_head(commit.id) + return + except Exception: + pass + + raise GitCommandError(f"Cannot resolve ref for checkout: {ref}") + + def checkout_new_branch(self, branch_name, start_point): + commit = self._resolve_ref(start_point) + branch = self._repo.branches.local.create(branch_name, commit) + for prefix in [f'refs/remotes/{start_point}']: + remote_ref = self._repo.references.get(prefix) + if remote_ref is not None: + try: + branch.upstream = remote_ref + except Exception: + pass + break + self._repo.checkout(branch) + self._repo.set_head(branch.name) + + def submodule_update(self): + try: + self._repo.submodules.init() + self._repo.submodules.update() + except Exception: + import subprocess + try: + result = subprocess.run( + ['git', 'submodule', 'update', '--init', '--recursive'], + cwd=self._working_dir, + capture_output=True, timeout=120, + ) + if result.returncode != 0: + raise GitCommandError( + f"submodule update failed (exit {result.returncode}): " + f"{result.stderr.decode(errors='replace')}") + except FileNotFoundError: + print("[ComfyUI-Manager] pygit2: submodule update requires system git (not installed)", file=sys.stderr) + except GitCommandError: + raise + except Exception as sub_e: + print(f"[ComfyUI-Manager] pygit2: submodule update failed: {sub_e}", file=sys.stderr) + + def clear_cache(self): + pass + + def fetch_remote_by_index(self, index): + remotes = list(self._repo.remotes) + remotes[index].fetch() + + def pull_remote_by_index(self, index): + remotes = list(self._repo.remotes) + remote = remotes[index] + remote.fetch() + # After fetch, try to ff-merge tracking branch + try: + branch_name = self.active_branch_name + branch = self._repo.branches.get(branch_name) + if branch and branch.upstream: + remote_commit = branch.upstream.peel(_pygit2.Commit) + analysis, _ = self._repo.merge_analysis(remote_commit.id) + if analysis & _pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD: + self._repo.checkout_tree(self._repo.get(remote_commit.id)) + branch_ref = self._repo.references.get(f'refs/heads/{branch_name}') + if branch_ref is not None: + branch_ref.set_target(remote_commit.id) + self._repo.head.set_target(remote_commit.id) + except Exception: + pass + + def close(self): + self._repo.free() + + +# =================================================================== +# Public API +# =================================================================== + +def open_repo(path) -> GitRepo: + """Open a repository and return a backend-appropriate wrapper.""" + if USE_PYGIT2: + return _Pygit2Repo(path) + else: + return _GitPythonRepo(path) + + +def clone_repo(url, dest, progress=None): + """Clone a repository from *url* into *dest*. + + Returns a repo wrapper that the caller can use for post-clone operations + (checkout, clear_cache, close, etc.). + """ + if USE_PYGIT2: + _pygit2.clone_repository(url, dest) + repo = _Pygit2Repo(dest) + repo.submodule_update() + return repo + else: + if progress is None: + r = _git.Repo.clone_from(url, dest, recursive=True) + else: + r = _git.Repo.clone_from(url, dest, recursive=True, progress=progress) + return _GitPythonRepo(r.working_dir) + + +def setup_git_environment(git_exe): + """Configure the git executable path (GitPython only).""" + if USE_PYGIT2: + return + if git_exe: + _git.Git().update_environment(GIT_PYTHON_GIT_EXECUTABLE=git_exe) diff --git a/comfyui_manager/common/git_helper.py b/comfyui_manager/common/git_helper.py index 11605b74..f2b10991 100644 --- a/comfyui_manager/common/git_helper.py +++ b/comfyui_manager/common/git_helper.py @@ -3,12 +3,18 @@ import sys import os import traceback -import git +# Make git_compat importable as a standalone subprocess script +sys.path.insert(0, os.path.dirname(__file__)) + +from git_compat import open_repo, clone_repo, GitCommandError, setup_git_environment import json import yaml import requests from tqdm.auto import tqdm -from git.remote import RemoteProgress +try: + from git.remote import RemoteProgress +except ImportError: + RemoteProgress = object comfy_path = os.environ.get('COMFYUI_PATH') @@ -79,7 +85,7 @@ def get_backup_branch_name(repo=None): return base_name try: - existing_branches = {b.name for b in repo.heads} + existing_branches = {b.name for b in repo.list_heads()} except Exception: return base_name @@ -117,61 +123,60 @@ def gitclone(custom_nodes_path, url, target_hash=None, repo_path=None): # Disable tqdm progress when stderr is piped to avoid deadlock on Windows. progress = GitProgress() if sys.stderr.isatty() else None - repo = git.Repo.clone_from(url, repo_path, recursive=True, progress=progress) + repo = clone_repo(url, repo_path, progress=progress) if target_hash is not None: print(f"CHECKOUT: {repo_name} [{target_hash}]") - repo.git.checkout(target_hash) + repo.checkout(target_hash) - repo.git.clear_cache() + repo.clear_cache() repo.close() def gitcheck(path, do_fetch=False): try: # Fetch the latest commits from the remote repository - repo = git.Repo(path) + with open_repo(path) as repo: - if repo.head.is_detached: - print("CUSTOM NODE CHECK: True") - return - - current_branch = repo.active_branch - branch_name = current_branch.name - - remote_name = current_branch.tracking_branch().remote_name - remote = repo.remote(name=remote_name) - - if do_fetch: - remote.fetch() - - # Get the current commit hash and the commit hash of the remote branch - commit_hash = repo.head.commit.hexsha - - if f'{remote_name}/{branch_name}' in repo.refs: - remote_commit_hash = repo.refs[f'{remote_name}/{branch_name}'].object.hexsha - else: - print("CUSTOM NODE CHECK: True") # non default branch is treated as updatable - return - - # Compare the commit hashes to determine if the local repository is behind the remote repository - if commit_hash != remote_commit_hash: - # Get the commit dates - commit_date = repo.head.commit.committed_datetime - remote_commit_date = repo.refs[f'{remote_name}/{branch_name}'].object.committed_datetime - - # Compare the commit dates to determine if the local repository is behind the remote repository - if commit_date < remote_commit_date: + if repo.head_is_detached: print("CUSTOM NODE CHECK: True") - else: - print("CUSTOM NODE CHECK: False") + return + + branch_name = repo.active_branch_name + + remote_name = repo.get_tracking_remote_name() + remote = repo.get_remote(remote_name) + + if do_fetch: + remote.fetch() + + # Get the current commit hash and the commit hash of the remote branch + commit_hash = repo.head_commit_hexsha + + if repo.has_ref(f'{remote_name}/{branch_name}'): + remote_commit_hash = repo.get_ref_commit_hexsha(f'{remote_name}/{branch_name}') + else: + print("CUSTOM NODE CHECK: True") # non default branch is treated as updatable + return + + # Compare the commit hashes to determine if the local repository is behind the remote repository + if commit_hash != remote_commit_hash: + # Get the commit dates + commit_date = repo.head_commit_datetime + remote_commit_date = repo.get_ref_commit_datetime(f'{remote_name}/{branch_name}') + + # Compare the commit dates to determine if the local repository is behind the remote repository + if commit_date < remote_commit_date: + print("CUSTOM NODE CHECK: True") + else: + print("CUSTOM NODE CHECK: False") except Exception as e: print(e) print("CUSTOM NODE CHECK: Error") def get_remote_name(repo): - available_remotes = [remote.name for remote in repo.remotes] + available_remotes = [remote.name for remote in repo.list_remotes()] if 'origin' in available_remotes: return 'origin' elif 'upstream' in available_remotes: @@ -196,28 +201,28 @@ def switch_to_default_branch(repo): if remote_name is None: return False - default_branch = repo.git.symbolic_ref(f'refs/remotes/{remote_name}/HEAD').replace(f'refs/remotes/{remote_name}/', '') - repo.git.checkout(default_branch) + default_branch = repo.symbolic_ref(f'refs/remotes/{remote_name}/HEAD').replace(f'refs/remotes/{remote_name}/', '') + repo.checkout(default_branch) return True except Exception: # try checkout master # try checkout main if failed try: - repo.git.checkout(repo.heads.master) + repo.checkout(repo.get_head_by_name('master')) return True except Exception: try: if remote_name is not None: - repo.git.checkout('-b', 'master', f'{remote_name}/master') + repo.checkout_new_branch('master', f'{remote_name}/master') return True except Exception: try: - repo.git.checkout(repo.heads.main) + repo.checkout(repo.get_head_by_name('main')) return True except Exception: try: if remote_name is not None: - repo.git.checkout('-b', 'main', f'{remote_name}/main') + repo.checkout_new_branch('main', f'{remote_name}/main') return True except Exception: pass @@ -232,72 +237,67 @@ def gitpull(path): raise ValueError('Not a git repository') # Pull the latest changes from the remote repository - repo = git.Repo(path) - if repo.is_dirty(): - print(f"STASH: '{path}' is dirty.") - repo.git.stash() - - commit_hash = repo.head.commit.hexsha - try: - if repo.head.is_detached: - switch_to_default_branch(repo) - - current_branch = repo.active_branch - branch_name = current_branch.name - - remote_name = current_branch.tracking_branch().remote_name - remote = repo.remote(name=remote_name) - - if f'{remote_name}/{branch_name}' not in repo.refs: - switch_to_default_branch(repo) - current_branch = repo.active_branch - branch_name = current_branch.name - - remote.fetch() - if f'{remote_name}/{branch_name}' in repo.refs: - remote_commit_hash = repo.refs[f'{remote_name}/{branch_name}'].object.hexsha - else: - print("CUSTOM NODE PULL: Fail") # update fail - return - - if commit_hash == remote_commit_hash: - print("CUSTOM NODE PULL: None") # there is no update - repo.close() - return + with open_repo(path) as repo: + if repo.is_dirty(): + print(f"STASH: '{path}' is dirty.") + repo.stash() + commit_hash = repo.head_commit_hexsha try: - repo.git.pull('--ff-only') - except git.GitCommandError: - backup_name = get_backup_branch_name(repo) - repo.create_head(backup_name) - print(f"[ComfyUI-Manager] Cannot fast-forward. Backup created: {backup_name}") - repo.git.reset('--hard', f'{remote_name}/{branch_name}') - print(f"[ComfyUI-Manager] Reset to {remote_name}/{branch_name}") + if repo.head_is_detached: + switch_to_default_branch(repo) - repo.git.submodule('update', '--init', '--recursive') - new_commit_hash = repo.head.commit.hexsha + branch_name = repo.active_branch_name - if commit_hash != new_commit_hash: - print("CUSTOM NODE PULL: Success") # update success - else: - print("CUSTOM NODE PULL: Fail") # update fail - except Exception as e: - print(e) - print("CUSTOM NODE PULL: Fail") # unknown git error + remote_name = repo.get_tracking_remote_name() + remote = repo.get_remote(remote_name) - repo.close() + if not repo.has_ref(f'{remote_name}/{branch_name}'): + switch_to_default_branch(repo) + branch_name = repo.active_branch_name + + remote.fetch() + if repo.has_ref(f'{remote_name}/{branch_name}'): + remote_commit_hash = repo.get_ref_commit_hexsha(f'{remote_name}/{branch_name}') + else: + print("CUSTOM NODE PULL: Fail") # update fail + return + + if commit_hash == remote_commit_hash: + print("CUSTOM NODE PULL: None") # there is no update + return + + try: + repo.pull_ff_only() + except GitCommandError: + backup_name = get_backup_branch_name(repo) + repo.create_backup_branch(backup_name) + print(f"[ComfyUI-Manager] Cannot fast-forward. Backup created: {backup_name}") + repo.reset_hard(f'{remote_name}/{branch_name}') + print(f"[ComfyUI-Manager] Reset to {remote_name}/{branch_name}") + + repo.submodule_update() + new_commit_hash = repo.head_commit_hexsha + + if commit_hash != new_commit_hash: + print("CUSTOM NODE PULL: Success") # update success + else: + print("CUSTOM NODE PULL: Fail") # update fail + except Exception as e: + print(e) + print("CUSTOM NODE PULL: Fail") # unknown git error def checkout_comfyui_hash(target_hash): - repo = git.Repo(comfy_path) - commit_hash = repo.head.commit.hexsha + with open_repo(comfy_path) as repo: + commit_hash = repo.head_commit_hexsha - if commit_hash != target_hash: - try: - print(f"CHECKOUT: ComfyUI [{target_hash}]") - repo.git.checkout(target_hash) - except git.GitCommandError as e: - print(f"Error checking out the ComfyUI: {str(e)}") + if commit_hash != target_hash: + try: + print(f"CHECKOUT: ComfyUI [{target_hash}]") + repo.checkout(target_hash) + except GitCommandError as e: + print(f"Error checking out the ComfyUI: {str(e)}") def checkout_custom_node_hash(git_custom_node_infos): @@ -359,12 +359,12 @@ def checkout_custom_node_hash(git_custom_node_infos): need_checkout = True if need_checkout: - repo = git.Repo(fullpath) - commit_hash = repo.head.commit.hexsha + with open_repo(fullpath) as repo: + commit_hash = repo.head_commit_hexsha - if commit_hash != item['hash']: - print(f"CHECKOUT: {repo_name} [{item['hash']}]") - repo.git.checkout(item['hash']) + if commit_hash != item['hash']: + print(f"CHECKOUT: {repo_name} [{item['hash']}]") + repo.checkout(item['hash']) except Exception: print(f"Failed to restore snapshots for the custom node '{path}'") @@ -539,7 +539,7 @@ def restore_pip_snapshot(pips, options): def setup_environment(): if git_exe_path is not None: - git.Git().update_environment(GIT_PYTHON_GIT_EXECUTABLE=git_exe_path) + setup_git_environment(git_exe_path) setup_environment() diff --git a/comfyui_manager/common/timestamp_utils.py b/comfyui_manager/common/timestamp_utils.py index 772817c4..f5c597eb 100644 --- a/comfyui_manager/common/timestamp_utils.py +++ b/comfyui_manager/common/timestamp_utils.py @@ -85,7 +85,7 @@ def get_backup_branch_name(repo=None) -> str: # Check if branch exists try: - existing_branches = {b.name for b in repo.heads} + existing_branches = {b.name for b in repo.list_heads()} except Exception: return base_name diff --git a/comfyui_manager/glob/manager_core.py b/comfyui_manager/glob/manager_core.py index 3c9eb23f..187e91e0 100644 --- a/comfyui_manager/glob/manager_core.py +++ b/comfyui_manager/glob/manager_core.py @@ -13,9 +13,12 @@ import shutil import configparser import platform -import git +from ..common.git_compat import open_repo, clone_repo, GitCommandError +try: + from git.remote import RemoteProgress +except ImportError: + RemoteProgress = object from comfyui_manager.common.timestamp_utils import get_timestamp_for_path, get_backup_branch_name -from git.remote import RemoteProgress from urllib.parse import urlparse from tqdm.auto import tqdm import time @@ -41,7 +44,7 @@ from ..common.enums import NetworkMode, SecurityLevel, DBMode from ..common import context -version_code = [4, 1] +version_code = [4, 2] version_str = f"V{version_code[0]}.{version_code[1]}" + (f'.{version_code[2]}' if len(version_code) > 2 else '') @@ -93,6 +96,9 @@ def get_script_env(): if 'COMFYUI_FOLDERS_BASE_PATH' not in new_env: new_env['COMFYUI_FOLDERS_BASE_PATH'] = context.comfy_path + if 'CM_USE_PYGIT2' in os.environ: + new_env['CM_USE_PYGIT2'] = os.environ['CM_USE_PYGIT2'] + return new_env @@ -1344,8 +1350,8 @@ class UnifiedManager: if res != 0: return result.fail(f"Failed to clone repo: {clone_url}") else: - repo = git.Repo.clone_from(clone_url, repo_path, recursive=True, progress=GitProgress()) - repo.git.clear_cache() + repo = clone_repo(clone_url, repo_path, progress=GitProgress()) + repo.clear_cache() repo.close() def postinstall(): @@ -1371,24 +1377,23 @@ class UnifiedManager: return result.fail(f'Path not found: {repo_path}') # version check - with git.Repo(repo_path) as repo: - if repo.head.is_detached: + with open_repo(repo_path) as repo: + if repo.head_is_detached: if not switch_to_default_branch(repo): return result.fail(f"Failed to switch to default branch: {repo_path}") - current_branch = repo.active_branch - branch_name = current_branch.name + branch_name = repo.active_branch_name - if current_branch.tracking_branch() is None: - print(f"[ComfyUI-Manager] There is no tracking branch ({current_branch})") + try: + remote_name = repo.get_tracking_remote_name() + except Exception: + print(f"[ComfyUI-Manager] There is no tracking branch ({branch_name})") remote_name = get_remote_name(repo) - else: - remote_name = current_branch.tracking_branch().remote_name if remote_name is None: return result.fail(f"Failed to get remote when installing: {repo_path}") - remote = repo.remote(name=remote_name) + remote = repo.get_remote(remote_name) try: remote.fetch() @@ -1405,17 +1410,17 @@ class UnifiedManager: f'git config --global --add safe.directory "{safedir_path}"\n' "-----------------------------------------------------------------------------------------\n") - commit_hash = repo.head.commit.hexsha - if f'{remote_name}/{branch_name}' in repo.refs: - remote_commit_hash = repo.refs[f'{remote_name}/{branch_name}'].object.hexsha + commit_hash = repo.head_commit_hexsha + if repo.has_ref(f'{remote_name}/{branch_name}'): + remote_commit_hash = repo.get_ref_commit_hexsha(f'{remote_name}/{branch_name}') else: return result.fail(f"Not updatable branch: {branch_name}") if commit_hash != remote_commit_hash: git_pull(repo_path) - if len(repo.remotes) > 0: - url = repo.remotes[0].url + if len(repo.list_remotes()) > 0: + url = repo.get_remote_url(0) else: url = "unknown repo" @@ -1788,7 +1793,7 @@ def get_config(): def get_remote_name(repo): - available_remotes = [remote.name for remote in repo.remotes] + available_remotes = [remote.name for remote in repo.list_remotes()] if 'origin' in available_remotes: return 'origin' elif 'upstream' in available_remotes: @@ -1813,28 +1818,28 @@ def switch_to_default_branch(repo): if remote_name is None: return False - default_branch = repo.git.symbolic_ref(f'refs/remotes/{remote_name}/HEAD').replace(f'refs/remotes/{remote_name}/', '') - repo.git.checkout(default_branch) + default_branch = repo.symbolic_ref(f'refs/remotes/{remote_name}/HEAD').replace(f'refs/remotes/{remote_name}/', '') + repo.checkout(default_branch) return True except Exception: # try checkout master # try checkout main if failed try: - repo.git.checkout(repo.heads.master) + repo.checkout(repo.get_head_by_name('master')) return True except Exception: try: if remote_name is not None: - repo.git.checkout('-b', 'master', f'{remote_name}/master') + repo.checkout_new_branch('master', f'{remote_name}/master') return True except Exception: try: - repo.git.checkout(repo.heads.main) + repo.checkout(repo.get_head_by_name('main')) return True except Exception: try: if remote_name is not None: - repo.git.checkout('-b', 'main', f'{remote_name}/main') + repo.checkout_new_branch('main', f'{remote_name}/main') return True except Exception: pass @@ -2072,104 +2077,95 @@ def git_repo_update_check_with(path, do_fetch=False, do_update=False, no_deps=Fa return updated, success else: # Fetch the latest commits from the remote repository - repo = git.Repo(path) + with open_repo(path) as repo: + remote_name = get_remote_name(repo) - remote_name = get_remote_name(repo) + if remote_name is None: + raise ValueError(f"No remotes are configured for this repository: {path}") - if remote_name is None: - raise ValueError(f"No remotes are configured for this repository: {path}") + remote = repo.get_remote(remote_name) - remote = repo.remote(name=remote_name) + if not do_update and repo.head_is_detached: + if do_fetch: + remote.fetch() - if not do_update and repo.head.is_detached: - if do_fetch: + return True, True # detached branch is treated as updatable + + if repo.head_is_detached: + if not switch_to_default_branch(repo): + raise ValueError(f"Failed to switch detached branch to default branch: {path}") + + branch_name = repo.active_branch_name + + # Get the current commit hash + commit_hash = repo.head_commit_hexsha + + if do_fetch or do_update: remote.fetch() - return True, True # detached branch is treated as updatable + if do_update: + if repo.is_dirty(): + print(f"\nSTASH: '{path}' is dirty.") + repo.stash() - if repo.head.is_detached: - if not switch_to_default_branch(repo): - raise ValueError(f"Failed to switch detached branch to default branch: {path}") + if not repo.has_ref(f'{remote_name}/{branch_name}'): + if not switch_to_default_branch(repo): + raise ValueError(f"Failed to switch to default branch while updating: {path}") - current_branch = repo.active_branch - branch_name = current_branch.name + branch_name = repo.active_branch_name - # Get the current commit hash - commit_hash = repo.head.commit.hexsha - - if do_fetch or do_update: - remote.fetch() - - if do_update: - if repo.is_dirty(): - print(f"\nSTASH: '{path}' is dirty.") - repo.git.stash() - - if f'{remote_name}/{branch_name}' not in repo.refs: - if not switch_to_default_branch(repo): - raise ValueError(f"Failed to switch to default branch while updating: {path}") - - current_branch = repo.active_branch - branch_name = current_branch.name - - if f'{remote_name}/{branch_name}' in repo.refs: - remote_commit_hash = repo.refs[f'{remote_name}/{branch_name}'].object.hexsha - else: - return False, False - - if commit_hash == remote_commit_hash: - repo.close() - return False, True - - try: - try: - repo.git.pull('--ff-only') - except git.GitCommandError: - backup_name = get_backup_branch_name(repo) - repo.create_head(backup_name) - logging.info(f"[ComfyUI-Manager] Cannot fast-forward. Backup created: {backup_name}") - repo.git.reset('--hard', f'{remote_name}/{branch_name}') - logging.info(f"[ComfyUI-Manager] Reset to {remote_name}/{branch_name}") - - repo.git.submodule('update', '--init', '--recursive') - new_commit_hash = repo.head.commit.hexsha - - if commit_hash != new_commit_hash: - execute_install_script(None, path, no_deps=no_deps) - print(f"\x1b[2K\rUpdated: {path}") - return True, True + if repo.has_ref(f'{remote_name}/{branch_name}'): + remote_commit_hash = repo.get_ref_commit_hexsha(f'{remote_name}/{branch_name}') else: return False, False - except Exception as e: - print(f"\nUpdating failed: {path}\n{e}", file=sys.stderr) - return False, False + if commit_hash == remote_commit_hash: + return False, True - if repo.head.is_detached: - repo.close() - return True, True + try: + try: + repo.pull_ff_only() + except GitCommandError: + backup_name = get_backup_branch_name(repo) + repo.create_backup_branch(backup_name) + logging.info(f"[ComfyUI-Manager] Cannot fast-forward. Backup created: {backup_name}") + repo.reset_hard(f'{remote_name}/{branch_name}') + logging.info(f"[ComfyUI-Manager] Reset to {remote_name}/{branch_name}") - # Get commit hash of the remote branch - current_branch = repo.active_branch - branch_name = current_branch.name + repo.submodule_update() + new_commit_hash = repo.head_commit_hexsha - if f'{remote_name}/{branch_name}' in repo.refs: - remote_commit_hash = repo.refs[f'{remote_name}/{branch_name}'].object.hexsha - else: - return True, True # Assuming there's an update if it's not the default branch. + if commit_hash != new_commit_hash: + execute_install_script(None, path, no_deps=no_deps) + print(f"\x1b[2K\rUpdated: {path}") + return True, True + else: + return False, False - # Compare the commit hashes to determine if the local repository is behind the remote repository - if commit_hash != remote_commit_hash: - # Get the commit dates - commit_date = repo.head.commit.committed_datetime - remote_commit_date = repo.refs[f'{remote_name}/{branch_name}'].object.committed_datetime + except Exception as e: + print(f"\nUpdating failed: {path}\n{e}", file=sys.stderr) + return False, False - # Compare the commit dates to determine if the local repository is behind the remote repository - if commit_date < remote_commit_date: - repo.close() + if repo.head_is_detached: return True, True - repo.close() + # Get commit hash of the remote branch + branch_name = repo.active_branch_name + + if repo.has_ref(f'{remote_name}/{branch_name}'): + remote_commit_hash = repo.get_ref_commit_hexsha(f'{remote_name}/{branch_name}') + else: + return True, True # Assuming there's an update if it's not the default branch. + + # Compare the commit hashes to determine if the local repository is behind the remote repository + if commit_hash != remote_commit_hash: + # Get the commit dates + commit_date = repo.head_commit_datetime + remote_commit_date = repo.get_ref_commit_datetime(f'{remote_name}/{branch_name}') + + # Compare the commit dates to determine if the local repository is behind the remote repository + if commit_date < remote_commit_date: + return True, True return False, True @@ -2259,12 +2255,12 @@ async def gitclone_install(url, instant_execution=False, msg_prefix='', no_deps= if res != 0: return result.fail(f"Failed to clone '{clone_url}' into '{repo_path}'") else: - repo = git.Repo.clone_from(clone_url, repo_path, recursive=True, progress=GitProgress()) + repo = clone_repo(clone_url, repo_path, progress=GitProgress()) if commit_id!= "": - repo.git.checkout(commit_id) - repo.git.submodule('update', '--init', '--recursive') + repo.checkout(commit_id) + repo.submodule_update() - repo.git.clear_cache() + repo.clear_cache() repo.close() execute_install_script(url, repo_path, instant_execution=instant_execution, no_deps=no_deps) @@ -2286,32 +2282,28 @@ def git_pull(path): if platform.system() == "Windows": return __win_check_git_pull(path) else: - repo = git.Repo(path) + with open_repo(path) as repo: + if repo.is_dirty(): + print(f"STASH: '{path}' is dirty.") + repo.stash() - if repo.is_dirty(): - print(f"STASH: '{path}' is dirty.") - repo.git.stash() + if repo.head_is_detached: + if not switch_to_default_branch(repo): + raise ValueError(f"Failed to switch to default branch while pulling: {path}") - if repo.head.is_detached: - if not switch_to_default_branch(repo): - raise ValueError(f"Failed to switch to default branch while pulling: {path}") + branch_name = repo.active_branch_name + remote_name = repo.get_tracking_remote_name() - current_branch = repo.active_branch - remote_name = current_branch.tracking_branch().remote_name - branch_name = current_branch.name + try: + repo.pull_ff_only() + except GitCommandError: + backup_name = get_backup_branch_name(repo) + repo.create_backup_branch(backup_name) + logging.info(f"[ComfyUI-Manager] Cannot fast-forward. Backup created: {backup_name}") + repo.reset_hard(f'{remote_name}/{branch_name}') + logging.info(f"[ComfyUI-Manager] Reset to {remote_name}/{branch_name}") - try: - repo.git.pull('--ff-only') - except git.GitCommandError: - backup_name = get_backup_branch_name(repo) - repo.create_head(backup_name) - logging.info(f"[ComfyUI-Manager] Cannot fast-forward. Backup created: {backup_name}") - repo.git.reset('--hard', f'{remote_name}/{branch_name}') - logging.info(f"[ComfyUI-Manager] Reset to {remote_name}/{branch_name}") - - repo.git.submodule('update', '--init', '--recursive') - - repo.close() + repo.submodule_update() return True @@ -2567,33 +2559,33 @@ def gitclone_update(files, instant_execution=False, skip_script=False, msg_prefi def update_to_stable_comfyui(repo_path): try: - repo = git.Repo(repo_path) - try: - repo.git.checkout(repo.heads.master) - except Exception: - logging.error(f"[ComfyUI-Manager] Failed to checkout 'master' branch.\nrepo_path={repo_path}\nAvailable branches:") - for branch in repo.branches: - logging.error('\t'+branch.name) - return "fail", None + with open_repo(repo_path) as repo: + try: + repo.checkout(repo.get_head_by_name('master')) + except Exception: + logging.error(f"[ComfyUI-Manager] Failed to checkout 'master' branch.\nrepo_path={repo_path}\nAvailable branches:") + for branch in repo.list_branches(): + logging.error('\t'+branch.name) + return "fail", None - versions, current_tag, latest_tag = get_comfyui_versions(repo) + versions, current_tag, latest_tag = get_comfyui_versions(repo) - if latest_tag is None: - logging.info("[ComfyUI-Manager] Unable to update to the stable ComfyUI version.") - return "fail", None + if latest_tag is None: + logging.info("[ComfyUI-Manager] Unable to update to the stable ComfyUI version.") + return "fail", None - tag_ref = next((t for t in repo.tags if t.name == latest_tag), None) - if tag_ref is None: - logging.info(f"[ComfyUI-Manager] Unable to locate tag '{latest_tag}' in repository.") - return "fail", None + tag_ref = next((t for t in repo.list_tags() if t.name == latest_tag), None) + if tag_ref is None: + logging.info(f"[ComfyUI-Manager] Unable to locate tag '{latest_tag}' in repository.") + return "fail", None - if repo.head.commit == tag_ref.commit: - return "skip", None - else: - logging.info(f"[ComfyUI-Manager] Updating ComfyUI: {current_tag} -> {latest_tag}") - repo.git.checkout(tag_ref.name) - execute_install_script("ComfyUI", repo_path, instant_execution=False, no_deps=False) - return 'updated', latest_tag + if repo.head_commit_equals(tag_ref.commit): + return "skip", None + else: + logging.info(f"[ComfyUI-Manager] Updating ComfyUI: {current_tag} -> {latest_tag}") + repo.checkout(tag_ref.name) + execute_install_script("ComfyUI", repo_path, instant_execution=False, no_deps=False) + return 'updated', latest_tag except Exception: traceback.print_exc() return "fail", None @@ -2604,56 +2596,54 @@ def update_path(repo_path, instant_execution=False, no_deps=False): return "fail" # version check - repo = git.Repo(repo_path) - - is_switched = False - if repo.head.is_detached: - if not switch_to_default_branch(repo): - return "fail" - else: - is_switched = True - - current_branch = repo.active_branch - branch_name = current_branch.name - - if current_branch.tracking_branch() is None: - print(f"[ComfyUI-Manager] There is no tracking branch ({current_branch})") - remote_name = get_remote_name(repo) - else: - remote_name = current_branch.tracking_branch().remote_name - remote = repo.remote(name=remote_name) - - try: - remote.fetch() - except Exception as e: - if 'detected dubious' in str(e): - print(f"[ComfyUI-Manager] Try fixing 'dubious repository' error on '{repo_path}' repository") - safedir_path = repo_path.replace('\\', '/') - subprocess.run(['git', 'config', '--global', '--add', 'safe.directory', safedir_path]) - try: - remote.fetch() - except Exception: - print(f"\n[ComfyUI-Manager] Failed to fixing repository setup. Please execute this command on cmd: \n" - f"-----------------------------------------------------------------------------------------\n" - f'git config --global --add safe.directory "{safedir_path}"\n' - f"-----------------------------------------------------------------------------------------\n") + with open_repo(repo_path) as repo: + is_switched = False + if repo.head_is_detached: + if not switch_to_default_branch(repo): return "fail" + else: + is_switched = True - commit_hash = repo.head.commit.hexsha + branch_name = repo.active_branch_name - if f'{remote_name}/{branch_name}' in repo.refs: - remote_commit_hash = repo.refs[f'{remote_name}/{branch_name}'].object.hexsha - else: - return "fail" + try: + remote_name = repo.get_tracking_remote_name() + except Exception: + print(f"[ComfyUI-Manager] There is no tracking branch ({branch_name})") + remote_name = get_remote_name(repo) + remote = repo.get_remote(remote_name) - if commit_hash != remote_commit_hash: - git_pull(repo_path) - execute_install_script("ComfyUI", repo_path, instant_execution=instant_execution, no_deps=no_deps) - return "updated" - elif is_switched: - return "updated" - else: - return "skipped" + try: + remote.fetch() + except Exception as e: + if 'detected dubious' in str(e): + print(f"[ComfyUI-Manager] Try fixing 'dubious repository' error on '{repo_path}' repository") + safedir_path = repo_path.replace('\\', '/') + subprocess.run(['git', 'config', '--global', '--add', 'safe.directory', safedir_path]) + try: + remote.fetch() + except Exception: + print(f"\n[ComfyUI-Manager] Failed to fixing repository setup. Please execute this command on cmd: \n" + f"-----------------------------------------------------------------------------------------\n" + f'git config --global --add safe.directory "{safedir_path}"\n' + f"-----------------------------------------------------------------------------------------\n") + return "fail" + + commit_hash = repo.head_commit_hexsha + + if repo.has_ref(f'{remote_name}/{branch_name}'): + remote_commit_hash = repo.get_ref_commit_hexsha(f'{remote_name}/{branch_name}') + else: + return "fail" + + if commit_hash != remote_commit_hash: + git_pull(repo_path) + execute_install_script("ComfyUI", repo_path, instant_execution=instant_execution, no_deps=no_deps) + return "updated" + elif is_switched: + return "updated" + else: + return "skipped" def lookup_customnode_by_url(data, target): @@ -2752,8 +2742,8 @@ async def get_current_snapshot(custom_nodes_only = False): comfyui_commit_hash = None if not custom_nodes_only: if os.path.exists(os.path.join(repo_path, '.git')): - repo = git.Repo(repo_path) - comfyui_commit_hash = repo.head.commit.hexsha + with open_repo(repo_path) as repo: + comfyui_commit_hash = repo.head_commit_hexsha git_custom_nodes = {} cnr_custom_nodes = {} @@ -3409,101 +3399,98 @@ async def restore_snapshot(snapshot_path, git_helper_extras=None): def get_comfyui_versions(repo=None): - repo = repo or git.Repo(context.comfy_path) - - remote_name = None - try: - remote_name = get_remote_name(repo) - repo.remotes[remote_name].fetch() - except Exception: - logging.error("[ComfyUI-Manager] Failed to fetch ComfyUI") - - def parse_semver(tag_name): - match = re.match(r'^v(\d+)\.(\d+)\.(\d+)$', tag_name) - return tuple(int(x) for x in match.groups()) if match else None - - def normalize_describe(tag_name): - if not tag_name: - return None - base = tag_name.split('-', 1)[0] - return base if parse_semver(base) else None - - # Collect semver tags and sort descending (highest first) - semver_tags = [] - for tag in repo.tags: - semver = parse_semver(tag.name) - if semver: - semver_tags.append((semver, tag.name)) - semver_tags.sort(key=lambda x: x[0], reverse=True) - semver_tags = [name for _, name in semver_tags] - - latest_tag = semver_tags[0] if semver_tags else None + created_repo = repo is None + repo = repo or open_repo(context.comfy_path) try: - described = repo.git.describe('--tags') - except Exception: - described = '' - - try: - exact_tag = repo.git.describe('--tags', '--exact-match') - except Exception: - exact_tag = '' - - head_is_default = False - if remote_name: + remote_name = None try: - default_head_ref = repo.refs[f'{remote_name}/HEAD'] - default_commit = default_head_ref.reference.commit - head_is_default = repo.head.commit == default_commit + remote_name = get_remote_name(repo) + repo.get_remote(remote_name).fetch() except Exception: - # Fallback: compare directly with master branch + logging.error("[ComfyUI-Manager] Failed to fetch ComfyUI") + + def parse_semver(tag_name): + match = re.match(r'^v(\d+)\.(\d+)\.(\d+)$', tag_name) + return tuple(int(x) for x in match.groups()) if match else None + + def normalize_describe(tag_name): + if not tag_name: + return None + base = tag_name.split('-', 1)[0] + return base if parse_semver(base) else None + + # Collect semver tags and sort descending (highest first) + semver_tags = [] + for tag in repo.list_tags(): + semver = parse_semver(tag.name) + if semver: + semver_tags.append((semver, tag.name)) + semver_tags.sort(key=lambda x: x[0], reverse=True) + semver_tags = [name for _, name in semver_tags] + + latest_tag = semver_tags[0] if semver_tags else None + + described = repo.describe_tags() or '' + + exact_tag = repo.describe_tags(exact_match=True) or '' + + head_is_default = False + if remote_name: try: - if 'master' in [h.name for h in repo.heads]: - head_is_default = repo.head.commit == repo.heads.master.commit + default_head_ref = repo.get_ref_object(f'{remote_name}/HEAD') + default_commit = default_head_ref.reference.commit + head_is_default = repo.head_commit_equals(default_commit) except Exception: - head_is_default = False + # Fallback: compare directly with master branch + try: + if 'master' in [h.name for h in repo.list_heads()]: + head_is_default = repo.head_commit_equals(repo.get_head_by_name('master').commit) + except Exception: + head_is_default = False - nearest_semver = normalize_describe(described) - exact_semver = exact_tag if parse_semver(exact_tag) else None + nearest_semver = normalize_describe(described) + exact_semver = exact_tag if parse_semver(exact_tag) else None - if head_is_default and not exact_tag: - current_tag = 'nightly' - else: - current_tag = exact_tag or described or 'nightly' + if head_is_default and not exact_tag: + current_tag = 'nightly' + else: + current_tag = exact_tag or described or 'nightly' - # Prepare semver list for display: top 4 plus the current/nearest semver if missing - display_semver_tags = semver_tags[:4] - if exact_semver and exact_semver not in display_semver_tags: - display_semver_tags.append(exact_semver) - elif nearest_semver and nearest_semver not in display_semver_tags: - display_semver_tags.append(nearest_semver) + # Prepare semver list for display: top 4 plus the current/nearest semver if missing + display_semver_tags = semver_tags[:4] + if exact_semver and exact_semver not in display_semver_tags: + display_semver_tags.append(exact_semver) + elif nearest_semver and nearest_semver not in display_semver_tags: + display_semver_tags.append(nearest_semver) - versions = ['nightly'] + versions = ['nightly'] - if current_tag and not exact_semver and current_tag not in versions and current_tag not in display_semver_tags: - versions.append(current_tag) + if current_tag and not exact_semver and current_tag not in versions and current_tag not in display_semver_tags: + versions.append(current_tag) - for tag in display_semver_tags: - if tag not in versions: - versions.append(tag) + for tag in display_semver_tags: + if tag not in versions: + versions.append(tag) - versions = versions[:6] + versions = versions[:6] - return versions, current_tag, latest_tag + return versions, current_tag, latest_tag + finally: + if created_repo: + repo.close() def switch_comfyui(tag): - repo = git.Repo(context.comfy_path) - - if tag == 'nightly': - repo.git.checkout('master') - tracking_branch = repo.active_branch.tracking_branch() - remote_name = tracking_branch.remote_name - repo.remotes[remote_name].pull() - print("[ComfyUI-Manager] ComfyUI version is switched to the latest 'master' version") - else: - repo.git.checkout(tag) - print(f"[ComfyUI-Manager] ComfyUI version is switched to '{tag}'") + with open_repo(context.comfy_path) as repo: + if tag == 'nightly': + repo.checkout('master') + remote_name = repo.get_tracking_remote_name() + repo.get_remote(remote_name).pull() + print("[ComfyUI-Manager] ComfyUI version is switched to the latest 'master' version") + else: + repo.checkout(tag) + print(f"[ComfyUI-Manager] ComfyUI version is switched to '{tag}'") def resolve_giturl_from_path(fullpath): @@ -3527,11 +3514,11 @@ def resolve_giturl_from_path(fullpath): def repo_switch_commit(repo_path, commit_hash): try: - repo = git.Repo(repo_path) - if repo.head.commit.hexsha == commit_hash: - return False + with open_repo(repo_path) as repo: + if repo.head_commit_hexsha == commit_hash: + return False - repo.git.checkout(commit_hash) - return True + repo.checkout(commit_hash) + return True except Exception: return None diff --git a/comfyui_manager/glob/utils/environment_utils.py b/comfyui_manager/glob/utils/environment_utils.py index 2cacab0b..10b65abb 100644 --- a/comfyui_manager/glob/utils/environment_utils.py +++ b/comfyui_manager/glob/utils/environment_utils.py @@ -1,5 +1,5 @@ import os -import git +from comfyui_manager.common.git_compat import open_repo, setup_git_environment import logging import traceback @@ -20,17 +20,17 @@ def print_comfyui_version(): is_detached = False try: - repo = git.Repo(os.path.dirname(folder_paths.__file__)) - core.comfy_ui_revision = len(list(repo.iter_commits("HEAD"))) + with open_repo(os.path.dirname(folder_paths.__file__)) as repo: + core.comfy_ui_revision = repo.iter_commits_count() - comfy_ui_hash = repo.head.commit.hexsha - cm_global.variables["comfyui.revision"] = core.comfy_ui_revision + comfy_ui_hash = repo.head_commit_hexsha + cm_global.variables["comfyui.revision"] = core.comfy_ui_revision - core.comfy_ui_commit_datetime = repo.head.commit.committed_datetime - cm_global.variables["comfyui.commit_datetime"] = core.comfy_ui_commit_datetime + core.comfy_ui_commit_datetime = repo.head_commit_datetime + cm_global.variables["comfyui.commit_datetime"] = core.comfy_ui_commit_datetime - is_detached = repo.head.is_detached - current_branch = repo.active_branch.name + is_detached = repo.head_is_detached + current_branch = repo.active_branch_name comfyui_tag = context.get_comfyui_tag() @@ -103,7 +103,7 @@ def setup_environment(): git_exe = core.get_config()["git_exe"] if git_exe != "": - git.Git().update_environment(GIT_PYTHON_GIT_EXECUTABLE=git_exe) + setup_git_environment(git_exe) def initialize_environment(): diff --git a/comfyui_manager/legacy/manager_core.py b/comfyui_manager/legacy/manager_core.py index 5d522149..56e01f3f 100644 --- a/comfyui_manager/legacy/manager_core.py +++ b/comfyui_manager/legacy/manager_core.py @@ -42,7 +42,7 @@ from ..common.enums import NetworkMode, SecurityLevel, DBMode from ..common import context -version_code = [4, 1] +version_code = [4, 2] version_str = f"V{version_code[0]}.{version_code[1]}" + (f'.{version_code[2]}' if len(version_code) > 2 else '') diff --git a/pyproject.toml b/pyproject.toml index 271ed03f..cb5a4132 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "comfyui-manager" license = { text = "GPL-3.0-only" } -version = "4.1" +version = "4.2b1" requires-python = ">= 3.9" description = "ComfyUI-Manager provides features to install and manage custom nodes for ComfyUI, as well as various functionalities to assist with ComfyUI." readme = "README.md" @@ -39,7 +39,7 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["pre-commit", "pytest", "ruff", "pytest-cov"] +dev = ["pre-commit", "pytest", "ruff", "pytest-cov", "pygit2"] [project.urls] Repository = "https://github.com/ltdrdata/ComfyUI-Manager" diff --git a/requirements.txt b/requirements.txt index d225a63c..22be6974 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ GitPython +pygit2 PyGithub # matrix-nio transformers diff --git a/tests/test_git_compat.py b/tests/test_git_compat.py new file mode 100644 index 00000000..21156047 --- /dev/null +++ b/tests/test_git_compat.py @@ -0,0 +1,829 @@ +""" +Tests for comfyui_manager.common.git_compat + +Each test spawns a subprocess with/without CM_USE_PYGIT2=1 to fully isolate +backend selection. Both backends are tested against the same local git +repository and the results are compared for behavioral parity. + +Requirements: + - Both `pygit2` and `GitPython` installed in the test venv. + - A working `git` CLI (so GitPython backend can function). +""" + +import json +import os +import subprocess +import sys +import tempfile +import textwrap +import unittest + +# Path to the Python interpreter that has both pygit2 and GitPython +PYTHON = sys.executable + +# The git_compat module lives here +COMPAT_DIR = os.path.join(os.path.dirname(__file__), '..', 'comfyui_manager', 'common') +COMPAT_DIR = os.path.abspath(COMPAT_DIR) + + +def _run_snippet(snippet: str, repo_path: str, *, use_pygit2: bool) -> dict: + """Run a Python snippet in a subprocess and return JSON output. + + The snippet must print a single JSON line to stdout. + """ + env = os.environ.copy() + if use_pygit2: + env['CM_USE_PYGIT2'] = '1' + else: + env.pop('CM_USE_PYGIT2', None) + + full_code = textwrap.dedent(f"""\ + import sys, os, json + sys.path.insert(0, {COMPAT_DIR!r}) + os.environ.setdefault('CM_USE_PYGIT2', os.environ.get('CM_USE_PYGIT2', '')) + REPO_PATH = {repo_path!r} + from git_compat import open_repo, clone_repo, GitCommandError, setup_git_environment, USE_PYGIT2 + """) + textwrap.dedent(snippet) + + result = subprocess.run( + [PYTHON, '-c', full_code], + capture_output=True, text=True, env=env, timeout=60, + ) + if result.returncode != 0: + raise RuntimeError( + f"Subprocess failed (pygit2={use_pygit2}):\n" + f"STDOUT: {result.stdout}\n" + f"STDERR: {result.stderr}" + ) + # Find the last JSON line in stdout (skip banner lines) + for line in reversed(result.stdout.strip().split('\n')): + line = line.strip() + if line.startswith('{'): + return json.loads(line) + raise RuntimeError( + f"No JSON output found (pygit2={use_pygit2}):\n" + f"STDOUT: {result.stdout}\n" + f"STDERR: {result.stderr}" + ) + + +def _run_both(snippet: str, repo_path: str) -> tuple: + """Run snippet with both backends and return (gitpython_result, pygit2_result).""" + gp = _run_snippet(snippet, repo_path, use_pygit2=False) + p2 = _run_snippet(snippet, repo_path, use_pygit2=True) + return gp, p2 + + +class TestGitCompat(unittest.TestCase): + """Test suite comparing GitPython and pygit2 backends.""" + + @classmethod + def setUpClass(cls): + """Create a temporary git repository for testing.""" + cls._tmpdir = tempfile.mkdtemp(prefix='test_git_compat_') + cls.repo_path = os.path.join(cls._tmpdir, 'test_repo') + os.makedirs(cls.repo_path) + + # Initialize a git repo with a commit + _git = lambda *args: subprocess.run( + ['git'] + list(args), + cwd=cls.repo_path, capture_output=True, text=True, check=True, + ) + _git('init', '-b', 'master') + _git('config', 'user.email', 'test@test.com') + _git('config', 'user.name', 'Test') + + # Create initial commit + with open(os.path.join(cls.repo_path, 'file.txt'), 'w') as f: + f.write('hello') + _git('add', '.') + _git('commit', '-m', 'initial commit') + + # Create a tag + _git('tag', 'v1.0.0') + + # Create a second commit + with open(os.path.join(cls.repo_path, 'file2.txt'), 'w') as f: + f.write('world') + _git('add', '.') + _git('commit', '-m', 'second commit') + + # Create another tag + _git('tag', 'v1.1.0') + + # Create a branch + _git('branch', 'feature-branch') + + # Store the HEAD commit hash for assertions + result = subprocess.run( + ['git', 'rev-parse', 'HEAD'], + cwd=cls.repo_path, capture_output=True, text=True, check=True, + ) + cls.head_sha = result.stdout.strip() + + # Store first commit hash + result = subprocess.run( + ['git', 'rev-parse', 'HEAD~1'], + cwd=cls.repo_path, capture_output=True, text=True, check=True, + ) + cls.first_sha = result.stdout.strip() + + # Create a bare remote to test fetch/tracking + cls.remote_path = os.path.join(cls._tmpdir, 'remote_repo.git') + subprocess.run( + ['git', 'clone', '--bare', cls.repo_path, cls.remote_path], + capture_output=True, check=True, + ) + _git('remote', 'add', 'origin', cls.remote_path) + _git('push', '-u', 'origin', 'master') + + @classmethod + def tearDownClass(cls): + import shutil + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + # === Backend selection === + + def test_backend_selection_gitpython(self): + gp = _run_snippet('print(json.dumps({"backend": "pygit2" if USE_PYGIT2 else "gitpython"}))', + self.repo_path, use_pygit2=False) + self.assertEqual(gp['backend'], 'gitpython') + + def test_backend_selection_pygit2(self): + p2 = _run_snippet('print(json.dumps({"backend": "pygit2" if USE_PYGIT2 else "gitpython"}))', + self.repo_path, use_pygit2=True) + self.assertEqual(p2['backend'], 'pygit2') + + # === head_commit_hexsha === + + def test_head_commit_hexsha(self): + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"sha": repo.head_commit_hexsha})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['sha'], self.head_sha) + self.assertEqual(p2['sha'], self.head_sha) + + # === head_is_detached === + + def test_head_is_detached_false(self): + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"detached": repo.head_is_detached})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertFalse(gp['detached']) + self.assertFalse(p2['detached']) + + # === head_commit_datetime === + + def test_head_commit_datetime(self): + snippet = """ +repo = open_repo(REPO_PATH) +dt = repo.head_commit_datetime +print(json.dumps({"ts": dt.timestamp()})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertAlmostEqual(gp['ts'], p2['ts'], places=0) + + # === active_branch_name === + + def test_active_branch_name(self): + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"branch": repo.active_branch_name})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['branch'], 'master') + self.assertEqual(p2['branch'], 'master') + + # === is_dirty === + + def test_is_dirty_clean(self): + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"dirty": repo.is_dirty()})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertFalse(gp['dirty']) + self.assertFalse(p2['dirty']) + + def test_is_dirty_modified(self): + # Modify a file temporarily + filepath = os.path.join(self.repo_path, 'file.txt') + with open(filepath, 'r') as f: + original = f.read() + with open(filepath, 'w') as f: + f.write('modified') + try: + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"dirty": repo.is_dirty()})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['dirty']) + self.assertTrue(p2['dirty']) + finally: + with open(filepath, 'w') as f: + f.write(original) + + def test_is_dirty_untracked_not_dirty(self): + # Untracked files should NOT make is_dirty() return True + untracked = os.path.join(self.repo_path, 'untracked_file.txt') + with open(untracked, 'w') as f: + f.write('untracked') + try: + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"dirty": repo.is_dirty()})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertFalse(gp['dirty']) + self.assertFalse(p2['dirty']) + finally: + os.remove(untracked) + + # === working_dir === + + def test_working_dir(self): + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"wd": repo.working_dir})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(os.path.normcase(gp['wd']), os.path.normcase(self.repo_path)) + self.assertEqual(os.path.normcase(p2['wd']), os.path.normcase(self.repo_path)) + + # === list_remotes === + + def test_list_remotes(self): + snippet = """ +repo = open_repo(REPO_PATH) +remotes = repo.list_remotes() +print(json.dumps({"names": [r.name for r in remotes]})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertIn('origin', gp['names']) + self.assertIn('origin', p2['names']) + + # === get_remote === + + def test_get_remote(self): + snippet = """ +repo = open_repo(REPO_PATH) +r = repo.get_remote('origin') +print(json.dumps({"name": r.name, "has_url": bool(r.url)})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['name'], 'origin') + self.assertTrue(gp['has_url']) + self.assertEqual(p2['name'], 'origin') + self.assertTrue(p2['has_url']) + + # === get_tracking_remote_name === + + def test_get_tracking_remote_name(self): + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"remote": repo.get_tracking_remote_name()})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['remote'], 'origin') + self.assertEqual(p2['remote'], 'origin') + + # === has_ref === + + def test_has_ref_true(self): + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"has": repo.has_ref('origin/master')})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['has']) + self.assertTrue(p2['has']) + + def test_has_ref_false(self): + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"has": repo.has_ref('origin/nonexistent')})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertFalse(gp['has']) + self.assertFalse(p2['has']) + + # === get_ref_commit_hexsha === + + def test_get_ref_commit_hexsha(self): + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"sha": repo.get_ref_commit_hexsha('origin/master')})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['sha'], self.head_sha) + self.assertEqual(p2['sha'], self.head_sha) + + # === get_ref_commit_datetime === + + def test_get_ref_commit_datetime(self): + snippet = """ +repo = open_repo(REPO_PATH) +dt = repo.get_ref_commit_datetime('origin/master') +print(json.dumps({"ts": dt.timestamp()})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertAlmostEqual(gp['ts'], p2['ts'], places=0) + + # === iter_commits_count === + + def test_iter_commits_count(self): + snippet = """ +repo = open_repo(REPO_PATH) +print(json.dumps({"count": repo.iter_commits_count()})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['count'], 2) + self.assertEqual(p2['count'], 2) + + # === symbolic_ref === + + def test_symbolic_ref(self): + snippet = """ +repo = open_repo(REPO_PATH) +try: + ref = repo.symbolic_ref('refs/remotes/origin/HEAD') + print(json.dumps({"ref": ref})) +except Exception as e: + print(json.dumps({"error": str(e)})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + # Both should return refs/remotes/origin/master or error consistently + if 'ref' in gp: + self.assertIn('master', gp['ref']) + if 'ref' in p2: + self.assertIn('master', p2['ref']) + + # === describe_tags === + + def test_describe_tags(self): + snippet = """ +repo = open_repo(REPO_PATH) +desc = repo.describe_tags() +print(json.dumps({"desc": desc})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + # HEAD is at v1.1.0, so describe should return v1.1.0 + self.assertIsNotNone(gp['desc']) + self.assertIsNotNone(p2['desc']) + self.assertIn('v1.1.0', gp['desc']) + self.assertIn('v1.1.0', p2['desc']) + + def test_describe_tags_exact_match(self): + snippet = """ +repo = open_repo(REPO_PATH) +desc = repo.describe_tags(exact_match=True) +print(json.dumps({"desc": desc})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['desc'], 'v1.1.0') + self.assertEqual(p2['desc'], 'v1.1.0') + + # === list_tags === + + def test_list_tags(self): + snippet = """ +repo = open_repo(REPO_PATH) +tags = [t.name for t in repo.list_tags()] +print(json.dumps({"tags": sorted(tags)})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['tags'], ['v1.0.0', 'v1.1.0']) + self.assertEqual(p2['tags'], ['v1.0.0', 'v1.1.0']) + + # === list_heads === + + def test_list_heads(self): + snippet = """ +repo = open_repo(REPO_PATH) +heads = sorted([h.name for h in repo.list_heads()]) +print(json.dumps({"heads": heads})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertIn('master', gp['heads']) + self.assertIn('feature-branch', gp['heads']) + self.assertIn('master', p2['heads']) + self.assertIn('feature-branch', p2['heads']) + + # === list_branches === + + def test_list_branches(self): + snippet = """ +repo = open_repo(REPO_PATH) +branches = sorted([b.name for b in repo.list_branches()]) +print(json.dumps({"branches": branches})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['branches'], p2['branches']) + + # === get_head_by_name === + + def test_get_head_by_name(self): + snippet = """ +repo = open_repo(REPO_PATH) +h = repo.get_head_by_name('master') +print(json.dumps({"name": h.name, "has_commit": h.commit is not None})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['name'], 'master') + self.assertTrue(gp['has_commit']) + self.assertEqual(p2['name'], 'master') + self.assertTrue(p2['has_commit']) + + def test_get_head_by_name_not_found(self): + snippet = """ +repo = open_repo(REPO_PATH) +try: + h = repo.get_head_by_name('nonexistent') + print(json.dumps({"error": False})) +except (AttributeError, Exception): + print(json.dumps({"error": True})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['error']) + self.assertTrue(p2['error']) + + # === head_commit_equals === + + def test_head_commit_equals_same(self): + snippet = """ +repo = open_repo(REPO_PATH) +h = repo.get_head_by_name('master') +print(json.dumps({"eq": repo.head_commit_equals(h.commit)})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['eq']) + self.assertTrue(p2['eq']) + + def test_head_commit_equals_different(self): + snippet = """ +repo = open_repo(REPO_PATH) +h = repo.get_head_by_name('feature-branch') +# feature-branch points to same commit as master in setup, so this should be True +print(json.dumps({"eq": repo.head_commit_equals(h.commit)})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['eq'], p2['eq']) + + # === context manager === + + def test_context_manager(self): + snippet = """ +with open_repo(REPO_PATH) as repo: + sha = repo.head_commit_hexsha +print(json.dumps({"sha": sha})) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['sha'], self.head_sha) + self.assertEqual(p2['sha'], self.head_sha) + + # === get_remote_url === + + def test_get_remote_url_by_name(self): + snippet = """ +repo = open_repo(REPO_PATH) +url = repo.get_remote_url('origin') +print(json.dumps({"has_url": bool(url)})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['has_url']) + self.assertTrue(p2['has_url']) + + def test_get_remote_url_by_index(self): + snippet = """ +repo = open_repo(REPO_PATH) +url = repo.get_remote_url(0) +print(json.dumps({"has_url": bool(url)})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['has_url']) + self.assertTrue(p2['has_url']) + + # === clone_repo === + + def test_clone_repo(self): + snippet = """ +import tempfile, shutil +dest = tempfile.mkdtemp() +try: + repo = clone_repo(REPO_PATH, os.path.join(dest, 'cloned')) + sha = repo.head_commit_hexsha + repo.close() + print(json.dumps({"sha": sha})) +finally: + shutil.rmtree(dest, ignore_errors=True) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['sha'], self.head_sha) + self.assertEqual(p2['sha'], self.head_sha) + + # === checkout === + + def test_checkout_tag(self): + # Test in a clone to avoid messing up the shared repo + head = self.head_sha + snippet = f""" +import tempfile, shutil +dest = tempfile.mkdtemp() +try: + repo = clone_repo(REPO_PATH, os.path.join(dest, 'cloned')) + repo.checkout('v1.0.0') + sha = repo.head_commit_hexsha + detached = repo.head_is_detached + repo.close() + print(json.dumps({{"detached": detached, "not_head": sha != {head!r}}})) +finally: + shutil.rmtree(dest, ignore_errors=True) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['detached']) + self.assertTrue(gp['not_head']) + self.assertTrue(p2['detached']) + self.assertTrue(p2['not_head']) + + # === checkout_new_branch === + + def test_checkout_new_branch(self): + snippet = """ +import tempfile, shutil +dest = tempfile.mkdtemp() +try: + repo = clone_repo(REPO_PATH, os.path.join(dest, 'cloned')) + repo.checkout_new_branch('test-branch', 'origin/master') + name = repo.active_branch_name + repo.close() + print(json.dumps({"branch": name})) +finally: + shutil.rmtree(dest, ignore_errors=True) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['branch'], 'test-branch') + self.assertEqual(p2['branch'], 'test-branch') + + # === create_backup_branch === + + def test_create_backup_branch(self): + snippet = """ +import tempfile, shutil +dest = tempfile.mkdtemp() +try: + repo = clone_repo(REPO_PATH, os.path.join(dest, 'cloned')) + repo.create_backup_branch('backup_test') + heads = [h.name for h in repo.list_heads()] + repo.close() + print(json.dumps({"has_backup": 'backup_test' in heads})) +finally: + shutil.rmtree(dest, ignore_errors=True) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['has_backup']) + self.assertTrue(p2['has_backup']) + + # === stash === + + def test_stash(self): + snippet = """ +import tempfile, shutil +dest = tempfile.mkdtemp() +try: + repo = clone_repo(REPO_PATH, os.path.join(dest, 'cloned')) + # Make dirty + with open(os.path.join(dest, 'cloned', 'file.txt'), 'w') as f: + f.write('dirty') + dirty_before = repo.is_dirty() + repo.stash() + dirty_after = repo.is_dirty() + repo.close() + print(json.dumps({"dirty_before": dirty_before, "dirty_after": dirty_after})) +finally: + shutil.rmtree(dest, ignore_errors=True) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['dirty_before']) + self.assertFalse(gp['dirty_after']) + self.assertTrue(p2['dirty_before']) + self.assertFalse(p2['dirty_after']) + + # === reset_hard === + + def test_reset_hard(self): + first = self.first_sha + snippet = f""" +import tempfile, shutil +dest = tempfile.mkdtemp() +try: + repo = clone_repo(REPO_PATH, os.path.join(dest, 'cloned')) + repo.reset_hard({first!r}) + sha = repo.head_commit_hexsha + repo.close() + print(json.dumps({{"sha": sha}})) +finally: + shutil.rmtree(dest, ignore_errors=True) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['sha'], self.first_sha) + self.assertEqual(p2['sha'], self.first_sha) + + # === clear_cache === + + def test_clear_cache(self): + snippet = """ +repo = open_repo(REPO_PATH) +repo.clear_cache() +print(json.dumps({"ok": True})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['ok']) + self.assertTrue(p2['ok']) + + # === close === + + def test_close(self): + snippet = """ +repo = open_repo(REPO_PATH) +repo.close() +print(json.dumps({"ok": True})) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['ok']) + self.assertTrue(p2['ok']) + + # === fetch_remote_by_index === + + def test_fetch_remote_by_index(self): + snippet = """ +repo = open_repo(REPO_PATH) +repo.fetch_remote_by_index(0) +print(json.dumps({"ok": True})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['ok']) + self.assertTrue(p2['ok']) + + # === get_ref_object === + + def test_get_ref_object(self): + snippet = """ +repo = open_repo(REPO_PATH) +ref = repo.get_ref_object('origin/master') +print(json.dumps({"sha": ref.object.hexsha, "has_dt": ref.object.committed_datetime is not None})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['sha'], self.head_sha) + self.assertTrue(gp['has_dt']) + self.assertEqual(p2['sha'], self.head_sha) + self.assertTrue(p2['has_dt']) + + # === tag.commit === + + def test_tag_commit(self): + snippet = """ +repo = open_repo(REPO_PATH) +tags = {t.name: t.commit.hexsha for t in repo.list_tags() if t.commit is not None} +print(json.dumps({"tags": tags})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertIn('v1.0.0', gp['tags']) + self.assertIn('v1.1.0', gp['tags']) + self.assertEqual(gp['tags']['v1.1.0'], self.head_sha) + self.assertEqual(p2['tags']['v1.1.0'], self.head_sha) + self.assertEqual(gp['tags']['v1.0.0'], p2['tags']['v1.0.0']) + + # === setup_git_environment === + + def test_setup_git_environment(self): + snippet = """ +# Just verify it doesn't crash +setup_git_environment('') +setup_git_environment(None) +print(json.dumps({"ok": True})) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['ok']) + self.assertTrue(p2['ok']) + + # === GitCommandError === + + def test_git_command_error(self): + snippet = """ +try: + raise GitCommandError("test error") +except GitCommandError as e: + print(json.dumps({"has_msg": "test error" in str(e)})) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['has_msg']) + self.assertTrue(p2['has_msg']) + + # === pull_ff_only === + + def test_pull_ff_only(self): + snippet = """ +import tempfile, shutil, subprocess +dest = tempfile.mkdtemp() +try: + # Create a bare remote from REPO_PATH so we can push to it + bare = os.path.join(dest, 'bare.git') + subprocess.run(['git', 'clone', '--bare', REPO_PATH, bare], capture_output=True, check=True) + # Clone from the bare remote + repo = clone_repo(bare, os.path.join(dest, 'cloned')) + # Push a new commit to the bare remote via a second clone + work = os.path.join(dest, 'work') + subprocess.run(['git', 'clone', bare, work], capture_output=True, check=True) + with open(os.path.join(work, 'new.txt'), 'w') as f: + f.write('new') + subprocess.run(['git', '-C', work, 'add', '.'], capture_output=True, check=True) + subprocess.run(['git', '-C', work, 'commit', '-m', 'new'], capture_output=True, check=True) + subprocess.run(['git', '-C', work, 'push'], capture_output=True, check=True) + old_sha = repo.head_commit_hexsha + repo.pull_ff_only() + new_sha = repo.head_commit_hexsha + repo.close() + print(json.dumps({"advanced": old_sha != new_sha})) +finally: + shutil.rmtree(dest, ignore_errors=True) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['advanced']) + self.assertTrue(p2['advanced']) + + # === submodule_update === + + def test_submodule_update(self): + snippet = """ +repo = open_repo(REPO_PATH) +repo.submodule_update() +print(json.dumps({"ok": True})) +repo.close() +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertTrue(gp['ok']) + self.assertTrue(p2['ok']) + + # === checkout by SHA === + + def test_checkout_by_sha(self): + first = self.first_sha + snippet = f""" +import tempfile, shutil +dest = tempfile.mkdtemp() +try: + repo = clone_repo(REPO_PATH, os.path.join(dest, 'cloned')) + repo.checkout({first!r}) + sha = repo.head_commit_hexsha + detached = repo.head_is_detached + repo.close() + print(json.dumps({{"sha": sha, "detached": detached}})) +finally: + shutil.rmtree(dest, ignore_errors=True) +""" + gp, p2 = _run_both(snippet, self.repo_path) + self.assertEqual(gp['sha'], self.first_sha) + self.assertTrue(gp['detached']) + self.assertEqual(p2['sha'], self.first_sha) + self.assertTrue(p2['detached']) + + +if __name__ == '__main__': + unittest.main()