From bd07ad1861949007139de7dd5c6bcdb77426919c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Nov 2023 13:23:25 -0500 Subject: [PATCH 1/3] Add PatchModelAddDownscale (Kohya Deep Shrink) node. By adding a downscale to the unet in the first timesteps this node lets you generate images at higher resolutions with less consistency issues. --- comfy_extras/nodes_model_downscale.py | 45 +++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 46 insertions(+) create mode 100644 comfy_extras/nodes_model_downscale.py diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py new file mode 100644 index 000000000..f1b2d3ff2 --- /dev/null +++ b/comfy_extras/nodes_model_downscale.py @@ -0,0 +1,45 @@ +import torch + +class PatchModelAddDownscale: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), + "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, block_number, downscale_factor, start_percent, end_percent): + sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item() + sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item() + + def input_block_patch(h, transformer_options): + if transformer_options["block"][1] == block_number: + sigma = transformer_options["sigmas"][0].item() + if sigma <= sigma_start and sigma >= sigma_end: + h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False) + return h + + def output_block_patch(h, hsp, transformer_options): + if h.shape[2] != hsp.shape[2]: + h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False) + return h, hsp + + m = model.clone() + m.set_model_input_block_patch(input_block_patch) + m.set_model_output_block_patch(output_block_patch) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "PatchModelAddDownscale": PatchModelAddDownscale, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + # Sampling + "PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)", +} diff --git a/nodes.py b/nodes.py index e8cfb5e6a..f9d2d7f6c 100644 --- a/nodes.py +++ b/nodes.py @@ -1799,6 +1799,7 @@ def init_custom_nodes(): "nodes_custom_sampler.py", "nodes_hypertile.py", "nodes_model_advanced.py", + "nodes_model_downscale.py", ] for node_file in extras_files: From 9f00a18095e5f8ef114525bc19db035756501959 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Nov 2023 14:59:54 -0500 Subject: [PATCH 2/3] Fix potential issues. --- comfy/model_patcher.py | 2 +- comfy/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 023684331..7f5ed45fe 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -37,7 +37,7 @@ class ModelPatcher: return size def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] diff --git a/comfy/utils.py b/comfy/utils.py index 1985012e0..f4c0ab419 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -258,7 +258,7 @@ def set_attr(obj, attr, value): for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value)) + setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) del prev def copy_to_param(obj, attr, value): From 7e3fe3ad28fad4dede2893d77093a086344b81b6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Nov 2023 15:26:28 -0500 Subject: [PATCH 3/3] Make deep shrink behave like it should. --- comfy/ldm/modules/diffusionmodules/openaimodel.py | 4 ++++ comfy/model_patcher.py | 3 +++ comfy_extras/nodes_model_downscale.py | 8 ++++++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 504b79ede..10eb68d73 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -633,6 +633,10 @@ class UNetModel(nn.Module): h = p(h, transformer_options) hs.append(h) + if "input_block_patch_after_skip" in transformer_patches: + patch = transformer_patches["input_block_patch_after_skip"] + for p in patch: + h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7f5ed45fe..a3cffc3be 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -99,6 +99,9 @@ class ModelPatcher: def set_model_input_block_patch(self, patch): self.set_model_patch(patch, "input_block_patch") + def set_model_input_block_patch_after_skip(self, patch): + self.set_model_patch(patch, "input_block_patch_after_skip") + def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index f1b2d3ff2..8850d0948 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -8,13 +8,14 @@ class PatchModelAddDownscale: "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), + "downscale_after_skip": ("BOOLEAN", {"default": True}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "_for_testing" - def patch(self, model, block_number, downscale_factor, start_percent, end_percent): + def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip): sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item() sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item() @@ -31,7 +32,10 @@ class PatchModelAddDownscale: return h, hsp m = model.clone() - m.set_model_input_block_patch(input_block_patch) + if downscale_after_skip: + m.set_model_input_block_patch_after_skip(input_block_patch) + else: + m.set_model_input_block_patch(input_block_patch) m.set_model_output_block_patch(output_block_patch) return (m, )