Improve comments, optimize memory in blur routine

This commit is contained in:
Raphael Walker 2023-12-07 12:02:40 +01:00
parent f8d719f9e7
commit ecd098a7fd
No known key found for this signature in database
GPG Key ID: E6F58BE3395D3AA8
2 changed files with 5 additions and 8 deletions

View File

@ -257,19 +257,18 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
return x - model_options["sampler_cfg_function"](args) return x - model_options["sampler_cfg_function"](args)
# if cfg = 1.0, we can't do sag # if cfg = 1.0, we can't do sag
elif "sag" in model_options: elif "sag" in model_options:
assert uncond is not None, "SAG requires uncond guidance"
sag_scale = model_options["sag_scale"] sag_scale = model_options["sag_scale"]
sag_sigma = model_options["sag_sigma"] sag_sigma = model_options["sag_sigma"]
sag_threshold = model_options.get("sag_threshold", 1.0) sag_threshold = model_options.get("sag_threshold", 1.0)
# this method is added by the sag patcher # these methods are added by the sag patcher
uncond_attn = model.get_attn_scores() uncond_attn = model.get_attn_scores()
mid_shape = model.get_mid_block_shape() mid_shape = model.get_mid_block_shape()
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold) degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred degraded_noised = degraded + x - uncond_pred
assert uncond is not None, "SAG requires uncond guidance" # call into the UNet with the adversarially blurred image
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options) (sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
# Unless I've misunderstood the paper, this is supposed to be (uncond_pred - sag) * sag_scale.
# but this is what the automatic1111 implementation does, and it works better??
return uncond_pred + (cond_pred - uncond_pred) * cond_scale + (degraded - sag) * sag_scale return uncond_pred + (cond_pred - uncond_pred) * cond_scale + (degraded - sag) * sag_scale
else: else:
return uncond_pred + (cond_pred - uncond_pred) * cond_scale return uncond_pred + (cond_pred - uncond_pred) * cond_scale
@ -277,7 +276,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0): def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map # reshape and GAP the attention map
_, hw1, hw2 = attn.shape _, hw1, hw2 = attn.shape
b, lc, lh, lw = x0.shape b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2) attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool # Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
@ -285,7 +284,6 @@ def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
mask = ( mask = (
mask.reshape(b, *mid_shape) mask.reshape(b, *mid_shape)
.unsqueeze(1) .unsqueeze(1)
.repeat(1, lc, 1, 1)
.type(attn.dtype) .type(attn.dtype)
) )
# Upsample # Upsample

View File

@ -101,7 +101,6 @@ class SagNode:
if name not in to["patches_replace"]: if name not in to["patches_replace"]:
to["patches_replace"][name] = {} to["patches_replace"][name] = {}
to["patches_replace"][name][key] = patch to["patches_replace"][name][key] = patch
# this actually patches 2 attn calls -- confusing, since we only want to get one
set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0)) set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0))
# from diffusers: # from diffusers:
# unet.mid_block.attentions[0].register_forward_hook() # unet.mid_block.attentions[0].register_forward_hook()