mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
Improve comments, optimize memory in blur routine
This commit is contained in:
parent
f8d719f9e7
commit
ecd098a7fd
@ -257,27 +257,26 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
return x - model_options["sampler_cfg_function"](args)
|
return x - model_options["sampler_cfg_function"](args)
|
||||||
# if cfg = 1.0, we can't do sag
|
# if cfg = 1.0, we can't do sag
|
||||||
elif "sag" in model_options:
|
elif "sag" in model_options:
|
||||||
|
assert uncond is not None, "SAG requires uncond guidance"
|
||||||
sag_scale = model_options["sag_scale"]
|
sag_scale = model_options["sag_scale"]
|
||||||
sag_sigma = model_options["sag_sigma"]
|
sag_sigma = model_options["sag_sigma"]
|
||||||
sag_threshold = model_options.get("sag_threshold", 1.0)
|
sag_threshold = model_options.get("sag_threshold", 1.0)
|
||||||
|
|
||||||
# this method is added by the sag patcher
|
# these methods are added by the sag patcher
|
||||||
uncond_attn = model.get_attn_scores()
|
uncond_attn = model.get_attn_scores()
|
||||||
mid_shape = model.get_mid_block_shape()
|
mid_shape = model.get_mid_block_shape()
|
||||||
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
|
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
|
||||||
degraded_noised = degraded + x - uncond_pred
|
degraded_noised = degraded + x - uncond_pred
|
||||||
assert uncond is not None, "SAG requires uncond guidance"
|
# call into the UNet with the adversarially blurred image
|
||||||
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
|
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
|
||||||
# Unless I've misunderstood the paper, this is supposed to be (uncond_pred - sag) * sag_scale.
|
return uncond_pred + (cond_pred - uncond_pred) * cond_scale + (degraded - sag) * sag_scale
|
||||||
# but this is what the automatic1111 implementation does, and it works better??
|
|
||||||
return uncond_pred + (cond_pred - uncond_pred) * cond_scale + (degraded - sag) * sag_scale
|
|
||||||
else:
|
else:
|
||||||
return uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
return uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||||
|
|
||||||
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
|
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
|
||||||
# reshape and GAP the attention map
|
# reshape and GAP the attention map
|
||||||
_, hw1, hw2 = attn.shape
|
_, hw1, hw2 = attn.shape
|
||||||
b, lc, lh, lw = x0.shape
|
b, _, lh, lw = x0.shape
|
||||||
attn = attn.reshape(b, -1, hw1, hw2)
|
attn = attn.reshape(b, -1, hw1, hw2)
|
||||||
# Global Average Pool
|
# Global Average Pool
|
||||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||||
@ -285,7 +284,6 @@ def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
|
|||||||
mask = (
|
mask = (
|
||||||
mask.reshape(b, *mid_shape)
|
mask.reshape(b, *mid_shape)
|
||||||
.unsqueeze(1)
|
.unsqueeze(1)
|
||||||
.repeat(1, lc, 1, 1)
|
|
||||||
.type(attn.dtype)
|
.type(attn.dtype)
|
||||||
)
|
)
|
||||||
# Upsample
|
# Upsample
|
||||||
|
|||||||
@ -101,7 +101,6 @@ class SagNode:
|
|||||||
if name not in to["patches_replace"]:
|
if name not in to["patches_replace"]:
|
||||||
to["patches_replace"][name] = {}
|
to["patches_replace"][name] = {}
|
||||||
to["patches_replace"][name][key] = patch
|
to["patches_replace"][name][key] = patch
|
||||||
# this actually patches 2 attn calls -- confusing, since we only want to get one
|
|
||||||
set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0))
|
set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0))
|
||||||
# from diffusers:
|
# from diffusers:
|
||||||
# unet.mid_block.attentions[0].register_forward_hook()
|
# unet.mid_block.attentions[0].register_forward_hook()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user