mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
157 lines
5.6 KiB
Python
157 lines
5.6 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib
|
|
import importlib.util
|
|
import inspect
|
|
import logging
|
|
import threading
|
|
from pathlib import Path
|
|
from types import ModuleType
|
|
from typing import Callable, Dict, Iterable, Tuple
|
|
|
|
logger = logging.getLogger("baymax-zimage")
|
|
|
|
_PATCH_LOCK = threading.Lock()
|
|
_ORIGINAL_FUNCTIONS: Dict[Tuple[str, str], Callable] = {}
|
|
_PACKAGE_DIR = Path(__file__).resolve().parent
|
|
_USER_IMPL_PATH = _PACKAGE_DIR / "user_impl.py"
|
|
|
|
|
|
def _available_targets() -> Tuple[Tuple[ModuleType, str], ...]:
|
|
targets = []
|
|
for module, attribute_name in _target_modules():
|
|
if hasattr(module, attribute_name):
|
|
targets.append((module, attribute_name))
|
|
else:
|
|
logger.warning(
|
|
"[baymax-zimage] Skipping %s.%s because it does not exist",
|
|
module.__name__,
|
|
attribute_name,
|
|
)
|
|
return tuple(targets)
|
|
|
|
|
|
def _target_modules() -> Iterable[Tuple[ModuleType, str]]:
|
|
return (
|
|
(importlib.import_module("comfy.ldm.lumina.model"), "apply_rope"),
|
|
(importlib.import_module("comfy.ldm.lumina.controlnet"), "apply_rope"),
|
|
)
|
|
|
|
|
|
def _load_user_module() -> ModuleType:
|
|
spec = importlib.util.spec_from_file_location("baymax_zimage_user_impl", _USER_IMPL_PATH)
|
|
if spec is None or spec.loader is None:
|
|
raise RuntimeError(f"Unable to load user implementation from {_USER_IMPL_PATH}")
|
|
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
def _call_user_impl(user_function: Callable, original_function: Callable, xq, xk, freqs_cis):
|
|
signature = inspect.signature(user_function)
|
|
parameters = signature.parameters.values()
|
|
positional_slots = sum(
|
|
1
|
|
for parameter in parameters
|
|
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
|
)
|
|
has_varargs = any(parameter.kind == inspect.Parameter.VAR_POSITIONAL for parameter in signature.parameters.values())
|
|
has_varkw = any(parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
|
|
|
if "original_apply_rope" in signature.parameters or has_varkw:
|
|
return user_function(xq, xk, freqs_cis, original_apply_rope=original_function)
|
|
|
|
if positional_slots >= 4 or has_varargs:
|
|
return user_function(xq, xk, freqs_cis, original_function)
|
|
|
|
return user_function(xq, xk, freqs_cis)
|
|
|
|
|
|
def _make_wrapper(module_name: str, user_function: Callable) -> Callable:
|
|
original_function = _ORIGINAL_FUNCTIONS[(module_name, "apply_rope")]
|
|
|
|
def patched_apply_rope(xq, xk, freqs_cis):
|
|
return _call_user_impl(user_function, original_function, xq, xk, freqs_cis)
|
|
|
|
patched_apply_rope.__name__ = "apply_rope"
|
|
patched_apply_rope.__module__ = __name__
|
|
return patched_apply_rope
|
|
|
|
|
|
def _patch_rotary(enable: bool, reload_user_impl: bool) -> str:
|
|
del reload_user_impl # The user module is reloaded on every execution.
|
|
|
|
with _PATCH_LOCK:
|
|
targets = list(_available_targets())
|
|
if not targets:
|
|
raise RuntimeError("No z-Image rotary targets with apply_rope were found")
|
|
|
|
for module, attribute_name in targets:
|
|
key = (module.__name__, attribute_name)
|
|
_ORIGINAL_FUNCTIONS.setdefault(key, getattr(module, attribute_name))
|
|
|
|
if not enable:
|
|
for module, attribute_name in targets:
|
|
setattr(module, attribute_name, _ORIGINAL_FUNCTIONS[(module.__name__, attribute_name)])
|
|
logger.info("[baymax-zimage] Restored the default z-Image rotary implementation")
|
|
return "restored"
|
|
|
|
user_module = _load_user_module()
|
|
user_function = getattr(user_module, "apply_rotary_emb", None)
|
|
if not callable(user_function):
|
|
raise RuntimeError(f"{_USER_IMPL_PATH} must define a callable apply_rotary_emb function")
|
|
|
|
for module, attribute_name in targets:
|
|
setattr(module, attribute_name, _make_wrapper(module.__name__, user_function))
|
|
|
|
logger.info("[baymax-zimage] Installed user rotary implementation from %s", _USER_IMPL_PATH)
|
|
return "patched"
|
|
|
|
|
|
class BaymaxZImage:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"enable": ("BOOLEAN", {"default": True}),
|
|
"reload_user_impl": ("BOOLEAN", {"default": True}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
RETURN_NAMES = ("model",)
|
|
FUNCTION = "patch_model"
|
|
CATEGORY = "baymax"
|
|
DESCRIPTION = (
|
|
"Patch the z-Image NextDiT rotary function with the implementation in "
|
|
"custom_nodes/baymax-zimage/user_impl.py."
|
|
)
|
|
|
|
def patch_model(self, model, enable=True, reload_user_impl=True):
|
|
diffusion_model = getattr(getattr(model, "model", None), "diffusion_model", None)
|
|
if diffusion_model is not None and diffusion_model.__class__.__name__ != "NextDiTPixelSpace":
|
|
logger.warning(
|
|
"[baymax-zimage] Received model type %s; patch still applies globally to z-Image rotary code paths",
|
|
diffusion_model.__class__.__name__,
|
|
)
|
|
|
|
patched_model = model.clone()
|
|
status = _patch_rotary(enable=enable, reload_user_impl=reload_user_impl)
|
|
|
|
transformer_options = patched_model.model_options.setdefault("transformer_options", {})
|
|
transformer_options["baymax_zimage"] = {
|
|
"status": status,
|
|
"user_impl": str(_USER_IMPL_PATH),
|
|
}
|
|
return (patched_model,)
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"BaymaxZImage": BaymaxZImage,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"BaymaxZImage": "baymax-zimage",
|
|
} |