mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
fix torch compile for language models
This commit is contained in:
parent
b149031748
commit
4349fac71a
@ -62,11 +62,93 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
self._on_set_processor(self._processor)
|
self._on_set_processor(self._processor)
|
||||||
self._model_type = ""
|
self._model_type = ""
|
||||||
self._original_transformers_managed_model: weakref.ReferenceType["TransformersManagedModel"] = weakref.ref(self)
|
self._original_transformers_managed_model: weakref.ReferenceType["TransformersManagedModel"] = weakref.ref(self)
|
||||||
|
self.wrappers = {}
|
||||||
|
self.callbacks = {}
|
||||||
|
self._hook_mode = None
|
||||||
|
self._model_options = {"transformer_options": {}}
|
||||||
|
|
||||||
if model.device != self.offload_device:
|
if model.device != self.offload_device:
|
||||||
model.to(device=self.offload_device)
|
model.to(device=self.offload_device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hook_mode(self):
|
||||||
|
from ..hooks import EnumHookMode
|
||||||
|
if self._hook_mode is None:
|
||||||
|
self._hook_mode = EnumHookMode.MaxSpeed
|
||||||
|
return self._hook_mode
|
||||||
|
|
||||||
|
@hook_mode.setter
|
||||||
|
def hook_mode(self, value):
|
||||||
|
self._hook_mode = value
|
||||||
|
|
||||||
|
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
||||||
|
return
|
||||||
|
|
||||||
|
def model_patches_models(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def restore_hook_patches(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def pre_run(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def prepare_state(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def register_all_hook_patches(self, a, b, c, d):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_nested_additional_models(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def apply_hooks(self, *args, **kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
|
||||||
|
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
||||||
|
|
||||||
|
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
||||||
|
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||||
|
w.append(wrapper)
|
||||||
|
|
||||||
|
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
||||||
|
w = self.wrappers.get(wrapper_type, {})
|
||||||
|
if key in w:
|
||||||
|
w.pop(key)
|
||||||
|
|
||||||
|
def get_wrappers_with_key(self, wrapper_type: str, key: str):
|
||||||
|
w_list = []
|
||||||
|
w_list.extend(self.wrappers.get(wrapper_type, {}).get(key, []))
|
||||||
|
return w_list
|
||||||
|
|
||||||
|
def get_all_wrappers(self, wrapper_type: str):
|
||||||
|
w_list = []
|
||||||
|
for w in self.wrappers.get(wrapper_type, {}).values():
|
||||||
|
w_list.extend(w)
|
||||||
|
return w_list
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_options(self):
|
||||||
|
return self._model_options
|
||||||
|
|
||||||
|
@model_options.setter
|
||||||
|
def model_options(self, value):
|
||||||
|
self._model_options = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def diffusion_model(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
@diffusion_model.setter
|
||||||
|
def diffusion_model(self, value):
|
||||||
|
self.add_object_patch("model", value)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None, config_dict: PretrainedConfig | dict | None = None) -> "TransformersManagedModel":
|
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None, config_dict: PretrainedConfig | dict | None = None, **kwargs) -> "TransformersManagedModel":
|
||||||
hub_kwargs = {}
|
hub_kwargs = {}
|
||||||
if subfolder is not None and subfolder.strip() != "":
|
if subfolder is not None and subfolder.strip() != "":
|
||||||
hub_kwargs["subfolder"] = subfolder
|
hub_kwargs["subfolder"] = subfolder
|
||||||
@ -89,7 +171,8 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
|
|
||||||
from_pretrained_kwargs = {
|
from_pretrained_kwargs = {
|
||||||
"pretrained_model_name_or_path": ckpt_name,
|
"pretrained_model_name_or_path": ckpt_name,
|
||||||
**hub_kwargs
|
**hub_kwargs,
|
||||||
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
# language models prefer to use bfloat16 over float16
|
# language models prefer to use bfloat16 over float16
|
||||||
@ -149,7 +232,16 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
tokenizer = processor
|
tokenizer = processor
|
||||||
processor = None
|
processor = None
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try)
|
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_name, use_fast=True, legacy=False, **hub_kwargs, **kwargs_to_try)
|
||||||
|
except Exception:
|
||||||
|
if repo_id != ckpt_name:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True, legacy=False, **hub_kwargs, **kwargs_to_try)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
if tokenizer is not None or processor is not None:
|
if tokenizer is not None or processor is not None:
|
||||||
break
|
break
|
||||||
except Exception as exc_info:
|
except Exception as exc_info:
|
||||||
@ -252,13 +344,23 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
with seed_for_block(seed), torch.inference_mode(mode=True) if has_triton else contextlib.nullcontext():
|
with seed_for_block(seed), torch.inference_mode(mode=True) if has_triton else contextlib.nullcontext():
|
||||||
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
||||||
inputs.pop("attention_mask")
|
inputs.pop("attention_mask")
|
||||||
output_ids = transformers_model.generate(
|
|
||||||
|
from ..patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers
|
||||||
|
|
||||||
|
def _generate(inputs, streamer, max_new_tokens, **generate_kwargs):
|
||||||
|
return transformers_model.generate(
|
||||||
**inputs,
|
**inputs,
|
||||||
streamer=text_streamer if num_beams <= 1 else None,
|
streamer=streamer,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
**generate_kwargs
|
**generate_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
output_ids = WrapperExecutor.new_class_executor(
|
||||||
|
_generate,
|
||||||
|
self,
|
||||||
|
get_all_wrappers(WrappersMP.APPLY_MODEL, self.model_options)
|
||||||
|
).execute(inputs, text_streamer if num_beams <= 1 else None, max_new_tokens, **generate_kwargs)
|
||||||
|
|
||||||
if not transformers_model.config.is_encoder_decoder:
|
if not transformers_model.config.is_encoder_decoder:
|
||||||
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
|
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
|
||||||
output_ids = output_ids[:, start_position:]
|
output_ids = output_ids[:, start_position:]
|
||||||
@ -336,6 +438,9 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
return self._size
|
return self._size
|
||||||
|
|
||||||
def model_patches_to(self, arg: torch.device | torch.dtype):
|
def model_patches_to(self, arg: torch.device | torch.dtype):
|
||||||
|
if getattr(self.model, "is_loaded_in_4bit", False) or getattr(self.model, "is_loaded_in_8bit", False):
|
||||||
|
return
|
||||||
|
|
||||||
if isinstance(arg, torch.device):
|
if isinstance(arg, torch.device):
|
||||||
self.model.to(device=arg)
|
self.model.to(device=arg)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, TYPE_CHECKING
|
from typing import Any, Callable, Protocol, runtime_checkable, Optional, TypeVar, NamedTuple, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn
|
import torch.nn
|
||||||
@ -26,8 +26,8 @@ class DeviceSettable(Protocol):
|
|||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class HooksSupport(Protocol):
|
class HooksSupport(Protocol):
|
||||||
wrappers: dict[str, dict[str, list[Callable]]]
|
wrappers: dict[str, list[Callable]]
|
||||||
callbacks: dict[str, dict[str, list[Callable]]]
|
callbacks: dict[str, list[Callable]]
|
||||||
hook_mode: "EnumHookMode"
|
hook_mode: "EnumHookMode"
|
||||||
|
|
||||||
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): ...
|
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): ...
|
||||||
|
|||||||
@ -219,6 +219,23 @@ class TransformersLoader1(TransformersLoader):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class TransformersLoaderQuantized(TransformersLoader):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"ckpt_name": ("STRING", {}),
|
||||||
|
"load_in_4bit": ("BOOLEAN", {"default": False}),
|
||||||
|
"load_in_8bit": ("BOOLEAN", {"default": False}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"subfolder": ("STRING", {}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, load_in_4bit: bool = False, load_in_8bit: bool = False, *args, **kwargs) -> tuple[TransformersManagedModel]:
|
||||||
|
return TransformersManagedModel.from_pretrained(ckpt_name, subfolder, load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit),
|
||||||
|
|
||||||
|
|
||||||
class TransformersTokenize(CustomNode):
|
class TransformersTokenize(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from torch.nn import LayerNorm
|
|||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
|
from comfy.model_management_types import HooksSupport
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||||
from comfy.sd import VAE
|
from comfy.sd import VAE
|
||||||
@ -100,18 +101,8 @@ class TorchCompileModel(CustomNode):
|
|||||||
"make_refittable": True,
|
"make_refittable": True,
|
||||||
}
|
}
|
||||||
del compile_kwargs["mode"]
|
del compile_kwargs["mode"]
|
||||||
if isinstance(model, TransformersManagedModel):
|
|
||||||
to_return = model.clone()
|
|
||||||
model = to_return.model
|
|
||||||
|
|
||||||
model_management.unload_all_models()
|
if isinstance(model, HooksSupport):
|
||||||
model.to(device=model_management.get_torch_device())
|
|
||||||
res = torch.compile(model=model, **compile_kwargs),
|
|
||||||
model.to(device=model_management.unet_offload_device())
|
|
||||||
|
|
||||||
to_return.add_object_patch("model", res)
|
|
||||||
return to_return,
|
|
||||||
elif isinstance(model, (ModelPatcher, VAE)):
|
|
||||||
to_return = model.clone()
|
to_return = model.clone()
|
||||||
object_patches = [p.strip() for p in object_patch.split(",")]
|
object_patches = [p.strip() for p in object_patch.split(",")]
|
||||||
patcher: ModelPatcher
|
patcher: ModelPatcher
|
||||||
|
|||||||
0
tests/language/__init__.py
Normal file
0
tests/language/__init__.py
Normal file
20
tests/language/test_phi4_loading.py
Normal file
20
tests/language/test_phi4_loading.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import pytest
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
|
from comfy.client.embedded_comfy_client import Comfy
|
||||||
|
from comfy.api.components.schema.prompt import Prompt
|
||||||
|
|
||||||
|
|
||||||
|
class TestPhi4Loading:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_phi4_loading(self):
|
||||||
|
graph = GraphBuilder()
|
||||||
|
model_loader = graph.node("TransformersLoaderQuantized", ckpt_name="microsoft/phi-4", load_in_4bit=True, load_in_8bit=False)
|
||||||
|
tokenizer = graph.node("OneShotInstructTokenize", model=model_loader.out(0), prompt="Hello", chat_template="default")
|
||||||
|
generation = graph.node("TransformersGenerate", model=model_loader.out(0), tokens=tokenizer.out(0), max_new_tokens=1, seed=42)
|
||||||
|
graph.node("SaveString", value=generation.out(0), filename_prefix="phi4_test")
|
||||||
|
|
||||||
|
workflow = graph.finalize()
|
||||||
|
prompt = Prompt.validate(workflow)
|
||||||
|
|
||||||
|
async with Comfy() as client:
|
||||||
|
await client.queue_prompt(prompt)
|
||||||
31
tests/language/test_torch_compile_transformers.py
Normal file
31
tests/language/test_torch_compile_transformers.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
|
from comfy.client.embedded_comfy_client import Comfy
|
||||||
|
from comfy.api.components.schema.prompt import Prompt
|
||||||
|
|
||||||
|
|
||||||
|
class TestTorchCompileTransformers:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_torch_compile_transformers(self):
|
||||||
|
graph = GraphBuilder()
|
||||||
|
model_loader = graph.node("TransformersLoader1", ckpt_name="Qwen/Qwen2.5-0.5B")
|
||||||
|
compiled_model = graph.node("TorchCompileModel", model=model_loader.out(0), backend="inductor", mode="max-autotune")
|
||||||
|
tokenizer = graph.node("OneShotInstructTokenize", model=compiled_model.out(0), prompt="Hello, world!", chat_template="default")
|
||||||
|
generation = graph.node("TransformersGenerate", model=compiled_model.out(0), tokens=tokenizer.out(0), max_new_tokens=10, seed=42)
|
||||||
|
|
||||||
|
save_string = graph.node("SaveString", value=generation.out(0), filename_prefix="test_output")
|
||||||
|
|
||||||
|
workflow = graph.finalize()
|
||||||
|
prompt = Prompt.validate(workflow)
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
with patch("torch.compile", side_effect=torch.compile) as mock_compile:
|
||||||
|
async with Comfy() as client:
|
||||||
|
outputs = await client.queue_prompt(prompt)
|
||||||
|
|
||||||
|
assert mock_compile.called, "torch.compile should have been called"
|
||||||
|
|
||||||
|
assert len(outputs) > 0
|
||||||
|
assert save_string.id in outputs
|
||||||
|
assert outputs[save_string.id]["string"][0] is not None
|
||||||
Loading…
Reference in New Issue
Block a user