mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 02:07:32 +08:00
Implement NAG on all the models based on the Flux code. (#12500)
Use the Normalized Attention Guidance node. Flux, Flux2, Klein, Chroma, Chroma radiance, Hunyuan Video, etc..
This commit is contained in:
parent
8a6fbc2dc2
commit
18927538a1
@ -152,6 +152,7 @@ class Chroma(nn.Module):
|
|||||||
transformer_options={},
|
transformer_options={},
|
||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
@ -228,6 +229,7 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if i not in self.skip_dit:
|
if i not in self.skip_dit:
|
||||||
|
|||||||
@ -196,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
extra_options = transformer_options.copy()
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
@ -224,6 +227,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
|
if "attn1_output_patch" in transformer_patches:
|
||||||
|
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||||
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
attn = p(attn, extra_options)
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
@ -303,6 +312,9 @@ class SingleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
mod = vec
|
mod = vec
|
||||||
|
|
||||||
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
extra_options = transformer_options.copy()
|
||||||
|
|
||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
@ -312,6 +324,12 @@ class SingleStreamBlock(nn.Module):
|
|||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
|
if "attn1_output_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
attn = p(attn, extra_options)
|
||||||
|
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
if self.yak_mlp:
|
if self.yak_mlp:
|
||||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||||
|
|||||||
@ -142,6 +142,7 @@ class Flux(nn.Module):
|
|||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
@ -231,6 +232,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
|
|||||||
@ -304,6 +304,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
control=None,
|
control=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
initial_shape = list(img.shape)
|
initial_shape = list(img.shape)
|
||||||
@ -416,6 +417,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
|
|||||||
@ -406,13 +406,16 @@ class ModelPatcher:
|
|||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
return self.model.memory_required(input_shape=input_shape)
|
return self.model.memory_required(input_shape=input_shape)
|
||||||
|
|
||||||
|
def disable_model_cfg1_optimization(self):
|
||||||
|
self.model_options["disable_cfg1_optimization"] = True
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
else:
|
else:
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
if disable_cfg1_optimization:
|
if disable_cfg1_optimization:
|
||||||
self.model_options["disable_cfg1_optimization"] = True
|
self.disable_model_cfg1_optimization()
|
||||||
|
|
||||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||||
|
|||||||
99
comfy_extras/nodes_nag.py
Normal file
99
comfy_extras/nodes_nag.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import torch
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
class NAGuidance(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="NAGuidance",
|
||||||
|
display_name="Normalized Attention Guidance",
|
||||||
|
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
|
||||||
|
category="",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The model to apply NAG to."),
|
||||||
|
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
|
||||||
|
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
|
||||||
|
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
|
||||||
|
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
|
||||||
|
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(tooltip="The patched model with NAG enabled."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
|
||||||
|
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
|
||||||
|
|
||||||
|
def nag_attention_output_patch(out, extra_options):
|
||||||
|
cond_or_uncond = extra_options.get("cond_or_uncond", None)
|
||||||
|
if cond_or_uncond is None:
|
||||||
|
return out
|
||||||
|
|
||||||
|
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
|
||||||
|
return out
|
||||||
|
|
||||||
|
# sigma = extra_options.get("sigmas", None)
|
||||||
|
# if sigma is not None and len(sigma) > 0:
|
||||||
|
# sigma = sigma[0].item()
|
||||||
|
# if sigma > sigma_start or sigma < sigma_end:
|
||||||
|
# return out
|
||||||
|
|
||||||
|
img_slice = extra_options.get("img_slice", None)
|
||||||
|
|
||||||
|
if img_slice is not None:
|
||||||
|
orig_out = out
|
||||||
|
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
|
||||||
|
|
||||||
|
batch_size = out.shape[0]
|
||||||
|
half_size = batch_size // len(cond_or_uncond)
|
||||||
|
|
||||||
|
ind_neg = cond_or_uncond.index(1)
|
||||||
|
ind_pos = cond_or_uncond.index(0)
|
||||||
|
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
|
||||||
|
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
|
||||||
|
|
||||||
|
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
|
||||||
|
|
||||||
|
eps = 1e-6
|
||||||
|
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||||
|
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||||
|
|
||||||
|
ratio = norm_guided / norm_pos
|
||||||
|
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
|
||||||
|
|
||||||
|
guided_normalized = guided * scale_factor
|
||||||
|
|
||||||
|
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
|
||||||
|
|
||||||
|
if img_slice is not None:
|
||||||
|
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
|
||||||
|
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
|
||||||
|
return orig_out
|
||||||
|
else:
|
||||||
|
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
|
||||||
|
return out
|
||||||
|
|
||||||
|
m.set_model_attn1_output_patch(nag_attention_output_patch)
|
||||||
|
m.disable_model_cfg1_optimization()
|
||||||
|
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
class NagExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
NAGuidance,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> NagExtension:
|
||||||
|
return NagExtension()
|
||||||
Loading…
Reference in New Issue
Block a user