mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
add custome zimage ops node.
This commit is contained in:
parent
55e6478526
commit
19d51dadb2
5
custom_nodes/baymax-zimage/__init__.py
Normal file
5
custom_nodes/baymax-zimage/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .nodes_baymax_zimage import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
||||
|
||||
WEB_DIRECTORY = None
|
||||
157
custom_nodes/baymax-zimage/nodes_baymax_zimage.py
Normal file
157
custom_nodes/baymax-zimage/nodes_baymax_zimage.py
Normal file
@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Callable, Dict, Iterable, Tuple
|
||||
|
||||
logger = logging.getLogger("baymax-zimage")
|
||||
|
||||
_PATCH_LOCK = threading.Lock()
|
||||
_ORIGINAL_FUNCTIONS: Dict[Tuple[str, str], Callable] = {}
|
||||
_PACKAGE_DIR = Path(__file__).resolve().parent
|
||||
_USER_IMPL_PATH = _PACKAGE_DIR / "user_impl.py"
|
||||
|
||||
|
||||
def _available_targets() -> Tuple[Tuple[ModuleType, str], ...]:
|
||||
targets = []
|
||||
for module, attribute_name in _target_modules():
|
||||
if hasattr(module, attribute_name):
|
||||
targets.append((module, attribute_name))
|
||||
else:
|
||||
logger.warning(
|
||||
"[baymax-zimage] Skipping %s.%s because it does not exist",
|
||||
module.__name__,
|
||||
attribute_name,
|
||||
)
|
||||
return tuple(targets)
|
||||
|
||||
|
||||
def _target_modules() -> Iterable[Tuple[ModuleType, str]]:
|
||||
return (
|
||||
(importlib.import_module("comfy.ldm.lumina.model"), "apply_rope"),
|
||||
(importlib.import_module("comfy.ldm.lumina.controlnet"), "apply_rope"),
|
||||
)
|
||||
|
||||
|
||||
def _load_user_module() -> ModuleType:
|
||||
spec = importlib.util.spec_from_file_location("baymax_zimage_user_impl", _USER_IMPL_PATH)
|
||||
if spec is None or spec.loader is None:
|
||||
raise RuntimeError(f"Unable to load user implementation from {_USER_IMPL_PATH}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _call_user_impl(user_function: Callable, original_function: Callable, xq, xk, freqs_cis):
|
||||
signature = inspect.signature(user_function)
|
||||
parameters = signature.parameters.values()
|
||||
positional_slots = sum(
|
||||
1
|
||||
for parameter in parameters
|
||||
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||
)
|
||||
has_varargs = any(parameter.kind == inspect.Parameter.VAR_POSITIONAL for parameter in signature.parameters.values())
|
||||
has_varkw = any(parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
if "original_apply_rope" in signature.parameters or has_varkw:
|
||||
return user_function(xq, xk, freqs_cis, original_apply_rope=original_function)
|
||||
|
||||
if positional_slots >= 4 or has_varargs:
|
||||
return user_function(xq, xk, freqs_cis, original_function)
|
||||
|
||||
return user_function(xq, xk, freqs_cis)
|
||||
|
||||
|
||||
def _make_wrapper(module_name: str, user_function: Callable) -> Callable:
|
||||
original_function = _ORIGINAL_FUNCTIONS[(module_name, "apply_rope")]
|
||||
|
||||
def patched_apply_rope(xq, xk, freqs_cis):
|
||||
return _call_user_impl(user_function, original_function, xq, xk, freqs_cis)
|
||||
|
||||
patched_apply_rope.__name__ = "apply_rope"
|
||||
patched_apply_rope.__module__ = __name__
|
||||
return patched_apply_rope
|
||||
|
||||
|
||||
def _patch_rotary(enable: bool, reload_user_impl: bool) -> str:
|
||||
del reload_user_impl # The user module is reloaded on every execution.
|
||||
|
||||
with _PATCH_LOCK:
|
||||
targets = list(_available_targets())
|
||||
if not targets:
|
||||
raise RuntimeError("No z-Image rotary targets with apply_rope were found")
|
||||
|
||||
for module, attribute_name in targets:
|
||||
key = (module.__name__, attribute_name)
|
||||
_ORIGINAL_FUNCTIONS.setdefault(key, getattr(module, attribute_name))
|
||||
|
||||
if not enable:
|
||||
for module, attribute_name in targets:
|
||||
setattr(module, attribute_name, _ORIGINAL_FUNCTIONS[(module.__name__, attribute_name)])
|
||||
logger.info("[baymax-zimage] Restored the default z-Image rotary implementation")
|
||||
return "restored"
|
||||
|
||||
user_module = _load_user_module()
|
||||
user_function = getattr(user_module, "apply_rotary_emb", None)
|
||||
if not callable(user_function):
|
||||
raise RuntimeError(f"{_USER_IMPL_PATH} must define a callable apply_rotary_emb function")
|
||||
|
||||
for module, attribute_name in targets:
|
||||
setattr(module, attribute_name, _make_wrapper(module.__name__, user_function))
|
||||
|
||||
logger.info("[baymax-zimage] Installed user rotary implementation from %s", _USER_IMPL_PATH)
|
||||
return "patched"
|
||||
|
||||
|
||||
class BaymaxZImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"enable": ("BOOLEAN", {"default": True}),
|
||||
"reload_user_impl": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "patch_model"
|
||||
CATEGORY = "baymax"
|
||||
DESCRIPTION = (
|
||||
"Patch the z-Image NextDiT rotary function with the implementation in "
|
||||
"custom_nodes/baymax-zimage/user_impl.py."
|
||||
)
|
||||
|
||||
def patch_model(self, model, enable=True, reload_user_impl=True):
|
||||
diffusion_model = getattr(getattr(model, "model", None), "diffusion_model", None)
|
||||
if diffusion_model is not None and diffusion_model.__class__.__name__ != "NextDiTPixelSpace":
|
||||
logger.warning(
|
||||
"[baymax-zimage] Received model type %s; patch still applies globally to z-Image rotary code paths",
|
||||
diffusion_model.__class__.__name__,
|
||||
)
|
||||
|
||||
patched_model = model.clone()
|
||||
status = _patch_rotary(enable=enable, reload_user_impl=reload_user_impl)
|
||||
|
||||
transformer_options = patched_model.model_options.setdefault("transformer_options", {})
|
||||
transformer_options["baymax_zimage"] = {
|
||||
"status": status,
|
||||
"user_impl": str(_USER_IMPL_PATH),
|
||||
}
|
||||
return (patched_model,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"BaymaxZImage": BaymaxZImage,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"BaymaxZImage": "baymax-zimage",
|
||||
}
|
||||
30
custom_nodes/baymax-zimage/user_impl.py
Normal file
30
custom_nodes/baymax-zimage/user_impl.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""User-editable rotary implementation for the baymax-zimage node.
|
||||
|
||||
Edit apply_rotary_emb and then run the baymax-zimage node with reload enabled.
|
||||
The z-Image transformer in this repository calls apply_rope internally, so the
|
||||
node adapts that call into this apply_rotary_emb interface.
|
||||
"""
|
||||
|
||||
def apply_rotary_emb(xq, xk, freqs_cis, original_apply_rope=None):
|
||||
if freqs_cis is None:
|
||||
if original_apply_rope is None:
|
||||
raise RuntimeError("freqs_cis is None and no original_apply_rope fallback is available")
|
||||
return original_apply_rope(xq, xk, freqs_cis)
|
||||
|
||||
# Standalone rotary implementation compatible with z-Image NextDiT paths.
|
||||
def _apply_single(x):
|
||||
if x is None:
|
||||
return None
|
||||
|
||||
x_work = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
fc = freqs_cis
|
||||
|
||||
# Match the half-dim slice used by this q/k tensor.
|
||||
if x_work.shape[2] != 1 and fc.shape[2] != 1 and x_work.shape[2] != fc.shape[2]:
|
||||
fc = fc[:, :, :x_work.shape[2]]
|
||||
|
||||
x_out = fc[..., 0] * x_work[..., 0]
|
||||
x_out.addcmul_(fc[..., 1], x_work[..., 1])
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
return _apply_single(xq), _apply_single(xk)
|
||||
Loading…
Reference in New Issue
Block a user