convert nodes_sd3.py and nodes_slg.py to V3 schema (#10162)

This commit is contained in:
Alexander Piskun 2025-10-10 01:18:23 +03:00 committed by GitHub
parent f3d5d328a3
commit fc0fbf141c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 219 additions and 121 deletions

View File

@ -3,64 +3,83 @@ import comfy.sd
import comfy.model_management import comfy.model_management
import nodes import nodes
import torch import torch
import comfy_extras.nodes_slg from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_slg import SkipLayerGuidanceDiT
class TripleCLIPLoader: class TripleCLIPLoader(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), ) return io.Schema(
}} node_id="TripleCLIPLoader",
RETURN_TYPES = ("CLIP",) category="advanced/loaders",
FUNCTION = "load_clip" description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
inputs=[
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/loaders" @classmethod
def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput:
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,) return io.NodeOutput(clip)
load_clip = execute # TODO: remove
class EmptySD3LatentImage: class EmptySD3LatentImage(io.ComfyNode):
def __init__(self): @classmethod
self.device = comfy.model_management.intermediate_device() def define_schema(cls):
return io.Schema(
node_id="EmptySD3LatentImage",
category="latent/sd3",
inputs=[
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), return io.NodeOutput({"samples":latent})
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/sd3" generate = execute # TODO: remove
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
class CLIPTextEncodeSD3: class CLIPTextEncodeSD3(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="CLIPTextEncodeSD3",
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), category="advanced/conditioning",
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), inputs=[
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), io.Clip.Input("clip"),
"empty_padding": (["none", "empty_prompt"], ) io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
}} io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
RETURN_TYPES = ("CONDITIONING",) io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
FUNCTION = "encode" io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning" @classmethod
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput:
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
no_padding = empty_padding == "none" no_padding = empty_padding == "none"
tokens = clip.tokenize(clip_g) tokens = clip.tokenize(clip_g)
@ -82,57 +101,112 @@ class CLIPTextEncodeSD3:
tokens["l"] += empty["l"] tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]): while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"] tokens["g"] += empty["g"]
return (clip.encode_from_tokens_scheduled(tokens), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
encode = execute # TODO: remove
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): class ControlNetApplySD3(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="ControlNetApplySD3",
"control_net": ("CONTROL_NET", ), display_name="Apply Controlnet with VAE",
"vae": ("VAE", ), category="conditioning/controlnet",
"image": ("IMAGE", ), inputs=[
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), io.Conditioning.Input("positive"),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), io.Conditioning.Input("negative"),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) io.ControlNet.Input("control_net"),
}} io.Vae.Input("vae"),
CATEGORY = "conditioning/controlnet" io.Image.Input("image"),
DEPRECATED = True io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
is_deprecated=True,
)
@classmethod
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput:
if strength == 0:
return io.NodeOutput(positive, negative)
control_hint = image.movedim(-1, 1)
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent),
vae=vae, extra_concat=[])
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
n = [t[0], d]
c.append(n)
out.append(c)
return io.NodeOutput(out[0], out[1])
apply_controlnet = execute # TODO: remove
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT): class SkipLayerGuidanceSD3(io.ComfyNode):
''' '''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Experimental implementation by Dango233@StabilityAI. Experimental implementation by Dango233@StabilityAI.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceSD3",
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), category="advanced/guidance",
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) inputs=[
}} io.Model.Input("model"),
RETURN_TYPES = ("MODEL",) io.String.Input("layers", default="7, 8, 9", multiline=False),
FUNCTION = "skip_guidance_sd3" io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
)
CATEGORY = "advanced/guidance" @classmethod
def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput:
return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent): skip_guidance_sd3 = execute # TODO: remove
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
NODE_CLASS_MAPPINGS = { class SD3Extension(ComfyExtension):
"TripleCLIPLoader": TripleCLIPLoader, @override
"EmptySD3LatentImage": EmptySD3LatentImage, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"CLIPTextEncodeSD3": CLIPTextEncodeSD3, return [
"ControlNetApplySD3": ControlNetApplySD3, TripleCLIPLoader,
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3, EmptySD3LatentImage,
} CLIPTextEncodeSD3,
ControlNetApplySD3,
SkipLayerGuidanceSD3,
]
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling async def comfy_entrypoint() -> SD3Extension:
"ControlNetApplySD3": "Apply Controlnet with VAE", return SD3Extension()
}

View File

@ -1,33 +1,40 @@
import comfy.model_patcher import comfy.model_patcher
import comfy.samplers import comfy.samplers
import re import re
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class SkipLayerGuidanceDiT: class SkipLayerGuidanceDiT(io.ComfyNode):
''' '''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Original experimental implementation for SD3 by Dango233@StabilityAI. Original experimental implementation for SD3 by Dango233@StabilityAI.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceDiT",
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), category="advanced/guidance",
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), is_experimental=True,
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), inputs=[
"rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), io.Model.Input("model"),
}} io.String.Input("double_layers", default="7, 8, 9"),
RETURN_TYPES = ("MODEL",) io.String.Input("single_layers", default="7, 8, 9"),
FUNCTION = "skip_guidance" io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
EXPERIMENTAL = True io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model." @classmethod
def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0) -> io.NodeOutput:
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
# check if layer is comma separated integers # check if layer is comma separated integers
def skip(args, extra_args): def skip(args, extra_args):
return args return args
@ -43,7 +50,7 @@ class SkipLayerGuidanceDiT:
single_layers = [int(i) for i in single_layers] single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0: if len(double_layers) == 0 and len(single_layers) == 0:
return (model, ) return io.NodeOutput(model)
def post_cfg_function(args): def post_cfg_function(args):
model = args["model"] model = args["model"]
@ -76,29 +83,36 @@ class SkipLayerGuidanceDiT:
m = model.clone() m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function) m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m, ) return io.NodeOutput(m)
class SkipLayerGuidanceDiTSimple: skip_guidance = execute # TODO: remove
class SkipLayerGuidanceDiTSimple(io.ComfyNode):
''' '''
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass. Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceDiTSimple",
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), category="advanced/guidance",
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.",
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), is_experimental=True,
}} inputs=[
RETURN_TYPES = ("MODEL",) io.Model.Input("model"),
FUNCTION = "skip_guidance" io.String.Input("double_layers", default="7, 8, 9"),
EXPERIMENTAL = True io.String.Input("single_layers", default="7, 8, 9"),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
)
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass." @classmethod
def execute(cls, model, start_percent, end_percent, double_layers="", single_layers="") -> io.NodeOutput:
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
def skip(args, extra_args): def skip(args, extra_args):
return args return args
@ -113,7 +127,7 @@ class SkipLayerGuidanceDiTSimple:
single_layers = [int(i) for i in single_layers] single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0: if len(double_layers) == 0 and len(single_layers) == 0:
return (model, ) return io.NodeOutput(model)
def calc_cond_batch_function(args): def calc_cond_batch_function(args):
x = args["input"] x = args["input"]
@ -144,9 +158,19 @@ class SkipLayerGuidanceDiTSimple:
m = model.clone() m = model.clone()
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function) m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = { skip_guidance = execute # TODO: remove
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
} class SkipLayerGuidanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SkipLayerGuidanceDiT,
SkipLayerGuidanceDiTSimple,
]
async def comfy_entrypoint() -> SkipLayerGuidanceExtension:
return SkipLayerGuidanceExtension()