Merge branch 'master' into flipflop-stream

This commit is contained in:
Jedrzej Kosinski 2025-10-02 15:03:26 -07:00
commit a282586995
17 changed files with 742 additions and 440 deletions

View File

@ -3,10 +3,13 @@ https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOW
HOW TO RUN: HOW TO RUN:
if you have a AMD gpu: If you have a AMD gpu:
run_amd_gpu.bat run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
pause

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
import comfy.ops import comfy.ops
import comfy.ldm.models.autoencoder import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -17,11 +17,12 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * self.gamma return F.normalize(x, dim=1) * self.scale * self.gamma
class DnSmpl(nn.Module): class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True): def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2 fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0 assert oc % fct == 0
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3) self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tds = tds self.tds = tds
self.gs = fct * ic // oc self.gs = fct * ic // oc
@ -30,7 +31,7 @@ class DnSmpl(nn.Module):
r1 = 2 if self.tds else 1 r1 = 2 if self.tds else 1
h = self.conv(x) h = self.conv(x)
if self.tds: if self.tds and self.refiner_vae:
hf = h[:, :, :1, :, :] hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@ -66,6 +67,7 @@ class DnSmpl(nn.Module):
sc = torch.cat([xf, xn], dim=2) sc = torch.cat([xf, xn], dim=2)
else: else:
b, c, frms, ht, wd = h.shape b, c, frms, ht, wd = h.shape
nf = frms // r1 nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
@ -83,10 +85,11 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module): class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True): def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2 fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3) self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tus = tus self.tus = tus
self.rp = fct * oc // ic self.rp = fct * oc // ic
@ -95,7 +98,7 @@ class UpSmpl(nn.Module):
r1 = 2 if self.tus else 1 r1 = 2 if self.tus else 1
h = self.conv(x) h = self.conv(x)
if self.tus: if self.tus and self.refiner_vae:
hf = h[:, :, :1, :, :] hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape b, c, f, ht, wd = hf.shape
nc = c // (2 * 2) nc = c // (2 * 2)
@ -148,43 +151,56 @@ class UpSmpl(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_): ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
super().__init__() super().__init__()
self.z_channels = z_channels self.z_channels = z_channels
self.block_out_channels = block_out_channels self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1) self.ffactor_temporal = ffactor_temporal
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList() self.down = nn.ModuleList()
ch = block_out_channels[0] ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length() depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length() depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0, temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm) conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)]) for j in range(num_res_blocks)])
ch = tgt ch = tgt
if i < depth: if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal) stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt ch = nxt
self.down.append(stage) self.down.append(stage)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.norm_out = RMS_norm(ch) self.norm_out = norm_op(ch)
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
def forward(self, x): def forward(self, x):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x) x = self.conv_in(x)
for stage in self.down: for stage in self.down:
@ -200,31 +216,42 @@ class Encoder(nn.Module):
skip = x.view(b, c // grp, grp, t, h, w).mean(2) skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2) if self.refiner_vae:
out = out.permute(0, 2, 1, 3, 4) out = self.regul(out)[0]
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4).contiguous() out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out return out
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_): ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
super().__init__() super().__init__()
block_out_channels = block_out_channels[::-1] block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels self.z_channels = z_channels
self.block_out_channels = block_out_channels self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
ch = block_out_channels[0] ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, 3) self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList() self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length() depth = (ffactor_spatial >> 1).bit_length()
@ -235,25 +262,26 @@ class Decoder(nn.Module):
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0, temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm) conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)]) for j in range(num_res_blocks + 1)])
ch = tgt ch = tgt
if i < depth: if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal) stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt ch = nxt
self.up.append(stage) self.up.append(stage)
self.norm_out = RMS_norm(ch) self.norm_out = norm_op(ch)
self.conv_out = VideoConv3d(ch, out_channels, 3) self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z): def forward(self, z):
z = z.permute(0, 2, 1, 3, 4) if self.refiner_vae:
b, f, c, h, w = z.shape z = z.permute(0, 2, 1, 3, 4)
z = z.reshape(b, f, 2, c // 2, h, w) b, f, c, h, w = z.shape
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4) z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z[:, :, 1:] z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
@ -264,4 +292,10 @@ class Decoder(nn.Module):
if hasattr(stage, 'upsample'): if hasattr(stage, 'upsample'):
x = stage.upsample(x) x = stage.upsample(x)
return self.conv_out(F.silu(self.norm_out(x))) out = self.conv_out(F.silu(self.norm_out(x)))
if not self.refiner_vae:
if z.shape[-3] == 1:
out = out[:, :, -1:]
return out

