mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
- ComfyUI can now load EXR files. - There are new arithmetic nodes for floats and integers. - EXR nodes can load depth maps and be remapped with ImageApplyColormap. This allows end users to use ground truth depth data from video game engines or 3D graphics tools and recolor it to the format expected by depth ControlNets: grayscale inverse depth maps and "inferno" colored inverse depth maps. - Fixed license notes. - Added an additional known ControlNet model. - Because CV2 is now used to read OpenEXR files, an environment variable must be set early on in the application, before CV2 is imported. This file, main_pre, is now imported early on in more places.
171 lines
5.8 KiB
Python
171 lines
5.8 KiB
Python
import torch
|
|
from torch import einsum
|
|
import torch.nn.functional as F
|
|
import math
|
|
|
|
from einops import rearrange, repeat
|
|
import os
|
|
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
|
|
from comfy import samplers
|
|
|
|
# from comfy/ldm/modules/attention.py
|
|
# but modified to return attention scores as well as output
|
|
def attention_basic_with_sim(q, k, v, heads, mask=None):
|
|
b, _, dim_head = q.shape
|
|
dim_head //= heads
|
|
scale = dim_head ** -0.5
|
|
|
|
h = heads
|
|
q, k, v = map(
|
|
lambda t: t.unsqueeze(3)
|
|
.reshape(b, -1, heads, dim_head)
|
|
.permute(0, 2, 1, 3)
|
|
.reshape(b * heads, -1, dim_head)
|
|
.contiguous(),
|
|
(q, k, v),
|
|
)
|
|
|
|
# force cast to fp32 to avoid overflowing
|
|
if _ATTN_PRECISION =="fp32":
|
|
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
|
else:
|
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
|
|
|
del q, k
|
|
|
|
if mask is not None:
|
|
mask = rearrange(mask, 'b ... -> b (...)')
|
|
max_neg_value = -torch.finfo(sim.dtype).max
|
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
|
sim.masked_fill_(~mask, max_neg_value)
|
|
|
|
# attention, what we cannot get enough of
|
|
sim = sim.softmax(dim=-1)
|
|
|
|
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
|
out = (
|
|
out.unsqueeze(0)
|
|
.reshape(b, heads, -1, dim_head)
|
|
.permute(0, 2, 1, 3)
|
|
.reshape(b, -1, heads * dim_head)
|
|
)
|
|
return (out, sim)
|
|
|
|
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
|
# reshape and GAP the attention map
|
|
_, hw1, hw2 = attn.shape
|
|
b, _, lh, lw = x0.shape
|
|
attn = attn.reshape(b, -1, hw1, hw2)
|
|
# Global Average Pool
|
|
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
|
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
|
|
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
|
|
|
# Reshape
|
|
mask = (
|
|
mask.reshape(b, *mid_shape)
|
|
.unsqueeze(1)
|
|
.type(attn.dtype)
|
|
)
|
|
# Upsample
|
|
mask = F.interpolate(mask, (lh, lw))
|
|
|
|
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
|
blurred = blurred * mask + x0 * (1 - mask)
|
|
return blurred
|
|
|
|
def gaussian_blur_2d(img, kernel_size, sigma):
|
|
ksize_half = (kernel_size - 1) * 0.5
|
|
|
|
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
|
|
|
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
|
|
|
x_kernel = pdf / pdf.sum()
|
|
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
|
|
|
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
|
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
|
|
|
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
|
|
|
img = F.pad(img, padding, mode="reflect")
|
|
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
|
return img
|
|
|
|
class SelfAttentionGuidance:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model": ("MODEL",),
|
|
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01, "round": 0.01}),
|
|
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": 0.01}),
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
|
|
CATEGORY = "_for_testing"
|
|
|
|
def patch(self, model, scale, blur_sigma):
|
|
m = model.clone()
|
|
|
|
attn_scores = None
|
|
|
|
# TODO: make this work properly with chunked batches
|
|
# currently, we can only save the attn from one UNet call
|
|
def attn_and_record(q, k, v, extra_options):
|
|
nonlocal attn_scores
|
|
# if uncond, save the attention scores
|
|
heads = extra_options["n_heads"]
|
|
cond_or_uncond = extra_options["cond_or_uncond"]
|
|
b = q.shape[0] // len(cond_or_uncond)
|
|
if 1 in cond_or_uncond:
|
|
uncond_index = cond_or_uncond.index(1)
|
|
# do the entire attention operation, but save the attention scores to attn_scores
|
|
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
|
|
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
|
|
n_slices = heads * b
|
|
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
|
|
return out
|
|
else:
|
|
return optimized_attention(q, k, v, heads=heads)
|
|
|
|
def post_cfg_function(args):
|
|
nonlocal attn_scores
|
|
uncond_attn = attn_scores
|
|
|
|
sag_scale = scale
|
|
sag_sigma = blur_sigma
|
|
sag_threshold = 1.0
|
|
model = args["model"]
|
|
uncond_pred = args["uncond_denoised"]
|
|
uncond = args["uncond"]
|
|
cfg_result = args["denoised"]
|
|
sigma = args["sigma"]
|
|
model_options = args["model_options"]
|
|
x = args["input"]
|
|
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
|
|
return cfg_result
|
|
|
|
# create the adversarially blurred image
|
|
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
|
degraded_noised = degraded + x - uncond_pred
|
|
# call into the UNet
|
|
(sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
|
|
return cfg_result + (degraded - sag) * sag_scale
|
|
|
|
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
|
|
|
|
# from diffusers:
|
|
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
|
|
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
|
|
|
|
return (m, )
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"SelfAttentionGuidance": SelfAttentionGuidance,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"SelfAttentionGuidance": "Self-Attention Guidance",
|
|
}
|