mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Improve tests, fix issues with alternate filenames, improve group offloading support for transformers models
This commit is contained in:
parent
79b8723f61
commit
bc201cea4d
@ -242,7 +242,7 @@ ComfyUI LTS supports text and multi-modal LLM models from the `transformers` eco
|
|||||||
|
|
||||||
In this example, LLAVA-NEXT (LLAVA 1.6) is prompted to describe an image.
|
In this example, LLAVA-NEXT (LLAVA 1.6) is prompted to describe an image.
|
||||||
|
|
||||||
You can try the [LLAVA-NEXT](tests/inference/workflows/llava-0.json), [Phi-3](tests/inference/workflows/phi-3-0.json), and two [translation](tests/inference/workflows/translation-0.json) [workflows](tests/inference/workflows/translation-1.json).
|
You can try the [LLAVA-NEXT](tests/inference/workflows/llava-0.json), [Phi-3](tests/inference/workflows/phi-4-0.json), and two [translation](tests/inference/workflows/translation-0.json) [workflows](tests/inference/workflows/translation-1.json).
|
||||||
|
|
||||||
# SVG Conversion and String Saving
|
# SVG Conversion and String Saving
|
||||||
|
|
||||||
|
|||||||
@ -130,7 +130,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||||
parser.add_argument("--disable-smart-memory", action="store_true",
|
parser.add_argument("--disable-smart-memory", action="store_true",
|
||||||
help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
help="Force ComfyUI to aggressively offload to regular ram instead of keeping models in VRAM when it can.")
|
||||||
parser.add_argument("--deterministic", action="store_true",
|
parser.add_argument("--deterministic", action="store_true",
|
||||||
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||||
|
|
||||||
|
|||||||
@ -12,9 +12,12 @@ from typing import Optional, Any, Callable
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from huggingface_hub.errors import EntryNotFoundError
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
|
||||||
BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \
|
BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \
|
||||||
PretrainedConfig, TextStreamer, LogitsProcessor
|
PretrainedConfig, TextStreamer, LogitsProcessor
|
||||||
|
from huggingface_hub import hf_api
|
||||||
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
|
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
|
||||||
@ -25,7 +28,7 @@ from .. import model_management
|
|||||||
from ..component_model.tensor_types import RGBImageBatch
|
from ..component_model.tensor_types import RGBImageBatch
|
||||||
from ..model_downloader import get_or_download_huggingface_repo
|
from ..model_downloader import get_or_download_huggingface_repo
|
||||||
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
|
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
|
||||||
from ..model_management_types import ModelManageable
|
from ..model_management_types import ModelManageableStub
|
||||||
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
|
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -37,7 +40,7 @@ _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING
|
|||||||
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'}
|
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'}
|
||||||
|
|
||||||
|
|
||||||
class TransformersManagedModel(ModelManageable, LanguageModel):
|
class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
@ -69,7 +72,20 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
|||||||
hub_kwargs["subfolder"] = subfolder
|
hub_kwargs["subfolder"] = subfolder
|
||||||
repo_id = ckpt_name
|
repo_id = ckpt_name
|
||||||
with comfy_tqdm():
|
with comfy_tqdm():
|
||||||
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
ckpt_name = get_or_download_huggingface_repo(repo_id)
|
||||||
|
|
||||||
|
if config_dict is None:
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
|
||||||
|
elif isinstance(config_dict, PretrainedConfig):
|
||||||
|
config_dict: dict = config_dict.to_dict()
|
||||||
|
else:
|
||||||
|
config_dict = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_type = config_dict["model_type"]
|
||||||
|
except KeyError:
|
||||||
|
logger.debug(f"Configuration was missing for repo_id={repo_id}")
|
||||||
|
model_type = ""
|
||||||
|
|
||||||
from_pretrained_kwargs = {
|
from_pretrained_kwargs = {
|
||||||
"pretrained_model_name_or_path": ckpt_name,
|
"pretrained_model_name_or_path": ckpt_name,
|
||||||
@ -77,19 +93,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
|||||||
**hub_kwargs
|
**hub_kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
# compute bitsandbytes configuration
|
|
||||||
try:
|
|
||||||
import bitsandbytes
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if config_dict is None:
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
|
|
||||||
elif isinstance(config_dict, PretrainedConfig):
|
|
||||||
config_dict: dict = config_dict.to_dict()
|
|
||||||
model_type = config_dict["model_type"]
|
|
||||||
# language models prefer to use bfloat16 over float16
|
# language models prefer to use bfloat16 over float16
|
||||||
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
kwargs_to_try = ({"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
||||||
"low_cpu_mem_usage": True,
|
"low_cpu_mem_usage": True,
|
||||||
"device_map": str(unet_offload_device()), }, {})
|
"device_map": str(unet_offload_device()), }, {})
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any, TypeAlias, Union
|
from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any, TypeAlias, Union, runtime_checkable
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
PatchOffset = tuple[int, int, int]
|
PatchOffset = tuple[int, int, int]
|
||||||
@ -31,6 +31,7 @@ class PatchTuple(NamedTuple):
|
|||||||
ModelPatchesDictValue: TypeAlias = list[Union[PatchTuple, PatchWeightTuple]]
|
ModelPatchesDictValue: TypeAlias = list[Union[PatchTuple, PatchWeightTuple]]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class PatchSupport(Protocol):
|
class PatchSupport(Protocol):
|
||||||
"""
|
"""
|
||||||
Defines the interface for a model that supports LoRA patching.
|
Defines the interface for a model that supports LoRA patching.
|
||||||
|
|||||||
@ -45,7 +45,7 @@ def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[
|
|||||||
known_files = _get_known_models_for_folder_name(folder_name)
|
known_files = _get_known_models_for_folder_name(folder_name)
|
||||||
|
|
||||||
existing = folder_paths.get_filename_list(folder_name)
|
existing = folder_paths.get_filename_list(folder_name)
|
||||||
|
|
||||||
downloadable_files = []
|
downloadable_files = []
|
||||||
if not args.disable_known_models:
|
if not args.disable_known_models:
|
||||||
downloadable_files = known_files
|
downloadable_files = known_files
|
||||||
@ -725,15 +725,15 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
|
|||||||
return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids)
|
return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids)
|
||||||
|
|
||||||
|
|
||||||
def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]:
|
def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None, force: bool = False, subset: bool = False) -> Optional[str]:
|
||||||
with comfy_tqdm():
|
with comfy_tqdm():
|
||||||
return _get_or_download_huggingface_repo(repo_id, cache_dirs, local_dirs)
|
return _get_or_download_huggingface_repo(repo_id, cache_dirs, local_dirs, force=force, subset=subset)
|
||||||
|
|
||||||
|
|
||||||
def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]:
|
def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None, force: bool = False, subset: bool = False) -> Optional[str]:
|
||||||
cache_dirs = cache_dirs or folder_paths.get_folder_paths("huggingface_cache")
|
cache_dirs = cache_dirs or folder_paths.get_folder_paths("huggingface_cache")
|
||||||
local_dirs = local_dirs or folder_paths.get_folder_paths("huggingface")
|
local_dirs = local_dirs or folder_paths.get_folder_paths("huggingface")
|
||||||
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id)
|
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id, subset=subset)
|
||||||
|
|
||||||
local_dirs_cache_hit = len(local_dirs_snapshots) > 0
|
local_dirs_cache_hit = len(local_dirs_snapshots) > 0
|
||||||
cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0
|
cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0
|
||||||
@ -742,25 +742,25 @@ def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] =
|
|||||||
# if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download
|
# if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download
|
||||||
if args.force_hf_local_dir_mode:
|
if args.force_hf_local_dir_mode:
|
||||||
# todo: we still have to figure out a way to download things to the right places by default
|
# todo: we still have to figure out a way to download things to the right places by default
|
||||||
if len(local_dirs_snapshots) > 0:
|
if len(local_dirs_snapshots) > 0 and not force:
|
||||||
return local_dirs_snapshots[0]
|
return local_dirs_snapshots[0]
|
||||||
elif not args.disable_known_models:
|
elif not args.disable_known_models:
|
||||||
destination = os.path.join(local_dirs[0], repo_id)
|
destination = os.path.join(local_dirs[0], repo_id)
|
||||||
logger.debug(f"downloading repo_id={repo_id}, local_dir={destination}")
|
logger.debug(f"downloading repo_id={repo_id}, local_dir={destination}")
|
||||||
return snapshot_download(repo_id, local_dir=destination)
|
return snapshot_download(repo_id, local_dir=destination, force_download=force)
|
||||||
|
|
||||||
snapshots = local_dirs_snapshots + cache_dirs_snapshots
|
snapshots = local_dirs_snapshots + cache_dirs_snapshots
|
||||||
if len(snapshots) > 0:
|
if len(snapshots) > 0 and not force:
|
||||||
return snapshots[0]
|
return snapshots[0]
|
||||||
elif not args.disable_known_models:
|
elif not args.disable_known_models:
|
||||||
logger.debug(f"downloading repo_id={repo_id}")
|
logger.debug(f"downloading repo_id={repo_id}")
|
||||||
return snapshot_download(repo_id)
|
return snapshot_download(repo_id, force_download=force)
|
||||||
|
|
||||||
# this repo was not found
|
# this repo was not found
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id):
|
def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id, subset=False):
|
||||||
local_dirs_snapshots = []
|
local_dirs_snapshots = []
|
||||||
cache_dirs_snapshots = []
|
cache_dirs_snapshots = []
|
||||||
# find all the pre-existing downloads for this repo_id
|
# find all the pre-existing downloads for this repo_id
|
||||||
@ -772,13 +772,12 @@ def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_i
|
|||||||
if len(repo_files) > 0:
|
if len(repo_files) > 0:
|
||||||
for local_dir in local_dirs:
|
for local_dir in local_dirs:
|
||||||
local_path = Path(local_dir) / repo_id
|
local_path = Path(local_dir) / repo_id
|
||||||
local_files = set(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file())
|
local_files = frozenset(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file())
|
||||||
# fix path representation
|
# fix path representation
|
||||||
local_files = set(f.replace("\\", "/") for f in local_files)
|
local_files = frozenset(f.replace("\\", "/") for f in local_files)
|
||||||
# remove .huggingface
|
# remove .huggingface
|
||||||
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache"))
|
local_files = frozenset(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache"))
|
||||||
# local_files.issubsetof(repo_files)
|
if len(local_files) > 0 and ((subset and local_files.issubset(repo_files)) or (not subset and repo_files.issubset(local_files))):
|
||||||
if len(local_files) > 0 and local_files.issubset(repo_files):
|
|
||||||
local_dirs_snapshots.append(str(local_path))
|
local_dirs_snapshots.append(str(local_path))
|
||||||
else:
|
else:
|
||||||
# an empty repository or unknown repository info, trust that if the directory exists, it matches
|
# an empty repository or unknown repository info, trust that if the directory exists, it matches
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override
|
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn
|
import torch.nn
|
||||||
@ -11,22 +12,48 @@ from typing_extensions import TypedDict, NotRequired
|
|||||||
from .comfy_types import UnetWrapperFunction
|
from .comfy_types import UnetWrapperFunction
|
||||||
from .latent_formats import LatentFormat
|
from .latent_formats import LatentFormat
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .hooks import EnumHookMode
|
||||||
|
|
||||||
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
|
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
|
||||||
LatentFormatT = TypeVar('LatentFormatT', bound=LatentFormat)
|
LatentFormatT = TypeVar('LatentFormatT', bound=LatentFormat)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class DeviceSettable(Protocol):
|
class DeviceSettable(Protocol):
|
||||||
@property
|
device: torch.device
|
||||||
def device(self) -> torch.device:
|
|
||||||
...
|
|
||||||
|
|
||||||
@device.setter
|
|
||||||
def device(self, value: torch.device):
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class HooksSupport(Protocol, metaclass=ABCMeta):
|
@runtime_checkable
|
||||||
|
class HooksSupport(Protocol):
|
||||||
|
wrappers: dict[str, dict[str, list[Callable]]]
|
||||||
|
callbacks: dict[str, dict[str, list[Callable]]]
|
||||||
|
hook_mode: "EnumHookMode"
|
||||||
|
|
||||||
|
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): ...
|
||||||
|
|
||||||
|
def model_patches_models(self) -> list[ModelManageableT]: ...
|
||||||
|
|
||||||
|
def restore_hook_patches(self): ...
|
||||||
|
|
||||||
|
def cleanup(self): ...
|
||||||
|
|
||||||
|
def pre_run(self): ...
|
||||||
|
|
||||||
|
def prepare_state(self, *args, **kwargs): ...
|
||||||
|
|
||||||
|
def register_all_hook_patches(self, a, b, c, d): ...
|
||||||
|
|
||||||
|
def get_nested_additional_models(self): ...
|
||||||
|
|
||||||
|
def apply_hooks(self, *args, **kwargs): ...
|
||||||
|
|
||||||
|
def add_wrapper(self, wrapper_type: str, wrapper: Callable): ...
|
||||||
|
|
||||||
|
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): ...
|
||||||
|
|
||||||
|
|
||||||
|
class HooksSupportStub(HooksSupport, metaclass=ABCMeta):
|
||||||
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -82,6 +109,8 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
|
|||||||
if isinstance(model, BaseModel) or hasattr(model, "current_patcher") and isinstance(self, ModelManageable):
|
if isinstance(model, BaseModel) or hasattr(model, "current_patcher") and isinstance(self, ModelManageable):
|
||||||
model.current_patcher = self
|
model.current_patcher = self
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_state(self, *args, **kwargs):
|
def prepare_state(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -94,8 +123,22 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
|
|||||||
def apply_hooks(self, *args, **kwargs):
|
def apply_hooks(self, *args, **kwargs):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
|
||||||
|
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
||||||
|
|
||||||
class TrainingSupport(Protocol, metaclass=ABCMeta):
|
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
||||||
|
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||||
|
w.append(wrapper)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class TrainingSupport(Protocol):
|
||||||
|
def set_model_compute_dtype(self, dtype: torch.dtype): ...
|
||||||
|
|
||||||
|
def add_weight_wrapper(self, name, function): ...
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingSupportStub(TrainingSupport, metaclass=ABCMeta):
|
||||||
def set_model_compute_dtype(self, dtype: torch.dtype):
|
def set_model_compute_dtype(self, dtype: torch.dtype):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -103,13 +146,68 @@ class TrainingSupport(Protocol, metaclass=ABCMeta):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
class ModelManageableExtras(Protocol, metaclass=ABCMeta):
|
@runtime_checkable
|
||||||
|
class ModelManageable(HooksSupport, TrainingSupport, Protocol):
|
||||||
|
"""
|
||||||
|
Objects which implement this protocol can be managed by
|
||||||
|
|
||||||
|
>>> from comfy.model_management import load_models_gpu
|
||||||
|
>>> class ModelWrapper(ModelManageable):
|
||||||
|
>>> ...
|
||||||
|
>>>
|
||||||
|
>>> some_model = ModelWrapper()
|
||||||
|
>>> load_models_gpu([some_model])
|
||||||
|
|
||||||
|
The minimum required
|
||||||
|
"""
|
||||||
|
load_device: torch.device
|
||||||
|
offload_device: torch.device
|
||||||
|
model: torch.nn.Module
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_device(self) -> torch.device:
|
def current_device(self) -> torch.device: ...
|
||||||
return torch.device("cpu")
|
|
||||||
|
def is_clone(self, other: ModelManageableT) -> bool: ...
|
||||||
|
|
||||||
|
def clone_has_same_weights(self, clone: ModelManageableT) -> bool: ...
|
||||||
|
|
||||||
|
def model_size(self) -> int: ...
|
||||||
|
|
||||||
|
def model_patches_to(self, arg: torch.device | torch.dtype): ...
|
||||||
|
|
||||||
|
def model_dtype(self) -> torch.dtype: ...
|
||||||
|
|
||||||
|
def lowvram_patch_counter(self) -> int: ...
|
||||||
|
|
||||||
|
def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False) -> int: ...
|
||||||
|
|
||||||
|
def partially_unload(self, device_to: torch.device, memory_to_free: int = 0) -> int: ...
|
||||||
|
|
||||||
|
def memory_required(self, input_shape: torch.Size) -> int: ...
|
||||||
|
|
||||||
|
def loaded_size(self) -> int: ...
|
||||||
|
|
||||||
|
def current_loaded_device(self) -> torch.device: ...
|
||||||
|
|
||||||
|
def get_model_object(self, name: str) -> torch.nn.Module: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_options(self) -> ModelOptions: ...
|
||||||
|
|
||||||
|
@model_options.setter
|
||||||
|
def model_options(self, value): ...
|
||||||
|
|
||||||
|
def __del__(self): ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parent(self) -> ModelManageableT | None: ...
|
||||||
|
|
||||||
|
def detach(self, unpatch_all: bool = True): ...
|
||||||
|
|
||||||
|
def clone(self) -> ModelManageableT: ...
|
||||||
|
|
||||||
|
|
||||||
class ModelManageableRequired(Protocol, metaclass=ABCMeta):
|
class ModelManageableStub(HooksSupportStub, TrainingSupportStub, ModelManageable, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
The bare minimum that must be implemented to support model management when inheriting from ModelManageable
|
The bare minimum that must be implemented to support model management when inheriting from ModelManageable
|
||||||
|
|
||||||
@ -120,12 +218,11 @@ class ModelManageableRequired(Protocol, metaclass=ABCMeta):
|
|||||||
:see: ModelManageable
|
:see: ModelManageable
|
||||||
:see: PatchSupport
|
:see: PatchSupport
|
||||||
"""
|
"""
|
||||||
load_device: torch.device
|
|
||||||
offload_device: torch.device
|
|
||||||
model: torch.nn.Module
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True,
|
||||||
|
force_patch_weights: bool = False) -> torch.nn.Module:
|
||||||
"""
|
"""
|
||||||
Called by ModelManageable
|
Called by ModelManageable
|
||||||
|
|
||||||
@ -155,25 +252,6 @@ class ModelManageableRequired(Protocol, metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, metaclass=ABCMeta):
|
|
||||||
"""
|
|
||||||
Objects which implement this protocol can be managed by
|
|
||||||
|
|
||||||
>>> from comfy.model_management import load_models_gpu
|
|
||||||
>>> class ModelWrapper(ModelManageable):
|
|
||||||
>>> ...
|
|
||||||
>>>
|
|
||||||
>>> some_model = ModelWrapper()
|
|
||||||
>>> load_models_gpu([some_model])
|
|
||||||
|
|
||||||
The minimum required
|
|
||||||
"""
|
|
||||||
load_device: torch.device
|
|
||||||
offload_device: torch.device
|
|
||||||
model: torch.nn.Module
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
def current_device(self) -> torch.device:
|
def current_device(self) -> torch.device:
|
||||||
@ -265,6 +343,9 @@ class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol,
|
|||||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def clone(self) -> ModelManageableT:
|
||||||
|
return copy.copy(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MemoryMeasurements:
|
class MemoryMeasurements:
|
||||||
|
|||||||
@ -230,7 +230,7 @@ class GGUFQuantization:
|
|||||||
patch_on_device: bool = False
|
patch_on_device: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher(ModelManageable, TrainingSupport, HooksSupport, PatchSupport):
|
class ModelPatcher(ModelManageable, PatchSupport):
|
||||||
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model: BaseModel | torch.nn.Module = model
|
self.model: BaseModel | torch.nn.Module = model
|
||||||
|
|||||||
@ -27,6 +27,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -57,6 +58,9 @@ DISABLE_MMAP = args.disable_mmap
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALWAYS_SAFE_LOAD = False
|
ALWAYS_SAFE_LOAD = False
|
||||||
|
|
||||||
|
_lock = threading.RLock()
|
||||||
|
|
||||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||||
class ModelCheckpoint:
|
class ModelCheckpoint:
|
||||||
pass
|
pass
|
||||||
@ -1178,7 +1182,9 @@ class ProgressBar:
|
|||||||
def update_absolute(self, value, total=None, preview_image_or_output=None):
|
def update_absolute(self, value, total=None, preview_image_or_output=None):
|
||||||
if total is not None:
|
if total is not None:
|
||||||
self.total = total
|
self.total = total
|
||||||
if value > self.total:
|
if value is None:
|
||||||
|
return
|
||||||
|
if self.total is not None and value > self.total:
|
||||||
value = self.total
|
value = self.total
|
||||||
self.current = value
|
self.current = value
|
||||||
_progress_bar_update(self.current, self.total, preview_image_or_output, server=self.server, node_id=self.node_id)
|
_progress_bar_update(self.current, self.total, preview_image_or_output, server=self.server, node_id=self.node_id)
|
||||||
@ -1198,31 +1204,39 @@ def comfy_tqdm() -> Generator[TqdmWatcher, None, None]:
|
|||||||
Monkey patches child calls to tqdm, sends progress to the UI,
|
Monkey patches child calls to tqdm, sends progress to the UI,
|
||||||
and yields a watcher object for stall detection.
|
and yields a watcher object for stall detection.
|
||||||
"""
|
"""
|
||||||
|
with _lock:
|
||||||
|
if hasattr(tqdm, "__patched_by_comfyui__"):
|
||||||
|
yield getattr(tqdm, "__patched_by_comfyui__")
|
||||||
|
return
|
||||||
|
|
||||||
|
watcher = TqdmWatcher()
|
||||||
|
setattr(tqdm, "__patched_by_comfyui__", watcher)
|
||||||
|
|
||||||
_original_init = tqdm.__init__
|
_original_init = tqdm.__init__
|
||||||
_original_call = tqdm.__call__
|
_original_call = tqdm.__call__
|
||||||
_original_update = tqdm.update
|
_original_update = tqdm.update
|
||||||
|
|
||||||
# Create the watcher instance that the patched methods will update
|
# Create the watcher instance that the patched methods will update
|
||||||
# and that will be yielded to the caller.
|
# and that will be yielded to the caller.
|
||||||
watcher = TqdmWatcher()
|
|
||||||
context = contextvars.copy_context()
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# These inner functions are closures; they capture the `watcher` variable
|
# These inner functions are closures; they capture the `watcher` variable
|
||||||
# from the enclosing scope.
|
# from the enclosing scope.
|
||||||
def __init(self, *args, **kwargs):
|
def __init(self, *args, **kwargs):
|
||||||
context.run(lambda: _original_init(self, *args, **kwargs))
|
_original_init(self, *args, **kwargs)
|
||||||
self._progress_bar = context.run(lambda: ProgressBar(self.total))
|
self._progress_bar = context.run(lambda: ProgressBar(self.total))
|
||||||
watcher.tick() # Signal progress on initialization
|
watcher.tick() # Signal progress on initialization
|
||||||
|
|
||||||
def __update(self, n=1):
|
def __update(self, n=1):
|
||||||
assert self._progress_bar is not None
|
assert self._progress_bar is not None
|
||||||
context.run(lambda: _original_update(self, n))
|
_original_update(self, n)
|
||||||
context.run(lambda: self._progress_bar.update(n))
|
context.run(lambda: self._progress_bar.update(n))
|
||||||
watcher.tick() # Signal progress on update
|
watcher.tick() # Signal progress on update
|
||||||
|
|
||||||
def __call(self, *args, **kwargs):
|
def __call(self, *args, **kwargs):
|
||||||
instance = context.run(lambda: _original_call(self, *args, **kwargs))
|
instance = _original_call(self, *args, **kwargs)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
tqdm.__init__ = __init
|
tqdm.__init__ = __init
|
||||||
@ -1236,10 +1250,11 @@ def comfy_tqdm() -> Generator[TqdmWatcher, None, None]:
|
|||||||
tqdm.__init__ = _original_init
|
tqdm.__init__ = _original_init
|
||||||
tqdm.__call__ = _original_call
|
tqdm.__call__ = _original_call
|
||||||
tqdm.update = _original_update
|
tqdm.update = _original_update
|
||||||
|
delattr(tqdm, "__patched_by_comfyui__")
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def comfy_progress(total: float) -> ProgressBar:
|
def comfy_progress(total: float) -> Generator[ProgressBar, Any, None]:
|
||||||
ctx = current_execution_context()
|
ctx = current_execution_context()
|
||||||
if ctx.server.receive_all_progress_notifications:
|
if ctx.server.receive_all_progress_notifications:
|
||||||
yield ProgressBar(total)
|
yield ProgressBar(total)
|
||||||
|
|||||||
@ -765,7 +765,7 @@ class DualCFGGuider:
|
|||||||
FUNCTION = "get_guider"
|
FUNCTION = "get_guider"
|
||||||
CATEGORY = "sampling/custom_sampling/guiders"
|
CATEGORY = "sampling/custom_sampling/guiders"
|
||||||
|
|
||||||
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style):
|
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style="regular"):
|
||||||
guider = Guider_DualCFG(model)
|
guider = Guider_DualCFG(model)
|
||||||
guider.set_conds(cond1, cond2, negative)
|
guider.set_conds(cond1, cond2, negative)
|
||||||
guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))
|
guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))
|
||||||
|
|||||||
@ -2,7 +2,9 @@ import torch
|
|||||||
from diffusers import HookRegistry
|
from diffusers import HookRegistry
|
||||||
from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook
|
from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook
|
||||||
|
|
||||||
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
from comfy.model_management import vram_state, VRAMState
|
from comfy.model_management import vram_state, VRAMState
|
||||||
|
from comfy.model_management_types import HooksSupport, ModelManageable
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.node_helpers import export_custom_nodes
|
from comfy.node_helpers import export_custom_nodes
|
||||||
from comfy.nodes.package_typing import CustomNode
|
from comfy.nodes.package_typing import CustomNode
|
||||||
@ -117,9 +119,21 @@ class GroupOffload(CustomNode):
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
def execute(self, model: ModelPatcher) -> tuple[ModelPatcher,]:
|
def execute(self, model: ModelManageable | HooksSupport | TransformersManagedModel) -> tuple[ModelPatcher,]:
|
||||||
model = model.clone()
|
if isinstance(model, ModelManageable):
|
||||||
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device))
|
model = model.clone()
|
||||||
|
if isinstance(model, TransformersManagedModel):
|
||||||
|
apply_group_offloading(
|
||||||
|
model.model,
|
||||||
|
model.load_device,
|
||||||
|
model.offload_device,
|
||||||
|
use_stream=True,
|
||||||
|
record_stream=True,
|
||||||
|
low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,),
|
||||||
|
num_blocks_per_group=1
|
||||||
|
)
|
||||||
|
elif isinstance(model, HooksSupport) and isinstance(model, ModelManageable):
|
||||||
|
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device))
|
||||||
return model,
|
return model,
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -42,6 +42,7 @@ class LTXVImgToVideo:
|
|||||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
}, "optional": {
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
@ -51,7 +52,7 @@ class LTXVImgToVideo:
|
|||||||
CATEGORY = "conditioning/video_models"
|
CATEGORY = "conditioning/video_models"
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
|
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength=1.0):
|
||||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
encode_pixels = pixels[:, :, :, :3]
|
encode_pixels = pixels[:, :, :, :3]
|
||||||
t = vae.encode(encode_pixels)
|
t = vae.encode(encode_pixels)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from comfy import utils
|
|||||||
from comfy.component_model.tensor_types import RGBImageBatch
|
from comfy.component_model.tensor_types import RGBImageBatch
|
||||||
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
|
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
|
||||||
from comfy.model_management import load_models_gpu
|
from comfy.model_management import load_models_gpu
|
||||||
from comfy.model_management_types import ModelManageable
|
from comfy.model_management_types import ModelManageableStub
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
try:
|
try:
|
||||||
@ -22,7 +22,7 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UpscaleModelManageable(ModelManageable):
|
class UpscaleModelManageable(ModelManageableStub):
|
||||||
def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str):
|
def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str):
|
||||||
self.ckpt_name = ckpt_name
|
self.ckpt_name = ckpt_name
|
||||||
self.model_descriptor = model_descriptor
|
self.model_descriptor = model_descriptor
|
||||||
|
|||||||
@ -182,7 +182,7 @@ async def test_huggingface_alternate_filenames_in_combo():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 3. Get the list of files as the UI would
|
# 3. Get the list of files as the UI would
|
||||||
filename_list = get_filename_list_with_downloadable("checkpoints", known_files=[known_file])
|
filename_list = get_filename_list_with_downloadable("__xxx___", known_files=[known_file])
|
||||||
|
|
||||||
# 4. Assert that both the main and alternate filenames are present
|
# 4. Assert that both the main and alternate filenames are present
|
||||||
assert main_filename in filename_list
|
assert main_filename in filename_list
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"1": {
|
"1": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "microsoft/Phi-3-mini-4k-instruct",
|
"ckpt_name": "microsoft/Phi-4-mini-instruct",
|
||||||
"subfolder": ""
|
"subfolder": ""
|
||||||
},
|
},
|
||||||
"class_type": "TransformersLoader",
|
"class_type": "TransformersLoader",
|
||||||
@ -33,7 +33,7 @@
|
|||||||
"4": {
|
"4": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"prompt": "What comes after apple?",
|
"prompt": "What comes after apple?",
|
||||||
"chat_template": "phi-3",
|
"chat_template": "default",
|
||||||
"model": [
|
"model": [
|
||||||
"1",
|
"1",
|
||||||
0
|
0
|
||||||
Loading…
Reference in New Issue
Block a user