From bc201cea4dad149370db0e4e30922d166e08c3c1 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 18 Sep 2025 13:25:08 -0700 Subject: [PATCH] Improve tests, fix issues with alternate filenames, improve group offloading support for transformers models --- README.md | 2 +- comfy/cli_args.py | 2 +- .../language/transformers_model_management.py | 35 ++-- comfy/lora_types.py | 3 +- comfy/model_downloader.py | 29 ++-- comfy/model_management_types.py | 155 +++++++++++++----- comfy/model_patcher.py | 2 +- comfy/utils.py | 27 ++- comfy_extras/nodes/nodes_custom_sampler.py | 2 +- comfy_extras/nodes/nodes_group_offloading.py | 20 ++- comfy_extras/nodes/nodes_lt.py | 3 +- comfy_extras/nodes/nodes_upscale_model.py | 4 +- .../downloader/test_huggingface_downloads.py | 2 +- .../workflows/{phi-3-0.json => phi-4-0.json} | 4 +- 14 files changed, 203 insertions(+), 87 deletions(-) rename tests/inference/workflows/{phi-3-0.json => phi-4-0.json} (95%) diff --git a/README.md b/README.md index e97b090b8..1ba9b2a42 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,7 @@ ComfyUI LTS supports text and multi-modal LLM models from the `transformers` eco In this example, LLAVA-NEXT (LLAVA 1.6) is prompted to describe an image. -You can try the [LLAVA-NEXT](tests/inference/workflows/llava-0.json), [Phi-3](tests/inference/workflows/phi-3-0.json), and two [translation](tests/inference/workflows/translation-0.json) [workflows](tests/inference/workflows/translation-1.json). +You can try the [LLAVA-NEXT](tests/inference/workflows/llava-0.json), [Phi-3](tests/inference/workflows/phi-4-0.json), and two [translation](tests/inference/workflows/translation-0.json) [workflows](tests/inference/workflows/translation-1.json). # SVG Conversion and String Saving diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 9faf7b461..f77656a03 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -130,7 +130,7 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--disable-smart-memory", action="store_true", - help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") + help="Force ComfyUI to aggressively offload to regular ram instead of keeping models in VRAM when it can.") parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index cd78dec2b..d9c9ac303 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -12,9 +12,12 @@ from typing import Optional, Any, Callable import torch import transformers +from huggingface_hub.errors import EntryNotFoundError from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \ BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \ PretrainedConfig, TextStreamer, LogitsProcessor +from huggingface_hub import hf_api +from huggingface_hub.file_download import hf_hub_download from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES @@ -25,7 +28,7 @@ from .. import model_management from ..component_model.tensor_types import RGBImageBatch from ..model_downloader import get_or_download_huggingface_repo from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu -from ..model_management_types import ModelManageable +from ..model_management_types import ModelManageableStub from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block logger = logging.getLogger(__name__) @@ -37,7 +40,7 @@ _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING _DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'} -class TransformersManagedModel(ModelManageable, LanguageModel): +class TransformersManagedModel(ModelManageableStub, LanguageModel): def __init__( self, repo_id: str, @@ -69,7 +72,20 @@ class TransformersManagedModel(ModelManageable, LanguageModel): hub_kwargs["subfolder"] = subfolder repo_id = ckpt_name with comfy_tqdm(): - ckpt_name = get_or_download_huggingface_repo(ckpt_name) + ckpt_name = get_or_download_huggingface_repo(repo_id) + + if config_dict is None: + config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs) + elif isinstance(config_dict, PretrainedConfig): + config_dict: dict = config_dict.to_dict() + else: + config_dict = {} + + try: + model_type = config_dict["model_type"] + except KeyError: + logger.debug(f"Configuration was missing for repo_id={repo_id}") + model_type = "" from_pretrained_kwargs = { "pretrained_model_name_or_path": ckpt_name, @@ -77,19 +93,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel): **hub_kwargs } - # compute bitsandbytes configuration - try: - import bitsandbytes - except ImportError: - pass - - if config_dict is None: - config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs) - elif isinstance(config_dict, PretrainedConfig): - config_dict: dict = config_dict.to_dict() - model_type = config_dict["model_type"] # language models prefer to use bfloat16 over float16 - kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)), + kwargs_to_try = ({"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)), "low_cpu_mem_usage": True, "device_map": str(unet_offload_device()), }, {}) diff --git a/comfy/lora_types.py b/comfy/lora_types.py index c08006467..ebe4db3d8 100644 --- a/comfy/lora_types.py +++ b/comfy/lora_types.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any, TypeAlias, Union +from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any, TypeAlias, Union, runtime_checkable import torch PatchOffset = tuple[int, int, int] @@ -31,6 +31,7 @@ class PatchTuple(NamedTuple): ModelPatchesDictValue: TypeAlias = list[Union[PatchTuple, PatchWeightTuple]] +@runtime_checkable class PatchSupport(Protocol): """ Defines the interface for a model that supports LoRA patching. diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 430fcda76..2b5538883 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -45,7 +45,7 @@ def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[ known_files = _get_known_models_for_folder_name(folder_name) existing = folder_paths.get_filename_list(folder_name) - + downloadable_files = [] if not args.disable_known_models: downloadable_files = known_files @@ -725,15 +725,15 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]: return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids) -def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]: +def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None, force: bool = False, subset: bool = False) -> Optional[str]: with comfy_tqdm(): - return _get_or_download_huggingface_repo(repo_id, cache_dirs, local_dirs) + return _get_or_download_huggingface_repo(repo_id, cache_dirs, local_dirs, force=force, subset=subset) -def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]: +def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None, force: bool = False, subset: bool = False) -> Optional[str]: cache_dirs = cache_dirs or folder_paths.get_folder_paths("huggingface_cache") local_dirs = local_dirs or folder_paths.get_folder_paths("huggingface") - cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id) + cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id, subset=subset) local_dirs_cache_hit = len(local_dirs_snapshots) > 0 cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0 @@ -742,25 +742,25 @@ def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = # if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download if args.force_hf_local_dir_mode: # todo: we still have to figure out a way to download things to the right places by default - if len(local_dirs_snapshots) > 0: + if len(local_dirs_snapshots) > 0 and not force: return local_dirs_snapshots[0] elif not args.disable_known_models: destination = os.path.join(local_dirs[0], repo_id) logger.debug(f"downloading repo_id={repo_id}, local_dir={destination}") - return snapshot_download(repo_id, local_dir=destination) + return snapshot_download(repo_id, local_dir=destination, force_download=force) snapshots = local_dirs_snapshots + cache_dirs_snapshots - if len(snapshots) > 0: + if len(snapshots) > 0 and not force: return snapshots[0] elif not args.disable_known_models: logger.debug(f"downloading repo_id={repo_id}") - return snapshot_download(repo_id) + return snapshot_download(repo_id, force_download=force) # this repo was not found return None -def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id): +def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id, subset=False): local_dirs_snapshots = [] cache_dirs_snapshots = [] # find all the pre-existing downloads for this repo_id @@ -772,13 +772,12 @@ def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_i if len(repo_files) > 0: for local_dir in local_dirs: local_path = Path(local_dir) / repo_id - local_files = set(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file()) + local_files = frozenset(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file()) # fix path representation - local_files = set(f.replace("\\", "/") for f in local_files) + local_files = frozenset(f.replace("\\", "/") for f in local_files) # remove .huggingface - local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache")) - # local_files.issubsetof(repo_files) - if len(local_files) > 0 and local_files.issubset(repo_files): + local_files = frozenset(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache")) + if len(local_files) > 0 and ((subset and local_files.issubset(repo_files)) or (not subset and repo_files.issubset(local_files))): local_dirs_snapshots.append(str(local_path)) else: # an empty repository or unknown repository info, trust that if the directory exists, it matches diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index 259455bb3..3f44c647b 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -1,8 +1,9 @@ from __future__ import annotations +import copy import dataclasses from abc import ABCMeta, abstractmethod -from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override +from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override, TYPE_CHECKING import torch import torch.nn @@ -11,22 +12,48 @@ from typing_extensions import TypedDict, NotRequired from .comfy_types import UnetWrapperFunction from .latent_formats import LatentFormat +if TYPE_CHECKING: + from .hooks import EnumHookMode + ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable') LatentFormatT = TypeVar('LatentFormatT', bound=LatentFormat) @runtime_checkable class DeviceSettable(Protocol): - @property - def device(self) -> torch.device: - ... - - @device.setter - def device(self, value: torch.device): - ... + device: torch.device -class HooksSupport(Protocol, metaclass=ABCMeta): +@runtime_checkable +class HooksSupport(Protocol): + wrappers: dict[str, dict[str, list[Callable]]] + callbacks: dict[str, dict[str, list[Callable]]] + hook_mode: "EnumHookMode" + + def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): ... + + def model_patches_models(self) -> list[ModelManageableT]: ... + + def restore_hook_patches(self): ... + + def cleanup(self): ... + + def pre_run(self): ... + + def prepare_state(self, *args, **kwargs): ... + + def register_all_hook_patches(self, a, b, c, d): ... + + def get_nested_additional_models(self): ... + + def apply_hooks(self, *args, **kwargs): ... + + def add_wrapper(self, wrapper_type: str, wrapper: Callable): ... + + def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): ... + + +class HooksSupportStub(HooksSupport, metaclass=ABCMeta): def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): return @@ -82,6 +109,8 @@ class HooksSupport(Protocol, metaclass=ABCMeta): if isinstance(model, BaseModel) or hasattr(model, "current_patcher") and isinstance(self, ModelManageable): model.current_patcher = self + + def prepare_state(self, *args, **kwargs): pass @@ -94,8 +123,22 @@ class HooksSupport(Protocol, metaclass=ABCMeta): def apply_hooks(self, *args, **kwargs): return {} + def add_wrapper(self, wrapper_type: str, wrapper: Callable): + self.add_wrapper_with_key(wrapper_type, None, wrapper) -class TrainingSupport(Protocol, metaclass=ABCMeta): + def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): + w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, []) + w.append(wrapper) + + +@runtime_checkable +class TrainingSupport(Protocol): + def set_model_compute_dtype(self, dtype: torch.dtype): ... + + def add_weight_wrapper(self, name, function): ... + + +class TrainingSupportStub(TrainingSupport, metaclass=ABCMeta): def set_model_compute_dtype(self, dtype: torch.dtype): return @@ -103,13 +146,68 @@ class TrainingSupport(Protocol, metaclass=ABCMeta): return -class ModelManageableExtras(Protocol, metaclass=ABCMeta): +@runtime_checkable +class ModelManageable(HooksSupport, TrainingSupport, Protocol): + """ + Objects which implement this protocol can be managed by + + >>> from comfy.model_management import load_models_gpu + >>> class ModelWrapper(ModelManageable): + >>> ... + >>> + >>> some_model = ModelWrapper() + >>> load_models_gpu([some_model]) + + The minimum required + """ + load_device: torch.device + offload_device: torch.device + model: torch.nn.Module + @property - def current_device(self) -> torch.device: - return torch.device("cpu") + def current_device(self) -> torch.device: ... + + def is_clone(self, other: ModelManageableT) -> bool: ... + + def clone_has_same_weights(self, clone: ModelManageableT) -> bool: ... + + def model_size(self) -> int: ... + + def model_patches_to(self, arg: torch.device | torch.dtype): ... + + def model_dtype(self) -> torch.dtype: ... + + def lowvram_patch_counter(self) -> int: ... + + def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False) -> int: ... + + def partially_unload(self, device_to: torch.device, memory_to_free: int = 0) -> int: ... + + def memory_required(self, input_shape: torch.Size) -> int: ... + + def loaded_size(self) -> int: ... + + def current_loaded_device(self) -> torch.device: ... + + def get_model_object(self, name: str) -> torch.nn.Module: ... + + @property + def model_options(self) -> ModelOptions: ... + + @model_options.setter + def model_options(self, value): ... + + def __del__(self): ... + + @property + def parent(self) -> ModelManageableT | None: ... + + def detach(self, unpatch_all: bool = True): ... + + def clone(self) -> ModelManageableT: ... -class ModelManageableRequired(Protocol, metaclass=ABCMeta): +class ModelManageableStub(HooksSupportStub, TrainingSupportStub, ModelManageable, metaclass=ABCMeta): """ The bare minimum that must be implemented to support model management when inheriting from ModelManageable @@ -120,12 +218,11 @@ class ModelManageableRequired(Protocol, metaclass=ABCMeta): :see: ModelManageable :see: PatchSupport """ - load_device: torch.device - offload_device: torch.device - model: torch.nn.Module + @abstractmethod - def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module: + def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, + force_patch_weights: bool = False) -> torch.nn.Module: """ Called by ModelManageable @@ -155,25 +252,6 @@ class ModelManageableRequired(Protocol, metaclass=ABCMeta): """ ... - -@runtime_checkable -class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, metaclass=ABCMeta): - """ - Objects which implement this protocol can be managed by - - >>> from comfy.model_management import load_models_gpu - >>> class ModelWrapper(ModelManageable): - >>> ... - >>> - >>> some_model = ModelWrapper() - >>> load_models_gpu([some_model]) - - The minimum required - """ - load_device: torch.device - offload_device: torch.device - model: torch.nn.Module - @property @override def current_device(self) -> torch.device: @@ -265,6 +343,9 @@ class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) return self.model + def clone(self) -> ModelManageableT: + return copy.copy(self) + @dataclasses.dataclass class MemoryMeasurements: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 650e1b757..9b5f8bed2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -230,7 +230,7 @@ class GGUFQuantization: patch_on_device: bool = False -class ModelPatcher(ModelManageable, TrainingSupport, HooksSupport, PatchSupport): +class ModelPatcher(ModelManageable, PatchSupport): def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None): self.size = size self.model: BaseModel | torch.nn.Module = model diff --git a/comfy/utils.py b/comfy/utils.py index 7e5d72104..8649a996f 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -27,6 +27,7 @@ import os import random import struct import sys +import threading import warnings from contextlib import contextmanager from pathlib import Path @@ -57,6 +58,9 @@ DISABLE_MMAP = args.disable_mmap logger = logging.getLogger(__name__) ALWAYS_SAFE_LOAD = False + +_lock = threading.RLock() + if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated class ModelCheckpoint: pass @@ -1178,7 +1182,9 @@ class ProgressBar: def update_absolute(self, value, total=None, preview_image_or_output=None): if total is not None: self.total = total - if value > self.total: + if value is None: + return + if self.total is not None and value > self.total: value = self.total self.current = value _progress_bar_update(self.current, self.total, preview_image_or_output, server=self.server, node_id=self.node_id) @@ -1198,31 +1204,39 @@ def comfy_tqdm() -> Generator[TqdmWatcher, None, None]: Monkey patches child calls to tqdm, sends progress to the UI, and yields a watcher object for stall detection. """ + with _lock: + if hasattr(tqdm, "__patched_by_comfyui__"): + yield getattr(tqdm, "__patched_by_comfyui__") + return + + watcher = TqdmWatcher() + setattr(tqdm, "__patched_by_comfyui__", watcher) + _original_init = tqdm.__init__ _original_call = tqdm.__call__ _original_update = tqdm.update # Create the watcher instance that the patched methods will update # and that will be yielded to the caller. - watcher = TqdmWatcher() context = contextvars.copy_context() try: + # These inner functions are closures; they capture the `watcher` variable # from the enclosing scope. def __init(self, *args, **kwargs): - context.run(lambda: _original_init(self, *args, **kwargs)) + _original_init(self, *args, **kwargs) self._progress_bar = context.run(lambda: ProgressBar(self.total)) watcher.tick() # Signal progress on initialization def __update(self, n=1): assert self._progress_bar is not None - context.run(lambda: _original_update(self, n)) + _original_update(self, n) context.run(lambda: self._progress_bar.update(n)) watcher.tick() # Signal progress on update def __call(self, *args, **kwargs): - instance = context.run(lambda: _original_call(self, *args, **kwargs)) + instance = _original_call(self, *args, **kwargs) return instance tqdm.__init__ = __init @@ -1236,10 +1250,11 @@ def comfy_tqdm() -> Generator[TqdmWatcher, None, None]: tqdm.__init__ = _original_init tqdm.__call__ = _original_call tqdm.update = _original_update + delattr(tqdm, "__patched_by_comfyui__") @contextmanager -def comfy_progress(total: float) -> ProgressBar: +def comfy_progress(total: float) -> Generator[ProgressBar, Any, None]: ctx = current_execution_context() if ctx.server.receive_all_progress_notifications: yield ProgressBar(total) diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index c0794de83..f025a95a7 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -765,7 +765,7 @@ class DualCFGGuider: FUNCTION = "get_guider" CATEGORY = "sampling/custom_sampling/guiders" - def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style): + def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style="regular"): guider = Guider_DualCFG(model) guider.set_conds(cond1, cond2, negative) guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested")) diff --git a/comfy_extras/nodes/nodes_group_offloading.py b/comfy_extras/nodes/nodes_group_offloading.py index f60101268..2f3204120 100644 --- a/comfy_extras/nodes/nodes_group_offloading.py +++ b/comfy_extras/nodes/nodes_group_offloading.py @@ -2,7 +2,9 @@ import torch from diffusers import HookRegistry from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook +from comfy.language.transformers_model_management import TransformersManagedModel from comfy.model_management import vram_state, VRAMState +from comfy.model_management_types import HooksSupport, ModelManageable from comfy.model_patcher import ModelPatcher from comfy.node_helpers import export_custom_nodes from comfy.nodes.package_typing import CustomNode @@ -117,9 +119,21 @@ class GroupOffload(CustomNode): RETURN_TYPES = ("MODEL",) FUNCTION = "execute" - def execute(self, model: ModelPatcher) -> tuple[ModelPatcher,]: - model = model.clone() - model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device)) + def execute(self, model: ModelManageable | HooksSupport | TransformersManagedModel) -> tuple[ModelPatcher,]: + if isinstance(model, ModelManageable): + model = model.clone() + if isinstance(model, TransformersManagedModel): + apply_group_offloading( + model.model, + model.load_device, + model.offload_device, + use_stream=True, + record_stream=True, + low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,), + num_blocks_per_group=1 + ) + elif isinstance(model, HooksSupport) and isinstance(model, ModelManageable): + model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device)) return model, diff --git a/comfy_extras/nodes/nodes_lt.py b/comfy_extras/nodes/nodes_lt.py index 981d24317..fd013bc07 100644 --- a/comfy_extras/nodes/nodes_lt.py +++ b/comfy_extras/nodes/nodes_lt.py @@ -42,6 +42,7 @@ class LTXVImgToVideo: "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, "optional": { "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), }} @@ -51,7 +52,7 @@ class LTXVImgToVideo: CATEGORY = "conditioning/video_models" FUNCTION = "generate" - def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength): + def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength=1.0): pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) diff --git a/comfy_extras/nodes/nodes_upscale_model.py b/comfy_extras/nodes/nodes_upscale_model.py index 2f3b1accf..a1a31e202 100644 --- a/comfy_extras/nodes/nodes_upscale_model.py +++ b/comfy_extras/nodes/nodes_upscale_model.py @@ -9,7 +9,7 @@ from comfy import utils from comfy.component_model.tensor_types import RGBImageBatch from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download from comfy.model_management import load_models_gpu -from comfy.model_management_types import ModelManageable +from comfy.model_management_types import ModelManageableStub logger = logging.getLogger(__name__) try: @@ -22,7 +22,7 @@ except: pass -class UpscaleModelManageable(ModelManageable): +class UpscaleModelManageable(ModelManageableStub): def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str): self.ckpt_name = ckpt_name self.model_descriptor = model_descriptor diff --git a/tests/downloader/test_huggingface_downloads.py b/tests/downloader/test_huggingface_downloads.py index ff1bc8666..dcf51f63e 100644 --- a/tests/downloader/test_huggingface_downloads.py +++ b/tests/downloader/test_huggingface_downloads.py @@ -182,7 +182,7 @@ async def test_huggingface_alternate_filenames_in_combo(): ) # 3. Get the list of files as the UI would - filename_list = get_filename_list_with_downloadable("checkpoints", known_files=[known_file]) + filename_list = get_filename_list_with_downloadable("__xxx___", known_files=[known_file]) # 4. Assert that both the main and alternate filenames are present assert main_filename in filename_list diff --git a/tests/inference/workflows/phi-3-0.json b/tests/inference/workflows/phi-4-0.json similarity index 95% rename from tests/inference/workflows/phi-3-0.json rename to tests/inference/workflows/phi-4-0.json index 7779dc859..d0f2eda50 100644 --- a/tests/inference/workflows/phi-3-0.json +++ b/tests/inference/workflows/phi-4-0.json @@ -1,7 +1,7 @@ { "1": { "inputs": { - "ckpt_name": "microsoft/Phi-3-mini-4k-instruct", + "ckpt_name": "microsoft/Phi-4-mini-instruct", "subfolder": "" }, "class_type": "TransformersLoader", @@ -33,7 +33,7 @@ "4": { "inputs": { "prompt": "What comes after apple?", - "chat_template": "phi-3", + "chat_template": "default", "model": [ "1", 0