ComfyUI/comfy_extras/nodes/nodes_sag.py
doctorpangloss b0be335d59 Improved support for ControlNet workflows with depth
- ComfyUI can now load EXR files.
 - There are new arithmetic nodes for floats and integers.
 - EXR nodes can load depth maps and be remapped with
   ImageApplyColormap. This allows end users to use ground truth depth
   data from video game engines or 3D graphics tools and recolor it to
   the format expected by depth ControlNets: grayscale inverse depth
   maps and "inferno" colored inverse depth maps.
 - Fixed license notes.
 - Added an additional known ControlNet model.
 - Because CV2 is now used to read OpenEXR files, an environment
   variable must be set early on in the application, before CV2 is
   imported. This file, main_pre, is now imported early on in more
   places.
2024-03-26 22:32:15 -07:00

171 lines
5.8 KiB
Python

import torch
from torch import einsum
import torch.nn.functional as F
import math
from einops import rearrange, repeat
import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
from comfy import samplers
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
def attention_basic_with_sim(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return (out, sim)
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class SelfAttentionGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01, "round": 0.01}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, scale, blur_sigma):
m = model.clone()
attn_scores = None
# TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call
def attn_and_record(q, k, v, extra_options):
nonlocal attn_scores
# if uncond, save the attention scores
heads = extra_options["n_heads"]
cond_or_uncond = extra_options["cond_or_uncond"]
b = q.shape[0] // len(cond_or_uncond)
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out
else:
return optimized_attention(q, k, v, heads=heads)
def post_cfg_function(args):
nonlocal attn_scores
uncond_attn = attn_scores
sag_scale = scale
sag_sigma = blur_sigma
sag_threshold = 1.0
model = args["model"]
uncond_pred = args["uncond_denoised"]
uncond = args["uncond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"]
x = args["input"]
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
return cfg_result
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
return (m, )
NODE_CLASS_MAPPINGS = {
"SelfAttentionGuidance": SelfAttentionGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SelfAttentionGuidance": "Self-Attention Guidance",
}