mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
fix torch compile for language models
This commit is contained in:
parent
b149031748
commit
4349fac71a
@ -62,11 +62,93 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
||||
self._on_set_processor(self._processor)
|
||||
self._model_type = ""
|
||||
self._original_transformers_managed_model: weakref.ReferenceType["TransformersManagedModel"] = weakref.ref(self)
|
||||
self.wrappers = {}
|
||||
self.callbacks = {}
|
||||
self._hook_mode = None
|
||||
self._model_options = {"transformer_options": {}}
|
||||
|
||||
if model.device != self.offload_device:
|
||||
model.to(device=self.offload_device)
|
||||
|
||||
@property
|
||||
def hook_mode(self):
|
||||
from ..hooks import EnumHookMode
|
||||
if self._hook_mode is None:
|
||||
self._hook_mode = EnumHookMode.MaxSpeed
|
||||
return self._hook_mode
|
||||
|
||||
@hook_mode.setter
|
||||
def hook_mode(self, value):
|
||||
self._hook_mode = value
|
||||
|
||||
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
||||
return
|
||||
|
||||
def model_patches_models(self):
|
||||
return []
|
||||
|
||||
def restore_hook_patches(self):
|
||||
return
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
def pre_run(self):
|
||||
pass
|
||||
|
||||
def prepare_state(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def register_all_hook_patches(self, a, b, c, d):
|
||||
pass
|
||||
|
||||
def get_nested_additional_models(self):
|
||||
return []
|
||||
|
||||
def apply_hooks(self, *args, **kwargs):
|
||||
return {}
|
||||
|
||||
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
|
||||
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
||||
|
||||
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
||||
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||
w.append(wrapper)
|
||||
|
||||
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
||||
w = self.wrappers.get(wrapper_type, {})
|
||||
if key in w:
|
||||
w.pop(key)
|
||||
|
||||
def get_wrappers_with_key(self, wrapper_type: str, key: str):
|
||||
w_list = []
|
||||
w_list.extend(self.wrappers.get(wrapper_type, {}).get(key, []))
|
||||
return w_list
|
||||
|
||||
def get_all_wrappers(self, wrapper_type: str):
|
||||
w_list = []
|
||||
for w in self.wrappers.get(wrapper_type, {}).values():
|
||||
w_list.extend(w)
|
||||
return w_list
|
||||
|
||||
@property
|
||||
def model_options(self):
|
||||
return self._model_options
|
||||
|
||||
@model_options.setter
|
||||
def model_options(self, value):
|
||||
self._model_options = value
|
||||
|
||||
@property
|
||||
def diffusion_model(self):
|
||||
return self.model
|
||||
|
||||
@diffusion_model.setter
|
||||
def diffusion_model(self, value):
|
||||
self.add_object_patch("model", value)
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None, config_dict: PretrainedConfig | dict | None = None) -> "TransformersManagedModel":
|
||||
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None, config_dict: PretrainedConfig | dict | None = None, **kwargs) -> "TransformersManagedModel":
|
||||
hub_kwargs = {}
|
||||
if subfolder is not None and subfolder.strip() != "":
|
||||
hub_kwargs["subfolder"] = subfolder
|
||||
@ -89,7 +171,8 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
||||
|
||||
from_pretrained_kwargs = {
|
||||
"pretrained_model_name_or_path": ckpt_name,
|
||||
**hub_kwargs
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# language models prefer to use bfloat16 over float16
|
||||
@ -149,7 +232,16 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
||||
tokenizer = processor
|
||||
processor = None
|
||||
else:
|
||||
try:
|
||||
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try)
|
||||
except Exception:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_name, use_fast=True, legacy=False, **hub_kwargs, **kwargs_to_try)
|
||||
except Exception:
|
||||
if repo_id != ckpt_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True, legacy=False, **hub_kwargs, **kwargs_to_try)
|
||||
else:
|
||||
raise
|
||||
if tokenizer is not None or processor is not None:
|
||||
break
|
||||
except Exception as exc_info:
|
||||
@ -252,13 +344,23 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
||||
with seed_for_block(seed), torch.inference_mode(mode=True) if has_triton else contextlib.nullcontext():
|
||||
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
||||
inputs.pop("attention_mask")
|
||||
output_ids = transformers_model.generate(
|
||||
|
||||
from ..patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers
|
||||
|
||||
def _generate(inputs, streamer, max_new_tokens, **generate_kwargs):
|
||||
return transformers_model.generate(
|
||||
**inputs,
|
||||
streamer=text_streamer if num_beams <= 1 else None,
|
||||
streamer=streamer,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**generate_kwargs
|
||||
)
|
||||
|
||||
output_ids = WrapperExecutor.new_class_executor(
|
||||
_generate,
|
||||
self,
|
||||
get_all_wrappers(WrappersMP.APPLY_MODEL, self.model_options)
|
||||
).execute(inputs, text_streamer if num_beams <= 1 else None, max_new_tokens, **generate_kwargs)
|
||||
|
||||
if not transformers_model.config.is_encoder_decoder:
|
||||
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
|
||||
output_ids = output_ids[:, start_position:]
|
||||
@ -336,6 +438,9 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
||||
return self._size
|
||||
|
||||
def model_patches_to(self, arg: torch.device | torch.dtype):
|
||||
if getattr(self.model, "is_loaded_in_4bit", False) or getattr(self.model, "is_loaded_in_8bit", False):
|
||||
return
|
||||
|
||||
if isinstance(arg, torch.device):
|
||||
self.model.to(device=arg)
|
||||
else:
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import copy
|
||||
import dataclasses
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, TYPE_CHECKING
|
||||
from typing import Any, Callable, Protocol, runtime_checkable, Optional, TypeVar, NamedTuple, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
@ -26,8 +26,8 @@ class DeviceSettable(Protocol):
|
||||
|
||||
@runtime_checkable
|
||||
class HooksSupport(Protocol):
|
||||
wrappers: dict[str, dict[str, list[Callable]]]
|
||||
callbacks: dict[str, dict[str, list[Callable]]]
|
||||
wrappers: dict[str, list[Callable]]
|
||||
callbacks: dict[str, list[Callable]]
|
||||
hook_mode: "EnumHookMode"
|
||||
|
||||
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): ...
|
||||
|
||||
@ -219,6 +219,23 @@ class TransformersLoader1(TransformersLoader):
|
||||
}
|
||||
}
|
||||
|
||||
class TransformersLoaderQuantized(TransformersLoader):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": ("STRING", {}),
|
||||
"load_in_4bit": ("BOOLEAN", {"default": False}),
|
||||
"load_in_8bit": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"subfolder": ("STRING", {}),
|
||||
}
|
||||
}
|
||||
|
||||
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, load_in_4bit: bool = False, load_in_8bit: bool = False, *args, **kwargs) -> tuple[TransformersManagedModel]:
|
||||
return TransformersManagedModel.from_pretrained(ckpt_name, subfolder, load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit),
|
||||
|
||||
|
||||
class TransformersTokenize(CustomNode):
|
||||
@classmethod
|
||||
|
||||
@ -9,6 +9,7 @@ from torch.nn import LayerNorm
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.model_management_types import HooksSupport
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
from comfy.sd import VAE
|
||||
@ -100,18 +101,8 @@ class TorchCompileModel(CustomNode):
|
||||
"make_refittable": True,
|
||||
}
|
||||
del compile_kwargs["mode"]
|
||||
if isinstance(model, TransformersManagedModel):
|
||||
to_return = model.clone()
|
||||
model = to_return.model
|
||||
|
||||
model_management.unload_all_models()
|
||||
model.to(device=model_management.get_torch_device())
|
||||
res = torch.compile(model=model, **compile_kwargs),
|
||||
model.to(device=model_management.unet_offload_device())
|
||||
|
||||
to_return.add_object_patch("model", res)
|
||||
return to_return,
|
||||
elif isinstance(model, (ModelPatcher, VAE)):
|
||||
if isinstance(model, HooksSupport):
|
||||
to_return = model.clone()
|
||||
object_patches = [p.strip() for p in object_patch.split(",")]
|
||||
patcher: ModelPatcher
|
||||
|
||||
0
tests/language/__init__.py
Normal file
0
tests/language/__init__.py
Normal file
20
tests/language/test_phi4_loading.py
Normal file
20
tests/language/test_phi4_loading.py
Normal file
@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.client.embedded_comfy_client import Comfy
|
||||
from comfy.api.components.schema.prompt import Prompt
|
||||
|
||||
|
||||
class TestPhi4Loading:
|
||||
@pytest.mark.asyncio
|
||||
async def test_phi4_loading(self):
|
||||
graph = GraphBuilder()
|
||||
model_loader = graph.node("TransformersLoaderQuantized", ckpt_name="microsoft/phi-4", load_in_4bit=True, load_in_8bit=False)
|
||||
tokenizer = graph.node("OneShotInstructTokenize", model=model_loader.out(0), prompt="Hello", chat_template="default")
|
||||
generation = graph.node("TransformersGenerate", model=model_loader.out(0), tokens=tokenizer.out(0), max_new_tokens=1, seed=42)
|
||||
graph.node("SaveString", value=generation.out(0), filename_prefix="phi4_test")
|
||||
|
||||
workflow = graph.finalize()
|
||||
prompt = Prompt.validate(workflow)
|
||||
|
||||
async with Comfy() as client:
|
||||
await client.queue_prompt(prompt)
|
||||
31
tests/language/test_torch_compile_transformers.py
Normal file
31
tests/language/test_torch_compile_transformers.py
Normal file
@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
import torch
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.client.embedded_comfy_client import Comfy
|
||||
from comfy.api.components.schema.prompt import Prompt
|
||||
|
||||
|
||||
class TestTorchCompileTransformers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_torch_compile_transformers(self):
|
||||
graph = GraphBuilder()
|
||||
model_loader = graph.node("TransformersLoader1", ckpt_name="Qwen/Qwen2.5-0.5B")
|
||||
compiled_model = graph.node("TorchCompileModel", model=model_loader.out(0), backend="inductor", mode="max-autotune")
|
||||
tokenizer = graph.node("OneShotInstructTokenize", model=compiled_model.out(0), prompt="Hello, world!", chat_template="default")
|
||||
generation = graph.node("TransformersGenerate", model=compiled_model.out(0), tokens=tokenizer.out(0), max_new_tokens=10, seed=42)
|
||||
|
||||
save_string = graph.node("SaveString", value=generation.out(0), filename_prefix="test_output")
|
||||
|
||||
workflow = graph.finalize()
|
||||
prompt = Prompt.validate(workflow)
|
||||
|
||||
from unittest.mock import patch
|
||||
with patch("torch.compile", side_effect=torch.compile) as mock_compile:
|
||||
async with Comfy() as client:
|
||||
outputs = await client.queue_prompt(prompt)
|
||||
|
||||
assert mock_compile.called, "torch.compile should have been called"
|
||||
|
||||
assert len(outputs) > 0
|
||||
assert save_string.id in outputs
|
||||
assert outputs[save_string.id]["string"][0] is not None
|
||||
Loading…
Reference in New Issue
Block a user