Improve tests, fix issues with alternate filenames, improve group offloading support for transformers models

This commit is contained in:
doctorpangloss 2025-09-18 13:25:08 -07:00
parent 79b8723f61
commit bc201cea4d
14 changed files with 203 additions and 87 deletions

View File

@ -242,7 +242,7 @@ ComfyUI LTS supports text and multi-modal LLM models from the `transformers` eco
In this example, LLAVA-NEXT (LLAVA 1.6) is prompted to describe an image. In this example, LLAVA-NEXT (LLAVA 1.6) is prompted to describe an image.
You can try the [LLAVA-NEXT](tests/inference/workflows/llava-0.json), [Phi-3](tests/inference/workflows/phi-3-0.json), and two [translation](tests/inference/workflows/translation-0.json) [workflows](tests/inference/workflows/translation-1.json). You can try the [LLAVA-NEXT](tests/inference/workflows/llava-0.json), [Phi-3](tests/inference/workflows/phi-4-0.json), and two [translation](tests/inference/workflows/translation-0.json) [workflows](tests/inference/workflows/translation-1.json).
# SVG Conversion and String Saving # SVG Conversion and String Saving

View File

@ -130,7 +130,7 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", parser.add_argument("--disable-smart-memory", action="store_true",
help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") help="Force ComfyUI to aggressively offload to regular ram instead of keeping models in VRAM when it can.")
parser.add_argument("--deterministic", action="store_true", parser.add_argument("--deterministic", action="store_true",
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")

View File

@ -12,9 +12,12 @@ from typing import Optional, Any, Callable
import torch import torch
import transformers import transformers
from huggingface_hub.errors import EntryNotFoundError
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \ from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \ BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \
PretrainedConfig, TextStreamer, LogitsProcessor PretrainedConfig, TextStreamer, LogitsProcessor
from huggingface_hub import hf_api
from huggingface_hub.file_download import hf_hub_download
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \ from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
@ -25,7 +28,7 @@ from .. import model_management
from ..component_model.tensor_types import RGBImageBatch from ..component_model.tensor_types import RGBImageBatch
from ..model_downloader import get_or_download_huggingface_repo from ..model_downloader import get_or_download_huggingface_repo
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
from ..model_management_types import ModelManageable from ..model_management_types import ModelManageableStub
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,7 +40,7 @@ _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'} _DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'}
class TransformersManagedModel(ModelManageable, LanguageModel): class TransformersManagedModel(ModelManageableStub, LanguageModel):
def __init__( def __init__(
self, self,
repo_id: str, repo_id: str,
@ -69,7 +72,20 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
hub_kwargs["subfolder"] = subfolder hub_kwargs["subfolder"] = subfolder
repo_id = ckpt_name repo_id = ckpt_name
with comfy_tqdm(): with comfy_tqdm():
ckpt_name = get_or_download_huggingface_repo(ckpt_name) ckpt_name = get_or_download_huggingface_repo(repo_id)
if config_dict is None:
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
elif isinstance(config_dict, PretrainedConfig):
config_dict: dict = config_dict.to_dict()
else:
config_dict = {}
try:
model_type = config_dict["model_type"]
except KeyError:
logger.debug(f"Configuration was missing for repo_id={repo_id}")
model_type = ""
from_pretrained_kwargs = { from_pretrained_kwargs = {
"pretrained_model_name_or_path": ckpt_name, "pretrained_model_name_or_path": ckpt_name,
@ -77,19 +93,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
**hub_kwargs **hub_kwargs
} }
# compute bitsandbytes configuration
try:
import bitsandbytes
except ImportError:
pass
if config_dict is None:
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
elif isinstance(config_dict, PretrainedConfig):
config_dict: dict = config_dict.to_dict()
model_type = config_dict["model_type"]
# language models prefer to use bfloat16 over float16 # language models prefer to use bfloat16 over float16
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)), kwargs_to_try = ({"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
"low_cpu_mem_usage": True, "low_cpu_mem_usage": True,
"device_map": str(unet_offload_device()), }, {}) "device_map": str(unet_offload_device()), }, {})

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any, TypeAlias, Union from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any, TypeAlias, Union, runtime_checkable
import torch import torch
PatchOffset = tuple[int, int, int] PatchOffset = tuple[int, int, int]
@ -31,6 +31,7 @@ class PatchTuple(NamedTuple):
ModelPatchesDictValue: TypeAlias = list[Union[PatchTuple, PatchWeightTuple]] ModelPatchesDictValue: TypeAlias = list[Union[PatchTuple, PatchWeightTuple]]
@runtime_checkable
class PatchSupport(Protocol): class PatchSupport(Protocol):
""" """
Defines the interface for a model that supports LoRA patching. Defines the interface for a model that supports LoRA patching.