View File

@ -468,55 +468,46 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_upsample, dropout) attn_scales, self.temperal_upsample, dropout)
def encode(self, x): def encode(self, x):
self.clear_cache() conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
## cache ## cache
t = x.shape[2] t = x.shape[2]
iter_ = 1 + (t - 1) // 4 iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4.... ## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_): for i in range(iter_):
self._enc_conv_idx = [0] conv_idx = [0]
if i == 0: if i == 0:
out = self.encoder( out = self.encoder(
x[:, :, :1, :, :], x[:, :, :1, :, :],
feat_cache=self._enc_feat_map, feat_cache=feat_map,
feat_idx=self._enc_conv_idx) feat_idx=conv_idx)
else: else:
out_ = self.encoder( out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map, feat_cache=feat_map,
feat_idx=self._enc_conv_idx) feat_idx=conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1) mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu return mu
def decode(self, z): def decode(self, z):
self.clear_cache() conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
# z: [b,c,t,h,w] # z: [b,c,t,h,w]
iter_ = z.shape[2] iter_ = z.shape[2]
x = self.conv2(z) x = self.conv2(z)
for i in range(iter_): for i in range(iter_):
self._conv_idx = [0] conv_idx = [0]
if i == 0: if i == 0:
out = self.decoder( out = self.decoder(
x[:, :, i:i + 1, :, :], x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=feat_map,
feat_idx=self._conv_idx) feat_idx=conv_idx)
else: else:
out_ = self.decoder( out_ = self.decoder(
x[:, :, i:i + 1, :, :], x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=feat_map,
feat_idx=self._conv_idx) feat_idx=conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
self.clear_cache()
return out return out
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -332,35 +332,51 @@ class VAE:
self.first_stage_model = StageC_coder() self.first_stage_model = StageC_coder()
self.downscale_ratio = 32 self.downscale_ratio = 32
self.latent_channels = 16 self.latent_channels = 16
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif "decoder.conv_in.weight" in sd: elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE self.downscale_ratio = 32
ddconfig['ch_mult'] = [1, 2, 4] self.upscale_ratio = 32
self.downscale_ratio = 4 self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd: elif "decoder.layers.1.layers.0.beta" in sd:
self.first_stage_model = AudioOobleckVAE() self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
@ -636,6 +652,7 @@ class VAE:
def decode(self, samples_in, vae_options={}): def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
pixel_samples = None pixel_samples = None
do_tile = False
try: try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -651,6 +668,13 @@ class VAE:
pixel_samples[x:x+batch_number] = out pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
dims = samples_in.ndim - 2 dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None: if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in) pixel_samples = self.decode_tiled_1d(samples_in)
@ -697,6 +721,7 @@ class VAE:
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
do_tile = False
if self.latent_dim == 3 and pixel_samples.ndim < 5: if self.latent_dim == 3 and pixel_samples.ndim < 5:
if not self.not_video: if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
@ -718,6 +743,13 @@ class VAE:
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
if self.latent_dim == 3: if self.latent_dim == 3:
tile = 256 tile = 256
overlap = tile // 4 overlap = tile // 4

View File

@ -95,6 +95,7 @@ import aiohttp
import asyncio import asyncio
import logging import logging
import io import io
import os
import socket import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
@ -499,7 +500,9 @@ class ApiClient:
else: else:
raise ValueError("File must be BytesIO or str path") raise ValueError("File must be BytesIO or str path")
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" parsed = urlparse(upload_url)
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method="PUT", request_method="PUT",

View File

