mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-10 19:57:42 +08:00
Merge remote-tracking branch 'origin/master' into group-nodes
This commit is contained in:
commit
08af9c6655
@ -633,6 +633,10 @@ class UNetModel(nn.Module):
|
|||||||
h = p(h, transformer_options)
|
h = p(h, transformer_options)
|
||||||
|
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
if "input_block_patch_after_skip" in transformer_patches:
|
||||||
|
patch = transformer_patches["input_block_patch_after_skip"]
|
||||||
|
for p in patch:
|
||||||
|
h = p(h, transformer_options)
|
||||||
|
|
||||||
transformer_options["block"] = ("middle", 0)
|
transformer_options["block"] = ("middle", 0)
|
||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class ModelPatcher:
|
|||||||
return size
|
return size
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
n.patches[k] = self.patches[k][:]
|
n.patches[k] = self.patches[k][:]
|
||||||
@ -99,6 +99,9 @@ class ModelPatcher:
|
|||||||
def set_model_input_block_patch(self, patch):
|
def set_model_input_block_patch(self, patch):
|
||||||
self.set_model_patch(patch, "input_block_patch")
|
self.set_model_patch(patch, "input_block_patch")
|
||||||
|
|
||||||
|
def set_model_input_block_patch_after_skip(self, patch):
|
||||||
|
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||||
|
|
||||||
def set_model_output_block_patch(self, patch):
|
def set_model_output_block_patch(self, patch):
|
||||||
self.set_model_patch(patch, "output_block_patch")
|
self.set_model_patch(patch, "output_block_patch")
|
||||||
|
|
||||||
|
|||||||
@ -258,7 +258,7 @@ def set_attr(obj, attr, value):
|
|||||||
for name in attrs[:-1]:
|
for name in attrs[:-1]:
|
||||||
obj = getattr(obj, name)
|
obj = getattr(obj, name)
|
||||||
prev = getattr(obj, attrs[-1])
|
prev = getattr(obj, attrs[-1])
|
||||||
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
|
||||||
del prev
|
del prev
|
||||||
|
|
||||||
def copy_to_param(obj, attr, value):
|
def copy_to_param(obj, attr, value):
|
||||||
|
|||||||
49
comfy_extras/nodes_model_downscale.py
Normal file
49
comfy_extras/nodes_model_downscale.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class PatchModelAddDownscale:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
|
||||||
|
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip):
|
||||||
|
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item()
|
||||||
|
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item()
|
||||||
|
|
||||||
|
def input_block_patch(h, transformer_options):
|
||||||
|
if transformer_options["block"][1] == block_number:
|
||||||
|
sigma = transformer_options["sigmas"][0].item()
|
||||||
|
if sigma <= sigma_start and sigma >= sigma_end:
|
||||||
|
h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def output_block_patch(h, hsp, transformer_options):
|
||||||
|
if h.shape[2] != hsp.shape[2]:
|
||||||
|
h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False)
|
||||||
|
return h, hsp
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
if downscale_after_skip:
|
||||||
|
m.set_model_input_block_patch_after_skip(input_block_patch)
|
||||||
|
else:
|
||||||
|
m.set_model_input_block_patch(input_block_patch)
|
||||||
|
m.set_model_output_block_patch(output_block_patch)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"PatchModelAddDownscale": PatchModelAddDownscale,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
# Sampling
|
||||||
|
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
|
||||||
|
}
|
||||||
1
nodes.py
1
nodes.py
@ -1799,6 +1799,7 @@ def init_custom_nodes():
|
|||||||
"nodes_custom_sampler.py",
|
"nodes_custom_sampler.py",
|
||||||
"nodes_hypertile.py",
|
"nodes_hypertile.py",
|
||||||
"nodes_model_advanced.py",
|
"nodes_model_advanced.py",
|
||||||
|
"nodes_model_downscale.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
for node_file in extras_files:
|
for node_file in extras_files:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user