mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
v2
This commit is contained in:
parent
5c6fcbda91
commit
20dbf31c0f
@ -1,6 +1,7 @@
|
|||||||
# adapted from https://github.com/guyyariv/DyPE
|
# adapted from https://github.com/guyyariv/DyPE
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -39,7 +40,7 @@ def find_newbase_ntk(dim, base, scale):
|
|||||||
|
|
||||||
def get_1d_rotary_pos_embed(
|
def get_1d_rotary_pos_embed(
|
||||||
dim: int,
|
dim: int,
|
||||||
pos: np.ndarray | int,
|
pos: torch.Tensor,
|
||||||
theta: float = 10000.0,
|
theta: float = 10000.0,
|
||||||
use_real=False,
|
use_real=False,
|
||||||
linear_factor=1.0,
|
linear_factor=1.0,
|
||||||
@ -49,7 +50,6 @@ def get_1d_rotary_pos_embed(
|
|||||||
yarn=False,
|
yarn=False,
|
||||||
max_pe_len=None,
|
max_pe_len=None,
|
||||||
ori_max_pe_len=64,
|
ori_max_pe_len=64,
|
||||||
dype=False,
|
|
||||||
current_timestep=1.0,
|
current_timestep=1.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -80,8 +80,6 @@ def get_1d_rotary_pos_embed(
|
|||||||
Maximum position encoding length (current patches for vision models).
|
Maximum position encoding length (current patches for vision models).
|
||||||
ori_max_pe_len (`int`, *optional*, defaults to 64):
|
ori_max_pe_len (`int`, *optional*, defaults to 64):
|
||||||
Original maximum position encoding length (base patches for vision models).
|
Original maximum position encoding length (base patches for vision models).
|
||||||
dype (`bool`, *optional*, defaults to False):
|
|
||||||
If True, enable DyPE (Dynamic Position Encoding) with timestep-aware scaling.
|
|
||||||
current_timestep (`float`, *optional*, defaults to 1.0):
|
current_timestep (`float`, *optional*, defaults to 1.0):
|
||||||
Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise.
|
Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise.
|
||||||
|
|
||||||
@ -91,11 +89,6 @@ def get_1d_rotary_pos_embed(
|
|||||||
"""
|
"""
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
|
|
||||||
if isinstance(pos, int):
|
|
||||||
pos = torch.arange(pos)
|
|
||||||
if isinstance(pos, np.ndarray):
|
|
||||||
pos = torch.from_numpy(pos)
|
|
||||||
|
|
||||||
device = pos.device
|
device = pos.device
|
||||||
|
|
||||||
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
||||||
@ -104,10 +97,8 @@ def get_1d_rotary_pos_embed(
|
|||||||
|
|
||||||
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
|
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
|
||||||
|
|
||||||
beta_0 = 1.25
|
beta_0, beta_1 = 1.25, 0.75
|
||||||
beta_1 = 0.75
|
gamma_0, gamma_1 = 16, 2
|
||||||
gamma_0 = 16
|
|
||||||
gamma_1 = 2
|
|
||||||
|
|
||||||
freqs_base = 1.0 / (
|
freqs_base = 1.0 / (
|
||||||
theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)
|
theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)
|
||||||
@ -131,9 +122,8 @@ def get_1d_rotary_pos_embed(
|
|||||||
if freqs_ntk.dim() > 1:
|
if freqs_ntk.dim() > 1:
|
||||||
freqs_ntk = freqs_ntk.squeeze()
|
freqs_ntk = freqs_ntk.squeeze()
|
||||||
|
|
||||||
if dype:
|
beta_0 = beta_0 ** (2.0 * (current_timestep**2.0))
|
||||||
beta_0 = beta_0 ** (2.0 * (current_timestep**2.0))
|
beta_1 = beta_1 ** (2.0 * (current_timestep**2.0))
|
||||||
beta_1 = beta_1 ** (2.0 * (current_timestep**2.0))
|
|
||||||
|
|
||||||
low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
|
low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
|
||||||
low = max(0, low)
|
low = max(0, low)
|
||||||
@ -144,9 +134,8 @@ def get_1d_rotary_pos_embed(
|
|||||||
)
|
)
|
||||||
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
|
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
|
||||||
|
|
||||||
if dype:
|
gamma_0 = gamma_0 ** (2.0 * (current_timestep**2.0))
|
||||||
gamma_0 = gamma_0 ** (2.0 * (current_timestep**2.0))
|
gamma_1 = gamma_1 ** (2.0 * (current_timestep**2.0))
|
||||||
gamma_1 = gamma_1 ** (2.0 * (current_timestep**2.0))
|
|
||||||
|
|
||||||
low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
|
low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
|
||||||
low = max(0, low)
|
low = max(0, low)
|
||||||
@ -174,50 +163,44 @@ def get_1d_rotary_pos_embed(
|
|||||||
if is_npu:
|
if is_npu:
|
||||||
freqs = freqs.float()
|
freqs = freqs.float()
|
||||||
|
|
||||||
if use_real and repeat_interleave_real:
|
if use_real:
|
||||||
freqs_cos = (
|
if repeat_interleave_real:
|
||||||
freqs.cos()
|
freqs_cos = (
|
||||||
.repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2)
|
freqs.cos()
|
||||||
.float()
|
.repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2)
|
||||||
)
|
.float()
|
||||||
freqs_sin = (
|
)
|
||||||
freqs.sin()
|
freqs_sin = (
|
||||||
.repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2)
|
freqs.sin()
|
||||||
.float()
|
.repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2)
|
||||||
)
|
.float()
|
||||||
|
)
|
||||||
|
|
||||||
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
||||||
mscale = torch.where(
|
mscale = torch.where(
|
||||||
scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0
|
scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0
|
||||||
).to(scale)
|
).to(scale)
|
||||||
freqs_cos = freqs_cos * mscale
|
freqs_cos = freqs_cos * mscale
|
||||||
freqs_sin = freqs_sin * mscale
|
freqs_sin = freqs_sin * mscale
|
||||||
|
|
||||||
return freqs_cos, freqs_sin
|
return freqs_cos, freqs_sin
|
||||||
elif use_real:
|
else:
|
||||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()
|
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()
|
||||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()
|
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()
|
||||||
return freqs_cos, freqs_sin
|
return freqs_cos, freqs_sin
|
||||||
else:
|
else:
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
return freqs_cis
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
class FluxPosEmbed(torch.nn.Module):
|
class FluxPosEmbed(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(self, theta: int, axes_dim: list[int], method: str = "yarn"):
|
||||||
self,
|
|
||||||
theta: int,
|
|
||||||
axes_dim: list[int],
|
|
||||||
method: str = "yarn",
|
|
||||||
dype: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.theta = theta
|
self.theta = theta
|
||||||
self.axes_dim = axes_dim
|
self.axes_dim = axes_dim
|
||||||
self.base_resolution = 1024
|
self.base_resolution = 1024
|
||||||
self.base_patches = (self.base_resolution // 8) // 2
|
self.base_patches = (self.base_resolution // 8) // 2
|
||||||
self.method = method
|
self.method = method
|
||||||
self.dype = dype if method != "base" else False
|
|
||||||
self.current_timestep = 1.0
|
self.current_timestep = 1.0
|
||||||
|
|
||||||
def set_timestep(self, timestep: float):
|
def set_timestep(self, timestep: float):
|
||||||
@ -244,43 +227,35 @@ class FluxPosEmbed(torch.nn.Module):
|
|||||||
"freqs_dtype": freqs_dtype,
|
"freqs_dtype": freqs_dtype,
|
||||||
}
|
}
|
||||||
|
|
||||||
if i > 0:
|
max_pos = axis_pos.max().item()
|
||||||
max_pos = axis_pos.max().item()
|
current_patches = max_pos + 1
|
||||||
current_patches = max_pos + 1
|
|
||||||
|
|
||||||
if self.method == "yarn" and current_patches > self.base_patches:
|
if i == 0 or current_patches <= self.base_patches:
|
||||||
max_pe_len = torch.tensor(
|
|
||||||
current_patches, dtype=freqs_dtype, device=pos.device
|
|
||||||
)
|
|
||||||
cos, sin = get_1d_rotary_pos_embed(
|
|
||||||
**common_kwargs,
|
|
||||||
yarn=True,
|
|
||||||
max_pe_len=max_pe_len,
|
|
||||||
ori_max_pe_len=self.base_patches,
|
|
||||||
dype=self.dype,
|
|
||||||
current_timestep=self.current_timestep,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self.method == "ntk" and current_patches > self.base_patches:
|
|
||||||
base_ntk = (current_patches / self.base_patches) ** (
|
|
||||||
self.axes_dim[i] / (self.axes_dim[i] - 2)
|
|
||||||
)
|
|
||||||
ntk_factor = (
|
|
||||||
base_ntk ** (2.0 * (self.current_timestep**2.0))
|
|
||||||
if self.dype
|
|
||||||
else base_ntk
|
|
||||||
)
|
|
||||||
ntk_factor = max(1.0, ntk_factor)
|
|
||||||
|
|
||||||
cos, sin = get_1d_rotary_pos_embed(
|
|
||||||
**common_kwargs, ntk_factor=ntk_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
cos, sin = get_1d_rotary_pos_embed(**common_kwargs)
|
|
||||||
else:
|
|
||||||
cos, sin = get_1d_rotary_pos_embed(**common_kwargs)
|
cos, sin = get_1d_rotary_pos_embed(**common_kwargs)
|
||||||
|
|
||||||
|
elif self.method == "yarn":
|
||||||
|
max_pe_len = torch.tensor(
|
||||||
|
current_patches, dtype=freqs_dtype, device=pos.device
|
||||||
|
)
|
||||||
|
cos, sin = get_1d_rotary_pos_embed(
|
||||||
|
**common_kwargs,
|
||||||
|
yarn=True,
|
||||||
|
max_pe_len=max_pe_len,
|
||||||
|
ori_max_pe_len=self.base_patches,
|
||||||
|
current_timestep=self.current_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.method == "ntk":
|
||||||
|
base_ntk = (current_patches / self.base_patches) ** (
|
||||||
|
self.axes_dim[i] / (self.axes_dim[i] - 2)
|
||||||
|
)
|
||||||
|
ntk_factor = base_ntk ** (2.0 * (self.current_timestep**2.0))
|
||||||
|
ntk_factor = max(1.0, ntk_factor)
|
||||||
|
|
||||||
|
cos, sin = get_1d_rotary_pos_embed(
|
||||||
|
**common_kwargs, ntk_factor=ntk_factor
|
||||||
|
)
|
||||||
|
|
||||||
cos_out.append(cos)
|
cos_out.append(cos)
|
||||||
sin_out.append(sin)
|
sin_out.append(sin)
|
||||||
|
|
||||||
@ -298,35 +273,26 @@ class FluxPosEmbed(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def apply_dype_flux(model: ModelPatcher, method: str) -> ModelPatcher:
|
def apply_dype_flux(model: ModelPatcher, method: str) -> ModelPatcher:
|
||||||
if getattr(model.model, "_dype", None) == method:
|
_pe_embedder = model.model.diffusion_model.pe_embedder
|
||||||
return model
|
|
||||||
|
|
||||||
m = model.clone()
|
|
||||||
m.model._dype = method
|
|
||||||
|
|
||||||
_pe_embedder = m.model.diffusion_model.pe_embedder
|
|
||||||
_theta, _axes_dim = _pe_embedder.theta, _pe_embedder.axes_dim
|
_theta, _axes_dim = _pe_embedder.theta, _pe_embedder.axes_dim
|
||||||
|
|
||||||
pos_embedder = FluxPosEmbed(_theta, _axes_dim, method, dype=True)
|
pos_embedder = FluxPosEmbed(_theta, _axes_dim, method)
|
||||||
m.add_object_patch("diffusion_model.pe_embedder", pos_embedder)
|
model.add_object_patch("diffusion_model.pe_embedder", pos_embedder)
|
||||||
|
|
||||||
sigma_max = m.model.model_sampling.sigma_max.item()
|
sigma_max: float = model.model.model_sampling.sigma_max.item()
|
||||||
|
|
||||||
def dype_wrapper_function(model_function, args_dict):
|
def dype_wrapper_function(apply_model: Callable, args: dict):
|
||||||
timestep_tensor = args_dict.get("timestep")
|
timestep: torch.Tensor = args["timestep"]
|
||||||
if timestep_tensor is not None and timestep_tensor.numel() > 0:
|
sigma: float = timestep.item()
|
||||||
current_sigma = timestep_tensor.flatten()[0].item()
|
|
||||||
|
|
||||||
if sigma_max > 0:
|
normalized_timestep = min(max(sigma / sigma_max, 0.0), 1.0)
|
||||||
normalized_timestep = min(max(current_sigma / sigma_max, 0.0), 1.0)
|
pos_embedder.set_timestep(normalized_timestep)
|
||||||
pos_embedder.set_timestep(normalized_timestep)
|
|
||||||
|
|
||||||
input_x, c = args_dict.get("input"), args_dict.get("c", {})
|
return apply_model(args["input"], timestep, **args["c"])
|
||||||
return model_function(input_x, args_dict.get("timestep"), **c)
|
|
||||||
|
|
||||||
m.set_model_unet_function_wrapper(dype_wrapper_function)
|
model.set_model_unet_function_wrapper(dype_wrapper_function)
|
||||||
|
|
||||||
return m
|
return model
|
||||||
|
|
||||||
|
|
||||||
class DyPEPatchModelFlux(io.ComfyNode):
|
class DyPEPatchModelFlux(io.ComfyNode):
|
||||||
@ -338,11 +304,7 @@ class DyPEPatchModelFlux(io.ComfyNode):
|
|||||||
category="_for_testing",
|
category="_for_testing",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Combo.Input(
|
io.Combo.Input("method", options=["yarn", "ntk"], default="yarn"),
|
||||||
"method",
|
|
||||||
options=["yarn", "ntk", "base"],
|
|
||||||
default="yarn",
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
outputs=[io.Model.Output()],
|
outputs=[io.Model.Output()],
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
@ -350,8 +312,9 @@ class DyPEPatchModelFlux(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: ModelPatcher, method: str) -> io.NodeOutput:
|
def execute(cls, model: ModelPatcher, method: str) -> io.NodeOutput:
|
||||||
m = apply_dype_flux(model, method)
|
model = model.clone()
|
||||||
return io.NodeOutput(m)
|
model = apply_dype_flux(model, method)
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
class DyPEExtension(ComfyExtension):
|
class DyPEExtension(ComfyExtension):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user