mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
convert nodes_sd3.py and nodes_slg.py to V3 schema (#10162)
This commit is contained in:
parent
f3d5d328a3
commit
fc0fbf141c
@ -3,64 +3,83 @@ import comfy.sd
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import nodes
|
import nodes
|
||||||
import torch
|
import torch
|
||||||
import comfy_extras.nodes_slg
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_extras.nodes_slg import SkipLayerGuidanceDiT
|
||||||
|
|
||||||
|
|
||||||
class TripleCLIPLoader:
|
class TripleCLIPLoader(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
|
return io.Schema(
|
||||||
}}
|
node_id="TripleCLIPLoader",
|
||||||
RETURN_TYPES = ("CLIP",)
|
category="advanced/loaders",
|
||||||
FUNCTION = "load_clip"
|
description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
|
||||||
|
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
|
||||||
|
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Clip.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
@classmethod
|
||||||
|
def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput:
|
||||||
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
|
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
|
||||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return io.NodeOutput(clip)
|
||||||
|
|
||||||
|
load_clip = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class EmptySD3LatentImage:
|
class EmptySD3LatentImage(io.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.device = comfy.model_management.intermediate_device()
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EmptySD3LatentImage",
|
||||||
|
category="latent/sd3",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
||||||
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
return io.NodeOutput({"samples":latent})
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
|
||||||
FUNCTION = "generate"
|
|
||||||
|
|
||||||
CATEGORY = "latent/sd3"
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
|
||||||
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
|
||||||
return ({"samples":latent}, )
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodeSD3:
|
class CLIPTextEncodeSD3(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"clip": ("CLIP", ),
|
node_id="CLIPTextEncodeSD3",
|
||||||
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
category="advanced/conditioning",
|
||||||
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
inputs=[
|
||||||
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
io.Clip.Input("clip"),
|
||||||
"empty_padding": (["none", "empty_prompt"], )
|
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||||
}}
|
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
||||||
FUNCTION = "encode"
|
io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning"
|
@classmethod
|
||||||
|
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput:
|
||||||
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
|
|
||||||
no_padding = empty_padding == "none"
|
no_padding = empty_padding == "none"
|
||||||
|
|
||||||
tokens = clip.tokenize(clip_g)
|
tokens = clip.tokenize(clip_g)
|
||||||
@ -82,57 +101,112 @@ class CLIPTextEncodeSD3:
|
|||||||
tokens["l"] += empty["l"]
|
tokens["l"] += empty["l"]
|
||||||
while len(tokens["l"]) > len(tokens["g"]):
|
while len(tokens["l"]) > len(tokens["g"]):
|
||||||
tokens["g"] += empty["g"]
|
tokens["g"] += empty["g"]
|
||||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
|
||||||
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
class ControlNetApplySD3(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING", ),
|
node_id="ControlNetApplySD3",
|
||||||
"control_net": ("CONTROL_NET", ),
|
display_name="Apply Controlnet with VAE",
|
||||||
"vae": ("VAE", ),
|
category="conditioning/controlnet",
|
||||||
"image": ("IMAGE", ),
|
inputs=[
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Conditioning.Input("positive"),
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
io.Conditioning.Input("negative"),
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
io.ControlNet.Input("control_net"),
|
||||||
}}
|
io.Vae.Input("vae"),
|
||||||
CATEGORY = "conditioning/controlnet"
|
io.Image.Input("image"),
|
||||||
DEPRECATED = True
|
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
],
|
||||||
|
is_deprecated=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput:
|
||||||
|
if strength == 0:
|
||||||
|
return io.NodeOutput(positive, negative)
|
||||||
|
|
||||||
|
control_hint = image.movedim(-1, 1)
|
||||||
|
cnets = {}
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for conditioning in [positive, negative]:
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
|
||||||
|
prev_cnet = d.get('control', None)
|
||||||
|
if prev_cnet in cnets:
|
||||||
|
c_net = cnets[prev_cnet]
|
||||||
|
else:
|
||||||
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent),
|
||||||
|
vae=vae, extra_concat=[])
|
||||||
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
|
d['control'] = c_net
|
||||||
|
d['control_apply_to_uncond'] = False
|
||||||
|
n = [t[0], d]
|
||||||
|
c.append(n)
|
||||||
|
out.append(c)
|
||||||
|
return io.NodeOutput(out[0], out[1])
|
||||||
|
|
||||||
|
apply_controlnet = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
|
class SkipLayerGuidanceSD3(io.ComfyNode):
|
||||||
'''
|
'''
|
||||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||||
Experimental implementation by Dango233@StabilityAI.
|
Experimental implementation by Dango233@StabilityAI.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"model": ("MODEL", ),
|
return io.Schema(
|
||||||
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
node_id="SkipLayerGuidanceSD3",
|
||||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
category="advanced/guidance",
|
||||||
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
|
||||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
inputs=[
|
||||||
}}
|
io.Model.Input("model"),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.String.Input("layers", default="7, 8, 9", multiline=False),
|
||||||
FUNCTION = "skip_guidance_sd3"
|
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
|
||||||
|
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/guidance"
|
@classmethod
|
||||||
|
def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput:
|
||||||
|
return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
|
||||||
|
|
||||||
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
|
skip_guidance_sd3 = execute # TODO: remove
|
||||||
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class SD3Extension(ComfyExtension):
|
||||||
"TripleCLIPLoader": TripleCLIPLoader,
|
@override
|
||||||
"EmptySD3LatentImage": EmptySD3LatentImage,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
return [
|
||||||
"ControlNetApplySD3": ControlNetApplySD3,
|
TripleCLIPLoader,
|
||||||
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
|
EmptySD3LatentImage,
|
||||||
}
|
CLIPTextEncodeSD3,
|
||||||
|
ControlNetApplySD3,
|
||||||
|
SkipLayerGuidanceSD3,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
# Sampling
|
async def comfy_entrypoint() -> SD3Extension:
|
||||||
"ControlNetApplySD3": "Apply Controlnet with VAE",
|
return SD3Extension()
|
||||||
}
|
|
||||||
|
|||||||
@ -1,33 +1,40 @@
|
|||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import re
|
import re
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
class SkipLayerGuidanceDiT:
|
class SkipLayerGuidanceDiT(io.ComfyNode):
|
||||||
'''
|
'''
|
||||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||||
Original experimental implementation for SD3 by Dango233@StabilityAI.
|
Original experimental implementation for SD3 by Dango233@StabilityAI.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"model": ("MODEL", ),
|
return io.Schema(
|
||||||
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
node_id="SkipLayerGuidanceDiT",
|
||||||
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
category="advanced/guidance",
|
||||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
|
||||||
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
is_experimental=True,
|
||||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
|
inputs=[
|
||||||
"rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Model.Input("model"),
|
||||||
}}
|
io.String.Input("double_layers", default="7, 8, 9"),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.String.Input("single_layers", default="7, 8, 9"),
|
||||||
FUNCTION = "skip_guidance"
|
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
|
||||||
EXPERIMENTAL = True
|
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model."
|
@classmethod
|
||||||
|
def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0) -> io.NodeOutput:
|
||||||
CATEGORY = "advanced/guidance"
|
|
||||||
|
|
||||||
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
|
|
||||||
# check if layer is comma separated integers
|
# check if layer is comma separated integers
|
||||||
def skip(args, extra_args):
|
def skip(args, extra_args):
|
||||||
return args
|
return args
|
||||||
@ -43,7 +50,7 @@ class SkipLayerGuidanceDiT:
|
|||||||
single_layers = [int(i) for i in single_layers]
|
single_layers = [int(i) for i in single_layers]
|
||||||
|
|
||||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||||
return (model, )
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
model = args["model"]
|
model = args["model"]
|
||||||
@ -76,29 +83,36 @@ class SkipLayerGuidanceDiT:
|
|||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||||
|
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class SkipLayerGuidanceDiTSimple:
|
skip_guidance = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class SkipLayerGuidanceDiTSimple(io.ComfyNode):
|
||||||
'''
|
'''
|
||||||
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
|
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
|
||||||
'''
|
'''
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"model": ("MODEL", ),
|
return io.Schema(
|
||||||
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
node_id="SkipLayerGuidanceDiTSimple",
|
||||||
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
category="advanced/guidance",
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.",
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
is_experimental=True,
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Model.Input("model"),
|
||||||
FUNCTION = "skip_guidance"
|
io.String.Input("double_layers", default="7, 8, 9"),
|
||||||
EXPERIMENTAL = True
|
io.String.Input("single_layers", default="7, 8, 9"),
|
||||||
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."
|
@classmethod
|
||||||
|
def execute(cls, model, start_percent, end_percent, double_layers="", single_layers="") -> io.NodeOutput:
|
||||||
CATEGORY = "advanced/guidance"
|
|
||||||
|
|
||||||
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
|
|
||||||
def skip(args, extra_args):
|
def skip(args, extra_args):
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -113,7 +127,7 @@ class SkipLayerGuidanceDiTSimple:
|
|||||||
single_layers = [int(i) for i in single_layers]
|
single_layers = [int(i) for i in single_layers]
|
||||||
|
|
||||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||||
return (model, )
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
def calc_cond_batch_function(args):
|
def calc_cond_batch_function(args):
|
||||||
x = args["input"]
|
x = args["input"]
|
||||||
@ -144,9 +158,19 @@ class SkipLayerGuidanceDiTSimple:
|
|||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
|
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
|
||||||
|
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
skip_guidance = execute # TODO: remove
|
||||||
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
|
|
||||||
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
|
|
||||||
}
|
class SkipLayerGuidanceExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
SkipLayerGuidanceDiT,
|
||||||
|
SkipLayerGuidanceDiTSimple,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> SkipLayerGuidanceExtension:
|
||||||
|
return SkipLayerGuidanceExtension()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user