Fix a crash when using weird resolutions. Remove an unnecessary UNet call

This commit is contained in:
Raphael Walker 2023-12-07 11:24:30 +01:00
parent 20d8a06bbe
commit f8d719f9e7
No known key found for this signature in database
GPG Key ID: E6F58BE3395D3AA8
2 changed files with 23 additions and 13 deletions

View File

@ -263,32 +263,27 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
# this method is added by the sag patcher # this method is added by the sag patcher
uncond_attn = model.get_attn_scores() uncond_attn = model.get_attn_scores()
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) mid_shape = model.get_mid_block_shape()
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred degraded_noised = degraded + x - uncond_pred
# TODO optimize this: doing it this way creates an extra call that we don't even use assert uncond is not None, "SAG requires uncond guidance"
(_, sag) = calc_cond_uncond_batch(model, cond, uncond, degraded_noised, timestep, model_options) (sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
# Unless I've misunderstood the paper, this is supposed to be (uncond_pred - sag) * sag_scale. # Unless I've misunderstood the paper, this is supposed to be (uncond_pred - sag) * sag_scale.
# but this is what the automatic1111 implementation does, and it works better?? # but this is what the automatic1111 implementation does, and it works better??
return uncond_pred + (cond_pred - uncond_pred) * cond_scale + (degraded - sag) * sag_scale return uncond_pred + (cond_pred - uncond_pred) * cond_scale + (degraded - sag) * sag_scale
else: else:
return uncond_pred + (cond_pred - uncond_pred) * cond_scale return uncond_pred + (cond_pred - uncond_pred) * cond_scale
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map # reshape and GAP the attention map
_, hw1, hw2 = attn.shape _, hw1, hw2 = attn.shape
b, lc, lh, lw = x0.shape b, lc, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2) attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool # Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
# we want to reshape the mask, which now has shape (b, w*h), to shape (b, 1, h, w).
# if we know the model beforehand, we can just divide lh and wh by the correct factor to size of the latent in the middle of the UNet
# but if we want to be model-agnostic, we can do it this way: just figure out the scale factor by the number of "pixels".
total_size_latent = lh * lw
scale_factor = int(math.sqrt(total_size_latent / mask.shape[1]))
middle_layer_latent_size = [math.ceil(lh/scale_factor), math.ceil(lw/scale_factor)]
# Reshape # Reshape
mask = ( mask = (
mask.reshape(b, *middle_layer_latent_size) mask.reshape(b, *mid_shape)
.unsqueeze(1) .unsqueeze(1)
.repeat(1, lc, 1, 1) .repeat(1, lc, 1, 1)
.type(attn.dtype) .type(attn.dtype)

View File

@ -69,7 +69,9 @@ class SagNode:
m.model_options["sag_sigma"] = blur_sigma m.model_options["sag_sigma"] = blur_sigma
attn_scores = None attn_scores = None
mid_block_shape = None
m.model.get_attn_scores = lambda: attn_scores m.model.get_attn_scores = lambda: attn_scores
m.model.get_mid_block_shape = lambda: mid_block_shape
# TODO: make this work properly with chunked batches # TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call # currently, we can only save the attn from one UNet call
@ -92,8 +94,21 @@ class SagNode:
# from diffusers: # from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
# we might have to patch at different locations depending on sd1.5/2.1 vs sdXL def set_model_patch_replace(patch, name, key):
m.set_model_patch_replace(attn_and_record, "attn1", "middle", 0) to = m.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][key] = patch
# this actually patches 2 attn calls -- confusing, since we only want to get one
set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0))
# from diffusers:
# unet.mid_block.attentions[0].register_forward_hook()
def forward_hook(m, inp, out):
nonlocal mid_block_shape
mid_block_shape = out[0].shape[-2:]
m.model.diffusion_model.middle_block[0].register_forward_hook(forward_hook)
return (m, ) return (m, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {