From 20d8a06bbeee08dc1c181c530c9b83fe78b23ace Mon Sep 17 00:00:00 2001 From: Raphael Walker Date: Wed, 6 Dec 2023 17:23:02 +0100 Subject: [PATCH] Use @ashen-uncensored formula, which works better!!! --- comfy/samplers.py | 22 ++++++++-------------- comfy_extras/nodes_sag.py | 4 ++-- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index d75372b25..a23949ae7 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -261,26 +261,22 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option sag_sigma = model_options["sag_sigma"] sag_threshold = model_options.get("sag_threshold", 1.0) - # or is it x - uncond? - # or do I have to use the sigma ? - x0_est = uncond_pred # this method is added by the sag patcher uncond_attn = model.get_attn_scores() - degraded = create_blur_map(x0_est, uncond_attn, x - uncond_pred, sag_sigma, sag_threshold) - # todo, optimize this: doing it this way creates an extra call that we don't even use - (_, sag) = calc_cond_uncond_batch(model, cond, uncond, degraded, timestep, model_options) - - return uncond_pred + (cond_pred - uncond_pred) * cond_scale + (uncond_pred - sag) * sag_scale + degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) + degraded_noised = degraded + x - uncond_pred + # TODO optimize this: doing it this way creates an extra call that we don't even use + (_, sag) = calc_cond_uncond_batch(model, cond, uncond, degraded_noised, timestep, model_options) + # Unless I've misunderstood the paper, this is supposed to be (uncond_pred - sag) * sag_scale. + # but this is what the automatic1111 implementation does, and it works better?? + return uncond_pred + (cond_pred - uncond_pred) * cond_scale + (degraded - sag) * sag_scale else: return uncond_pred + (cond_pred - uncond_pred) * cond_scale -def create_blur_map(x0, attn, noise, sigma=3.0, threshold=1.0): +def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): # reshape and GAP the attention map _, hw1, hw2 = attn.shape b, lc, lh, lw = x0.shape - # I think this depends on the model: - # sdxl has 20 heads and the middle of the unet is 4 times smaller - # sd 1.5 has 8 heads and the middle of the unet is 8 times smaller attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold @@ -302,8 +298,6 @@ def create_blur_map(x0, attn, noise, sigma=3.0, threshold=1.0): blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) blurred = blurred * mask + x0 * (1 - mask) - blurred = blurred + noise - return blurred def gaussian_blur_2d(img, kernel_size, sigma): diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index ee3c0f733..4f4e3f64b 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -53,8 +53,8 @@ class SagNode: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "scale": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 30.0, "step": 0.1}), - "blur_sigma": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), + "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}), + "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch"