Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-10-02 07:31:37 +09:00
commit 17064a993c
13 changed files with 621 additions and 378 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
import comfy.ops import comfy.ops
import comfy.ldm.models.autoencoder import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -17,11 +17,12 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * self.gamma return F.normalize(x, dim=1) * self.scale * self.gamma
class DnSmpl(nn.Module): class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True): def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2 fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0 assert oc % fct == 0
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3) self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tds = tds self.tds = tds
self.gs = fct * ic // oc self.gs = fct * ic // oc
@ -30,7 +31,7 @@ class DnSmpl(nn.Module):
r1 = 2 if self.tds else 1 r1 = 2 if self.tds else 1
h = self.conv(x) h = self.conv(x)
if self.tds: if self.tds and self.refiner_vae:
hf = h[:, :, :1, :, :] hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@ -66,6 +67,7 @@ class DnSmpl(nn.Module):
sc = torch.cat([xf, xn], dim=2) sc = torch.cat([xf, xn], dim=2)
else: else:
b, c, frms, ht, wd = h.shape b, c, frms, ht, wd = h.shape
nf = frms // r1 nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
@ -83,10 +85,11 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module): class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True): def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2 fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3) self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tus = tus self.tus = tus
self.rp = fct * oc // ic self.rp = fct * oc // ic
@ -95,7 +98,7 @@ class UpSmpl(nn.Module):
r1 = 2 if self.tus else 1 r1 = 2 if self.tus else 1
h = self.conv(x) h = self.conv(x)
if self.tus: if self.tus and self.refiner_vae:
hf = h[:, :, :1, :, :] hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape b, c, f, ht, wd = hf.shape
nc = c // (2 * 2) nc = c // (2 * 2)
@ -148,43 +151,56 @@ class UpSmpl(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_): ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
super().__init__() super().__init__()
self.z_channels = z_channels self.z_channels = z_channels
self.block_out_channels = block_out_channels self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1) self.ffactor_temporal = ffactor_temporal
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList() self.down = nn.ModuleList()
ch = block_out_channels[0] ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length() depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length() depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0, temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm) conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)]) for j in range(num_res_blocks)])
ch = tgt ch = tgt
if i < depth: if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal) stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt ch = nxt
self.down.append(stage) self.down.append(stage)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.norm_out = RMS_norm(ch) self.norm_out = norm_op(ch)
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
def forward(self, x): def forward(self, x):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x) x = self.conv_in(x)
for stage in self.down: for stage in self.down:
@ -200,31 +216,42 @@ class Encoder(nn.Module):
skip = x.view(b, c // grp, grp, t, h, w).mean(2) skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2) if self.refiner_vae:
out = out.permute(0, 2, 1, 3, 4) out = self.regul(out)[0]
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4).contiguous() out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out return out
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_): ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
super().__init__() super().__init__()
block_out_channels = block_out_channels[::-1] block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels self.z_channels = z_channels
self.block_out_channels = block_out_channels self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
ch = block_out_channels[0] ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, 3) self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList() self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length() depth = (ffactor_spatial >> 1).bit_length()
@ -235,25 +262,26 @@ class Decoder(nn.Module):
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0, temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm) conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)]) for j in range(num_res_blocks + 1)])
ch = tgt ch = tgt
if i < depth: if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal) stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt ch = nxt
self.up.append(stage) self.up.append(stage)
self.norm_out = RMS_norm(ch) self.norm_out = norm_op(ch)
self.conv_out = VideoConv3d(ch, out_channels, 3) self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z): def forward(self, z):
z = z.permute(0, 2, 1, 3, 4) if self.refiner_vae:
b, f, c, h, w = z.shape z = z.permute(0, 2, 1, 3, 4)
z = z.reshape(b, f, 2, c // 2, h, w) b, f, c, h, w = z.shape
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4) z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z[:, :, 1:] z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
@ -264,4 +292,10 @@ class Decoder(nn.Module):
if hasattr(stage, 'upsample'): if hasattr(stage, 'upsample'):
x = stage.upsample(x) x = stage.upsample(x)
return self.conv_out(F.silu(self.norm_out(x))) out = self.conv_out(F.silu(self.norm_out(x)))
if not self.refiner_vae:
if z.shape[-3] == 1:
out = out[:, :, -1:]
return out

