fix torch compile for language models

This commit is contained in:
doctorpangloss 2025-12-03 14:38:03 -08:00
parent b149031748
commit 4349fac71a
8 changed files with 187 additions and 23 deletions

View File

@ -62,11 +62,93 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
self._on_set_processor(self._processor)
self._model_type = ""
self._original_transformers_managed_model: weakref.ReferenceType["TransformersManagedModel"] = weakref.ref(self)
self.wrappers = {}
self.callbacks = {}
self._hook_mode = None
self._model_options = {"transformer_options": {}}
if model.device != self.offload_device:
model.to(device=self.offload_device)
@property
def hook_mode(self):
from ..hooks import EnumHookMode
if self._hook_mode is None:
self._hook_mode = EnumHookMode.MaxSpeed
return self._hook_mode
@hook_mode.setter
def hook_mode(self, value):
self._hook_mode = value
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
return
def model_patches_models(self):
return []
def restore_hook_patches(self):
return
def cleanup(self):
pass
def pre_run(self):
pass
def prepare_state(self, *args, **kwargs):
pass
def register_all_hook_patches(self, a, b, c, d):
pass
def get_nested_additional_models(self):
return []
def apply_hooks(self, *args, **kwargs):
return {}
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
self.add_wrapper_with_key(wrapper_type, None, wrapper)
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
w = self.wrappers.get(wrapper_type, {})
if key in w:
w.pop(key)
def get_wrappers_with_key(self, wrapper_type: str, key: str):
w_list = []
w_list.extend(self.wrappers.get(wrapper_type, {}).get(key, []))
return w_list
def get_all_wrappers(self, wrapper_type: str):
w_list = []
for w in self.wrappers.get(wrapper_type, {}).values():
w_list.extend(w)
return w_list
@property
def model_options(self):
return self._model_options
@model_options.setter
def model_options(self, value):
self._model_options = value
@property
def diffusion_model(self):
return self.model
@diffusion_model.setter
def diffusion_model(self, value):
self.add_object_patch("model", value)
@staticmethod
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None, config_dict: PretrainedConfig | dict | None = None) -> "TransformersManagedModel":
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None, config_dict: PretrainedConfig | dict | None = None, **kwargs) -> "TransformersManagedModel":
hub_kwargs = {}
if subfolder is not None and subfolder.strip() != "":
hub_kwargs["subfolder"] = subfolder
@ -89,7 +171,8 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
from_pretrained_kwargs = {
"pretrained_model_name_or_path": ckpt_name,
**hub_kwargs
**hub_kwargs,
**kwargs,
}
# language models prefer to use bfloat16 over float16
@ -149,7 +232,16 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
tokenizer = processor
processor = None
else:
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try)
try:
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try)
except Exception:
try:
tokenizer = AutoTokenizer.from_pretrained(ckpt_name, use_fast=True, legacy=False, **hub_kwargs, **kwargs_to_try)
except Exception:
if repo_id != ckpt_name:
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True, legacy=False, **hub_kwargs, **kwargs_to_try)
else:
raise
if tokenizer is not None or processor is not None:
break
except Exception as exc_info:
@ -252,12 +344,22 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
with seed_for_block(seed), torch.inference_mode(mode=True) if has_triton else contextlib.nullcontext():
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
inputs.pop("attention_mask")
output_ids = transformers_model.generate(
**inputs,
streamer=text_streamer if num_beams <= 1 else None,
max_new_tokens=max_new_tokens,
**generate_kwargs
)
from ..patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers
def _generate(inputs, streamer, max_new_tokens, **generate_kwargs):
return transformers_model.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
**generate_kwargs
)
output_ids = WrapperExecutor.new_class_executor(
_generate,
self,
get_all_wrappers(WrappersMP.APPLY_MODEL, self.model_options)
).execute(inputs, text_streamer if num_beams <= 1 else None, max_new_tokens, **generate_kwargs)
if not transformers_model.config.is_encoder_decoder:
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
@ -336,6 +438,9 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
return self._size
def model_patches_to(self, arg: torch.device | torch.dtype):
if getattr(self.model, "is_loaded_in_4bit", False) or getattr(self.model, "is_loaded_in_8bit", False):
return
if isinstance(arg, torch.device):
self.model.to(device=arg)
else:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import copy
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, TYPE_CHECKING
from typing import Any, Callable, Protocol, runtime_checkable, Optional, TypeVar, NamedTuple, TYPE_CHECKING
import torch
import torch.nn
@ -26,8 +26,8 @@ class DeviceSettable(Protocol):
@runtime_checkable
class HooksSupport(Protocol):
wrappers: dict[str, dict[str, list[Callable]]]
callbacks: dict[str, dict[str, list[Callable]]]
wrappers: dict[str, list[Callable]]
callbacks: dict[str, list[Callable]]
hook_mode: "EnumHookMode"
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): ...

View File

@ -219,6 +219,23 @@ class TransformersLoader1(TransformersLoader):
}
}
class TransformersLoaderQuantized(TransformersLoader):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": ("STRING", {}),
"load_in_4bit": ("BOOLEAN", {"default": False}),
"load_in_8bit": ("BOOLEAN", {"default": False}),
},
"optional": {
"subfolder": ("STRING", {}),
}
}
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, load_in_4bit: bool = False, load_in_8bit: bool = False, *args, **kwargs) -> tuple[TransformersManagedModel]:
return TransformersManagedModel.from_pretrained(ckpt_name, subfolder, load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit),
class TransformersTokenize(CustomNode):
@classmethod

View File

@ -9,6 +9,7 @@ from torch.nn import LayerNorm
from comfy import model_management
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_management_types import HooksSupport
from comfy.model_patcher import ModelPatcher
from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy.sd import VAE
@ -100,18 +101,8 @@ class TorchCompileModel(CustomNode):
"make_refittable": True,
}
del compile_kwargs["mode"]
if isinstance(model, TransformersManagedModel):
to_return = model.clone()
model = to_return.model
model_management.unload_all_models()
model.to(device=model_management.get_torch_device())
res = torch.compile(model=model, **compile_kwargs),
model.to(device=model_management.unet_offload_device())
to_return.add_object_patch("model", res)
return to_return,
elif isinstance(model, (ModelPatcher, VAE)):
if isinstance(model, HooksSupport):
to_return = model.clone()
object_patches = [p.strip() for p in object_patch.split(",")]
patcher: ModelPatcher

View File

View File

@ -0,0 +1,20 @@
import pytest
from comfy_execution.graph_utils import GraphBuilder
from comfy.client.embedded_comfy_client import Comfy
from comfy.api.components.schema.prompt import Prompt
class TestPhi4Loading:
@pytest.mark.asyncio
async def test_phi4_loading(self):
graph = GraphBuilder()
model_loader = graph.node("TransformersLoaderQuantized", ckpt_name="microsoft/phi-4", load_in_4bit=True, load_in_8bit=False)
tokenizer = graph.node("OneShotInstructTokenize", model=model_loader.out(0), prompt="Hello", chat_template="default")
generation = graph.node("TransformersGenerate", model=model_loader.out(0), tokens=tokenizer.out(0), max_new_tokens=1, seed=42)
graph.node("SaveString", value=generation.out(0), filename_prefix="phi4_test")
workflow = graph.finalize()
prompt = Prompt.validate(workflow)
async with Comfy() as client:
await client.queue_prompt(prompt)

View File

@ -0,0 +1,31 @@
import pytest
import torch
from comfy_execution.graph_utils import GraphBuilder
from comfy.client.embedded_comfy_client import Comfy
from comfy.api.components.schema.prompt import Prompt
class TestTorchCompileTransformers:
@pytest.mark.asyncio
async def test_torch_compile_transformers(self):
graph = GraphBuilder()
model_loader = graph.node("TransformersLoader1", ckpt_name="Qwen/Qwen2.5-0.5B")
compiled_model = graph.node("TorchCompileModel", model=model_loader.out(0), backend="inductor", mode="max-autotune")
tokenizer = graph.node("OneShotInstructTokenize", model=compiled_model.out(0), prompt="Hello, world!", chat_template="default")
generation = graph.node("TransformersGenerate", model=compiled_model.out(0), tokens=tokenizer.out(0), max_new_tokens=10, seed=42)
save_string = graph.node("SaveString", value=generation.out(0), filename_prefix="test_output")
workflow = graph.finalize()
prompt = Prompt.validate(workflow)
from unittest.mock import patch
with patch("torch.compile", side_effect=torch.compile) as mock_compile:
async with Comfy() as client:
outputs = await client.queue_prompt(prompt)
assert mock_compile.called, "torch.compile should have been called"
assert len(outputs) > 0
assert save_string.id in outputs
assert outputs[save_string.id]["string"][0] is not None