mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Improve tests, fix issues with alternate filenames, improve group offloading support for transformers models
This commit is contained in:
parent
79b8723f61
commit
bc201cea4d
@ -242,7 +242,7 @@ ComfyUI LTS supports text and multi-modal LLM models from the `transformers` eco
|
||||
|
||||
In this example, LLAVA-NEXT (LLAVA 1.6) is prompted to describe an image.
|
||||
|
||||
You can try the [LLAVA-NEXT](tests/inference/workflows/llava-0.json), [Phi-3](tests/inference/workflows/phi-3-0.json), and two [translation](tests/inference/workflows/translation-0.json) [workflows](tests/inference/workflows/translation-1.json).
|
||||
You can try the [LLAVA-NEXT](tests/inference/workflows/llava-0.json), [Phi-3](tests/inference/workflows/phi-4-0.json), and two [translation](tests/inference/workflows/translation-0.json) [workflows](tests/inference/workflows/translation-1.json).
|
||||
|
||||
# SVG Conversion and String Saving
|
||||
|
||||
|
||||
@ -130,7 +130,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
parser.add_argument("--disable-smart-memory", action="store_true",
|
||||
help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
help="Force ComfyUI to aggressively offload to regular ram instead of keeping models in VRAM when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true",
|
||||
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
|
||||
|
||||
@ -12,9 +12,12 @@ from typing import Optional, Any, Callable
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from huggingface_hub.errors import EntryNotFoundError
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
|
||||
BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \
|
||||
PretrainedConfig, TextStreamer, LogitsProcessor
|
||||
from huggingface_hub import hf_api
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
|
||||
@ -25,7 +28,7 @@ from .. import model_management
|
||||
from ..component_model.tensor_types import RGBImageBatch
|
||||
from ..model_downloader import get_or_download_huggingface_repo
|
||||
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
|
||||
from ..model_management_types import ModelManageable
|
||||
from ..model_management_types import ModelManageableStub
|
||||
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -37,7 +40,7 @@ _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'}
|
||||
|
||||
|
||||
class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
@ -69,7 +72,20 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
hub_kwargs["subfolder"] = subfolder
|
||||
repo_id = ckpt_name
|
||||
with comfy_tqdm():
|
||||
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
||||
ckpt_name = get_or_download_huggingface_repo(repo_id)
|
||||
|
||||
if config_dict is None:
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
|
||||
elif isinstance(config_dict, PretrainedConfig):
|
||||
config_dict: dict = config_dict.to_dict()
|
||||
else:
|
||||
config_dict = {}
|
||||
|
||||
try:
|
||||
model_type = config_dict["model_type"]
|
||||
except KeyError:
|
||||
logger.debug(f"Configuration was missing for repo_id={repo_id}")
|
||||
model_type = ""
|
||||
|
||||
from_pretrained_kwargs = {
|
||||
"pretrained_model_name_or_path": ckpt_name,
|
||||
@ -77,19 +93,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
**hub_kwargs
|
||||
}
|
||||
|
||||
# compute bitsandbytes configuration
|
||||
try:
|
||||
import bitsandbytes
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if config_dict is None:
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
|
||||
elif isinstance(config_dict, PretrainedConfig):
|
||||
config_dict: dict = config_dict.to_dict()
|
||||
model_type = config_dict["model_type"]
|
||||
# language models prefer to use bfloat16 over float16
|
||||
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
||||
kwargs_to_try = ({"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
||||
"low_cpu_mem_usage": True,
|
||||
"device_map": str(unet_offload_device()), }, {})
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any, TypeAlias, Union
|
||||
from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any, TypeAlias, Union, runtime_checkable
|
||||
import torch
|
||||
|
||||
PatchOffset = tuple[int, int, int]
|
||||
@ -31,6 +31,7 @@ class PatchTuple(NamedTuple):
|
||||
ModelPatchesDictValue: TypeAlias = list[Union[PatchTuple, PatchWeightTuple]]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PatchSupport(Protocol):
|
||||
"""
|
||||
Defines the interface for a model that supports LoRA patching.
|
||||
|
||||
@ -725,15 +725,15 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
|
||||
return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids)
|
||||
|
||||
|
||||
def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]:
|
||||
def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None, force: bool = False, subset: bool = False) -> Optional[str]:
|
||||
with comfy_tqdm():
|
||||
return _get_or_download_huggingface_repo(repo_id, cache_dirs, local_dirs)
|
||||
return _get_or_download_huggingface_repo(repo_id, cache_dirs, local_dirs, force=force, subset=subset)
|
||||
|
||||
|
||||
def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]:
|
||||
def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None, force: bool = False, subset: bool = False) -> Optional[str]:
|
||||
cache_dirs = cache_dirs or folder_paths.get_folder_paths("huggingface_cache")
|
||||
local_dirs = local_dirs or folder_paths.get_folder_paths("huggingface")
|
||||
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id)
|
||||
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id, subset=subset)
|
||||
|
||||
local_dirs_cache_hit = len(local_dirs_snapshots) > 0
|
||||
cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0
|
||||
@ -742,25 +742,25 @@ def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] =
|
||||
# if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download
|
||||
if args.force_hf_local_dir_mode:
|
||||
# todo: we still have to figure out a way to download things to the right places by default
|
||||
if len(local_dirs_snapshots) > 0:
|
||||
if len(local_dirs_snapshots) > 0 and not force:
|
||||
return local_dirs_snapshots[0]
|
||||
elif not args.disable_known_models:
|
||||
destination = os.path.join(local_dirs[0], repo_id)
|
||||
logger.debug(f"downloading repo_id={repo_id}, local_dir={destination}")
|
||||
return snapshot_download(repo_id, local_dir=destination)
|
||||
return snapshot_download(repo_id, local_dir=destination, force_download=force)
|
||||
|
||||
snapshots = local_dirs_snapshots + cache_dirs_snapshots
|
||||
if len(snapshots) > 0:
|
||||
if len(snapshots) > 0 and not force:
|
||||
return snapshots[0]
|
||||
elif not args.disable_known_models:
|
||||
logger.debug(f"downloading repo_id={repo_id}")
|
||||
return snapshot_download(repo_id)
|
||||
return snapshot_download(repo_id, force_download=force)
|
||||
|
||||
# this repo was not found
|
||||
return None
|
||||
|
||||
|
||||
def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id):
|
||||
def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id, subset=False):
|
||||
local_dirs_snapshots = []
|
||||
cache_dirs_snapshots = []
|
||||
# find all the pre-existing downloads for this repo_id
|
||||
@ -772,13 +772,12 @@ def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_i
|
||||
if len(repo_files) > 0:
|
||||
for local_dir in local_dirs:
|
||||
local_path = Path(local_dir) / repo_id
|
||||
local_files = set(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file())
|
||||
local_files = frozenset(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file())
|
||||
# fix path representation
|
||||
local_files = set(f.replace("\\", "/") for f in local_files)
|
||||
local_files = frozenset(f.replace("\\", "/") for f in local_files)
|
||||
# remove .huggingface
|
||||
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache"))
|
||||
# local_files.issubsetof(repo_files)
|
||||
if len(local_files) > 0 and local_files.issubset(repo_files):
|
||||
local_files = frozenset(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache"))
|
||||
if len(local_files) > 0 and ((subset and local_files.issubset(repo_files)) or (not subset and repo_files.issubset(local_files))):
|
||||
local_dirs_snapshots.append(str(local_path))
|
||||
else:
|
||||
# an empty repository or unknown repository info, trust that if the directory exists, it matches
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override
|
||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
@ -11,22 +12,48 @@ from typing_extensions import TypedDict, NotRequired
|
||||
from .comfy_types import UnetWrapperFunction
|
||||
from .latent_formats import LatentFormat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .hooks import EnumHookMode
|
||||
|
||||
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
|
||||
LatentFormatT = TypeVar('LatentFormatT', bound=LatentFormat)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class DeviceSettable(Protocol):
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
...
|
||||
|
||||
@device.setter
|
||||
def device(self, value: torch.device):
|
||||
...
|
||||
device: torch.device
|
||||
|
||||
|
||||
class HooksSupport(Protocol, metaclass=ABCMeta):
|
||||
@runtime_checkable
|
||||
class HooksSupport(Protocol):
|
||||
wrappers: dict[str, dict[str, list[Callable]]]
|
||||
callbacks: dict[str, dict[str, list[Callable]]]
|
||||
hook_mode: "EnumHookMode"
|
||||
|
||||
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): ...
|
||||
|
||||
def model_patches_models(self) -> list[ModelManageableT]: ...
|
||||
|
||||
def restore_hook_patches(self): ...
|
||||
|
||||
def cleanup(self): ...
|
||||
|
||||
def pre_run(self): ...
|
||||
|
||||
def prepare_state(self, *args, **kwargs): ...
|
||||
|
||||
def register_all_hook_patches(self, a, b, c, d): ...
|
||||
|
||||
def get_nested_additional_models(self): ...
|
||||
|
||||
def apply_hooks(self, *args, **kwargs): ...
|
||||
|
||||
def add_wrapper(self, wrapper_type: str, wrapper: Callable): ...
|
||||
|
||||
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): ...
|
||||
|
||||
|
||||
class HooksSupportStub(HooksSupport, metaclass=ABCMeta):
|
||||
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
||||
return
|
||||
|
||||
@ -82,6 +109,8 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
|
||||
if isinstance(model, BaseModel) or hasattr(model, "current_patcher") and isinstance(self, ModelManageable):
|
||||
model.current_patcher = self
|
||||
|
||||
|
||||
|
||||
def prepare_state(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@ -94,8 +123,22 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
|
||||
def apply_hooks(self, *args, **kwargs):
|
||||
return {}
|
||||
|
||||
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
|
||||
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
||||
|
||||
class TrainingSupport(Protocol, metaclass=ABCMeta):
|
||||
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
||||
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||
w.append(wrapper)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TrainingSupport(Protocol):
|
||||
def set_model_compute_dtype(self, dtype: torch.dtype): ...
|
||||
|
||||
def add_weight_wrapper(self, name, function): ...
|
||||
|
||||
|
||||
class TrainingSupportStub(TrainingSupport, metaclass=ABCMeta):
|
||||
def set_model_compute_dtype(self, dtype: torch.dtype):
|
||||
return
|
||||
|
||||
@ -103,13 +146,68 @@ class TrainingSupport(Protocol, metaclass=ABCMeta):
|
||||
return
|
||||
|
||||
|
||||
class ModelManageableExtras(Protocol, metaclass=ABCMeta):
|
||||
@runtime_checkable
|
||||
class ModelManageable(HooksSupport, TrainingSupport, Protocol):
|
||||
"""
|
||||
Objects which implement this protocol can be managed by
|
||||
|
||||
>>> from comfy.model_management import load_models_gpu
|
||||
>>> class ModelWrapper(ModelManageable):
|
||||
>>> ...
|
||||
>>>
|
||||
>>> some_model = ModelWrapper()
|
||||
>>> load_models_gpu([some_model])
|
||||
|
||||
The minimum required
|
||||
"""
|
||||
load_device: torch.device
|
||||
offload_device: torch.device
|
||||
model: torch.nn.Module
|
||||
|
||||
@property
|
||||
def current_device(self) -> torch.device:
|
||||
return torch.device("cpu")
|
||||
def current_device(self) -> torch.device: ...
|
||||
|
||||
def is_clone(self, other: ModelManageableT) -> bool: ...
|
||||
|
||||
def clone_has_same_weights(self, clone: ModelManageableT) -> bool: ...
|
||||
|
||||
def model_size(self) -> int: ...
|
||||
|
||||
def model_patches_to(self, arg: torch.device | torch.dtype): ...
|
||||
|
||||
def model_dtype(self) -> torch.dtype: ...
|
||||
|
||||
def lowvram_patch_counter(self) -> int: ...
|
||||
|
||||
def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False) -> int: ...
|
||||
|
||||
def partially_unload(self, device_to: torch.device, memory_to_free: int = 0) -> int: ...
|
||||
|
||||
def memory_required(self, input_shape: torch.Size) -> int: ...
|
||||
|
||||
def loaded_size(self) -> int: ...
|
||||
|
||||
def current_loaded_device(self) -> torch.device: ...
|
||||
|
||||
def get_model_object(self, name: str) -> torch.nn.Module: ...
|
||||
|
||||
@property
|
||||
def model_options(self) -> ModelOptions: ...
|
||||
|
||||
@model_options.setter
|
||||
def model_options(self, value): ...
|
||||
|
||||
def __del__(self): ...
|
||||
|
||||
@property
|
||||
def parent(self) -> ModelManageableT | None: ...
|
||||
|
||||
def detach(self, unpatch_all: bool = True): ...
|
||||
|
||||
def clone(self) -> ModelManageableT: ...
|
||||
|
||||
|
||||
class ModelManageableRequired(Protocol, metaclass=ABCMeta):
|
||||
class ModelManageableStub(HooksSupportStub, TrainingSupportStub, ModelManageable, metaclass=ABCMeta):
|
||||
"""
|
||||
The bare minimum that must be implemented to support model management when inheriting from ModelManageable
|
||||
|
||||
@ -120,12 +218,11 @@ class ModelManageableRequired(Protocol, metaclass=ABCMeta):
|
||||
:see: ModelManageable
|
||||
:see: PatchSupport
|
||||
"""
|
||||
load_device: torch.device
|
||||
offload_device: torch.device
|
||||
model: torch.nn.Module
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True,
|
||||
force_patch_weights: bool = False) -> torch.nn.Module:
|
||||
"""
|
||||
Called by ModelManageable
|
||||
|
||||
@ -155,25 +252,6 @@ class ModelManageableRequired(Protocol, metaclass=ABCMeta):
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, metaclass=ABCMeta):
|
||||
"""
|
||||
Objects which implement this protocol can be managed by
|
||||
|
||||
>>> from comfy.model_management import load_models_gpu
|
||||
>>> class ModelWrapper(ModelManageable):
|
||||
>>> ...
|
||||
>>>
|
||||
>>> some_model = ModelWrapper()
|
||||
>>> load_models_gpu([some_model])
|
||||
|
||||
The minimum required
|
||||
"""
|
||||
load_device: torch.device
|
||||
offload_device: torch.device
|
||||
model: torch.nn.Module
|
||||
|
||||
@property
|
||||
@override
|
||||
def current_device(self) -> torch.device:
|
||||
@ -265,6 +343,9 @@ class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol,
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
return self.model
|
||||
|
||||
def clone(self) -> ModelManageableT:
|
||||
return copy.copy(self)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MemoryMeasurements:
|
||||
|
||||
@ -230,7 +230,7 @@ class GGUFQuantization:
|
||||
patch_on_device: bool = False
|
||||
|
||||
|
||||
class ModelPatcher(ModelManageable, TrainingSupport, HooksSupport, PatchSupport):
|
||||
class ModelPatcher(ModelManageable, PatchSupport):
|
||||
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||
self.size = size
|
||||
self.model: BaseModel | torch.nn.Module = model
|
||||
|
||||
@ -27,6 +27,7 @@ import os
|
||||
import random
|
||||
import struct
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
@ -57,6 +58,9 @@ DISABLE_MMAP = args.disable_mmap
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
|
||||
_lock = threading.RLock()
|
||||
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
class ModelCheckpoint:
|
||||
pass
|
||||
@ -1178,7 +1182,9 @@ class ProgressBar:
|
||||
def update_absolute(self, value, total=None, preview_image_or_output=None):
|
||||
if total is not None:
|
||||
self.total = total
|
||||
if value > self.total:
|
||||
if value is None:
|
||||
return
|
||||
if self.total is not None and value > self.total:
|
||||
value = self.total
|
||||
self.current = value
|
||||
_progress_bar_update(self.current, self.total, preview_image_or_output, server=self.server, node_id=self.node_id)
|
||||
@ -1198,31 +1204,39 @@ def comfy_tqdm() -> Generator[TqdmWatcher, None, None]:
|
||||
Monkey patches child calls to tqdm, sends progress to the UI,
|
||||
and yields a watcher object for stall detection.
|
||||
"""
|
||||
with _lock:
|
||||
if hasattr(tqdm, "__patched_by_comfyui__"):
|
||||
yield getattr(tqdm, "__patched_by_comfyui__")
|
||||
return
|
||||
|
||||
watcher = TqdmWatcher()
|
||||
setattr(tqdm, "__patched_by_comfyui__", watcher)
|
||||
|
||||
_original_init = tqdm.__init__
|
||||
_original_call = tqdm.__call__
|
||||
_original_update = tqdm.update
|
||||
|
||||
# Create the watcher instance that the patched methods will update
|
||||
# and that will be yielded to the caller.
|
||||
watcher = TqdmWatcher()
|
||||
context = contextvars.copy_context()
|
||||
|
||||
try:
|
||||
|
||||
# These inner functions are closures; they capture the `watcher` variable
|
||||
# from the enclosing scope.
|
||||
def __init(self, *args, **kwargs):
|
||||
context.run(lambda: _original_init(self, *args, **kwargs))
|
||||
_original_init(self, *args, **kwargs)
|
||||
self._progress_bar = context.run(lambda: ProgressBar(self.total))
|
||||
watcher.tick() # Signal progress on initialization
|
||||
|
||||
def __update(self, n=1):
|
||||
assert self._progress_bar is not None
|
||||
context.run(lambda: _original_update(self, n))
|
||||
_original_update(self, n)
|
||||
context.run(lambda: self._progress_bar.update(n))
|
||||
watcher.tick() # Signal progress on update
|
||||
|
||||
def __call(self, *args, **kwargs):
|
||||
instance = context.run(lambda: _original_call(self, *args, **kwargs))
|
||||
instance = _original_call(self, *args, **kwargs)
|
||||
return instance
|
||||
|
||||
tqdm.__init__ = __init
|
||||
@ -1236,10 +1250,11 @@ def comfy_tqdm() -> Generator[TqdmWatcher, None, None]:
|
||||
tqdm.__init__ = _original_init
|
||||
tqdm.__call__ = _original_call
|
||||
tqdm.update = _original_update
|
||||
delattr(tqdm, "__patched_by_comfyui__")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def comfy_progress(total: float) -> ProgressBar:
|
||||
def comfy_progress(total: float) -> Generator[ProgressBar, Any, None]:
|
||||
ctx = current_execution_context()
|
||||
if ctx.server.receive_all_progress_notifications:
|
||||
yield ProgressBar(total)
|
||||
|
||||
@ -765,7 +765,7 @@ class DualCFGGuider:
|
||||
FUNCTION = "get_guider"
|
||||
CATEGORY = "sampling/custom_sampling/guiders"
|
||||
|
||||
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style):
|
||||
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style="regular"):
|
||||
guider = Guider_DualCFG(model)
|
||||
guider.set_conds(cond1, cond2, negative)
|
||||
guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))
|
||||
|
||||
@ -2,7 +2,9 @@ import torch
|
||||
from diffusers import HookRegistry
|
||||
from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook
|
||||
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.model_management import vram_state, VRAMState
|
||||
from comfy.model_management_types import HooksSupport, ModelManageable
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
@ -117,8 +119,20 @@ class GroupOffload(CustomNode):
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, model: ModelPatcher) -> tuple[ModelPatcher,]:
|
||||
def execute(self, model: ModelManageable | HooksSupport | TransformersManagedModel) -> tuple[ModelPatcher,]:
|
||||
if isinstance(model, ModelManageable):
|
||||
model = model.clone()
|
||||
if isinstance(model, TransformersManagedModel):
|
||||
apply_group_offloading(
|
||||
model.model,
|
||||
model.load_device,
|
||||
model.offload_device,
|
||||
use_stream=True,
|
||||
record_stream=True,
|
||||
low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,),
|
||||
num_blocks_per_group=1
|
||||
)
|
||||
elif isinstance(model, HooksSupport) and isinstance(model, ModelManageable):
|
||||
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device))
|
||||
return model,
|
||||
|
||||
|
||||
@ -42,6 +42,7 @@ class LTXVImgToVideo:
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
}, "optional": {
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
|
||||
}}
|
||||
|
||||
@ -51,7 +52,7 @@ class LTXVImgToVideo:
|
||||
CATEGORY = "conditioning/video_models"
|
||||
FUNCTION = "generate"
|
||||
|
||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
|
||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength=1.0):
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
|
||||
@ -9,7 +9,7 @@ from comfy import utils
|
||||
from comfy.component_model.tensor_types import RGBImageBatch
|
||||
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
|
||||
from comfy.model_management import load_models_gpu
|
||||
from comfy.model_management_types import ModelManageable
|
||||
from comfy.model_management_types import ModelManageableStub
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
@ -22,7 +22,7 @@ except:
|
||||
pass
|
||||
|
||||
|
||||
class UpscaleModelManageable(ModelManageable):
|
||||
class UpscaleModelManageable(ModelManageableStub):
|
||||
def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str):
|
||||
self.ckpt_name = ckpt_name
|
||||
self.model_descriptor = model_descriptor
|
||||
|
||||
@ -182,7 +182,7 @@ async def test_huggingface_alternate_filenames_in_combo():
|
||||
)
|
||||
|
||||
# 3. Get the list of files as the UI would
|
||||
filename_list = get_filename_list_with_downloadable("checkpoints", known_files=[known_file])
|
||||
filename_list = get_filename_list_with_downloadable("__xxx___", known_files=[known_file])
|
||||
|
||||
# 4. Assert that both the main and alternate filenames are present
|
||||
assert main_filename in filename_list
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
{
|
||||
"1": {
|
||||
"inputs": {
|
||||
"ckpt_name": "microsoft/Phi-3-mini-4k-instruct",
|
||||
"ckpt_name": "microsoft/Phi-4-mini-instruct",
|
||||
"subfolder": ""
|
||||
},
|
||||
"class_type": "TransformersLoader",
|
||||
@ -33,7 +33,7 @@
|
||||
"4": {
|
||||
"inputs": {
|
||||
"prompt": "What comes after apple?",
|
||||
"chat_template": "phi-3",
|
||||
"chat_template": "default",
|
||||
"model": [
|
||||
"1",
|
||||
0
|
||||
Loading…
Reference in New Issue
Block a user