View File

@ -332,35 +332,51 @@ class VAE:
self.first_stage_model = StageC_coder() self.first_stage_model = StageC_coder()
self.downscale_ratio = 32 self.downscale_ratio = 32
self.latent_channels = 16 self.latent_channels = 16
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif "decoder.conv_in.weight" in sd: elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE self.downscale_ratio = 32
ddconfig['ch_mult'] = [1, 2, 4] self.upscale_ratio = 32
self.downscale_ratio = 4 self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd: elif "decoder.layers.1.layers.0.beta" in sd:
self.first_stage_model = AudioOobleckVAE() self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)

View File

@ -1605,6 +1605,7 @@ class _IO:
Model = Model Model = Model
ClipVision = ClipVision ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput ClipVisionOutput = ClipVisionOutput
AudioEncoder = AudioEncoder
AudioEncoderOutput = AudioEncoderOutput AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel StyleModel = StyleModel
Gligen = Gligen Gligen = Gligen

View File

@ -1,44 +1,62 @@
import folder_paths import folder_paths
import comfy.audio_encoders.audio_encoders import comfy.audio_encoders.audio_encoders
import comfy.utils import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class AudioEncoderLoader: class AudioEncoderLoader(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ), return io.Schema(
}} node_id="AudioEncoderLoader",
RETURN_TYPES = ("AUDIO_ENCODER",) category="loaders",
FUNCTION = "load_model" inputs=[
io.Combo.Input(
"audio_encoder_name",
options=folder_paths.get_filename_list("audio_encoders"),
),
],
outputs=[io.AudioEncoder.Output()],
)
CATEGORY = "loaders" @classmethod
def execute(cls, audio_encoder_name) -> io.NodeOutput:
def load_model(self, audio_encoder_name):
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name) audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True) sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd) audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
if audio_encoder is None: if audio_encoder is None:
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.") raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
return (audio_encoder,) return io.NodeOutput(audio_encoder)
class AudioEncoderEncode: class AudioEncoderEncode(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
return {"required": { "audio_encoder": ("AUDIO_ENCODER",), return io.Schema(
"audio": ("AUDIO",), node_id="AudioEncoderEncode",
}} category="conditioning",
RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",) inputs=[
FUNCTION = "encode" io.AudioEncoder.Input("audio_encoder"),
io.Audio.Input("audio"),
],
outputs=[io.AudioEncoderOutput.Output()],
)
CATEGORY = "conditioning" @classmethod
def execute(cls, audio_encoder, audio) -> io.NodeOutput:
def encode(self, audio_encoder, audio):
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"]) output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
return (output,) return io.NodeOutput(output)
NODE_CLASS_MAPPINGS = { class AudioEncoder(ComfyExtension):
"AudioEncoderLoader": AudioEncoderLoader, @override
"AudioEncoderEncode": AudioEncoderEncode, async def get_node_list(self) -> list[type[io.ComfyNode]]:
} return [
AudioEncoderLoader,
AudioEncoderEncode,
]
async def comfy_entrypoint() -> AudioEncoder:
return AudioEncoder()

View File