View File

@ -45,7 +45,7 @@ def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[
known_files = _get_known_models_for_folder_name(folder_name) known_files = _get_known_models_for_folder_name(folder_name)
existing = folder_paths.get_filename_list(folder_name) existing = folder_paths.get_filename_list(folder_name)
downloadable_files = [] downloadable_files = []
if not args.disable_known_models: if not args.disable_known_models:
downloadable_files = known_files downloadable_files = known_files
@ -725,15 +725,15 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids) return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids)
def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]: def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None, force: bool = False, subset: bool = False) -> Optional[str]:
with comfy_tqdm(): with comfy_tqdm():
return _get_or_download_huggingface_repo(repo_id, cache_dirs, local_dirs) return _get_or_download_huggingface_repo(repo_id, cache_dirs, local_dirs, force=force, subset=subset)
def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]: def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None, force: bool = False, subset: bool = False) -> Optional[str]:
cache_dirs = cache_dirs or folder_paths.get_folder_paths("huggingface_cache") cache_dirs = cache_dirs or folder_paths.get_folder_paths("huggingface_cache")
local_dirs = local_dirs or folder_paths.get_folder_paths("huggingface") local_dirs = local_dirs or folder_paths.get_folder_paths("huggingface")
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id) cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id, subset=subset)
local_dirs_cache_hit = len(local_dirs_snapshots) > 0 local_dirs_cache_hit = len(local_dirs_snapshots) > 0
cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0 cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0
@ -742,25 +742,25 @@ def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] =
# if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download # if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download
if args.force_hf_local_dir_mode: if args.force_hf_local_dir_mode:
# todo: we still have to figure out a way to download things to the right places by default # todo: we still have to figure out a way to download things to the right places by default
if len(local_dirs_snapshots) > 0: if len(local_dirs_snapshots) > 0 and not force:
return local_dirs_snapshots[0] return local_dirs_snapshots[0]
elif not args.disable_known_models: elif not args.disable_known_models:
destination = os.path.join(local_dirs[0], repo_id) destination = os.path.join(local_dirs[0], repo_id)
logger.debug(f"downloading repo_id={repo_id}, local_dir={destination}") logger.debug(f"downloading repo_id={repo_id}, local_dir={destination}")
return snapshot_download(repo_id, local_dir=destination) return snapshot_download(repo_id, local_dir=destination, force_download=force)
snapshots = local_dirs_snapshots + cache_dirs_snapshots snapshots = local_dirs_snapshots + cache_dirs_snapshots
if len(snapshots) > 0: if len(snapshots) > 0 and not force:
return snapshots[0] return snapshots[0]
elif not args.disable_known_models: elif not args.disable_known_models:
logger.debug(f"downloading repo_id={repo_id}") logger.debug(f"downloading repo_id={repo_id}")
return snapshot_download(repo_id) return snapshot_download(repo_id, force_download=force)
# this repo was not found # this repo was not found
return None return None
def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id): def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id, subset=False):
local_dirs_snapshots = [] local_dirs_snapshots = []
cache_dirs_snapshots = [] cache_dirs_snapshots = []
# find all the pre-existing downloads for this repo_id # find all the pre-existing downloads for this repo_id
@ -772,13 +772,12 @@ def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_i
if len(repo_files) > 0: if len(repo_files) > 0:
for local_dir in local_dirs: for local_dir in local_dirs:
local_path = Path(local_dir) / repo_id local_path = Path(local_dir) / repo_id
local_files = set(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file()) local_files = frozenset(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file())
# fix path representation # fix path representation
local_files = set(f.replace("\\", "/") for f in local_files) local_files = frozenset(f.replace("\\", "/") for f in local_files)
# remove .huggingface # remove .huggingface
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache")) local_files = frozenset(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache"))
# local_files.issubsetof(repo_files) if len(local_files) > 0 and ((subset and local_files.issubset(repo_files)) or (not subset and repo_files.issubset(local_files))):
if len(local_files) > 0 and local_files.issubset(repo_files):
local_dirs_snapshots.append(str(local_path)) local_dirs_snapshots.append(str(local_path))
else: else:
# an empty repository or unknown repository info, trust that if the directory exists, it matches # an empty repository or unknown repository info, trust that if the directory exists, it matches