@ -4,16 +4,18 @@ import os
import datetime import datetime
import json import json
import logging import logging
import re
import hashlib
from typing import Any
import folder_paths import folder_paths
# Get the logger instance # Get the logger instance
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_log_directory(): def get_log_directory():
""" """Ensures the API log directory exists within ComfyUI's temp directory and returns its path."""
Ensures the API log directory exists within ComfyUI's temp directory
and returns its path.
"""
base_temp_dir = folder_paths.get_temp_directory() base_temp_dir = folder_paths.get_temp_directory()
log_dir = os.path.join(base_temp_dir, "api_logs") log_dir = os.path.join(base_temp_dir, "api_logs")
try: try:
@ -24,42 +26,77 @@ def get_log_directory():
return base_temp_dir return base_temp_dir
return log_dir return log_dir
def _format_data_for_logging(data):
def _sanitize_filename_component(name: str) -> str:
if not name:
return "log"
sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", name) # Replace disallowed characters with underscore
sanitized = sanitized.strip(" ._") # Windows: trailing dots or spaces are not allowed
if not sanitized:
sanitized = "log"
return sanitized
def _short_hash(*parts: str, length: int = 10) -> str:
return hashlib.sha1(("|".join(parts)).encode("utf-8")).hexdigest()[:length]
def _build_log_filepath(log_dir: str, operation_id: str, request_url: str) -> str:
"""Build log filepath. We keep it well under common path length limits aiming for <= 240 characters total."""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
slug = _sanitize_filename_component(operation_id) # Best-effort human-readable slug from operation_id
h = _short_hash(operation_id or "", request_url or "") # Short hash ties log to the full operation and URL
# Compute how much room we have for the slug given the directory length
# Keep total path length reasonably below ~260 on Windows.
max_total_path = 240
prefix = f"{timestamp}_"
suffix = f"_{h}.log"
if not slug:
slug = "op"
max_filename_len = max(60, max_total_path - len(log_dir) - 1)
max_slug_len = max(8, max_filename_len - len(prefix) - len(suffix))
if len(slug) > max_slug_len:
slug = slug[:max_slug_len].rstrip(" ._-")
return os.path.join(log_dir, f"{prefix}{slug}{suffix}")
def _format_data_for_logging(data: Any) -> str:
"""Helper to format data (dict, str, bytes) for logging.""" """Helper to format data (dict, str, bytes) for logging."""
if isinstance(data, bytes): if isinstance(data, bytes):
try: try:
return data.decode('utf-8') # Try to decode as text return data.decode("utf-8") # Try to decode as text
except UnicodeDecodeError: except UnicodeDecodeError:
return f"[Binary data of length {len(data)} bytes]" return f"[Binary data of length {len(data)} bytes]"
elif isinstance(data, (dict, list)): elif isinstance(data, (dict, list)):
try: try:
return json.dumps(data, indent=2, ensure_ascii=False) return json.dumps(data, indent=2, ensure_ascii=False)
except TypeError: except TypeError:
return str(data) # Fallback for non-serializable objects return str(data) # Fallback for non-serializable objects
return str(data) return str(data)
def log_request_response( def log_request_response(
operation_id: str, operation_id: str,
request_method: str, request_method: str,
request_url: str, request_url: str,
request_headers: dict | None = None, request_headers: dict | None = None,
request_params: dict | None = None, request_params: dict | None = None,
request_data: any = None, request_data: Any = None,
response_status_code: int | None = None, response_status_code: int | None = None,
response_headers: dict | None = None, response_headers: dict | None = None,
response_content: any = None, response_content: Any = None,
error_message: str | None = None error_message: str | None = None,
): ):
""" """
Logs API request and response details to a file in the temp/api_logs directory. Logs API request and response details to a file in the temp/api_logs directory.
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
""" """
log_dir = get_log_directory() log_dir = get_log_directory()
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") filepath = _build_log_filepath(log_dir, operation_id, request_url)
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
filepath = os.path.join(log_dir, filename)
log_content = []
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}") log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}") log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30) log_content.append("-" * 30 + " REQUEST " + "-" * 30)
@ -69,7 +106,7 @@ def log_request_response(
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params: if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data: if request_data is not None:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}") log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30) log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
@ -77,7 +114,7 @@ def log_request_response(
log_content.append(f"Status Code: {response_status_code}") log_content.append(f"Status Code: {response_status_code}")
if response_headers: if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content: if response_content is not None:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message: if error_message:
log_content.append(f"Error:\n{error_message}") log_content.append(f"Error:\n{error_message}")
@ -89,6 +126,7 @@ def log_request_response(
except Exception as e: except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}") logger.error(f"Error writing API log to {filepath}: {e}")
if __name__ == '__main__': if __name__ == '__main__':
# Example usage (for testing the logger directly) # Example usage (for testing the logger directly)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)

View File

@ -1,34 +1,41 @@
# code adapted from https://github.com/exx8/differential-diffusion # code adapted from https://github.com/exx8/differential-diffusion
from typing_extensions import override
import torch import torch
from comfy_api.latest import ComfyExtension, io
class DifferentialDiffusion():
class DifferentialDiffusion(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="DifferentialDiffusion",
"model": ("MODEL", ), display_name="Differential Diffusion",
}, category="_for_testing",
"optional": { inputs=[
"strength": ("FLOAT", { io.Model.Input("model"),
"default": 1.0, io.Float.Input(
"min": 0.0, "strength",
"max": 1.0, default=1.0,
"step": 0.01, min=0.0,
}), max=1.0,
} step=0.01,
} optional=True,
RETURN_TYPES = ("MODEL",) ),
FUNCTION = "apply" ],
CATEGORY = "_for_testing" outputs=[io.Model.Output()],
INIT = False is_experimental=True,
)
def apply(self, model, strength=1.0): @classmethod
def execute(cls, model, strength=1.0) -> io.NodeOutput:
model = model.clone() model = model.clone()
model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength)) model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength))
return (model, ) return io.NodeOutput(model)
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): @classmethod
def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
model = extra_options["model"] model = extra_options["model"]
step_sigmas = extra_options["sigmas"] step_sigmas = extra_options["sigmas"]
sigma_to = model.inner_model.model_sampling.sigma_min sigma_to = model.inner_model.model_sampling.sigma_min
@ -53,9 +60,13 @@ class DifferentialDiffusion():
return binary_mask return binary_mask
NODE_CLASS_MAPPINGS = { class DifferentialDiffusionExtension(ComfyExtension):
"DifferentialDiffusion": DifferentialDiffusion, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
NODE_DISPLAY_NAME_MAPPINGS = { return [
"DifferentialDiffusion": "Differential Diffusion", DifferentialDiffusion,
} ]
async def comfy_entrypoint() -> DifferentialDiffusionExtension:
return DifferentialDiffusionExtension()

60
comfy_extras/nodes_eps.py Normal file
View File

@ -0,0 +1,60 @@
class EpsilonScaling:
"""
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
(https://arxiv.org/abs/2308.15321v6).
This method mitigates exposure bias by scaling the predicted noise during sampling,
which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scaling_factor": ("FLOAT", {
"default": 1.005,
"min": 0.5,
"max": 1.5,
"step": 0.001,
"display": "number"
}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def patch(self, model, scaling_factor):
# Prevent division by zero, though the UI's min value should prevent this.
if scaling_factor == 0:
scaling_factor = 1e-9
def epsilon_scaling_function(args):
"""
This function is applied after the CFG guidance has been calculated.
It recalculates the denoised latent by scaling the predicted noise.
"""
denoised = args["denoised"]
x = args["input"]
noise_pred = x - denoised
scaled_noise_pred = noise_pred / scaling_factor
new_denoised = x - scaled_noise_pred
return new_denoised
# Clone the model patcher to avoid modifying the original model in place
model_clone = model.clone()
model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)
return (model_clone,)
NODE_CLASS_MAPPINGS = {
"Epsilon Scaling": EpsilonScaling
}

