From 954683d0dbd8f098c5485422a1e27f33fe951c32 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Tue, 29 Oct 2024 21:59:21 +0800 Subject: [PATCH 1/2] SLG first implementation for SD3.5 (#5404) * SLG first implementation for SD3.5 * * Simplify and align with comfy style --- comfy_extras/nodes_sd3.py | 61 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index ddf538deb..6bd06f4a3 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -3,7 +3,7 @@ import comfy.sd import comfy.model_management import nodes import torch - +import re class TripleCLIPLoader: @classmethod def INPUT_TYPES(s): @@ -95,11 +95,70 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): CATEGORY = "conditioning/controlnet" DEPRECATED = True +class SkipLayerGuidanceSD3: + ''' + Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. + Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) + Experimental implementation by Dango233@StabilityAI. + ''' + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "layers": ("STRING", {"default": "7,8,9", "multiline": False}), + "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), + "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "skip_guidance" + + CATEGORY = "advanced/guidance" + + + def skip_guidance(self, model, layers, scale, start_percent, end_percent): + if layers == "" or layers == None: + return (model, ) + # check if layer is comma separated integers + assert layers.replace(",", "").isdigit(), "Layers must be comma separated integers" + def skip(args, extra_args): + return args + + model_sampling = model.get_model_object("model_sampling") + + def post_cfg_function(args): + model = args["model"] + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + x = args["input"] + model_options = args["model_options"].copy() + + for layer in layers: + model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer) + model_sampling.percent_to_sigma(start_percent) + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + sigma_ = sigma[0].item() + if scale > 0 and sigma_ > sigma_end and sigma_ < sigma_start: + (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) + cfg_result = cfg_result + (cond_pred - slg) * scale + return cfg_result + + layers = re.findall(r'\d+', layers) + layers = [int(i) for i in layers] + m = model.clone() + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return (m, ) + + NODE_CLASS_MAPPINGS = { "TripleCLIPLoader": TripleCLIPLoader, "EmptySD3LatentImage": EmptySD3LatentImage, "CLIPTextEncodeSD3": CLIPTextEncodeSD3, "ControlNetApplySD3": ControlNetApplySD3, + "SkipLayerGuidanceSD3": SkipLayerGuidanceSD3, } NODE_DISPLAY_NAME_MAPPINGS = { From 770ab200f296d8d0269d37fdca84bb742cee38b1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 29 Oct 2024 10:11:46 -0400 Subject: [PATCH 2/2] Cleanup SkipLayerGuidanceSD3 node. --- comfy_extras/nodes_sd3.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 6bd06f4a3..4d664093c 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -104,7 +104,7 @@ class SkipLayerGuidanceSD3: @classmethod def INPUT_TYPES(s): return {"required": {"model": ("MODEL", ), - "layers": ("STRING", {"default": "7,8,9", "multiline": False}), + "layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) @@ -119,11 +119,12 @@ class SkipLayerGuidanceSD3: if layers == "" or layers == None: return (model, ) # check if layer is comma separated integers - assert layers.replace(",", "").isdigit(), "Layers must be comma separated integers" def skip(args, extra_args): return args model_sampling = model.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) def post_cfg_function(args): model = args["model"] @@ -137,10 +138,9 @@ class SkipLayerGuidanceSD3: for layer in layers: model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer) model_sampling.percent_to_sigma(start_percent) - sigma_start = model_sampling.percent_to_sigma(start_percent) - sigma_end = model_sampling.percent_to_sigma(end_percent) + sigma_ = sigma[0].item() - if scale > 0 and sigma_ > sigma_end and sigma_ < sigma_start: + if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start: (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) cfg_result = cfg_result + (cond_pred - slg) * scale return cfg_result