View File

@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
import copy
import dataclasses import dataclasses
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override, TYPE_CHECKING
import torch import torch
import torch.nn import torch.nn
@ -11,22 +12,48 @@ from typing_extensions import TypedDict, NotRequired
from .comfy_types import UnetWrapperFunction from .comfy_types import UnetWrapperFunction
from .latent_formats import LatentFormat from .latent_formats import LatentFormat
if TYPE_CHECKING:
from .hooks import EnumHookMode
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable') ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
LatentFormatT = TypeVar('LatentFormatT', bound=LatentFormat) LatentFormatT = TypeVar('LatentFormatT', bound=LatentFormat)
@runtime_checkable @runtime_checkable
class DeviceSettable(Protocol): class DeviceSettable(Protocol):
@property device: torch.device
def device(self) -> torch.device:
...
@device.setter
def device(self, value: torch.device):
...
class HooksSupport(Protocol, metaclass=ABCMeta): @runtime_checkable
class HooksSupport(Protocol):
wrappers: dict[str, dict[str, list[Callable]]]
callbacks: dict[str, dict[str, list[Callable]]]
hook_mode: "EnumHookMode"
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): ...
def model_patches_models(self) -> list[ModelManageableT]: ...
def restore_hook_patches(self): ...
def cleanup(self): ...
def pre_run(self): ...
def prepare_state(self, *args, **kwargs): ...
def register_all_hook_patches(self, a, b, c, d): ...
def get_nested_additional_models(self): ...
def apply_hooks(self, *args, **kwargs): ...
def add_wrapper(self, wrapper_type: str, wrapper: Callable): ...
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): ...
class HooksSupportStub(HooksSupport, metaclass=ABCMeta):
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
return return
@ -82,6 +109,8 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
if isinstance(model, BaseModel) or hasattr(model, "current_patcher") and isinstance(self, ModelManageable): if isinstance(model, BaseModel) or hasattr(model, "current_patcher") and isinstance(self, ModelManageable):
model.current_patcher = self model.current_patcher = self
def prepare_state(self, *args, **kwargs): def prepare_state(self, *args, **kwargs):
pass pass
@ -94,8 +123,22 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
def apply_hooks(self, *args, **kwargs): def apply_hooks(self, *args, **kwargs):
return {} return {}
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
self.add_wrapper_with_key(wrapper_type, None, wrapper)
class TrainingSupport(Protocol, metaclass=ABCMeta): def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
@runtime_checkable
class TrainingSupport(Protocol):
def set_model_compute_dtype(self, dtype: torch.dtype): ...
def add_weight_wrapper(self, name, function): ...
class TrainingSupportStub(TrainingSupport, metaclass=ABCMeta):
def set_model_compute_dtype(self, dtype: torch.dtype): def set_model_compute_dtype(self, dtype: torch.dtype):
return return
@ -103,13 +146,68 @@ class TrainingSupport(Protocol, metaclass=ABCMeta):
return return
class ModelManageableExtras(Protocol, metaclass=ABCMeta): @runtime_checkable
class ModelManageable(HooksSupport, TrainingSupport, Protocol):
"""
Objects which implement this protocol can be managed by
>>> from comfy.model_management import load_models_gpu
>>> class ModelWrapper(ModelManageable):
>>> ...
>>>
>>> some_model = ModelWrapper()
>>> load_models_gpu([some_model])
The minimum required
"""
load_device: torch.device
offload_device: torch.device
model: torch.nn.Module
@property @property
def current_device(self) -> torch.device: def current_device(self) -> torch.device: ...
return torch.device("cpu")
def is_clone(self, other: ModelManageableT) -> bool: ...
def clone_has_same_weights(self, clone: ModelManageableT) -> bool: ...
def model_size(self) -> int: ...
def model_patches_to(self, arg: torch.device | torch.dtype): ...
def model_dtype(self) -> torch.dtype: ...
def lowvram_patch_counter(self) -> int: ...
def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False) -> int: ...
def partially_unload(self, device_to: torch.device, memory_to_free: int = 0) -> int: ...
def memory_required(self, input_shape: torch.Size) -> int: ...
def loaded_size(self) -> int: ...
def current_loaded_device(self) -> torch.device: ...
def get_model_object(self, name: str) -> torch.nn.Module: ...
@property
def model_options(self) -> ModelOptions: ...
@model_options.setter
def model_options(self, value): ...
def __del__(self): ...
@property
def parent(self) -> ModelManageableT | None: ...
def detach(self, unpatch_all: bool = True): ...
def clone(self) -> ModelManageableT: ...
class ModelManageableRequired(Protocol, metaclass=ABCMeta): class ModelManageableStub(HooksSupportStub, TrainingSupportStub, ModelManageable, metaclass=ABCMeta):
""" """
The bare minimum that must be implemented to support model management when inheriting from ModelManageable The bare minimum that must be implemented to support model management when inheriting from ModelManageable
@ -120,12 +218,11 @@ class ModelManageableRequired(Protocol, metaclass=ABCMeta):
:see: ModelManageable :see: ModelManageable
:see: PatchSupport :see: PatchSupport
""" """
load_device: torch.device
offload_device: torch.device
model: torch.nn.Module
@abstractmethod @abstractmethod
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module: def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True,
force_patch_weights: bool = False) -> torch.nn.Module:
""" """
Called by ModelManageable Called by ModelManageable
@ -155,25 +252,6 @@ class ModelManageableRequired(Protocol, metaclass=ABCMeta):
""" """
... ...
@runtime_checkable
class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, metaclass=ABCMeta):
"""
Objects which implement this protocol can be managed by
>>> from comfy.model_management import load_models_gpu
>>> class ModelWrapper(ModelManageable):
>>> ...
>>>
>>> some_model = ModelWrapper()
>>> load_models_gpu([some_model])
The minimum required
"""
load_device: torch.device
offload_device: torch.device
model: torch.nn.Module
@property @property
@override @override
def current_device(self) -> torch.device: def current_device(self) -> torch.device:
@ -265,6 +343,9 @@ class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol,
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
return self.model return self.model
def clone(self) -> ModelManageableT:
return copy.copy(self)
@dataclasses.dataclass @dataclasses.dataclass
class MemoryMeasurements: class MemoryMeasurements:

View File

@ -230,7 +230,7 @@ class GGUFQuantization:
patch_on_device: bool = False patch_on_device: bool = False
class ModelPatcher(ModelManageable, TrainingSupport, HooksSupport, PatchSupport): class ModelPatcher(ModelManageable, PatchSupport):
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None): def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
self.size = size self.size = size
self.model: BaseModel | torch.nn.Module = model self.model: BaseModel | torch.nn.Module = model

View File

@ -27,6 +27,7 @@ import os
import random import random
import struct import struct
import sys import sys
import threading
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
@ -57,6 +58,9 @@ DISABLE_MMAP = args.disable_mmap
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALWAYS_SAFE_LOAD = False ALWAYS_SAFE_LOAD = False
_lock = threading.RLock()
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
class ModelCheckpoint: class ModelCheckpoint:
pass pass
@ -1178,7 +1182,9 @@ class ProgressBar:
def update_absolute(self, value, total=None, preview_image_or_output=None): def update_absolute(self, value, total=None, preview_image_or_output=None):
if total is not None: if total is not None:
self.total = total self.total = total
if value > self.total: if value is None:
return
if self.total is not None and value > self.total:
value = self.total value = self.total
self.current = value self.current = value
_progress_bar_update(self.current, self.total, preview_image_or_output, server=self.server, node_id=self.node_id) _progress_bar_update(self.current, self.total, preview_image_or_output, server=self.server, node_id=self.node_id)
@ -1198,31 +1204,39 @@ def comfy_tqdm() -> Generator[TqdmWatcher, None, None]:
Monkey patches child calls to tqdm, sends progress to the UI, Monkey patches child calls to tqdm, sends progress to the UI,
and yields a watcher object for stall detection. and yields a watcher object for stall detection.
""" """
with _lock:
if hasattr(tqdm, "__patched_by_comfyui__"):
yield getattr(tqdm, "__patched_by_comfyui__")
return
watcher = TqdmWatcher()
setattr(tqdm, "__patched_by_comfyui__", watcher)
_original_init = tqdm.__init__ _original_init = tqdm.__init__
_original_call = tqdm.__call__ _original_call = tqdm.__call__
_original_update = tqdm.update _original_update = tqdm.update
# Create the watcher instance that the patched methods will update # Create the watcher instance that the patched methods will update
# and that will be yielded to the caller. # and that will be yielded to the caller.
watcher = TqdmWatcher()
context = contextvars.copy_context() context = contextvars.copy_context()
try: try:
# These inner functions are closures; they capture the `watcher` variable # These inner functions are closures; they capture the `watcher` variable
# from the enclosing scope. # from the enclosing scope.
def __init(self, *args, **kwargs): def __init(self, *args, **kwargs):
context.run(lambda: _original_init(self, *args, **kwargs)) _original_init(self, *args, **kwargs)
self._progress_bar = context.run(lambda: ProgressBar(self.total)) self._progress_bar = context.run(lambda: ProgressBar(self.total))
watcher.tick() # Signal progress on initialization watcher.tick() # Signal progress on initialization
def __update(self, n=1): def __update(self, n=1):
assert self._progress_bar is not None assert self._progress_bar is not None
context.run(lambda: _original_update(self, n)) _original_update(self, n)
context.run(lambda: self._progress_bar.update(n)) context.run(lambda: self._progress_bar.update(n))
watcher.tick() # Signal progress on update watcher.tick() # Signal progress on update
def __call(self, *args, **kwargs): def __call(self, *args, **kwargs):
instance = context.run(lambda: _original_call(self, *args, **kwargs)) instance = _original_call(self, *args, **kwargs)
return instance return instance
tqdm.__init__ = __init tqdm.__init__ = __init
@ -1236,10 +1250,11 @@ def comfy_tqdm() -> Generator[TqdmWatcher, None, None]:
tqdm.__init__ = _original_init tqdm.__init__ = _original_init
tqdm.__call__ = _original_call tqdm.__call__ = _original_call
tqdm.update = _original_update tqdm.update = _original_update
delattr(tqdm, "__patched_by_comfyui__")
@contextmanager @contextmanager
def comfy_progress(total: float) -> ProgressBar: def comfy_progress(total: float) -> Generator[ProgressBar, Any, None]:
ctx = current_execution_context() ctx = current_execution_context()
if ctx.server.receive_all_progress_notifications: if ctx.server.receive_all_progress_notifications:
yield ProgressBar(total) yield ProgressBar(total)

View File

@ -765,7 +765,7 @@ class DualCFGGuider:
FUNCTION = "get_guider" FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders" CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style): def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style="regular"):
guider = Guider_DualCFG(model) guider = Guider_DualCFG(model)
guider.set_conds(cond1, cond2, negative) guider.set_conds(cond1, cond2, negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested")) guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))

View File

@ -2,7 +2,9 @@ import torch
from diffusers import HookRegistry from diffusers import HookRegistry
from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_management import vram_state, VRAMState from comfy.model_management import vram_state, VRAMState
from comfy.model_management_types import HooksSupport, ModelManageable
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.node_helpers import export_custom_nodes from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode from comfy.nodes.package_typing import CustomNode
@ -117,9 +119,21 @@ class GroupOffload(CustomNode):
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "execute" FUNCTION = "execute"
def execute(self, model: ModelPatcher) -> tuple[ModelPatcher,]: def execute(self, model: ModelManageable | HooksSupport | TransformersManagedModel) -> tuple[ModelPatcher,]:
model = model.clone() if isinstance(model, ModelManageable):
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device)) model = model.clone()
if isinstance(model, TransformersManagedModel):
apply_group_offloading(
model.model,
model.load_device,
model.offload_device,
use_stream=True,
record_stream=True,
low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,),
num_blocks_per_group=1
)
elif isinstance(model, HooksSupport) and isinstance(model, ModelManageable):
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device))
return model, return model,