View File

@ -1,6 +1,8 @@
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main # from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
import numpy as np import numpy as np
import torch import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps): def loglinear_interp(t_steps, num_steps):
""" """
@ -333,25 +335,28 @@ NOISE_LEVELS = {
], ],
} }
class GITSScheduler: class GITSScheduler(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return io.Schema(
{"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}), node_id="GITSScheduler",
"steps": ("INT", {"default": 10, "min": 2, "max": 1000}), category="sampling/custom_sampling/schedulers",
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), inputs=[
} io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
} io.Int.Input("steps", default=10, min=2, max=1000),
RETURN_TYPES = ("SIGMAS",) io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
CATEGORY = "sampling/custom_sampling/schedulers" ],
outputs=[
io.Sigmas.Output(),
],
)
FUNCTION = "get_sigmas" @classmethod
def execute(cls, coeff, steps, denoise):
def get_sigmas(self, coeff, steps, denoise):
total_steps = steps total_steps = steps
if denoise < 1.0: if denoise < 1.0:
if denoise <= 0.0: if denoise <= 0.0:
return (torch.FloatTensor([]),) return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise) total_steps = round(steps * denoise)
if steps <= 20: if steps <= 20:
@ -362,8 +367,16 @@ class GITSScheduler:
sigmas = sigmas[-(total_steps + 1):] sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0 sigmas[-1] = 0
return (torch.FloatTensor(sigmas), ) return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"GITSScheduler": GITSScheduler, class GITSSchedulerExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
GITSScheduler,
]
async def comfy_entrypoint() -> GITSSchedulerExtension:
return GITSSchedulerExtension()

View File

@ -1,21 +1,30 @@
import torch import torch
class InstructPixToPixConditioning: from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class InstructPixToPixConditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="InstructPixToPixConditioning",
"vae": ("VAE", ), category="conditioning/instructpix2pix",
"pixels": ("IMAGE", ), inputs=[
}} io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Image.Input("pixels"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, pixels, vae) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/instructpix2pix"
def encode(self, positive, negative, pixels, vae):
x = (pixels.shape[1] // 8) * 8 x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8 y = (pixels.shape[2] // 8) * 8
@ -38,8 +47,17 @@ class InstructPixToPixConditioning:
n = [t[0], d] n = [t[0], d]
c.append(n) c.append(n)
out.append(c) out.append(c)
return (out[0], out[1], out_latent) return io.NodeOutput(out[0], out[1], out_latent)
class InstructPix2PixExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
InstructPixToPixConditioning,
]
async def comfy_entrypoint() -> InstructPix2PixExtension:
return InstructPix2PixExtension()
NODE_CLASS_MAPPINGS = {
"InstructPixToPixConditioning": InstructPixToPixConditioning,
}

