diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index 31f465b43..83d703d5e 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -62,11 +62,93 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): self._on_set_processor(self._processor) self._model_type = "" self._original_transformers_managed_model: weakref.ReferenceType["TransformersManagedModel"] = weakref.ref(self) + self.wrappers = {} + self.callbacks = {} + self._hook_mode = None + self._model_options = {"transformer_options": {}} + if model.device != self.offload_device: model.to(device=self.offload_device) + @property + def hook_mode(self): + from ..hooks import EnumHookMode + if self._hook_mode is None: + self._hook_mode = EnumHookMode.MaxSpeed + return self._hook_mode + + @hook_mode.setter + def hook_mode(self, value): + self._hook_mode = value + + def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): + return + + def model_patches_models(self): + return [] + + def restore_hook_patches(self): + return + + def cleanup(self): + pass + + def pre_run(self): + pass + + def prepare_state(self, *args, **kwargs): + pass + + def register_all_hook_patches(self, a, b, c, d): + pass + + def get_nested_additional_models(self): + return [] + + def apply_hooks(self, *args, **kwargs): + return {} + + def add_wrapper(self, wrapper_type: str, wrapper: Callable): + self.add_wrapper_with_key(wrapper_type, None, wrapper) + + def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): + w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, []) + w.append(wrapper) + + def remove_wrappers_with_key(self, wrapper_type: str, key: str): + w = self.wrappers.get(wrapper_type, {}) + if key in w: + w.pop(key) + + def get_wrappers_with_key(self, wrapper_type: str, key: str): + w_list = [] + w_list.extend(self.wrappers.get(wrapper_type, {}).get(key, [])) + return w_list + + def get_all_wrappers(self, wrapper_type: str): + w_list = [] + for w in self.wrappers.get(wrapper_type, {}).values(): + w_list.extend(w) + return w_list + + @property + def model_options(self): + return self._model_options + + @model_options.setter + def model_options(self, value): + self._model_options = value + + @property + def diffusion_model(self): + return self.model + + @diffusion_model.setter + def diffusion_model(self, value): + self.add_object_patch("model", value) + @staticmethod - def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None, config_dict: PretrainedConfig | dict | None = None) -> "TransformersManagedModel": + def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None, config_dict: PretrainedConfig | dict | None = None, **kwargs) -> "TransformersManagedModel": hub_kwargs = {} if subfolder is not None and subfolder.strip() != "": hub_kwargs["subfolder"] = subfolder @@ -89,7 +171,8 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): from_pretrained_kwargs = { "pretrained_model_name_or_path": ckpt_name, - **hub_kwargs + **hub_kwargs, + **kwargs, } # language models prefer to use bfloat16 over float16 @@ -149,7 +232,16 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): tokenizer = processor processor = None else: - tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try) + try: + tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try) + except Exception: + try: + tokenizer = AutoTokenizer.from_pretrained(ckpt_name, use_fast=True, legacy=False, **hub_kwargs, **kwargs_to_try) + except Exception: + if repo_id != ckpt_name: + tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True, legacy=False, **hub_kwargs, **kwargs_to_try) + else: + raise if tokenizer is not None or processor is not None: break except Exception as exc_info: @@ -252,12 +344,22 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): with seed_for_block(seed), torch.inference_mode(mode=True) if has_triton else contextlib.nullcontext(): if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs: inputs.pop("attention_mask") - output_ids = transformers_model.generate( - **inputs, - streamer=text_streamer if num_beams <= 1 else None, - max_new_tokens=max_new_tokens, - **generate_kwargs - ) + + from ..patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers + + def _generate(inputs, streamer, max_new_tokens, **generate_kwargs): + return transformers_model.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + **generate_kwargs + ) + + output_ids = WrapperExecutor.new_class_executor( + _generate, + self, + get_all_wrappers(WrappersMP.APPLY_MODEL, self.model_options) + ).execute(inputs, text_streamer if num_beams <= 1 else None, max_new_tokens, **generate_kwargs) if not transformers_model.config.is_encoder_decoder: start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1] @@ -336,6 +438,9 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): return self._size def model_patches_to(self, arg: torch.device | torch.dtype): + if getattr(self.model, "is_loaded_in_4bit", False) or getattr(self.model, "is_loaded_in_8bit", False): + return + if isinstance(arg, torch.device): self.model.to(device=arg) else: diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index d51e56758..07731e92c 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -3,7 +3,7 @@ from __future__ import annotations import copy import dataclasses from abc import ABCMeta, abstractmethod -from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, TYPE_CHECKING +from typing import Any, Callable, Protocol, runtime_checkable, Optional, TypeVar, NamedTuple, TYPE_CHECKING import torch import torch.nn @@ -26,8 +26,8 @@ class DeviceSettable(Protocol): @runtime_checkable class HooksSupport(Protocol): - wrappers: dict[str, dict[str, list[Callable]]] - callbacks: dict[str, dict[str, list[Callable]]] + wrappers: dict[str, list[Callable]] + callbacks: dict[str, list[Callable]] hook_mode: "EnumHookMode" def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): ... diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index ae3c7be04..02547fd57 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -219,6 +219,23 @@ class TransformersLoader1(TransformersLoader): } } +class TransformersLoaderQuantized(TransformersLoader): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "ckpt_name": ("STRING", {}), + "load_in_4bit": ("BOOLEAN", {"default": False}), + "load_in_8bit": ("BOOLEAN", {"default": False}), + }, + "optional": { + "subfolder": ("STRING", {}), + } + } + + def execute(self, ckpt_name: str, subfolder: Optional[str] = None, load_in_4bit: bool = False, load_in_8bit: bool = False, *args, **kwargs) -> tuple[TransformersManagedModel]: + return TransformersManagedModel.from_pretrained(ckpt_name, subfolder, load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit), + class TransformersTokenize(CustomNode): @classmethod diff --git a/comfy_extras/nodes/nodes_torch_compile.py b/comfy_extras/nodes/nodes_torch_compile.py index 6a4a3f476..858c109d2 100644 --- a/comfy_extras/nodes/nodes_torch_compile.py +++ b/comfy_extras/nodes/nodes_torch_compile.py @@ -9,6 +9,7 @@ from torch.nn import LayerNorm from comfy import model_management from comfy.language.transformers_model_management import TransformersManagedModel +from comfy.model_management_types import HooksSupport from comfy.model_patcher import ModelPatcher from comfy.nodes.package_typing import CustomNode, InputTypes from comfy.sd import VAE @@ -100,18 +101,8 @@ class TorchCompileModel(CustomNode): "make_refittable": True, } del compile_kwargs["mode"] - if isinstance(model, TransformersManagedModel): - to_return = model.clone() - model = to_return.model - model_management.unload_all_models() - model.to(device=model_management.get_torch_device()) - res = torch.compile(model=model, **compile_kwargs), - model.to(device=model_management.unet_offload_device()) - - to_return.add_object_patch("model", res) - return to_return, - elif isinstance(model, (ModelPatcher, VAE)): + if isinstance(model, HooksSupport): to_return = model.clone() object_patches = [p.strip() for p in object_patch.split(",")] patcher: ModelPatcher diff --git a/tests/language/__init__.py b/tests/language/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/test_language.py b/tests/language/test_language.py similarity index 100% rename from tests/inference/test_language.py rename to tests/language/test_language.py diff --git a/tests/language/test_phi4_loading.py b/tests/language/test_phi4_loading.py new file mode 100644 index 000000000..9e42187c5 --- /dev/null +++ b/tests/language/test_phi4_loading.py @@ -0,0 +1,20 @@ +import pytest +from comfy_execution.graph_utils import GraphBuilder +from comfy.client.embedded_comfy_client import Comfy +from comfy.api.components.schema.prompt import Prompt + + +class TestPhi4Loading: + @pytest.mark.asyncio + async def test_phi4_loading(self): + graph = GraphBuilder() + model_loader = graph.node("TransformersLoaderQuantized", ckpt_name="microsoft/phi-4", load_in_4bit=True, load_in_8bit=False) + tokenizer = graph.node("OneShotInstructTokenize", model=model_loader.out(0), prompt="Hello", chat_template="default") + generation = graph.node("TransformersGenerate", model=model_loader.out(0), tokens=tokenizer.out(0), max_new_tokens=1, seed=42) + graph.node("SaveString", value=generation.out(0), filename_prefix="phi4_test") + + workflow = graph.finalize() + prompt = Prompt.validate(workflow) + + async with Comfy() as client: + await client.queue_prompt(prompt) diff --git a/tests/language/test_torch_compile_transformers.py b/tests/language/test_torch_compile_transformers.py new file mode 100644 index 000000000..9b494324a --- /dev/null +++ b/tests/language/test_torch_compile_transformers.py @@ -0,0 +1,31 @@ +import pytest +import torch +from comfy_execution.graph_utils import GraphBuilder +from comfy.client.embedded_comfy_client import Comfy +from comfy.api.components.schema.prompt import Prompt + + +class TestTorchCompileTransformers: + @pytest.mark.asyncio + async def test_torch_compile_transformers(self): + graph = GraphBuilder() + model_loader = graph.node("TransformersLoader1", ckpt_name="Qwen/Qwen2.5-0.5B") + compiled_model = graph.node("TorchCompileModel", model=model_loader.out(0), backend="inductor", mode="max-autotune") + tokenizer = graph.node("OneShotInstructTokenize", model=compiled_model.out(0), prompt="Hello, world!", chat_template="default") + generation = graph.node("TransformersGenerate", model=compiled_model.out(0), tokens=tokenizer.out(0), max_new_tokens=10, seed=42) + + save_string = graph.node("SaveString", value=generation.out(0), filename_prefix="test_output") + + workflow = graph.finalize() + prompt = Prompt.validate(workflow) + + from unittest.mock import patch + with patch("torch.compile", side_effect=torch.compile) as mock_compile: + async with Comfy() as client: + outputs = await client.queue_prompt(prompt) + + assert mock_compile.called, "torch.compile should have been called" + + assert len(outputs) > 0 + assert save_string.id in outputs + assert outputs[save_string.id]["string"][0] is not None