View File

@ -42,6 +42,7 @@ class LTXVImgToVideo:
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}, "optional": {
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
}} }}
@ -51,7 +52,7 @@ class LTXVImgToVideo:
CATEGORY = "conditioning/video_models" CATEGORY = "conditioning/video_models"
FUNCTION = "generate" FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength): def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength=1.0):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3] encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)

View File

@ -9,7 +9,7 @@ from comfy import utils
from comfy.component_model.tensor_types import RGBImageBatch from comfy.component_model.tensor_types import RGBImageBatch
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
from comfy.model_management import load_models_gpu from comfy.model_management import load_models_gpu
from comfy.model_management_types import ModelManageable from comfy.model_management_types import ModelManageableStub
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
@ -22,7 +22,7 @@ except:
pass pass
class UpscaleModelManageable(ModelManageable): class UpscaleModelManageable(ModelManageableStub):
def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str): def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str):
self.ckpt_name = ckpt_name self.ckpt_name = ckpt_name
self.model_descriptor = model_descriptor self.model_descriptor = model_descriptor

View File

@ -182,7 +182,7 @@ async def test_huggingface_alternate_filenames_in_combo():
) )
# 3. Get the list of files as the UI would # 3. Get the list of files as the UI would
filename_list = get_filename_list_with_downloadable("checkpoints", known_files=[known_file]) filename_list = get_filename_list_with_downloadable("__xxx___", known_files=[known_file])
# 4. Assert that both the main and alternate filenames are present # 4. Assert that both the main and alternate filenames are present
assert main_filename in filename_list assert main_filename in filename_list

View File

@ -1,7 +1,7 @@
{ {
"1": { "1": {
"inputs": { "inputs": {
"ckpt_name": "microsoft/Phi-3-mini-4k-instruct", "ckpt_name": "microsoft/Phi-4-mini-instruct",
"subfolder": "" "subfolder": ""
}, },
"class_type": "TransformersLoader", "class_type": "TransformersLoader",
@ -33,7 +33,7 @@
"4": { "4": {
"inputs": { "inputs": {
"prompt": "What comes after apple?", "prompt": "What comes after apple?",
"chat_template": "phi-3", "chat_template": "default",
"model": [ "model": [
"1", "1",
0 0