View File

@ -1,4 +1,3 @@
import io
import nodes import nodes
import node_helpers import node_helpers
import torch import torch
@ -8,46 +7,60 @@ import comfy.utils
import math import math
import numpy as np import numpy as np
import av import av
from io import BytesIO
from typing_extensions import override
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy_api.latest import ComfyExtension, io
class EmptyLTXVLatentVideo: class EmptyLTXVLatentVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), return io.Schema(
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), node_id="EmptyLTXVLatentVideo",
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), category="latent/video/ltxv",
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} inputs=[
RETURN_TYPES = ("LATENT",) io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
FUNCTION = "generate" io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent/video/ltxv" @classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples": latent}, ) return io.NodeOutput({"samples": latent})
class LTXVImgToVideo: class LTXVImgToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVImgToVideo",
"vae": ("VAE",), category="conditioning/video_models",
"image": ("IMAGE",), inputs=[
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), io.Conditioning.Input("positive"),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), io.Conditioning.Input("negative"),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), io.Vae.Input("vae"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Image.Input("image"),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
}} io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength) -> io.NodeOutput:
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3] encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
@ -62,7 +75,7 @@ class LTXVImgToVideo:
) )
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, ) return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask})
def conditioning_get_any_value(conditioning, key, default=None): def conditioning_get_any_value(conditioning, key, default=None):
@ -93,35 +106,46 @@ def get_keyframe_idxs(cond):
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
return keyframe_idxs, num_keyframes return keyframe_idxs, num_keyframes
class LTXVAddGuide: class LTXVAddGuide(io.ComfyNode):
NUM_PREFIX_FRAMES = 2
PATCHIFIER = SymmetricPatchifier(1)
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVAddGuide",
"vae": ("VAE",), category="conditioning/video_models",
"latent": ("LATENT",), inputs=[
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." io.Conditioning.Input("positive"),
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), io.Conditioning.Input("negative"),
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, io.Vae.Input("vae"),
"tooltip": "Frame index to start the conditioning at. For single-frame images or " io.Latent.Input("latent"),
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ " io.Image.Input(
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to " "image",
"the nearest multiple of 8. Negative values are counted from the end of the video."}), tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.",
} ),
} io.Int.Input(
"frame_idx",
default=0,
min=-9999,
max=9999,
tooltip="Frame index to start the conditioning at. "
"For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. "
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def encode(cls, vae, latent_width, latent_height, images, scale_factors):
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def __init__(self):
self._num_prefix_frames = 2
self._patchifier = SymmetricPatchifier(1)
def encode(self, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
@ -129,7 +153,8 @@ class LTXVAddGuide:
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
return encode_pixels, t return encode_pixels, t
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors): @classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond) _, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes latent_count = latent_length - num_keyframes
@ -141,9 +166,10 @@ class LTXVAddGuide:
return frame_idx, latent_idx return frame_idx, latent_idx
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors): @classmethod
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors):
keyframe_idxs, _ = get_keyframe_idxs(cond) keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = self._patchifier.patchify(guiding_latent) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
pixel_coords[:, 0] += frame_idx pixel_coords[:, 0] += frame_idx
if keyframe_idxs is None: if keyframe_idxs is None:
@ -152,8 +178,9 @@ class LTXVAddGuide:
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2) keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): @classmethod
_, latent_idx = self.get_latent_index( def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
_, latent_idx = cls.get_latent_index(
cond=positive, cond=positive,
latent_length=latent_image.shape[2], latent_length=latent_image.shape[2],
guide_length=guiding_latent.shape[2], guide_length=guiding_latent.shape[2],
@ -162,8 +189,8 @@ class LTXVAddGuide:
) )
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full( mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
@ -176,7 +203,8 @@ class LTXVAddGuide:
noise_mask = torch.cat([noise_mask, mask], dim=2) noise_mask = torch.cat([noise_mask, mask], dim=2)
return positive, negative, latent_image, noise_mask return positive, negative, latent_image, noise_mask
def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength): @classmethod
def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength):
cond_length = guiding_latent.shape[2] cond_length = guiding_latent.shape[2]
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence." assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
@ -195,20 +223,21 @@ class LTXVAddGuide:
return latent_image, noise_mask return latent_image, noise_mask
def generate(self, positive, negative, vae, latent, image, frame_idx, strength): @classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula scale_factors = vae.downscale_index_formula
latent_image = latent["samples"] latent_image = latent["samples"]
noise_mask = get_noise_mask(latent) noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape _, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
positive, negative, latent_image, noise_mask = self.append_keyframe( positive, negative, latent_image, noise_mask = cls.append_keyframe(
positive, positive,
negative, negative,
frame_idx, frame_idx,
@ -223,9 +252,9 @@ class LTXVAddGuide:
t = t[:, :, num_prefix_frames:] t = t[:, :, num_prefix_frames:]
if t.shape[2] == 0: if t.shape[2] == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
latent_image, noise_mask = self.replace_latent_frames( latent_image, noise_mask = cls.replace_latent_frames(
latent_image, latent_image,
noise_mask, noise_mask,
t, t,
@ -233,34 +262,35 @@ class LTXVAddGuide:
strength, strength,
) )
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
class LTXVCropGuides: class LTXVCropGuides(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVCropGuides",
"latent": ("LATENT",), category="conditioning/video_models",
} inputs=[
} io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Latent.Input("latent"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, latent) -> io.NodeOutput:
CATEGORY = "conditioning/video_models"
FUNCTION = "crop"
def __init__(self):
self._patchifier = SymmetricPatchifier(1)
def crop(self, positive, negative, latent):
latent_image = latent["samples"].clone() latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent) noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive) _, num_keyframes = get_keyframe_idxs(positive)
if num_keyframes == 0: if num_keyframes == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image = latent_image[:, :, :-num_keyframes] latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes]
@ -268,44 +298,52 @@ class LTXVCropGuides:
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
class LTXVConditioning: class LTXVConditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVConditioning",
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}), category="conditioning/video_models",
}} inputs=[
RETURN_TYPES = ("CONDITIONING", "CONDITIONING") io.Conditioning.Input("positive"),
RETURN_NAMES = ("positive", "negative") io.Conditioning.Input("negative"),
FUNCTION = "append" io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
CATEGORY = "conditioning/video_models" @classmethod
def execute(cls, positive, negative, frame_rate) -> io.NodeOutput:
def append(self, positive, negative, frame_rate):
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate}) positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate}) negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
return (positive, negative) return io.NodeOutput(positive, negative)
class ModelSamplingLTXV: class ModelSamplingLTXV(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), node_id="ModelSamplingLTXV",
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), category="advanced/model",
}, inputs=[
"optional": {"latent": ("LATENT",), } io.Model.Input("model"),
} io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Model.Output(),
],
)
RETURN_TYPES = ("MODEL",) @classmethod
FUNCTION = "patch" def execute(cls, model, max_shift, base_shift, latent=None) -> io.NodeOutput:
CATEGORY = "advanced/model"
def patch(self, model, max_shift, base_shift, latent=None):
m = model.clone() m = model.clone()
if latent is None: if latent is None:
@ -329,37 +367,41 @@ class ModelSamplingLTXV:
model_sampling.set_parameters(shift=shift) model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return io.NodeOutput(m)
class LTXVScheduler: class LTXVScheduler(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return io.Schema(
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), node_id="LTXVScheduler",
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), category="sampling/custom_sampling/schedulers",
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), inputs=[
"stretch": ("BOOLEAN", { io.Int.Input("steps", default=20, min=1, max=10000),
"default": True, io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]." io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
}), io.Boolean.Input(
"terminal": ( id="stretch",
"FLOAT", default=True,
{ tooltip="Stretch the sigmas to be in the range [terminal, 1].",
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01, ),
"tooltip": "The terminal value of the sigmas after stretching." io.Float.Input(
}, id="terminal",
), default=0.1,
}, min=0.0,
"optional": {"latent": ("LATENT",), } max=0.99,
} step=0.01,
tooltip="The terminal value of the sigmas after stretching.",
),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Sigmas.Output(),
],
)
RETURN_TYPES = ("SIGMAS",) @classmethod
CATEGORY = "sampling/custom_sampling/schedulers" def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None) -> io.NodeOutput:
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
if latent is None: if latent is None:
tokens = 4096 tokens = 4096
else: else:
@ -389,7 +431,7 @@ class LTXVScheduler:
stretched = 1.0 - (one_minus_z / scale_factor) stretched = 1.0 - (one_minus_z / scale_factor)
sigmas[non_zero_mask] = stretched sigmas[non_zero_mask] = stretched
return (sigmas,) return io.NodeOutput(sigmas)
def encode_single_frame(output_file, image_array: np.ndarray, crf): def encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4") container = av.open(output_file, "w", format="mp4")
@ -423,52 +465,54 @@ def preprocess(image: torch.Tensor, crf=29):
return image return image
image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy() image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
with io.BytesIO() as output_file: with BytesIO() as output_file:
encode_single_frame(output_file, image_array, crf) encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue() video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file: with BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file) image_array = decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor return tensor
class LTXVPreprocess: class LTXVPreprocess(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="LTXVPreprocess",
"image": ("IMAGE",), category="image",
"img_compression": ( inputs=[
"INT", io.Image.Input("image"),
{ io.Int.Input(
"default": 35, id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image."
"min": 0,
"max": 100,
"tooltip": "Amount of compression to apply on image.",
},
), ),
} ],
} outputs=[
io.Image.Output(display_name="output_image"),
],
)
FUNCTION = "preprocess" @classmethod
RETURN_TYPES = ("IMAGE",) def execute(cls, image, img_compression) -> io.NodeOutput:
RETURN_NAMES = ("output_image",)
CATEGORY = "image"
def preprocess(self, image, img_compression):
output_images = [] output_images = []
for i in range(image.shape[0]): for i in range(image.shape[0]):
output_images.append(preprocess(image[i], img_compression)) output_images.append(preprocess(image[i], img_compression))
return (torch.stack(output_images),) return io.NodeOutput(torch.stack(output_images))
NODE_CLASS_MAPPINGS = { class LtxvExtension(ComfyExtension):
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, @override
"LTXVImgToVideo": LTXVImgToVideo, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"ModelSamplingLTXV": ModelSamplingLTXV, return [
"LTXVConditioning": LTXVConditioning, EmptyLTXVLatentVideo,
"LTXVScheduler": LTXVScheduler, LTXVImgToVideo,
"LTXVAddGuide": LTXVAddGuide, ModelSamplingLTXV,
"LTXVPreprocess": LTXVPreprocess, LTXVConditioning,
"LTXVCropGuides": LTXVCropGuides, LTXVScheduler,
} LTXVAddGuide,
LTXVPreprocess,
LTXVCropGuides,
]
async def comfy_entrypoint() -> LtxvExtension:
return LtxvExtension()

