convert nodes_latent.py to V3 schema (#10160)

This commit is contained in:
Alexander Piskun 2025-10-09 09:14:00 +03:00 committed by GitHub
parent 6732014a0a
commit cbee7d3390
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,8 @@ import comfy.utils
import comfy_extras.nodes_post_processing
import torch
import nodes
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def reshape_latent_to(target_shape, latent, repeat_batch=True):
@ -13,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True):
return latent
class LatentAdd:
class LatentAdd(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
def define_schema(cls):
return io.Schema(
node_id="LatentAdd",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2):
@classmethod
def execute(cls, samples1, samples2) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@ -31,19 +39,25 @@ class LatentAdd:
s2 = reshape_latent_to(s1.shape, s2)
samples_out["samples"] = s1 + s2
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentSubtract:
class LatentSubtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
def define_schema(cls):
return io.Schema(
node_id="LatentSubtract",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2):
@classmethod
def execute(cls, samples1, samples2) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@ -51,41 +65,49 @@ class LatentSubtract:
s2 = reshape_latent_to(s1.shape, s2)
samples_out["samples"] = s1 - s2
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentMultiply:
class LatentMultiply(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentMultiply",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, multiplier):
@classmethod
def execute(cls, samples, multiplier) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
samples_out["samples"] = s1 * multiplier
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentInterpolate:
class LatentInterpolate(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentInterpolate",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
@classmethod
def execute(cls, samples1, samples2, ratio) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@ -104,19 +126,26 @@ class LatentInterpolate:
st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentConcat:
class LatentConcat(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
def define_schema(cls):
return io.Schema(
node_id="LatentConcat",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, dim):
@classmethod
def execute(cls, samples1, samples2, dim) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@ -136,22 +165,27 @@ class LatentConcat:
dim = -3
samples_out["samples"] = torch.cat(c, dim=dim)
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentCut:
class LatentCut(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT",),
"dim": (["x", "y", "t"], ),
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}),
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}}
def define_schema(cls):
return io.Schema(
node_id="LatentCut",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["x", "y", "t"]),
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, dim, index, amount):
@classmethod
def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
@ -171,19 +205,25 @@ class LatentCut:
amount = min(-index, amount)
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentBatch:
class LatentBatch(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
def define_schema(cls):
return io.Schema(
node_id="LatentBatch",
category="latent/batch",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "batch"
CATEGORY = "latent/batch"
def batch(self, samples1, samples2):
@classmethod
def execute(cls, samples1, samples2) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
@ -192,20 +232,25 @@ class LatentBatch:
s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentBatchSeedBehavior:
class LatentBatchSeedBehavior(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
def define_schema(cls):
return io.Schema(
node_id="LatentBatchSeedBehavior",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, seed_behavior):
@classmethod
def execute(cls, samples, seed_behavior) -> io.NodeOutput:
samples_out = samples.copy()
latent = samples["samples"]
if seed_behavior == "random":
@ -215,41 +260,50 @@ class LatentBatchSeedBehavior:
batch_number = samples_out.get("batch_index", [0])[0]
samples_out["batch_index"] = [batch_number] * latent.shape[0]
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentApplyOperation:
class LatentApplyOperation(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"operation": ("LATENT_OPERATION",),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentApplyOperation",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Latent.Input("samples"),
io.LatentOperation.Input("operation"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, samples, operation):
@classmethod
def execute(cls, samples, operation) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
samples_out["samples"] = operation(latent=s1)
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentApplyOperationCFG:
class LatentApplyOperationCFG(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"operation": ("LATENT_OPERATION",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="LatentApplyOperationCFG",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.LatentOperation.Input("operation"),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def patch(self, model, operation):
@classmethod
def execute(cls, model, operation) -> io.NodeOutput:
m = model.clone()
def pre_cfg_function(args):
@ -261,21 +315,25 @@ class LatentApplyOperationCFG:
return conds_out
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m, )
return io.NodeOutput(m)
class LatentOperationTonemapReinhard:
class LatentOperationTonemapReinhard(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentOperationTonemapReinhard",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.LatentOperation.Output(),
],
)
RETURN_TYPES = ("LATENT_OPERATION",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, multiplier):
@classmethod
def execute(cls, multiplier) -> io.NodeOutput:
def tonemap_reinhard(latent, **kwargs):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude
@ -291,39 +349,27 @@ class LatentOperationTonemapReinhard:
new_magnitude *= top
return normalized_latent * new_magnitude
return (tonemap_reinhard,)
return io.NodeOutput(tonemap_reinhard)
class LatentOperationSharpen:
class LatentOperationSharpen(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"sharpen_radius": ("INT", {
"default": 9,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
"alpha": ("FLOAT", {
"default": 0.1,
"min": 0.0,
"max": 5.0,
"step": 0.01
}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentOperationSharpen",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
],
outputs=[
io.LatentOperation.Output(),
],
)
RETURN_TYPES = ("LATENT_OPERATION",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, sharpen_radius, sigma, alpha):
@classmethod
def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput:
def sharpen(latent, **kwargs):
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
normalized_latent = latent / luminance
@ -340,19 +386,27 @@ class LatentOperationSharpen:
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
return luminance * sharpened
return (sharpen,)
return io.NodeOutput(sharpen)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
"LatentConcat": LatentConcat,
"LatentCut": LatentCut,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentApplyOperation": LatentApplyOperation,
"LatentApplyOperationCFG": LatentApplyOperationCFG,
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
"LatentOperationSharpen": LatentOperationSharpen,
}
class LatentExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LatentAdd,
LatentSubtract,
LatentMultiply,
LatentInterpolate,
LatentConcat,
LatentCut,
LatentBatch,
LatentBatchSeedBehavior,
LatentApplyOperation,
LatentApplyOperationCFG,
LatentOperationTonemapReinhard,
LatentOperationSharpen,
]
async def comfy_entrypoint() -> LatentExtension:
return LatentExtension()