@ -1,34 +1,41 @@
# code adapted from https://github.com/exx8/differential-diffusion # code adapted from https://github.com/exx8/differential-diffusion
from typing_extensions import override
import torch import torch
from comfy_api.latest import ComfyExtension, io
class DifferentialDiffusion():
class DifferentialDiffusion(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="DifferentialDiffusion",
"model": ("MODEL", ), display_name="Differential Diffusion",
}, category="_for_testing",
"optional": { inputs=[
"strength": ("FLOAT", { io.Model.Input("model"),
"default": 1.0, io.Float.Input(
"min": 0.0, "strength",
"max": 1.0, default=1.0,
"step": 0.01, min=0.0,
}), max=1.0,
} step=0.01,
} optional=True,
RETURN_TYPES = ("MODEL",) ),
FUNCTION = "apply" ],
CATEGORY = "_for_testing" outputs=[io.Model.Output()],
INIT = False is_experimental=True,
)
def apply(self, model, strength=1.0): @classmethod
def execute(cls, model, strength=1.0) -> io.NodeOutput:
model = model.clone() model = model.clone()
model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength)) model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength))
return (model, ) return io.NodeOutput(model)
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): @classmethod
def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
model = extra_options["model"] model = extra_options["model"]
step_sigmas = extra_options["sigmas"] step_sigmas = extra_options["sigmas"]
sigma_to = model.inner_model.model_sampling.sigma_min sigma_to = model.inner_model.model_sampling.sigma_min
@ -53,9 +60,13 @@ class DifferentialDiffusion():
return binary_mask return binary_mask
NODE_CLASS_MAPPINGS = { class DifferentialDiffusionExtension(ComfyExtension):
"DifferentialDiffusion": DifferentialDiffusion, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
NODE_DISPLAY_NAME_MAPPINGS = { return [
"DifferentialDiffusion": "Differential Diffusion", DifferentialDiffusion,
} ]
async def comfy_entrypoint() -> DifferentialDiffusionExtension:
return DifferentialDiffusionExtension()

60
comfy_extras/nodes_eps.py Normal file
View File

@ -0,0 +1,60 @@
class EpsilonScaling:
"""
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
(https://arxiv.org/abs/2308.15321v6).
This method mitigates exposure bias by scaling the predicted noise during sampling,
which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scaling_factor": ("FLOAT", {
"default": 1.005,
"min": 0.5,
"max": 1.5,
"step": 0.001,
"display": "number"
}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def patch(self, model, scaling_factor):
# Prevent division by zero, though the UI's min value should prevent this.
if scaling_factor == 0:
scaling_factor = 1e-9
def epsilon_scaling_function(args):
"""
This function is applied after the CFG guidance has been calculated.
It recalculates the denoised latent by scaling the predicted noise.
"""
denoised = args["denoised"]
x = args["input"]
noise_pred = x - denoised
scaled_noise_pred = noise_pred / scaling_factor
new_denoised = x - scaled_noise_pred
return new_denoised
# Clone the model patcher to avoid modifying the original model in place
model_clone = model.clone()
model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)
return (model_clone,)
NODE_CLASS_MAPPINGS = {
"Epsilon Scaling": EpsilonScaling
}

View File

@ -1,6 +1,8 @@
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main # from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
import numpy as np import numpy as np
import torch import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps): def loglinear_interp(t_steps, num_steps):
""" """
@ -333,25 +335,28 @@ NOISE_LEVELS = {
], ],
} }
class GITSScheduler: class GITSScheduler(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return io.Schema(
{"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}), node_id="GITSScheduler",
"steps": ("INT", {"default": 10, "min": 2, "max": 1000}), category="sampling/custom_sampling/schedulers",
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), inputs=[
} io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
} io.Int.Input("steps", default=10, min=2, max=1000),
RETURN_TYPES = ("SIGMAS",) io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
CATEGORY = "sampling/custom_sampling/schedulers" ],
outputs=[
io.Sigmas.Output(),
],
)
FUNCTION = "get_sigmas" @classmethod
def execute(cls, coeff, steps, denoise):
def get_sigmas(self, coeff, steps, denoise):
total_steps = steps total_steps = steps
if denoise < 1.0: if denoise < 1.0:
if denoise <= 0.0: if denoise <= 0.0:
return (torch.FloatTensor([]),) return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise) total_steps = round(steps * denoise)
if steps <= 20: if steps <= 20:
@ -362,8 +367,16 @@ class GITSScheduler:
sigmas = sigmas[-(total_steps + 1):] sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0 sigmas[-1] = 0
return (torch.FloatTensor(sigmas), ) return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"GITSScheduler": GITSScheduler, class GITSSchedulerExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
GITSScheduler,
]
async def comfy_entrypoint() -> GITSSchedulerExtension:
return GITSSchedulerExtension()

View File

@ -1,21 +1,30 @@
import torch import torch
class InstructPixToPixConditioning: from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class InstructPixToPixConditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="InstructPixToPixConditioning",
"vae": ("VAE", ), category="conditioning/instructpix2pix",
"pixels": ("IMAGE", ), inputs=[
}} io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Image.Input("pixels"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, pixels, vae) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/instructpix2pix"
def encode(self, positive, negative, pixels, vae):
x = (pixels.shape[1] // 8) * 8 x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8 y = (pixels.shape[2] // 8) * 8
@ -38,8 +47,17 @@ class InstructPixToPixConditioning:
n = [t[0], d] n = [t[0], d]
c.append(n) c.append(n)
out.append(c) out.append(c)
return (out[0], out[1], out_latent) return io.NodeOutput(out[0], out[1], out_latent)
class InstructPix2PixExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
InstructPixToPixConditioning,
]
async def comfy_entrypoint() -> InstructPix2PixExtension:
return InstructPix2PixExtension()
NODE_CLASS_MAPPINGS = {
"InstructPixToPixConditioning": InstructPixToPixConditioning,
}

View File

@ -1,4 +1,3 @@
import io
import nodes import nodes
import node_helpers import node_helpers
import torch import torch
@ -8,46 +7,60 @@ import comfy.utils
import math import math
import numpy as np import numpy as np
import av import av
from io import BytesIO
from typing_extensions import override
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy_api.latest import ComfyExtension, io
class EmptyLTXVLatentVideo: class EmptyLTXVLatentVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), return io.Schema(
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), node_id="EmptyLTXVLatentVideo",
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), category="latent/video/ltxv",
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} inputs=[
RETURN_TYPES = ("LATENT",) io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
FUNCTION = "generate" io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent/video/ltxv" @classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples": latent}, ) return io.NodeOutput({"samples": latent})
class LTXVImgToVideo: class LTXVImgToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVImgToVideo",
"vae": ("VAE",), category="conditioning/video_models",
"image": ("IMAGE",), inputs=[
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), io.Conditioning.Input("positive"),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), io.Conditioning.Input("negative"),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), io.Vae.Input("vae"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Image.Input("image"),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
}} io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength) -> io.NodeOutput:
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3] encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
@ -62,7 +75,7 @@ class LTXVImgToVideo:
) )
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, ) return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask})
def conditioning_get_any_value(conditioning, key, default=None): def conditioning_get_any_value(conditioning, key, default=None):
@ -93,35 +106,46 @@ def get_keyframe_idxs(cond):
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
return keyframe_idxs, num_keyframes return keyframe_idxs, num_keyframes
class LTXVAddGuide: class LTXVAddGuide(io.ComfyNode):
NUM_PREFIX_FRAMES = 2
PATCHIFIER = SymmetricPatchifier(1)
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVAddGuide",
"vae": ("VAE",), category="conditioning/video_models",
"latent": ("LATENT",), inputs=[
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." io.Conditioning.Input("positive"),
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), io.Conditioning.Input("negative"),
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, io.Vae.Input("vae"),
"tooltip": "Frame index to start the conditioning at. For single-frame images or " io.Latent.Input("latent"),
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ " io.Image.Input(
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to " "image",
"the nearest multiple of 8. Negative values are counted from the end of the video."}), tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.",
} ),
} io.Int.Input(
"frame_idx",
default=0,
min=-9999,
max=9999,
tooltip="Frame index to start the conditioning at. "
"For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. "
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def encode(cls, vae, latent_width, latent_height, images, scale_factors):
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def __init__(self):
self._num_prefix_frames = 2
self._patchifier = SymmetricPatchifier(1)
def encode(self, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
@ -129,7 +153,8 @@ class LTXVAddGuide:
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
return encode_pixels, t return encode_pixels, t
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors): @classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond) _, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes latent_count = latent_length - num_keyframes
@ -141,9 +166,10 @@ class LTXVAddGuide:
return frame_idx, latent_idx return frame_idx, latent_idx
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors): @classmethod
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors):
keyframe_idxs, _ = get_keyframe_idxs(cond) keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = self._patchifier.patchify(guiding_latent) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
pixel_coords[:, 0] += frame_idx pixel_coords[:, 0] += frame_idx
if keyframe_idxs is None: if keyframe_idxs is None:
@ -152,8 +178,9 @@ class LTXVAddGuide:
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2) keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): @classmethod
_, latent_idx = self.get_latent_index( def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
_, latent_idx = cls.get_latent_index(
cond=positive, cond=positive,
latent_length=latent_image.shape[2], latent_length=latent_image.shape[2],
guide_length=guiding_latent.shape[2], guide_length=guiding_latent.shape[2],
@ -162,8 +189,8 @@ class LTXVAddGuide:
) )
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full( mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
@ -176,7 +203,8 @@ class LTXVAddGuide:
noise_mask = torch.cat([noise_mask, mask], dim=2) noise_mask = torch.cat([noise_mask, mask], dim=2)
return positive, negative, latent_image, noise_mask return positive, negative, latent_image, noise_mask
def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength): @classmethod
def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength):
cond_length = guiding_latent.shape[2] cond_length = guiding_latent.shape[2]
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence." assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
@ -195,20 +223,21 @@ class LTXVAddGuide:
return latent_image, noise_mask return latent_image, noise_mask
def generate(self, positive, negative, vae, latent, image, frame_idx, strength): @classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula scale_factors = vae.downscale_index_formula
latent_image = latent["samples"] latent_image = latent["samples"]
noise_mask = get_noise_mask(latent) noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape _, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
positive, negative, latent_image, noise_mask = self.append_keyframe( positive, negative, latent_image, noise_mask = cls.append_keyframe(
positive, positive,
negative, negative,
frame_idx, frame_idx,
@ -223,9 +252,9 @@ class LTXVAddGuide:
t = t[:, :, num_prefix_frames:] t = t[:, :, num_prefix_frames:]
if t.shape[2] == 0: if t.shape[2] == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
latent_image, noise_mask = self.replace_latent_frames( latent_image, noise_mask = cls.replace_latent_frames(
latent_image, latent_image,
noise_mask, noise_mask,
t, t,
@ -233,34 +262,35 @@ class LTXVAddGuide:
strength, strength,
) )
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
class LTXVCropGuides: class LTXVCropGuides(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVCropGuides",
"latent": ("LATENT",), category="conditioning/video_models",
} inputs=[
} io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Latent.Input("latent"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, latent) -> io.NodeOutput:
CATEGORY = "conditioning/video_models"
FUNCTION = "crop"
def __init__(self):
self._patchifier = SymmetricPatchifier(1)
def crop(self, positive, negative, latent):
latent_image = latent["samples"].clone() latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent) noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive) _, num_keyframes = get_keyframe_idxs(positive)
if num_keyframes == 0: if num_keyframes == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image = latent_image[:, :, :-num_keyframes] latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes]
@ -268,44 +298,52 @@ class LTXVCropGuides:
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
class LTXVConditioning: class LTXVConditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVConditioning",
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}), category="conditioning/video_models",
}} inputs=[
RETURN_TYPES = ("CONDITIONING", "CONDITIONING") io.Conditioning.Input("positive"),
RETURN_NAMES = ("positive", "negative") io.Conditioning.Input("negative"),
FUNCTION = "append" io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
CATEGORY = "conditioning/video_models" @classmethod
def execute(cls, positive, negative, frame_rate) -> io.NodeOutput:
def append(self, positive, negative, frame_rate):
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate}) positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate}) negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
return (positive, negative) return io.NodeOutput(positive, negative)
class ModelSamplingLTXV: class ModelSamplingLTXV(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), node_id="ModelSamplingLTXV",
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), category="advanced/model",
}, inputs=[
"optional": {"latent": ("LATENT",), } io.Model.Input("model"),
} io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Model.Output(),
],
)
RETURN_TYPES = ("MODEL",) @classmethod
FUNCTION = "patch" def execute(cls, model, max_shift, base_shift, latent=None) -> io.NodeOutput:
CATEGORY = "advanced/model"
def patch(self, model, max_shift, base_shift, latent=None):
m = model.clone() m = model.clone()
if latent is None: if latent is None:
@ -329,37 +367,41 @@ class ModelSamplingLTXV:
model_sampling.set_parameters(shift=shift) model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return io.NodeOutput(m)
class LTXVScheduler: class LTXVScheduler(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return io.Schema(
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), node_id="LTXVScheduler",
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), category="sampling/custom_sampling/schedulers",
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), inputs=[
"stretch": ("BOOLEAN", { io.Int.Input("steps", default=20, min=1, max=10000),
"default": True, io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]." io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
}), io.Boolean.Input(
"terminal": ( id="stretch",
"FLOAT", default=True,
{ tooltip="Stretch the sigmas to be in the range [terminal, 1].",
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01, ),
"tooltip": "The terminal value of the sigmas after stretching." io.Float.Input(
}, id="terminal",
), default=0.1,
}, min=0.0,
"optional": {"latent": ("LATENT",), } max=0.99,
} step=0.01,
tooltip="The terminal value of the sigmas after stretching.",
),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Sigmas.Output(),
],
)
RETURN_TYPES = ("SIGMAS",) @classmethod
CATEGORY = "sampling/custom_sampling/schedulers" def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None) -> io.NodeOutput:
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
if latent is None: if latent is None:
tokens = 4096 tokens = 4096
else: else:
@ -389,7 +431,7 @@ class LTXVScheduler:
stretched = 1.0 - (one_minus_z / scale_factor) stretched = 1.0 - (one_minus_z / scale_factor)
sigmas[non_zero_mask] = stretched sigmas[non_zero_mask] = stretched
return (sigmas,) return io.NodeOutput(sigmas)
def encode_single_frame(output_file, image_array: np.ndarray, crf): def encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4") container = av.open(output_file, "w", format="mp4")
@ -423,52 +465,54 @@ def preprocess(image: torch.Tensor, crf=29):
return image return image
image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy() image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
with io.BytesIO() as output_file: with BytesIO() as output_file:
encode_single_frame(output_file, image_array, crf) encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue() video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file: with BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file) image_array = decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor return tensor
class LTXVPreprocess: class LTXVPreprocess(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="LTXVPreprocess",
"image": ("IMAGE",), category="image",
"img_compression": ( inputs=[
"INT", io.Image.Input("image"),
{ io.Int.Input(
"default": 35, id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image."
"min": 0,
"max": 100,
"tooltip": "Amount of compression to apply on image.",
},
), ),
} ],
} outputs=[
io.Image.Output(display_name="output_image"),
],
)
FUNCTION = "preprocess" @classmethod
RETURN_TYPES = ("IMAGE",) def execute(cls, image, img_compression) -> io.NodeOutput:
RETURN_NAMES = ("output_image",)
CATEGORY = "image"
def preprocess(self, image, img_compression):
output_images = [] output_images = []
for i in range(image.shape[0]): for i in range(image.shape[0]):
output_images.append(preprocess(image[i], img_compression)) output_images.append(preprocess(image[i], img_compression))
return (torch.stack(output_images),) return io.NodeOutput(torch.stack(output_images))
NODE_CLASS_MAPPINGS = { class LtxvExtension(ComfyExtension):
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, @override
"LTXVImgToVideo": LTXVImgToVideo, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"ModelSamplingLTXV": ModelSamplingLTXV, return [
"LTXVConditioning": LTXVConditioning, EmptyLTXVLatentVideo,
"LTXVScheduler": LTXVScheduler, LTXVImgToVideo,
"LTXVAddGuide": LTXVAddGuide, ModelSamplingLTXV,
"LTXVPreprocess": LTXVPreprocess, LTXVConditioning,
"LTXVCropGuides": LTXVCropGuides, LTXVScheduler,
} LTXVAddGuide,
LTXVPreprocess,
LTXVCropGuides,
]
async def comfy_entrypoint() -> LtxvExtension:
return LtxvExtension()

View File

@ -1,9 +1,12 @@
# from https://github.com/bebebe666/OptimalSteps # from https://github.com/bebebe666/OptimalSteps
import numpy as np import numpy as np
import torch import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps): def loglinear_interp(t_steps, num_steps):
""" """
Performs log-linear interpolation of a given array of decreasing numbers. Performs log-linear interpolation of a given array of decreasing numbers.
@ -23,25 +26,28 @@ NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0
"Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001], "Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001],
} }
class OptimalStepsScheduler: class OptimalStepsScheduler(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return io.Schema(
{"model_type": (["FLUX", "Wan", "Chroma"], ), node_id="OptimalStepsScheduler",
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}), category="sampling/custom_sampling/schedulers",
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), inputs=[
} io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]),
} io.Int.Input("steps", default=20, min=3, max=1000),
RETURN_TYPES = ("SIGMAS",) io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
CATEGORY = "sampling/custom_sampling/schedulers" ],
outputs=[
io.Sigmas.Output(),
],
)
FUNCTION = "get_sigmas" @classmethod
def execute(cls, model_type, steps, denoise) ->io.NodeOutput:
def get_sigmas(self, model_type, steps, denoise):
total_steps = steps total_steps = steps
if denoise < 1.0: if denoise < 1.0:
if denoise <= 0.0: if denoise <= 0.0:
return (torch.FloatTensor([]),) return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise) total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:] sigmas = NOISE_LEVELS[model_type][:]
@ -50,8 +56,16 @@ class OptimalStepsScheduler:
sigmas = sigmas[-(total_steps + 1):] sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0 sigmas[-1] = 0
return (torch.FloatTensor(sigmas), ) return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"OptimalStepsScheduler": OptimalStepsScheduler, class OptimalStepsExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
OptimalStepsScheduler,
]
async def comfy_entrypoint() -> OptimalStepsExtension:
return OptimalStepsExtension()

View File

@ -3,25 +3,30 @@
#My modified one here is more basic but has less chances of breaking with ComfyUI updates. #My modified one here is more basic but has less chances of breaking with ComfyUI updates.
from typing_extensions import override
import comfy.model_patcher import comfy.model_patcher
import comfy.samplers import comfy.samplers
from comfy_api.latest import ComfyExtension, io
class PerturbedAttentionGuidance:
class PerturbedAttentionGuidance(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="PerturbedAttentionGuidance",
"model": ("MODEL",), category="model_patches/unet",
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), inputs=[
} io.Model.Input("model"),
} io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01),
],
outputs=[
io.Model.Output(),
],
)
RETURN_TYPES = ("MODEL",) @classmethod
FUNCTION = "patch" def execute(cls, model, scale) -> io.NodeOutput:
CATEGORY = "model_patches/unet"
def patch(self, model, scale):
unet_block = "middle" unet_block = "middle"
unet_block_id = 0 unet_block_id = 0
m = model.clone() m = model.clone()
@ -49,8 +54,16 @@ class PerturbedAttentionGuidance:
m.set_model_sampler_post_cfg_function(post_cfg_function) m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m,) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"PerturbedAttentionGuidance": PerturbedAttentionGuidance, class PAGExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PerturbedAttentionGuidance,
]
async def comfy_entrypoint() -> PAGExtension:
return PAGExtension()

View File

@ -2306,6 +2306,7 @@ async def init_builtin_extra_nodes():
"nodes_gits.py", "nodes_gits.py",
"nodes_controlnet.py", "nodes_controlnet.py",
"nodes_hunyuan.py", "nodes_hunyuan.py",
"nodes_eps.py",
"nodes_flux.py", "nodes_flux.py",
"nodes_lora_extract.py", "nodes_lora_extract.py",
"nodes_torch_compile.py", "nodes_torch_compile.py",

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.26.13 comfyui-frontend-package==1.27.7
comfyui-workflow-templates==0.1.91 comfyui-workflow-templates==0.1.91
comfyui-embedded-docs==0.2.6 comfyui-embedded-docs==0.2.6
comfyui_manager==4.0.2 comfyui_manager==4.0.2