View File

@ -1,24 +1,34 @@
import torch import torch
import comfy.model_management import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
import kornia.color import kornia.color
class Morphology: class Morphology(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"image": ("IMAGE",), return io.Schema(
"operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],), node_id="Morphology",
"kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}), display_name="ImageMorphology",
}} category="image/postprocessing",
inputs=[
io.Image.Input("image"),
io.Combo.Input(
"operation",
options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],
),
io.Int.Input("kernel_size", default=3, min=3, max=999, step=1),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "process" def execute(cls, image, operation, kernel_size) -> io.NodeOutput:
CATEGORY = "image/postprocessing"
def process(self, image, operation, kernel_size):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
kernel = torch.ones(kernel_size, kernel_size, device=device) kernel = torch.ones(kernel_size, kernel_size, device=device)
image_k = image.to(device).movedim(-1, 1) image_k = image.to(device).movedim(-1, 1)
@ -39,49 +49,63 @@ class Morphology:
else: else:
raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'")
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
return (img_out,) return io.NodeOutput(img_out)
class ImageRGBToYUV: class ImageRGBToYUV(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return io.Schema(
}} node_id="ImageRGBToYUV",
category="image/batch",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(display_name="Y"),
io.Image.Output(display_name="U"),
io.Image.Output(display_name="V"),
],
)
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") @classmethod
RETURN_NAMES = ("Y", "U", "V") def execute(cls, image) -> io.NodeOutput:
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, image):
out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1) out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1)
return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) return io.NodeOutput(out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
class ImageYUVToRGB: class ImageYUVToRGB(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"Y": ("IMAGE",), return io.Schema(
"U": ("IMAGE",), node_id="ImageYUVToRGB",
"V": ("IMAGE",), category="image/batch",
}} inputs=[
io.Image.Input("Y"),
io.Image.Input("U"),
io.Image.Input("V"),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "execute" def execute(cls, Y, U, V) -> io.NodeOutput:
CATEGORY = "image/batch"
def execute(self, Y, U, V):
image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1) image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1)
out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1) out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)
return (out,) return io.NodeOutput(out)
NODE_CLASS_MAPPINGS = {
"Morphology": Morphology,
"ImageRGBToYUV": ImageRGBToYUV,
"ImageYUVToRGB": ImageYUVToRGB,
}
NODE_DISPLAY_NAME_MAPPINGS = { class MorphologyExtension(ComfyExtension):
"Morphology": "ImageMorphology", @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Morphology,
ImageRGBToYUV,
ImageYUVToRGB,
]
async def comfy_entrypoint() -> MorphologyExtension:
return MorphologyExtension()

