From f8d719f9e7e1ed7fc7e617640951dcf7e7ae9c68 Mon Sep 17 00:00:00 2001 From: Raphael Walker Date: Thu, 7 Dec 2023 11:24:30 +0100 Subject: [PATCH] Fix a crash when using weird resolutions. Remove an unnecessary UNet call --- comfy/samplers.py | 17 ++++++----------- comfy_extras/nodes_sag.py | 19 +++++++++++++++++-- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index a23949ae7..25dc22f1b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -263,32 +263,27 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option # this method is added by the sag patcher uncond_attn = model.get_attn_scores() - degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) + mid_shape = model.get_mid_block_shape() + degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold) degraded_noised = degraded + x - uncond_pred - # TODO optimize this: doing it this way creates an extra call that we don't even use - (_, sag) = calc_cond_uncond_batch(model, cond, uncond, degraded_noised, timestep, model_options) + assert uncond is not None, "SAG requires uncond guidance" + (sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options) # Unless I've misunderstood the paper, this is supposed to be (uncond_pred - sag) * sag_scale. # but this is what the automatic1111 implementation does, and it works better?? return uncond_pred + (cond_pred - uncond_pred) * cond_scale + (degraded - sag) * sag_scale else: return uncond_pred + (cond_pred - uncond_pred) * cond_scale -def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): +def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0): # reshape and GAP the attention map _, hw1, hw2 = attn.shape b, lc, lh, lw = x0.shape attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - # we want to reshape the mask, which now has shape (b, w*h), to shape (b, 1, h, w). - # if we know the model beforehand, we can just divide lh and wh by the correct factor to size of the latent in the middle of the UNet - # but if we want to be model-agnostic, we can do it this way: just figure out the scale factor by the number of "pixels". - total_size_latent = lh * lw - scale_factor = int(math.sqrt(total_size_latent / mask.shape[1])) - middle_layer_latent_size = [math.ceil(lh/scale_factor), math.ceil(lw/scale_factor)] # Reshape mask = ( - mask.reshape(b, *middle_layer_latent_size) + mask.reshape(b, *mid_shape) .unsqueeze(1) .repeat(1, lc, 1, 1) .type(attn.dtype) diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 4f4e3f64b..910257cff 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -69,7 +69,9 @@ class SagNode: m.model_options["sag_sigma"] = blur_sigma attn_scores = None + mid_block_shape = None m.model.get_attn_scores = lambda: attn_scores + m.model.get_mid_block_shape = lambda: mid_block_shape # TODO: make this work properly with chunked batches # currently, we can only save the attn from one UNet call @@ -92,8 +94,21 @@ class SagNode: # from diffusers: # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch - # we might have to patch at different locations depending on sd1.5/2.1 vs sdXL - m.set_model_patch_replace(attn_and_record, "attn1", "middle", 0) + def set_model_patch_replace(patch, name, key): + to = m.model_options["transformer_options"] + if "patches_replace" not in to: + to["patches_replace"] = {} + if name not in to["patches_replace"]: + to["patches_replace"][name] = {} + to["patches_replace"][name][key] = patch + # this actually patches 2 attn calls -- confusing, since we only want to get one + set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0)) + # from diffusers: + # unet.mid_block.attentions[0].register_forward_hook() + def forward_hook(m, inp, out): + nonlocal mid_block_shape + mid_block_shape = out[0].shape[-2:] + m.model.diffusion_model.middle_block[0].register_forward_hook(forward_hook) return (m, ) NODE_CLASS_MAPPINGS = {