add custome zimage ops node.

This commit is contained in:
kunshen1 2026-03-30 08:44:35 +00:00
parent 55e6478526
commit 19d51dadb2
3 changed files with 192 additions and 0 deletions

View File

@ -0,0 +1,5 @@
from .nodes_baymax_zimage import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
WEB_DIRECTORY = None

View File

@ -0,0 +1,157 @@
from __future__ import annotations
import importlib
import importlib.util
import inspect
import logging
import threading
from pathlib import Path
from types import ModuleType
from typing import Callable, Dict, Iterable, Tuple
logger = logging.getLogger("baymax-zimage")
_PATCH_LOCK = threading.Lock()
_ORIGINAL_FUNCTIONS: Dict[Tuple[str, str], Callable] = {}
_PACKAGE_DIR = Path(__file__).resolve().parent
_USER_IMPL_PATH = _PACKAGE_DIR / "user_impl.py"
def _available_targets() -> Tuple[Tuple[ModuleType, str], ...]:
targets = []
for module, attribute_name in _target_modules():
if hasattr(module, attribute_name):
targets.append((module, attribute_name))
else:
logger.warning(
"[baymax-zimage] Skipping %s.%s because it does not exist",
module.__name__,
attribute_name,
)
return tuple(targets)
def _target_modules() -> Iterable[Tuple[ModuleType, str]]:
return (
(importlib.import_module("comfy.ldm.lumina.model"), "apply_rope"),
(importlib.import_module("comfy.ldm.lumina.controlnet"), "apply_rope"),
)
def _load_user_module() -> ModuleType:
spec = importlib.util.spec_from_file_location("baymax_zimage_user_impl", _USER_IMPL_PATH)
if spec is None or spec.loader is None:
raise RuntimeError(f"Unable to load user implementation from {_USER_IMPL_PATH}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def _call_user_impl(user_function: Callable, original_function: Callable, xq, xk, freqs_cis):
signature = inspect.signature(user_function)
parameters = signature.parameters.values()
positional_slots = sum(
1
for parameter in parameters
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
)
has_varargs = any(parameter.kind == inspect.Parameter.VAR_POSITIONAL for parameter in signature.parameters.values())
has_varkw = any(parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
if "original_apply_rope" in signature.parameters or has_varkw:
return user_function(xq, xk, freqs_cis, original_apply_rope=original_function)
if positional_slots >= 4 or has_varargs:
return user_function(xq, xk, freqs_cis, original_function)
return user_function(xq, xk, freqs_cis)
def _make_wrapper(module_name: str, user_function: Callable) -> Callable:
original_function = _ORIGINAL_FUNCTIONS[(module_name, "apply_rope")]
def patched_apply_rope(xq, xk, freqs_cis):
return _call_user_impl(user_function, original_function, xq, xk, freqs_cis)
patched_apply_rope.__name__ = "apply_rope"
patched_apply_rope.__module__ = __name__
return patched_apply_rope
def _patch_rotary(enable: bool, reload_user_impl: bool) -> str:
del reload_user_impl # The user module is reloaded on every execution.
with _PATCH_LOCK:
targets = list(_available_targets())
if not targets:
raise RuntimeError("No z-Image rotary targets with apply_rope were found")
for module, attribute_name in targets:
key = (module.__name__, attribute_name)
_ORIGINAL_FUNCTIONS.setdefault(key, getattr(module, attribute_name))
if not enable:
for module, attribute_name in targets:
setattr(module, attribute_name, _ORIGINAL_FUNCTIONS[(module.__name__, attribute_name)])
logger.info("[baymax-zimage] Restored the default z-Image rotary implementation")
return "restored"
user_module = _load_user_module()
user_function = getattr(user_module, "apply_rotary_emb", None)
if not callable(user_function):
raise RuntimeError(f"{_USER_IMPL_PATH} must define a callable apply_rotary_emb function")
for module, attribute_name in targets:
setattr(module, attribute_name, _make_wrapper(module.__name__, user_function))
logger.info("[baymax-zimage] Installed user rotary implementation from %s", _USER_IMPL_PATH)
return "patched"
class BaymaxZImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"enable": ("BOOLEAN", {"default": True}),
"reload_user_impl": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "patch_model"
CATEGORY = "baymax"
DESCRIPTION = (
"Patch the z-Image NextDiT rotary function with the implementation in "
"custom_nodes/baymax-zimage/user_impl.py."
)
def patch_model(self, model, enable=True, reload_user_impl=True):
diffusion_model = getattr(getattr(model, "model", None), "diffusion_model", None)
if diffusion_model is not None and diffusion_model.__class__.__name__ != "NextDiTPixelSpace":
logger.warning(
"[baymax-zimage] Received model type %s; patch still applies globally to z-Image rotary code paths",
diffusion_model.__class__.__name__,
)
patched_model = model.clone()
status = _patch_rotary(enable=enable, reload_user_impl=reload_user_impl)
transformer_options = patched_model.model_options.setdefault("transformer_options", {})
transformer_options["baymax_zimage"] = {
"status": status,
"user_impl": str(_USER_IMPL_PATH),
}
return (patched_model,)
NODE_CLASS_MAPPINGS = {
"BaymaxZImage": BaymaxZImage,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BaymaxZImage": "baymax-zimage",
}

View File

@ -0,0 +1,30 @@
"""User-editable rotary implementation for the baymax-zimage node.
Edit apply_rotary_emb and then run the baymax-zimage node with reload enabled.
The z-Image transformer in this repository calls apply_rope internally, so the
node adapts that call into this apply_rotary_emb interface.
"""
def apply_rotary_emb(xq, xk, freqs_cis, original_apply_rope=None):
if freqs_cis is None:
if original_apply_rope is None:
raise RuntimeError("freqs_cis is None and no original_apply_rope fallback is available")
return original_apply_rope(xq, xk, freqs_cis)
# Standalone rotary implementation compatible with z-Image NextDiT paths.
def _apply_single(x):
if x is None:
return None
x_work = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
fc = freqs_cis
# Match the half-dim slice used by this q/k tensor.
if x_work.shape[2] != 1 and fc.shape[2] != 1 and x_work.shape[2] != fc.shape[2]:
fc = fc[:, :, :x_work.shape[2]]
x_out = fc[..., 0] * x_work[..., 0]
x_out.addcmul_(fc[..., 1], x_work[..., 1])
return x_out.reshape(*x.shape).type_as(x)
return _apply_single(xq), _apply_single(xk)