View File

@ -1,9 +1,12 @@
# from https://github.com/bebebe666/OptimalSteps # from https://github.com/bebebe666/OptimalSteps
import numpy as np import numpy as np
import torch import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps): def loglinear_interp(t_steps, num_steps):
""" """
Performs log-linear interpolation of a given array of decreasing numbers. Performs log-linear interpolation of a given array of decreasing numbers.
@ -23,25 +26,28 @@ NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0
"Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001], "Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001],
} }
class OptimalStepsScheduler: class OptimalStepsScheduler(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return io.Schema(
{"model_type": (["FLUX", "Wan", "Chroma"], ), node_id="OptimalStepsScheduler",
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}), category="sampling/custom_sampling/schedulers",
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), inputs=[
} io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]),
} io.Int.Input("steps", default=20, min=3, max=1000),
RETURN_TYPES = ("SIGMAS",) io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
CATEGORY = "sampling/custom_sampling/schedulers" ],
outputs=[
io.Sigmas.Output(),
],
)
FUNCTION = "get_sigmas" @classmethod
def execute(cls, model_type, steps, denoise) ->io.NodeOutput:
def get_sigmas(self, model_type, steps, denoise):
total_steps = steps total_steps = steps
if denoise < 1.0: if denoise < 1.0:
if denoise <= 0.0: if denoise <= 0.0:
return (torch.FloatTensor([]),) return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise) total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:] sigmas = NOISE_LEVELS[model_type][:]
@ -50,8 +56,16 @@ class OptimalStepsScheduler:
sigmas = sigmas[-(total_steps + 1):] sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0 sigmas[-1] = 0
return (torch.FloatTensor(sigmas), ) return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"OptimalStepsScheduler": OptimalStepsScheduler, class OptimalStepsExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
OptimalStepsScheduler,
]
async def comfy_entrypoint() -> OptimalStepsExtension:
return OptimalStepsExtension()

