ComfyUI/comfy_extras/nodes/nodes_torch_compile.py

53 lines
1.5 KiB
Python

import logging
import torch
from comfy.model_patcher import ModelPatcher
DIFFUSION_MODEL = "diffusion_model"
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
},
"optional": {
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
"fullgraph": ("BOOLEAN", {"default": False}),
"dynamic": ("BOOLEAN", {"default": False}),
"backend": ("STRING", {"default": "inductor"}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"):
if object_patch is None:
object_patch = DIFFUSION_MODEL
compile_kwargs = {
"fullgraph": fullgraph,
"dynamic": dynamic,
"backend": backend
}
if isinstance(model, ModelPatcher):
m = model.clone()
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
return (m,)
elif isinstance(model, torch.nn.Module):
return torch.compile(model=model, **compile_kwargs),
else:
logging.warning("Encountered a model that cannot be compiled")
return model,
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}