From 20dbf31c0ff356a141ef17251fa948eab7d8baa5 Mon Sep 17 00:00:00 2001 From: Haoming Date: Mon, 29 Dec 2025 16:02:43 +0800 Subject: [PATCH] v2 --- comfy_extras/nodes_dype.py | 185 +++++++++++++++---------------------- 1 file changed, 74 insertions(+), 111 deletions(-) diff --git a/comfy_extras/nodes_dype.py b/comfy_extras/nodes_dype.py index 3cc0be8a7..80cafdfc8 100644 --- a/comfy_extras/nodes_dype.py +++ b/comfy_extras/nodes_dype.py @@ -1,6 +1,7 @@ # adapted from https://github.com/guyyariv/DyPE import math +from typing import Callable import numpy as np import torch @@ -39,7 +40,7 @@ def find_newbase_ntk(dim, base, scale): def get_1d_rotary_pos_embed( dim: int, - pos: np.ndarray | int, + pos: torch.Tensor, theta: float = 10000.0, use_real=False, linear_factor=1.0, @@ -49,7 +50,6 @@ def get_1d_rotary_pos_embed( yarn=False, max_pe_len=None, ori_max_pe_len=64, - dype=False, current_timestep=1.0, ): """ @@ -80,8 +80,6 @@ def get_1d_rotary_pos_embed( Maximum position encoding length (current patches for vision models). ori_max_pe_len (`int`, *optional*, defaults to 64): Original maximum position encoding length (base patches for vision models). - dype (`bool`, *optional*, defaults to False): - If True, enable DyPE (Dynamic Position Encoding) with timestep-aware scaling. current_timestep (`float`, *optional*, defaults to 1.0): Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise. @@ -91,11 +89,6 @@ def get_1d_rotary_pos_embed( """ assert dim % 2 == 0 - if isinstance(pos, int): - pos = torch.arange(pos) - if isinstance(pos, np.ndarray): - pos = torch.from_numpy(pos) - device = pos.device if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len: @@ -104,10 +97,8 @@ def get_1d_rotary_pos_embed( scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0) - beta_0 = 1.25 - beta_1 = 0.75 - gamma_0 = 16 - gamma_1 = 2 + beta_0, beta_1 = 1.25, 0.75 + gamma_0, gamma_1 = 16, 2 freqs_base = 1.0 / ( theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim) @@ -131,9 +122,8 @@ def get_1d_rotary_pos_embed( if freqs_ntk.dim() > 1: freqs_ntk = freqs_ntk.squeeze() - if dype: - beta_0 = beta_0 ** (2.0 * (current_timestep**2.0)) - beta_1 = beta_1 ** (2.0 * (current_timestep**2.0)) + beta_0 = beta_0 ** (2.0 * (current_timestep**2.0)) + beta_1 = beta_1 ** (2.0 * (current_timestep**2.0)) low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len) low = max(0, low) @@ -144,9 +134,8 @@ def get_1d_rotary_pos_embed( ) freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask - if dype: - gamma_0 = gamma_0 ** (2.0 * (current_timestep**2.0)) - gamma_1 = gamma_1 ** (2.0 * (current_timestep**2.0)) + gamma_0 = gamma_0 ** (2.0 * (current_timestep**2.0)) + gamma_1 = gamma_1 ** (2.0 * (current_timestep**2.0)) low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len) low = max(0, low) @@ -174,50 +163,44 @@ def get_1d_rotary_pos_embed( if is_npu: freqs = freqs.float() - if use_real and repeat_interleave_real: - freqs_cos = ( - freqs.cos() - .repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2) - .float() - ) - freqs_sin = ( - freqs.sin() - .repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2) - .float() - ) + if use_real: + if repeat_interleave_real: + freqs_cos = ( + freqs.cos() + .repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2) + .float() + ) + freqs_sin = ( + freqs.sin() + .repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2) + .float() + ) - if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len: - mscale = torch.where( - scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0 - ).to(scale) - freqs_cos = freqs_cos * mscale - freqs_sin = freqs_sin * mscale + if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len: + mscale = torch.where( + scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0 + ).to(scale) + freqs_cos = freqs_cos * mscale + freqs_sin = freqs_sin * mscale - return freqs_cos, freqs_sin - elif use_real: - freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() - freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() - return freqs_cos, freqs_sin + return freqs_cos, freqs_sin + else: + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() + return freqs_cos, freqs_sin else: freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis class FluxPosEmbed(torch.nn.Module): - def __init__( - self, - theta: int, - axes_dim: list[int], - method: str = "yarn", - dype: bool = True, - ): + def __init__(self, theta: int, axes_dim: list[int], method: str = "yarn"): super().__init__() self.theta = theta self.axes_dim = axes_dim self.base_resolution = 1024 self.base_patches = (self.base_resolution // 8) // 2 self.method = method - self.dype = dype if method != "base" else False self.current_timestep = 1.0 def set_timestep(self, timestep: float): @@ -244,43 +227,35 @@ class FluxPosEmbed(torch.nn.Module): "freqs_dtype": freqs_dtype, } - if i > 0: - max_pos = axis_pos.max().item() - current_patches = max_pos + 1 + max_pos = axis_pos.max().item() + current_patches = max_pos + 1 - if self.method == "yarn" and current_patches > self.base_patches: - max_pe_len = torch.tensor( - current_patches, dtype=freqs_dtype, device=pos.device - ) - cos, sin = get_1d_rotary_pos_embed( - **common_kwargs, - yarn=True, - max_pe_len=max_pe_len, - ori_max_pe_len=self.base_patches, - dype=self.dype, - current_timestep=self.current_timestep, - ) - - elif self.method == "ntk" and current_patches > self.base_patches: - base_ntk = (current_patches / self.base_patches) ** ( - self.axes_dim[i] / (self.axes_dim[i] - 2) - ) - ntk_factor = ( - base_ntk ** (2.0 * (self.current_timestep**2.0)) - if self.dype - else base_ntk - ) - ntk_factor = max(1.0, ntk_factor) - - cos, sin = get_1d_rotary_pos_embed( - **common_kwargs, ntk_factor=ntk_factor - ) - - else: - cos, sin = get_1d_rotary_pos_embed(**common_kwargs) - else: + if i == 0 or current_patches <= self.base_patches: cos, sin = get_1d_rotary_pos_embed(**common_kwargs) + elif self.method == "yarn": + max_pe_len = torch.tensor( + current_patches, dtype=freqs_dtype, device=pos.device + ) + cos, sin = get_1d_rotary_pos_embed( + **common_kwargs, + yarn=True, + max_pe_len=max_pe_len, + ori_max_pe_len=self.base_patches, + current_timestep=self.current_timestep, + ) + + elif self.method == "ntk": + base_ntk = (current_patches / self.base_patches) ** ( + self.axes_dim[i] / (self.axes_dim[i] - 2) + ) + ntk_factor = base_ntk ** (2.0 * (self.current_timestep**2.0)) + ntk_factor = max(1.0, ntk_factor) + + cos, sin = get_1d_rotary_pos_embed( + **common_kwargs, ntk_factor=ntk_factor + ) + cos_out.append(cos) sin_out.append(sin) @@ -298,35 +273,26 @@ class FluxPosEmbed(torch.nn.Module): def apply_dype_flux(model: ModelPatcher, method: str) -> ModelPatcher: - if getattr(model.model, "_dype", None) == method: - return model - - m = model.clone() - m.model._dype = method - - _pe_embedder = m.model.diffusion_model.pe_embedder + _pe_embedder = model.model.diffusion_model.pe_embedder _theta, _axes_dim = _pe_embedder.theta, _pe_embedder.axes_dim - pos_embedder = FluxPosEmbed(_theta, _axes_dim, method, dype=True) - m.add_object_patch("diffusion_model.pe_embedder", pos_embedder) + pos_embedder = FluxPosEmbed(_theta, _axes_dim, method) + model.add_object_patch("diffusion_model.pe_embedder", pos_embedder) - sigma_max = m.model.model_sampling.sigma_max.item() + sigma_max: float = model.model.model_sampling.sigma_max.item() - def dype_wrapper_function(model_function, args_dict): - timestep_tensor = args_dict.get("timestep") - if timestep_tensor is not None and timestep_tensor.numel() > 0: - current_sigma = timestep_tensor.flatten()[0].item() + def dype_wrapper_function(apply_model: Callable, args: dict): + timestep: torch.Tensor = args["timestep"] + sigma: float = timestep.item() - if sigma_max > 0: - normalized_timestep = min(max(current_sigma / sigma_max, 0.0), 1.0) - pos_embedder.set_timestep(normalized_timestep) + normalized_timestep = min(max(sigma / sigma_max, 0.0), 1.0) + pos_embedder.set_timestep(normalized_timestep) - input_x, c = args_dict.get("input"), args_dict.get("c", {}) - return model_function(input_x, args_dict.get("timestep"), **c) + return apply_model(args["input"], timestep, **args["c"]) - m.set_model_unet_function_wrapper(dype_wrapper_function) + model.set_model_unet_function_wrapper(dype_wrapper_function) - return m + return model class DyPEPatchModelFlux(io.ComfyNode): @@ -338,11 +304,7 @@ class DyPEPatchModelFlux(io.ComfyNode): category="_for_testing", inputs=[ io.Model.Input("model"), - io.Combo.Input( - "method", - options=["yarn", "ntk", "base"], - default="yarn", - ), + io.Combo.Input("method", options=["yarn", "ntk"], default="yarn"), ], outputs=[io.Model.Output()], is_experimental=True, @@ -350,8 +312,9 @@ class DyPEPatchModelFlux(io.ComfyNode): @classmethod def execute(cls, model: ModelPatcher, method: str) -> io.NodeOutput: - m = apply_dype_flux(model, method) - return io.NodeOutput(m) + model = model.clone() + model = apply_dype_flux(model, method) + return io.NodeOutput(model) class DyPEExtension(ComfyExtension):