View File

@ -3,25 +3,30 @@
#My modified one here is more basic but has less chances of breaking with ComfyUI updates. #My modified one here is more basic but has less chances of breaking with ComfyUI updates.
from typing_extensions import override
import comfy.model_patcher import comfy.model_patcher
import comfy.samplers import comfy.samplers
from comfy_api.latest import ComfyExtension, io
class PerturbedAttentionGuidance:
class PerturbedAttentionGuidance(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="PerturbedAttentionGuidance",
"model": ("MODEL",), category="model_patches/unet",
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), inputs=[
} io.Model.Input("model"),
} io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01),
],
outputs=[
io.Model.Output(),
],
)
RETURN_TYPES = ("MODEL",) @classmethod
FUNCTION = "patch" def execute(cls, model, scale) -> io.NodeOutput:
CATEGORY = "model_patches/unet"
def patch(self, model, scale):
unet_block = "middle" unet_block = "middle"
unet_block_id = 0 unet_block_id = 0
m = model.clone() m = model.clone()
@ -49,8 +54,16 @@ class PerturbedAttentionGuidance:
m.set_model_sampler_post_cfg_function(post_cfg_function) m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m,) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"PerturbedAttentionGuidance": PerturbedAttentionGuidance, class PAGExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PerturbedAttentionGuidance,
]
async def comfy_entrypoint() -> PAGExtension:
return PAGExtension()

View File

@ -115,6 +115,7 @@ if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0' os.environ['MIMALLOC_PURGE_DELAY'] = '0'
if __name__ == "__main__": if __name__ == "__main__":
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
if args.default_device is not None: if args.default_device is not None:
default_dev = args.default_device default_dev = args.default_device
devices = list(range(32)) devices = list(range(32))

View File

@ -2297,6 +2297,7 @@ async def init_builtin_extra_nodes():
"nodes_gits.py", "nodes_gits.py",
"nodes_controlnet.py", "nodes_controlnet.py",
"nodes_hunyuan.py", "nodes_hunyuan.py",
"nodes_eps.py",
"nodes_flux.py", "nodes_flux.py",
"nodes_lora_extract.py", "nodes_lora_extract.py",
"nodes_torch_compile.py", "nodes_torch_compile.py",