First SAG test

This commit is contained in:
Raphael Walker 2023-12-06 14:46:16 +01:00
parent e134547341
commit 2920eab33b
No known key found for this signature in database
GPG Key ID: E6F58BE3395D3AA8
3 changed files with 158 additions and 3 deletions

View File

@ -1,6 +1,7 @@
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
import torch
import torch.nn.functional as F
import enum
from comfy import model_management
import math
@ -60,10 +61,10 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = None
if 'control' in conds:
@ -82,7 +83,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
patches['middle_patch'] = [gligen_patch]
return (input_x, mult, conditionning, area, control, patches)
return (input_x, mult, conditioning, area, control, patches)
def cond_equal_size(c1, c2):
if c1 is c2:
@ -253,9 +254,67 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
return x - model_options["sampler_cfg_function"](args)
elif "sag" in model_options:
sag_scale = model_options["sag_scale"]
sag_sigma = model_options["sag_sigma"]
sag_threshold = model_options.get("sag_threshold", 1.0)
# or is it x - uncond?
# or do I have to use the sigma ?
x0_est = uncond
# this method is added by the sag patcher
uncond_attn = model.get_attn_scores()
degraded = create_blur_map(x0_est, uncond_attn, (x - uncond), sag_sigma, sag_threshold)
# todo, optimize this: doing it this way creates an extra call that we don't even use
(_, sag) = calc_cond_uncond_batch(model, cond, uncond, degraded, timestep, model_options)
return uncond + (cond - uncond) * cond_scale + (uncond - sag) * sag_scale
else:
return uncond + (cond - uncond) * cond_scale
def create_blur_map(x0, attn, noise, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, lc, lh, lw = x0.shape
middle_layer_latent_size = [math.ceil(lh/8), math.ceil(lw/8)]
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=True) > threshold
# Reshape
mask = (
mask.reshape(b, **middle_layer_latent_size)
.unsqueeze(1)
.repeat(1, lc, 1, 1)
.type(attn.dtype)
)
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
blurred = blurred + noise
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
super().__init__()

95
comfy_extras/nodes_sag.py Normal file
View File

@ -0,0 +1,95 @@
import torch
from torch import einsum
from einops import rearrange, repeat
import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
def attention_basic(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return (out, sim)
class SagNode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 30.0}),
"blur_sigma": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, scale, blur_sigma):
m = model.clone()
# set extra options on the model
m.extra_options["sag"] = True
m.extra_options["sag_scale"] = scale
m.extra_options["sag_sigma"] = blur_sigma
attn_scores = None
m.get_attn_scores = lambda: attn_scores
def attn_and_record(q, k, v, extra_options):
nonlocal attn_scores
# if uncond, save the attention scores
cond_or_uncond = extra_options["cond_or_uncond"]
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic(q, k, v, heads=extra_options["n_heads"])
attn_scores = sim[uncond_index]
return out
else:
return optimized_attention(q, k, v, heads = extra_options["n_heads"])
# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
# we might have to patch at different locations depending on sd1.5/2.1 vs sdXL
m.set_model_patch_replace(attn_and_record, "attn1", "middle", 0)
return (m, )
NODE_CLASS_MAPPINGS = {
"Self-Attention Guidance": SagNode,
}

View File

@ -1867,6 +1867,7 @@ def init_custom_nodes():
"nodes_model_downscale.py",
"nodes_images.py",
"nodes_video_model.py",
"nodes_sag.py",
]
for node_file in extras_files: