From 33bd9ed9cb941127b335244c6cc0a8cdc1ac1696 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 11 Sep 2025 21:43:20 -0700 Subject: [PATCH 01/30] Implement hunyuan image refiner model. (#9817) --- comfy/latent_formats.py | 5 + comfy/ldm/hunyuan_video/model.py | 11 +- comfy/ldm/hunyuan_video/vae_refiner.py | 268 ++++++++++++++++++++ comfy/ldm/models/autoencoder.py | 6 + comfy/ldm/modules/diffusionmodules/model.py | 10 +- comfy/model_base.py | 20 ++ comfy/sd.py | 17 +- comfy/supported_models.py | 19 +- comfy_extras/nodes_hunyuan.py | 23 ++ 9 files changed, 367 insertions(+), 12 deletions(-) create mode 100644 comfy/ldm/hunyuan_video/vae_refiner.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f975b5e11..894540879 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -606,6 +606,11 @@ class HunyuanImage21(LatentFormat): latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206] +class HunyuanImage21Refiner(LatentFormat): + latent_channels = 64 + latent_dimensions = 3 + scale_factor = 1.03682 + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 7732182a4..ca86b8bb1 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -278,6 +278,7 @@ class HunyuanVideo(nn.Module): guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, + disable_time_r=False, control=None, transformer_options={}, ) -> Tensor: @@ -288,7 +289,7 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) - if self.time_r_in is not None: + if (self.time_r_in is not None) and (not disable_time_r): w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved if len(w) > 0: timesteps_r = transformer_options['sample_sigmas'][w[0] + 1] @@ -428,14 +429,14 @@ class HunyuanVideo(nn.Module): img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) return repeat(img_ids, "h w c -> b (h w) c", b=bs) - def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs) + ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): bs = x.shape[0] if len(self.patch_size) == 3: img_ids = self.img_ids(x) @@ -443,5 +444,5 @@ class HunyuanVideo(nn.Module): else: img_ids = self.img_ids_2d(x) txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) return out diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py new file mode 100644 index 000000000..e3fff9bbe --- /dev/null +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -0,0 +1,268 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d +import comfy.ops +import comfy.ldm.models.autoencoder +ops = comfy.ops.disable_weight_init + +class RMS_norm(nn.Module): + def __init__(self, dim): + super().__init__() + shape = (dim, 1, 1, 1) + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.empty(shape)) + + def forward(self, x): + return F.normalize(x, dim=1) * self.scale * self.gamma + +class DnSmpl(nn.Module): + def __init__(self, ic, oc, tds=True): + super().__init__() + fct = 2 * 2 * 2 if tds else 1 * 2 * 2 + assert oc % fct == 0 + self.conv = VideoConv3d(ic, oc // fct, kernel_size=3) + + self.tds = tds + self.gs = fct * ic // oc + + def forward(self, x): + r1 = 2 if self.tds else 1 + h = self.conv(x) + + if self.tds: + hf = h[:, :, :1, :, :] + b, c, f, ht, wd = hf.shape + hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) + hf = hf.permute(0, 4, 6, 1, 2, 3, 5) + hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2) + hf = torch.cat([hf, hf], dim=1) + + hn = h[:, :, 1:, :, :] + b, c, frms, ht, wd = hn.shape + nf = frms // r1 + hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) + hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6) + hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) + + h = torch.cat([hf, hn], dim=2) + + xf = x[:, :, :1, :, :] + b, ci, f, ht, wd = xf.shape + xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2) + xf = xf.permute(0, 4, 6, 1, 2, 3, 5) + xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2) + B, C, T, H, W = xf.shape + xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2) + + xn = x[:, :, 1:, :, :] + b, ci, frms, ht, wd = xn.shape + nf = frms // r1 + xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) + xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6) + xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) + B, C, T, H, W = xn.shape + xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + sc = torch.cat([xf, xn], dim=2) + else: + b, c, frms, ht, wd = h.shape + nf = frms // r1 + h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) + h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) + h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) + + b, ci, frms, ht, wd = x.shape + nf = frms // r1 + sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) + sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6) + sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) + B, C, T, H, W = sc.shape + sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + + return h + sc + + +class UpSmpl(nn.Module): + def __init__(self, ic, oc, tus=True): + super().__init__() + fct = 2 * 2 * 2 if tus else 1 * 2 * 2 + self.conv = VideoConv3d(ic, oc * fct, kernel_size=3) + + self.tus = tus + self.rp = fct * oc // ic + + def forward(self, x): + r1 = 2 if self.tus else 1 + h = self.conv(x) + + if self.tus: + hf = h[:, :, :1, :, :] + b, c, f, ht, wd = hf.shape + nc = c // (2 * 2) + hf = hf.reshape(b, 2, 2, nc, f, ht, wd) + hf = hf.permute(0, 3, 4, 5, 1, 6, 2) + hf = hf.reshape(b, nc, f, ht * 2, wd * 2) + hf = hf[:, : hf.shape[1] // 2] + + hn = h[:, :, 1:, :, :] + b, c, frms, ht, wd = hn.shape + nc = c // (r1 * 2 * 2) + hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd) + hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3) + hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + h = torch.cat([hf, hn], dim=2) + + xf = x[:, :, :1, :, :] + b, ci, f, ht, wd = xf.shape + xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1) + b, c, f, ht, wd = xf.shape + nc = c // (2 * 2) + xf = xf.reshape(b, 2, 2, nc, f, ht, wd) + xf = xf.permute(0, 3, 4, 5, 1, 6, 2) + xf = xf.reshape(b, nc, f, ht * 2, wd * 2) + + xn = x[:, :, 1:, :, :] + xn = xn.repeat_interleave(repeats=self.rp, dim=1) + b, c, frms, ht, wd = xn.shape + nc = c // (r1 * 2 * 2) + xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd) + xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3) + xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2) + sc = torch.cat([xf, xn], dim=2) + else: + b, c, frms, ht, wd = h.shape + nc = c // (r1 * 2 * 2) + h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) + h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) + h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + sc = x.repeat_interleave(repeats=self.rp, dim=1) + b, c, frms, ht, wd = sc.shape + nc = c // (r1 * 2 * 2) + sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd) + sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3) + sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + return h + sc + +class Encoder(nn.Module): + def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, + ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_): + super().__init__() + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1) + + self.down = nn.ModuleList() + ch = block_out_channels[0] + depth = (ffactor_spatial >> 1).bit_length() + depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=VideoConv3d, norm_op=RMS_norm) + for j in range(num_res_blocks)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch + stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal) + ch = nxt + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + + self.norm_out = RMS_norm(ch) + self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1) + + self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() + + def forward(self, x): + x = x.unsqueeze(2) + x = self.conv_in(x) + + for stage in self.down: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'downsample'): + x = stage.downsample(x) + + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + b, c, t, h, w = x.shape + grp = c // (self.z_channels << 1) + skip = x.view(b, c // grp, grp, t, h, w).mean(2) + + out = self.conv_out(F.silu(self.norm_out(x))) + skip + out = self.regul(out)[0] + + out = torch.cat((out[:, :, :1], out), dim=2) + out = out.permute(0, 2, 1, 3, 4) + b, f_times_2, c, h, w = out.shape + out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + out = out.permute(0, 2, 1, 3, 4).contiguous() + return out + +class Decoder(nn.Module): + def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, + ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_): + super().__init__() + block_out_channels = block_out_channels[::-1] + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + ch = block_out_channels[0] + self.conv_in = VideoConv3d(z_channels, ch, 3) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + + self.up = nn.ModuleList() + depth = (ffactor_spatial >> 1).bit_length() + depth_temporal = (ffactor_temporal >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=VideoConv3d, norm_op=RMS_norm) + for j in range(num_res_blocks + 1)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch + stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal) + ch = nxt + self.up.append(stage) + + self.norm_out = RMS_norm(ch) + self.conv_out = VideoConv3d(ch, out_channels, 3) + + def forward(self, z): + z = z.permute(0, 2, 1, 3, 4) + b, f, c, h, w = z.shape + z = z.reshape(b, f, 2, c // 2, h, w) + z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + z = z.permute(0, 2, 1, 3, 4) + z = z[:, :, 1:] + + x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + for stage in self.up: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'upsample'): + x = stage.upsample(x) + + return self.conv_out(F.silu(self.norm_out(x))) diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 13bd6e16b..611d36a1b 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -26,6 +26,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module): z = posterior.mode() return z, None +class EmptyRegularizer(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + return z, None class AbstractAutoencoder(torch.nn.Module): """ diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 8f598a848..4245eedca 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -145,7 +145,7 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout=0.0, temb_channels=512, conv_op=ops.Conv2d): + dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -153,7 +153,7 @@ class ResnetBlock(nn.Module): self.use_conv_shortcut = conv_shortcut self.swish = torch.nn.SiLU(inplace=True) - self.norm1 = Normalize(in_channels) + self.norm1 = norm_op(in_channels) self.conv1 = conv_op(in_channels, out_channels, kernel_size=3, @@ -162,7 +162,7 @@ class ResnetBlock(nn.Module): if temb_channels > 0: self.temb_proj = ops.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) + self.norm2 = norm_op(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) self.conv2 = conv_op(out_channels, out_channels, @@ -305,11 +305,11 @@ def vae_attention(): return normal_attention class AttnBlock(nn.Module): - def __init__(self, in_channels, conv_op=ops.Conv2d): + def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize): super().__init__() self.in_channels = in_channels - self.norm = Normalize(in_channels) + self.norm = norm_op(in_channels) self.q = conv_op(in_channels, in_channels, kernel_size=1, diff --git a/comfy/model_base.py b/comfy/model_base.py index 993ff65e6..c69a9d1ad 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1432,3 +1432,23 @@ class HunyuanImage21(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out + +class HunyuanImage21Refiner(HunyuanImage21): + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape_image = list(noise.shape) + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = self.process_latent_in(image) + image = utils.resize_to_batch_size(image, noise.shape[0]) + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out['disable_time_r'] = comfy.conds.CONDConstant(True) + return out diff --git a/comfy/sd.py b/comfy/sd.py index 9dd9a74d4..02ddc7239 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -285,6 +285,7 @@ class VAE: self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False + self.not_video = False self.downscale_index_formula = None self.upscale_index_formula = None @@ -409,6 +410,20 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] + elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: + ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] + self.downscale_ratio = 16 + self.upscale_ratio = 16 + self.latent_dim = 3 + self.not_video = True + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"}, + encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True @@ -669,7 +684,7 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) - if self.latent_dim == 3 and pixel_samples.ndim < 5: + if not self.not_video and self.latent_dim == 3 and pixel_samples.ndim < 5: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index aa953b462..ba1b8c313 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1321,6 +1321,23 @@ class HunyuanImage21(HunyuanVideo): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +class HunyuanImage21Refiner(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "patch_size": [1, 1, 1], + "vec_in_dim": None, + } + + sampling_settings = { + "shift": 1.0, + } + + latent_format = latent_formats.HunyuanImage21Refiner + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanImage21Refiner(self, device=device) + return out + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index ce031ceb2..351a7e2cb 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -128,6 +128,28 @@ class EmptyHunyuanImageLatent: latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device()) return ({"samples":latent}, ) +class HunyuanRefinerLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "latent": ("LATENT", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "execute" + + def execute(self, positive, negative, latent): + latent = latent["samples"] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent}) + out_latent = {} + out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, @@ -135,4 +157,5 @@ NODE_CLASS_MAPPINGS = { "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, "HunyuanImageToVideo": HunyuanImageToVideo, "EmptyHunyuanImageLatent": EmptyHunyuanImageLatent, + "HunyuanRefinerLatent": HunyuanRefinerLatent, } From 15ec9ea958d1c5d374add598b571a585541d4863 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 11 Sep 2025 21:44:20 -0700 Subject: [PATCH 02/30] Add Output to V3 Combo type to match what is possible with V1 (#9813) --- comfy_api/latest/_io.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index f770109d5..4826818df 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -331,7 +331,7 @@ class String(ComfyTypeIO): }) @comfytype(io_type="COMBO") -class Combo(ComfyTypeI): +class Combo(ComfyTypeIO): Type = str class Input(WidgetInput): """Combo input (dropdown).""" @@ -360,6 +360,14 @@ class Combo(ComfyTypeI): "remote": self.remote.as_dict() if self.remote else None, }) + class Output(Output): + def __init__(self, id: str=None, display_name: str=None, options: list[str]=None, tooltip: str=None, is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + self.options = options if options is not None else [] + + @property + def io_type(self): + return self.options @comfytype(io_type="COMBO") class MultiCombo(ComfyTypeI): From d6b977b2e680e98ad18a37ee13783da4f30e15f4 Mon Sep 17 00:00:00 2001 From: Benjamin Lu Date: Thu, 11 Sep 2025 21:46:01 -0700 Subject: [PATCH 03/30] Bump frontend to 1.26.11 (#9809) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0e21967ef..de5af5fac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.25.11 +comfyui-frontend-package==1.26.11 comfyui-workflow-templates==0.1.81 comfyui-embedded-docs==0.2.6 torch From fd2b820ec28e9575877dc6c51949b2d28dc78894 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 13:03:08 -0700 Subject: [PATCH 04/30] Add noise augmentation to hunyuan image refiner. (#9831) This was missing and should help with colors being blown out. --- comfy/model_base.py | 4 ++++ comfy_extras/nodes_hunyuan.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index c69a9d1ad..8422051bf 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1437,6 +1437,7 @@ class HunyuanImage21Refiner(HunyuanImage21): def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) image = kwargs.get("concat_latent_image", None) + noise_augmentation = kwargs.get("noise_augmentation", 0.0) device = kwargs["device"] if image is None: @@ -1446,6 +1447,9 @@ class HunyuanImage21Refiner(HunyuanImage21): image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") image = self.process_latent_in(image) image = utils.resize_to_batch_size(image, noise.shape[0]) + if noise_augmentation > 0: + noise = torch.randn(image.shape, generator=torch.manual_seed(kwargs.get("seed", 0) - 10), dtype=image.dtype, device="cpu").to(image.device) + image = noise_augmentation * noise + (1.0 - noise_augmentation) * image return image def extra_conds(self, **kwargs): diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 351a7e2cb..db398cdf1 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -134,6 +134,7 @@ class HunyuanRefinerLatent: return {"required": {"positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "latent": ("LATENT", ), + "noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @@ -141,11 +142,10 @@ class HunyuanRefinerLatent: FUNCTION = "execute" - def execute(self, positive, negative, latent): + def execute(self, positive, negative, latent, noise_augmentation): latent = latent["samples"] - - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent}) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) out_latent = {} out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) return (positive, negative, out_latent) From e600520f8aa583c79caa286a8d7d584edc3d059b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 13:35:34 -0700 Subject: [PATCH 05/30] Fix hunyuan refiner blownout colors at noise aug less than 0.25 (#9832) --- comfy/model_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8422051bf..4176bca25 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1449,7 +1449,9 @@ class HunyuanImage21Refiner(HunyuanImage21): image = utils.resize_to_batch_size(image, noise.shape[0]) if noise_augmentation > 0: noise = torch.randn(image.shape, generator=torch.manual_seed(kwargs.get("seed", 0) - 10), dtype=image.dtype, device="cpu").to(image.device) - image = noise_augmentation * noise + (1.0 - noise_augmentation) * image + image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image + else: + image = 0.75 * image return image def extra_conds(self, **kwargs): From 7757d5a657cbe9b22d1f3538ee0bc5387d3f5459 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 13:40:12 -0700 Subject: [PATCH 06/30] Set default hunyuan refiner shift to 4.0 (#9833) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ba1b8c313..472ea0ae9 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1329,7 +1329,7 @@ class HunyuanImage21Refiner(HunyuanVideo): } sampling_settings = { - "shift": 1.0, + "shift": 4.0, } latent_format = latent_formats.HunyuanImage21Refiner From 0aa074a420c450fd7793d83c6f8d66009a1ca2a2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:29:03 +0300 Subject: [PATCH 07/30] add kling-v2-1 model to the KlingStartEndFrame node (#9630) --- comfy_api_nodes/nodes_kling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9fa390985..5f55b2cc9 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -846,6 +846,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"), "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"), "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), + "pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"), + "pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"), } @classmethod From 45bc1f5c00307f3e85871ecfb46acaa2365b0096 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:37:31 +0300 Subject: [PATCH 08/30] convert Minimax API nodes to the V3 schema (#9693) --- comfy_api_nodes/nodes_minimax.py | 732 ++++++++++++++++--------------- 1 file changed, 378 insertions(+), 354 deletions(-) diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index bb3c9e710..bf560661c 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -1,9 +1,10 @@ from inspect import cleandoc -from typing import Union +from typing import Optional import logging import torch -from comfy.comfy_types.node_typing import IO +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api.input_impl.video_types import VideoFromFile from comfy_api_nodes.apis import ( MinimaxVideoGenerationRequest, @@ -11,7 +12,7 @@ from comfy_api_nodes.apis import ( MinimaxFileRetrieveResponse, MinimaxTaskResultResponse, SubjectReferenceItem, - MiniMaxModel + MiniMaxModel, ) from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -31,372 +32,398 @@ from server import PromptServer I2V_AVERAGE_DURATION = 114 T2V_AVERAGE_DURATION = 234 -class MinimaxTextToVideoNode: + +async def _generate_mm_video( + *, + auth: dict[str, str], + node_id: str, + prompt_text: str, + seed: int, + model: str, + image: Optional[torch.Tensor] = None, # used for ImageToVideo + subject: Optional[torch.Tensor] = None, # used for SubjectToVideo + average_duration: Optional[int] = None, +) -> comfy_io.NodeOutput: + if image is None: + validate_string(prompt_text, field_name="prompt_text") + # upload image, if passed in + image_url = None + if image is not None: + image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0] + + # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model + subject_reference = None + if subject is not None: + subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0] + subject_reference = [SubjectReferenceItem(image=subject_url)] + + + video_generate_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/minimax/video_generation", + method=HttpMethod.POST, + request_model=MinimaxVideoGenerationRequest, + response_model=MinimaxVideoGenerationResponse, + ), + request=MinimaxVideoGenerationRequest( + model=MiniMaxModel(model), + prompt=prompt_text, + callback_url=None, + first_frame_image=image_url, + subject_reference=subject_reference, + prompt_optimizer=None, + ), + auth_kwargs=auth, + ) + response = await video_generate_operation.execute() + + task_id = response.task_id + if not task_id: + raise Exception(f"MiniMax generation failed: {response.base_resp}") + + video_generate_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path="/proxy/minimax/query/video_generation", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MinimaxTaskResultResponse, + query_params={"task_id": task_id}, + ), + completed_statuses=["Success"], + failed_statuses=["Fail"], + status_extractor=lambda x: x.status.value, + estimated_duration=average_duration, + node_id=node_id, + auth_kwargs=auth, + ) + task_result = await video_generate_operation.execute() + + file_id = task_result.file_id + if file_id is None: + raise Exception("Request was not successful. Missing file ID.") + file_retrieve_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/minimax/files/retrieve", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MinimaxFileRetrieveResponse, + query_params={"file_id": int(file_id)}, + ), + request=EmptyRequest(), + auth_kwargs=auth, + ) + file_result = await file_retrieve_operation.execute() + + file_url = file_result.file.download_url + if file_url is None: + raise Exception( + f"No video was found in the response. Full response: {file_result.model_dump()}" + ) + logging.info("Generated video URL: %s", file_url) + if node_id: + if hasattr(file_result.file, "backup_download_url"): + message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" + else: + message = f"Result URL: {file_url}" + PromptServer.instance.send_progress_text(message, node_id) + + # Download and return as VideoFromFile + video_io = await download_url_to_bytesio(file_url) + if video_io is None: + error_msg = f"Failed to download video from {file_url}" + logging.error(error_msg) + raise Exception(error_msg) + return comfy_io.NodeOutput(VideoFromFile(video_io)) + + +class MinimaxTextToVideoNode(comfy_io.ComfyNode): """ Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API. """ - AVERAGE_DURATION = T2V_AVERAGE_DURATION + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MinimaxTextToVideoNode", + display_name="MiniMax Text to Video", + category="api node/video/MiniMax", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation", + ), + comfy_io.Combo.Input( + "model", + options=["T2V-01", "T2V-01-Director"], + default="T2V-01", + tooltip="Model to use for video generation", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt_text": ( - "STRING", - { - "multiline": True, - "default": "", - "tooltip": "Text prompt to guide the video generation", - }, - ), - "model": ( - [ - "T2V-01", - "T2V-01-Director", - ], - { - "default": "T2V-01", - "tooltip": "Model to use for video generation", - }, - ), + async def execute( + cls, + prompt_text: str, + model: str = "T2V-01", + seed: int = 0, + ) -> comfy_io.NodeOutput: + return await _generate_mm_video( + auth={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO",) - DESCRIPTION = "Generates videos from prompts using MiniMax's API" - FUNCTION = "generate_video" - CATEGORY = "api node/video/MiniMax" - API_NODE = True - - async def generate_video( - self, - prompt_text, - seed=0, - model="T2V-01", - image: torch.Tensor=None, # used for ImageToVideo - subject: torch.Tensor=None, # used for SubjectToVideo - unique_id: Union[str, None]=None, - **kwargs, - ): - ''' - Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments. - ''' - if image is None: - validate_string(prompt_text, field_name="prompt_text") - # upload image, if passed in - image_url = None - if image is not None: - image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0] - - # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model - subject_reference = None - if subject is not None: - subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0] - subject_reference = [SubjectReferenceItem(image=subject_url)] - - - video_generate_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/video_generation", - method=HttpMethod.POST, - request_model=MinimaxVideoGenerationRequest, - response_model=MinimaxVideoGenerationResponse, - ), - request=MinimaxVideoGenerationRequest( - model=MiniMaxModel(model), - prompt=prompt_text, - callback_url=None, - first_frame_image=image_url, - subject_reference=subject_reference, - prompt_optimizer=None, - ), - auth_kwargs=kwargs, + node_id=cls.hidden.unique_id, + prompt_text=prompt_text, + seed=seed, + model=model, + image=None, + subject=None, + average_duration=T2V_AVERAGE_DURATION, ) - response = await video_generate_operation.execute() - - task_id = response.task_id - if not task_id: - raise Exception(f"MiniMax generation failed: {response.base_resp}") - - video_generate_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/minimax/query/video_generation", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxTaskResultResponse, - query_params={"task_id": task_id}, - ), - completed_statuses=["Success"], - failed_statuses=["Fail"], - status_extractor=lambda x: x.status.value, - estimated_duration=self.AVERAGE_DURATION, - node_id=unique_id, - auth_kwargs=kwargs, - ) - task_result = await video_generate_operation.execute() - - file_id = task_result.file_id - if file_id is None: - raise Exception("Request was not successful. Missing file ID.") - file_retrieve_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/files/retrieve", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxFileRetrieveResponse, - query_params={"file_id": int(file_id)}, - ), - request=EmptyRequest(), - auth_kwargs=kwargs, - ) - file_result = await file_retrieve_operation.execute() - - file_url = file_result.file.download_url - if file_url is None: - raise Exception( - f"No video was found in the response. Full response: {file_result.model_dump()}" - ) - logging.info(f"Generated video URL: {file_url}") - if unique_id: - if hasattr(file_result.file, "backup_download_url"): - message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" - else: - message = f"Result URL: {file_url}" - PromptServer.instance.send_progress_text(message, unique_id) - - video_io = await download_url_to_bytesio(file_url) - if video_io is None: - error_msg = f"Failed to download video from {file_url}" - logging.error(error_msg) - raise Exception(error_msg) - return (VideoFromFile(video_io),) -class MinimaxImageToVideoNode(MinimaxTextToVideoNode): +class MinimaxImageToVideoNode(comfy_io.ComfyNode): """ Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. """ - AVERAGE_DURATION = I2V_AVERAGE_DURATION + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MinimaxImageToVideoNode", + display_name="MiniMax Image to Video", + category="api node/video/MiniMax", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input( + "image", + tooltip="Image to use as first frame of video generation", + ), + comfy_io.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation", + ), + comfy_io.Combo.Input( + "model", + options=["I2V-01-Director", "I2V-01", "I2V-01-live"], + default="I2V-01", + tooltip="Model to use for video generation", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ( - IO.IMAGE, - { - "tooltip": "Image to use as first frame of video generation" - }, - ), - "prompt_text": ( - "STRING", - { - "multiline": True, - "default": "", - "tooltip": "Text prompt to guide the video generation", - }, - ), - "model": ( - [ - "I2V-01-Director", - "I2V-01", - "I2V-01-live", - ], - { - "default": "I2V-01", - "tooltip": "Model to use for video generation", - }, - ), + async def execute( + cls, + image: torch.Tensor, + prompt_text: str, + model: str = "I2V-01", + seed: int = 0, + ) -> comfy_io.NodeOutput: + return await _generate_mm_video( + auth={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO",) - DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API" - FUNCTION = "generate_video" - CATEGORY = "api node/video/MiniMax" - API_NODE = True + node_id=cls.hidden.unique_id, + prompt_text=prompt_text, + seed=seed, + model=model, + image=image, + subject=None, + average_duration=I2V_AVERAGE_DURATION, + ) -class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): +class MinimaxSubjectToVideoNode(comfy_io.ComfyNode): """ Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. """ - AVERAGE_DURATION = T2V_AVERAGE_DURATION + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MinimaxSubjectToVideoNode", + display_name="MiniMax Subject to Video", + category="api node/video/MiniMax", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input( + "subject", + tooltip="Image of subject to reference for video generation", + ), + comfy_io.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation", + ), + comfy_io.Combo.Input( + "model", + options=["S2V-01"], + default="S2V-01", + tooltip="Model to use for video generation", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "subject": ( - IO.IMAGE, - { - "tooltip": "Image of subject to reference video generation" - }, - ), - "prompt_text": ( - "STRING", - { - "multiline": True, - "default": "", - "tooltip": "Text prompt to guide the video generation", - }, - ), - "model": ( - [ - "S2V-01", - ], - { - "default": "S2V-01", - "tooltip": "Model to use for video generation", - }, - ), + async def execute( + cls, + subject: torch.Tensor, + prompt_text: str, + model: str = "S2V-01", + seed: int = 0, + ) -> comfy_io.NodeOutput: + return await _generate_mm_video( + auth={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO",) - DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API" - FUNCTION = "generate_video" - CATEGORY = "api node/video/MiniMax" - API_NODE = True + node_id=cls.hidden.unique_id, + prompt_text=prompt_text, + seed=seed, + model=model, + image=None, + subject=subject, + average_duration=T2V_AVERAGE_DURATION, + ) -class MinimaxHailuoVideoNode: +class MinimaxHailuoVideoNode(comfy_io.ComfyNode): """Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt_text": ( - "STRING", - { - "multiline": True, - "default": "", - "tooltip": "Text prompt to guide the video generation.", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MinimaxHailuoVideoNode", + display_name="MiniMax Hailuo Video", + category="api node/video/MiniMax", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation.", ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, ), - "first_frame_image": ( - IO.IMAGE, - { - "tooltip": "Optional image to use as the first frame to generate a video." - }, + comfy_io.Image.Input( + "first_frame_image", + tooltip="Optional image to use as the first frame to generate a video.", + optional=True, ), - "prompt_optimizer": ( - IO.BOOLEAN, - { - "tooltip": "Optimize prompt to improve generation quality when needed.", - "default": True, - }, + comfy_io.Boolean.Input( + "prompt_optimizer", + default=True, + tooltip="Optimize prompt to improve generation quality when needed.", + optional=True, ), - "duration": ( - IO.COMBO, - { - "tooltip": "The length of the output video in seconds.", - "default": 6, - "options": [6, 10], - }, + comfy_io.Combo.Input( + "duration", + options=[6, 10], + default=6, + tooltip="The length of the output video in seconds.", + optional=True, ), - "resolution": ( - IO.COMBO, - { - "tooltip": "The dimensions of the video display. " - "1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.", - "default": "768P", - "options": ["768P", "1080P"], - }, + comfy_io.Combo.Input( + "resolution", + options=["768P", "1080P"], + default="768P", + tooltip="The dimensions of the video display. 1080p is 1920x1080, 768p is 1366x768.", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt_text: str, + seed: int = 0, + first_frame_image: Optional[torch.Tensor] = None, # used for ImageToVideo + prompt_optimizer: bool = True, + duration: int = 6, + resolution: str = "768P", + model: str = "MiniMax-Hailuo-02", + ) -> comfy_io.NodeOutput: + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } - - RETURN_TYPES = ("VIDEO",) - DESCRIPTION = cleandoc(__doc__ or "") - FUNCTION = "generate_video" - CATEGORY = "api node/video/MiniMax" - API_NODE = True - - async def generate_video( - self, - prompt_text, - seed=0, - first_frame_image: torch.Tensor=None, # used for ImageToVideo - prompt_optimizer=True, - duration=6, - resolution="768P", - model="MiniMax-Hailuo-02", - unique_id: Union[str, None]=None, - **kwargs, - ): if first_frame_image is None: validate_string(prompt_text, field_name="prompt_text") @@ -408,7 +435,7 @@ class MinimaxHailuoVideoNode: # upload image, if passed in image_url = None if first_frame_image is not None: - image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0] + image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0] video_generate_operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -426,7 +453,7 @@ class MinimaxHailuoVideoNode: duration=duration, resolution=resolution, ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response = await video_generate_operation.execute() @@ -447,8 +474,8 @@ class MinimaxHailuoVideoNode: failed_statuses=["Fail"], status_extractor=lambda x: x.status.value, estimated_duration=average_duration, - node_id=unique_id, - auth_kwargs=kwargs, + node_id=cls.hidden.unique_id, + auth_kwargs=auth, ) task_result = await video_generate_operation.execute() @@ -464,7 +491,7 @@ class MinimaxHailuoVideoNode: query_params={"file_id": int(file_id)}, ), request=EmptyRequest(), - auth_kwargs=kwargs, + auth_kwargs=auth, ) file_result = await file_retrieve_operation.execute() @@ -474,34 +501,31 @@ class MinimaxHailuoVideoNode: f"No video was found in the response. Full response: {file_result.model_dump()}" ) logging.info(f"Generated video URL: {file_url}") - if unique_id: + if cls.hidden.unique_id: if hasattr(file_result.file, "backup_download_url"): message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" else: message = f"Result URL: {file_url}" - PromptServer.instance.send_progress_text(message, unique_id) + PromptServer.instance.send_progress_text(message, cls.hidden.unique_id) video_io = await download_url_to_bytesio(file_url) if video_io is None: error_msg = f"Failed to download video from {file_url}" logging.error(error_msg) raise Exception(error_msg) - return (VideoFromFile(video_io),) + return comfy_io.NodeOutput(VideoFromFile(video_io)) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "MinimaxTextToVideoNode": MinimaxTextToVideoNode, - "MinimaxImageToVideoNode": MinimaxImageToVideoNode, - # "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode, - "MinimaxHailuoVideoNode": MinimaxHailuoVideoNode, -} +class MinimaxExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + MinimaxTextToVideoNode, + MinimaxImageToVideoNode, + # MinimaxSubjectToVideoNode, + MinimaxHailuoVideoNode, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "MinimaxTextToVideoNode": "MiniMax Text to Video", - "MinimaxImageToVideoNode": "MiniMax Image to Video", - "MinimaxSubjectToVideoNode": "MiniMax Subject to Video", - "MinimaxHailuoVideoNode": "MiniMax Hailuo Video", -} + +async def comfy_entrypoint() -> MinimaxExtension: + return MinimaxExtension() From f9d2e4b742af9aea3c9ffa822397c1b86cec9304 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:38:12 +0300 Subject: [PATCH 09/30] convert WanCameraEmbedding node to V3 schema (#9714) --- comfy_extras/nodes_camera_trajectory.py | 81 ++++++++++++++++--------- 1 file changed, 51 insertions(+), 30 deletions(-) diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py index 5e0e39f91..eb7ef363c 100644 --- a/comfy_extras/nodes_camera_trajectory.py +++ b/comfy_extras/nodes_camera_trajectory.py @@ -2,12 +2,12 @@ import nodes import torch import numpy as np from einops import rearrange +from typing_extensions import override import comfy.model_management +from comfy_api.latest import ComfyExtension, io -MAX_RESOLUTION = nodes.MAX_RESOLUTION - CAMERA_DICT = { "base_T_norm": 1.5, "base_angle": np.pi/3, @@ -148,32 +148,47 @@ def get_camera_motion(angle, T, speed, n=81): RT = np.stack(RT) return RT -class WanCameraEmbedding: +class WanCameraEmbedding(io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}), - "width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}), - }, - "optional":{ - "speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}), - "fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), - "fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), - "cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), - "cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), - } + def define_schema(cls): + return io.Schema( + node_id="WanCameraEmbedding", + category="camera", + inputs=[ + io.Combo.Input( + "camera_pose", + options=[ + "Static", + "Pan Up", + "Pan Down", + "Pan Left", + "Pan Right", + "Zoom In", + "Zoom Out", + "Anti Clockwise (ACW)", + "ClockWise (CW)", + ], + default="Static", + ), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True), + io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True), + io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True), + io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True), + io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True), + ], + outputs=[ + io.WanCameraEmbedding.Output(display_name="camera_embedding"), + io.Int.Output(display_name="width"), + io.Int.Output(display_name="height"), + io.Int.Output(display_name="length"), + ], + ) - } - - RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT") - RETURN_NAMES = ("camera_embedding","width","height","length") - FUNCTION = "run" - CATEGORY = "camera" - - def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5): + @classmethod + def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput: """ Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021) Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py @@ -210,9 +225,15 @@ class WanCameraEmbedding: control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) - return (control_camera_video, width, height, length) + return io.NodeOutput(control_camera_video, width, height, length) -NODE_CLASS_MAPPINGS = { - "WanCameraEmbedding": WanCameraEmbedding, -} +class CameraTrajectoryExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanCameraEmbedding, + ] + +async def comfy_entrypoint() -> CameraTrajectoryExtension: + return CameraTrajectoryExtension() From dcb883498337bcb2758fa9e7b326ea3b63c6b8d4 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:38:46 +0300 Subject: [PATCH 10/30] convert Cosmos nodes to V3 schema (#9721) --- comfy_extras/nodes_cosmos.py | 129 +++++++++++++++++++---------------- 1 file changed, 72 insertions(+), 57 deletions(-) diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index 4f4960551..7dd129d19 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -1,25 +1,32 @@ +from typing_extensions import override import nodes import torch import comfy.model_management import comfy.utils import comfy.latent_formats +from comfy_api.latest import ComfyExtension, io -class EmptyCosmosLatentVideo: + +class EmptyCosmosLatentVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyCosmosLatentVideo", + category="latent/video", + inputs=[ + io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[io.Latent.Output()], + ) - CATEGORY = "latent/video" - - def generate(self, width, height, length, batch_size=1): + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return ({"samples": latent}, ) + return io.NodeOutput({"samples": latent}) def vae_encode_with_padding(vae, image, width, height, length, padding=0): @@ -33,31 +40,31 @@ def vae_encode_with_padding(vae, image, width, height, length, padding=0): return latent_temp[:, :, :latent_len] -class CosmosImageToVideoLatent: +class CosmosImageToVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"vae": ("VAE", ), - "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CosmosImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[io.Latent.Output()], + ) - - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "conditioning/inpaint" - - def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput: latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is None and end_image is None: out_latent = {} out_latent["samples"] = latent - return (out_latent,) + return io.NodeOutput(out_latent) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) @@ -74,33 +81,33 @@ class CosmosImageToVideoLatent: out_latent = {} out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) - return (out_latent,) + return io.NodeOutput(out_latent) -class CosmosPredict2ImageToVideoLatent: +class CosmosPredict2ImageToVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"vae": ("VAE", ), - "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CosmosPredict2ImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=93, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[io.Latent.Output()], + ) - - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "conditioning/inpaint" - - def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput: latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is None and end_image is None: out_latent = {} out_latent["samples"] = latent - return (out_latent,) + return io.NodeOutput(out_latent) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) @@ -119,10 +126,18 @@ class CosmosPredict2ImageToVideoLatent: latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) - return (out_latent,) + return io.NodeOutput(out_latent) -NODE_CLASS_MAPPINGS = { - "EmptyCosmosLatentVideo": EmptyCosmosLatentVideo, - "CosmosImageToVideoLatent": CosmosImageToVideoLatent, - "CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent, -} + +class CosmosExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyCosmosLatentVideo, + CosmosImageToVideoLatent, + CosmosPredict2ImageToVideoLatent, + ] + + +async def comfy_entrypoint() -> CosmosExtension: + return CosmosExtension() From ba68e83f1c103eb4cb57fe01328706a0574fff3c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:39:30 +0300 Subject: [PATCH 11/30] convert nodes_cond.py to V3 schema (#9719) --- comfy_extras/nodes_cond.py | 75 ++++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index 58c16f621..8b06e3de9 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -1,15 +1,25 @@ +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io -class CLIPTextEncodeControlnet: +class CLIPTextEncodeControlnet(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CLIPTextEncodeControlnet", + category="_for_testing/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Conditioning.Input("conditioning"), + io.String.Input("text", multiline=True, dynamic_prompts=True), + ], + outputs=[io.Conditioning.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/conditioning" - - def encode(self, clip, conditioning, text): + @classmethod + def execute(cls, clip, conditioning, text) -> io.NodeOutput: tokens = clip.tokenize(text) cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) c = [] @@ -18,32 +28,41 @@ class CLIPTextEncodeControlnet: n[1]['cross_attn_controlnet'] = cond n[1]['pooled_output_controlnet'] = pooled c.append(n) - return (c, ) + return io.NodeOutput(c) -class T5TokenizerOptions: +class T5TokenizerOptions(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "clip": ("CLIP", ), - "min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}), - "min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}), - } - } + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="T5TokenizerOptions", + category="_for_testing/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Int.Input("min_padding", default=0, min=0, max=10000, step=1), + io.Int.Input("min_length", default=0, min=0, max=10000, step=1), + ], + outputs=[io.Clip.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/conditioning" - RETURN_TYPES = ("CLIP",) - FUNCTION = "set_options" - - def set_options(self, clip, min_padding, min_length): + @classmethod + def execute(cls, clip, min_padding, min_length) -> io.NodeOutput: clip = clip.clone() for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]: clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding) clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length) - return (clip, ) + return io.NodeOutput(clip) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet, - "T5TokenizerOptions": T5TokenizerOptions, -} + +class CondExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeControlnet, + T5TokenizerOptions, + ] + + +async def comfy_entrypoint() -> CondExtension: + return CondExtension() From 53c9c7d39ad8a459e84a29e46a3e053154ef6013 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:39:55 +0300 Subject: [PATCH 12/30] convert CFG nodes to V3 schema (#9717) --- comfy_extras/nodes_cfg.py | 71 +++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py index 5abdc115a..4ebb4b51e 100644 --- a/comfy_extras/nodes_cfg.py +++ b/comfy_extras/nodes_cfg.py @@ -1,5 +1,10 @@ +from typing_extensions import override + import torch +from comfy_api.latest import ComfyExtension, io + + # https://github.com/WeichenFan/CFG-Zero-star def optimized_scale(positive, negative): positive_flat = positive.reshape(positive.shape[0], -1) @@ -16,17 +21,20 @@ def optimized_scale(positive, negative): return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) -class CFGZeroStar: +class CFGZeroStar(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL",), - }} - RETURN_TYPES = ("MODEL",) - RETURN_NAMES = ("patched_model",) - FUNCTION = "patch" - CATEGORY = "advanced/guidance" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CFGZeroStar", + category="advanced/guidance", + inputs=[ + io.Model.Input("model"), + ], + outputs=[io.Model.Output(display_name="patched_model")], + ) - def patch(self, model): + @classmethod + def execute(cls, model) -> io.NodeOutput: m = model.clone() def cfg_zero_star(args): guidance_scale = args['cond_scale'] @@ -38,21 +46,24 @@ class CFGZeroStar: return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) m.set_model_sampler_post_cfg_function(cfg_zero_star) - return (m, ) + return io.NodeOutput(m) -class CFGNorm: +class CFGNorm(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL",), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - RETURN_NAMES = ("patched_model",) - FUNCTION = "patch" - CATEGORY = "advanced/guidance" - EXPERIMENTAL = True + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CFGNorm", + category="advanced/guidance", + inputs=[ + io.Model.Input("model"), + io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[io.Model.Output(display_name="patched_model")], + is_experimental=True, + ) - def patch(self, model, strength): + @classmethod + def execute(cls, model, strength) -> io.NodeOutput: m = model.clone() def cfg_norm(args): cond_p = args['cond_denoised'] @@ -64,9 +75,17 @@ class CFGNorm: return pred_text_ * scale * strength m.set_model_sampler_post_cfg_function(cfg_norm) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "CFGZeroStar": CFGZeroStar, - "CFGNorm": CFGNorm, -} + +class CfgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CFGZeroStar, + CFGNorm, + ] + + +async def comfy_entrypoint() -> CfgExtension: + return CfgExtension() From af99928f2218fc240dcfb3688ec47317ca146a78 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:40:34 +0300 Subject: [PATCH 13/30] convert Canny node to V3 schema (#9743) --- comfy_extras/nodes_canny.py | 46 +++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index d85e6b856..576f3640a 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -1,25 +1,41 @@ from kornia.filters import canny +from typing_extensions import override + import comfy.model_management +from comfy_api.latest import ComfyExtension, io -class Canny: +class Canny(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "low_threshold": ("FLOAT", {"default": 0.4, "min": 0.01, "max": 0.99, "step": 0.01}), - "high_threshold": ("FLOAT", {"default": 0.8, "min": 0.01, "max": 0.99, "step": 0.01}) - }} + def define_schema(cls): + return io.Schema( + node_id="Canny", + category="image/preprocessors", + inputs=[ + io.Image.Input("image"), + io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01), + io.Float.Input("high_threshold", default=0.8, min=0.01, max=0.99, step=0.01), + ], + outputs=[io.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "detect_edge" + @classmethod + def detect_edge(cls, image, low_threshold, high_threshold): + # Deprecated: use the V3 schema's `execute` method instead of this. + return cls.execute(image, low_threshold, high_threshold) - CATEGORY = "image/preprocessors" - - def detect_edge(self, image, low_threshold, high_threshold): + @classmethod + def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput: output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) - return (img_out,) + return io.NodeOutput(img_out) -NODE_CLASS_MAPPINGS = { - "Canny": Canny, -} + +class CannyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [Canny] + + +async def comfy_entrypoint() -> CannyExtension: + return CannyExtension() From 581bae2af30b0839a39734bd97006c4009f9d70a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:41:26 +0300 Subject: [PATCH 14/30] convert Moonvalley API nodes to the V3 schema (#9698) --- comfy_api_nodes/nodes_moonvalley.py | 572 +++++++++++++++------------- 1 file changed, 298 insertions(+), 274 deletions(-) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 806a70e06..08e838fef 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,6 +1,7 @@ import logging from typing import Any, Callable, Optional, TypeVar import torch +from typing_extensions import override from comfy_api_nodes.util.validation_utils import ( get_image_dimensions, validate_image_dimensions, @@ -26,11 +27,9 @@ from comfy_api_nodes.apinode_utils import ( upload_images_to_comfyapi, upload_video_to_comfyapi, ) -from comfy_api_nodes.mapper_utils import model_field_to_node_input -from comfy_api.input.video_types import VideoInput -from comfy.comfy_types.node_typing import IO -from comfy_api.input_impl import VideoFromFile +from comfy_api.input import VideoInput +from comfy_api.latest import ComfyExtension, InputImpl, io as comfy_io import av import io @@ -362,7 +361,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: # Return as VideoFromFile using the buffer output_buffer.seek(0) - return VideoFromFile(output_buffer) + return InputImpl.VideoFromFile(output_buffer) except Exception as e: # Clean up on error @@ -373,166 +372,150 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: raise RuntimeError(f"Failed to trim video: {str(e)}") from e -# --- BaseMoonvalleyVideoNode --- -class BaseMoonvalleyVideoNode: - def parseWidthHeightFromRes(self, resolution: str): - # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict - res_map = { - "16:9 (1920 x 1080)": {"width": 1920, "height": 1080}, - "9:16 (1080 x 1920)": {"width": 1080, "height": 1920}, - "1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, - "4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, - "3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, - "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, - } - if resolution in res_map: - return res_map[resolution] - else: - # Default to 1920x1080 if unknown - return {"width": 1920, "height": 1080} +def parse_width_height_from_res(resolution: str): + # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict + res_map = { + "16:9 (1920 x 1080)": {"width": 1920, "height": 1080}, + "9:16 (1080 x 1920)": {"width": 1080, "height": 1920}, + "1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, + "4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, + "3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, + "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, + } + return res_map.get(resolution, {"width": 1920, "height": 1080}) - def parseControlParameter(self, value): - control_map = { - "Motion Transfer": "motion_control", - "Canny": "canny_control", - "Pose Transfer": "pose_control", - "Depth": "depth_control", - } - if value in control_map: - return control_map[value] - else: - return control_map["Motion Transfer"] - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> MoonvalleyPromptResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{API_PROMPTS_ENDPOINT}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MoonvalleyPromptResponse, - ), - result_url_extractor=get_video_url_from_response, - node_id=node_id, - ) +def parse_control_parameter(value): + control_map = { + "Motion Transfer": "motion_control", + "Canny": "canny_control", + "Pose Transfer": "pose_control", + "Depth": "depth_control", + } + return control_map.get(value, control_map["Motion Transfer"]) + + +async def get_response( + task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None +) -> MoonvalleyPromptResponse: + return await poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=f"{API_PROMPTS_ENDPOINT}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MoonvalleyPromptResponse, + ), + result_url_extractor=get_video_url_from_response, + node_id=node_id, + ) + + +class MoonvalleyImg2VideoNode(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, - MoonvalleyTextToVideoRequest, - "prompt_text", + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MoonvalleyImg2VideoNode", + display_name="Moonvalley Marey Image to Video", + category="api node/video/Moonvalley Marey", + description="Moonvalley Marey Image to Video Node", + inputs=[ + comfy_io.Image.Input( + "image", + tooltip="The reference image used to generate the video", + ), + comfy_io.String.Input( + "prompt", multiline=True, ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - MoonvalleyTextToVideoInferenceParams, + comfy_io.String.Input( "negative_prompt", multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring", + default=" gopro, bright, contrast, static, overexposed, vignette, " + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + tooltip="Negative prompt text", ), - "resolution": ( - IO.COMBO, - { - "options": [ - "16:9 (1920 x 1080)", - "9:16 (1080 x 1920)", - "1:1 (1152 x 1152)", - "4:3 (1440 x 1080)", - "3:4 (1080 x 1440)", - "21:9 (2560 x 1080)", - ], - "default": "16:9 (1920 x 1080)", - "tooltip": "Resolution of the output video", - }, + comfy_io.Combo.Input( + "resolution", + options=[ + "16:9 (1920 x 1080)", + "9:16 (1080 x 1920)", + "1:1 (1152 x 1152)", + "4:3 (1536 x 1152)", + "3:4 (1152 x 1536)", + "21:9 (2560 x 1080)", + ], + default="16:9 (1920 x 1080)", + tooltip="Resolution of the output video", ), - "prompt_adherence": model_field_to_node_input( - IO.FLOAT, - MoonvalleyTextToVideoInferenceParams, - "guidance_scale", + comfy_io.Float.Input( + "prompt_adherence", default=10.0, - step=1, - min=1, - max=20, + min=1.0, + max=20.0, + step=1.0, + tooltip="Guidance scale for generation control", ), - "seed": model_field_to_node_input( - IO.INT, - MoonvalleyTextToVideoInferenceParams, + comfy_io.Int.Input( "seed", default=9, min=0, max=4294967295, step=1, - display="number", + display_mode=comfy_io.NumberDisplay.number, tooltip="Random seed value", ), - "steps": model_field_to_node_input( - IO.INT, - MoonvalleyTextToVideoInferenceParams, + comfy_io.Int.Input( "steps", default=100, min=1, max=100, + step=1, + tooltip="Number of denoising steps", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - "optional": { - "image": model_field_to_node_input( - IO.IMAGE, - MoonvalleyTextToVideoRequest, - "image_url", - tooltip="The reference image used to generate the video", - ), - }, - } - - RETURN_TYPES = ("STRING",) - FUNCTION = "generate" - CATEGORY = "api node/video/Moonvalley Marey" - API_NODE = True - - def generate(self, **kwargs): - return None - - -# --- MoonvalleyImg2VideoNode --- -class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls): - return super().INPUT_TYPES() - - RETURN_TYPES = ("VIDEO",) - RETURN_NAMES = ("video",) - DESCRIPTION = "Moonvalley Marey Image to Video Node" - - async def generate( - self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs - ): - image = kwargs.get("image", None) - if image is None: - raise MoonvalleyApiError("image is required") - + async def execute( + cls, + image: torch.Tensor, + prompt: str, + negative_prompt: str, + resolution: str, + prompt_adherence: float, + seed: int, + steps: int, + ) -> comfy_io.NodeOutput: validate_input_image(image, True) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - width_height = self.parseWidthHeightFromRes(kwargs.get("resolution")) + width_height = parse_width_height_from_res(resolution) + + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, - steps=kwargs.get("steps"), - seed=kwargs.get("seed"), - guidance_scale=kwargs.get("prompt_adherence"), + steps=steps, + seed=seed, + guidance_scale=prompt_adherence, num_frames=128, - width=width_height.get("width"), - height=width_height.get("height"), + width=width_height["width"], + height=width_height["height"], use_negative_prompts=True, ) """Upload image to comfy backend to have a URL available for further processing""" @@ -541,7 +524,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): image_url = ( await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type + image, max_images=1, auth_kwargs=auth, mime_type=mime_type ) )[0] @@ -556,127 +539,102 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): response_model=MoonvalleyPromptResponse, ), request=request, - auth_kwargs=kwargs, + auth_kwargs=auth, ) task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await get_response( + task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id ) video = await download_url_to_video_output(final_response.output_url) - return (video,) + return comfy_io.NodeOutput(video) -# --- MoonvalleyVid2VidNode --- -class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): - def __init__(self): - super().__init__() +class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, - MoonvalleyVideoToVideoRequest, - "prompt_text", + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MoonvalleyVideo2VideoNode", + display_name="Moonvalley Marey Video to Video", + category="api node/video/Moonvalley Marey", + description="", + inputs=[ + comfy_io.String.Input( + "prompt", multiline=True, + tooltip="Describes the video to generate", ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - MoonvalleyVideoToVideoInferenceParams, + comfy_io.String.Input( "negative_prompt", multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring", + default=" gopro, bright, contrast, static, overexposed, vignette, " + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + tooltip="Negative prompt text", ), - "seed": model_field_to_node_input( - IO.INT, - MoonvalleyVideoToVideoInferenceParams, + comfy_io.Int.Input( "seed", default=9, min=0, max=4294967295, step=1, - display="number", + display_mode=comfy_io.NumberDisplay.number, tooltip="Random seed value", control_after_generate=False, ), - "prompt_adherence": model_field_to_node_input( - IO.FLOAT, - MoonvalleyVideoToVideoInferenceParams, - "guidance_scale", - default=10.0, + comfy_io.Video.Input( + "video", + tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. " + "Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", + ), + comfy_io.Combo.Input( + "control_type", + options=["Motion Transfer", "Pose Transfer"], + default="Motion Transfer", + optional=True, + ), + comfy_io.Int.Input( + "motion_intensity", + default=100, + min=0, + max=100, step=1, - min=1, - max=20, + tooltip="Only used if control_type is 'Motion Transfer'", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - "optional": { - "video": ( - IO.VIDEO, - { - "default": "", - "multiline": False, - "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", - }, - ), - "control_type": ( - ["Motion Transfer", "Pose Transfer"], - {"default": "Motion Transfer"}, - ), - "motion_intensity": ( - "INT", - { - "default": 100, - "step": 1, - "min": 0, - "max": 100, - "tooltip": "Only used if control_type is 'Motion Transfer'", - }, - ), - "image": model_field_to_node_input( - IO.IMAGE, - MoonvalleyTextToVideoRequest, - "image_url", - tooltip="The reference image used to generate the video", - ), - }, + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + seed: int, + video: Optional[VideoInput] = None, + control_type: str = "Motion Transfer", + motion_intensity: Optional[int] = 100, + ) -> comfy_io.NodeOutput: + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } - RETURN_TYPES = ("VIDEO",) - RETURN_NAMES = ("video",) - - async def generate( - self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs - ): - video = kwargs.get("video") - image = kwargs.get("image", None) - - if not video: - raise MoonvalleyApiError("video is required") - - video_url = "" - if video: - validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi( - validated_video, auth_kwargs=kwargs - ) - mime_type = "image/png" - - if not image is None: - validate_input_image(image, with_frame_conditioning=True) - image_url = await upload_images_to_comfyapi( - image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type - ) - control_type = kwargs.get("control_type") - motion_intensity = kwargs.get("motion_intensity") + validated_video = validate_video_to_video_input(video) + video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth) """Validate prompts and inference input""" validate_prompts(prompt, negative_prompt) @@ -688,11 +646,11 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): inference_params = MoonvalleyVideoToVideoInferenceParams( negative_prompt=negative_prompt, - seed=kwargs.get("seed"), + seed=seed, control_params=control_params, ) - control = self.parseControlParameter(control_type) + control = parse_control_parameter(control_type) request = MoonvalleyVideoToVideoRequest( control_type=control, @@ -700,7 +658,6 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): prompt_text=prompt, inference_params=inference_params, ) - request.image_url = image_url if not image is None else None initial_operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -710,58 +667,125 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): response_model=MoonvalleyPromptResponse, ), request=request, - auth_kwargs=kwargs, + auth_kwargs=auth, ) task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await get_response( + task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id ) video = await download_url_to_video_output(final_response.output_url) - - return (video,) + return comfy_io.NodeOutput(video) -# --- MoonvalleyTxt2VideoNode --- -class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): - def __init__(self): - super().__init__() - - RETURN_TYPES = ("VIDEO",) - RETURN_NAMES = ("video",) +class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - input_types = super().INPUT_TYPES() - # Remove image-specific parameters - for param in ["image"]: - if param in input_types["optional"]: - del input_types["optional"][param] - return input_types + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="MoonvalleyTxt2VideoNode", + display_name="Moonvalley Marey Text to Video", + category="api node/video/Moonvalley Marey", + description="", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default=" gopro, bright, contrast, static, overexposed, vignette, " + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + tooltip="Negative prompt text", + ), + comfy_io.Combo.Input( + "resolution", + options=[ + "16:9 (1920 x 1080)", + "9:16 (1080 x 1920)", + "1:1 (1152 x 1152)", + "4:3 (1536 x 1152)", + "3:4 (1152 x 1536)", + "21:9 (2560 x 1080)", + ], + default="16:9 (1920 x 1080)", + tooltip="Resolution of the output video", + ), + comfy_io.Float.Input( + "prompt_adherence", + default=10.0, + min=1.0, + max=20.0, + step=1.0, + tooltip="Guidance scale for generation control", + ), + comfy_io.Int.Input( + "seed", + default=9, + min=0, + max=4294967295, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Random seed value", + ), + comfy_io.Int.Input( + "steps", + default=100, + min=1, + max=100, + step=1, + tooltip="Inference steps", + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - async def generate( - self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs - ): + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + resolution: str, + prompt_adherence: float, + seed: int, + steps: int, + ) -> comfy_io.NodeOutput: validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - width_height = self.parseWidthHeightFromRes(kwargs.get("resolution")) + width_height = parse_width_height_from_res(resolution) + + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, - steps=kwargs.get("steps"), - seed=kwargs.get("seed"), - guidance_scale=kwargs.get("prompt_adherence"), + steps=steps, + seed=seed, + guidance_scale=prompt_adherence, num_frames=128, - width=width_height.get("width"), - height=width_height.get("height"), + width=width_height["width"], + height=width_height["height"], ) request = MoonvalleyTextToVideoRequest( prompt_text=prompt, inference_params=inference_params ) - initial_operation = SynchronousOperation( + init_op = SynchronousOperation( endpoint=ApiEndpoint( path=API_TXT2VIDEO_ENDPOINT, method=HttpMethod.POST, @@ -769,29 +793,29 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): response_model=MoonvalleyPromptResponse, ), request=request, - auth_kwargs=kwargs, + auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() + task_creation_response = await init_op.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await get_response( + task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id ) video = await download_url_to_video_output(final_response.output_url) - return (video,) + return comfy_io.NodeOutput(video) -NODE_CLASS_MAPPINGS = { - "MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode, - "MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode, - "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode, -} +class MoonvalleyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + MoonvalleyImg2VideoNode, + MoonvalleyTxt2VideoNode, + MoonvalleyVideo2VideoNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video", - "MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video", - "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video", -} +async def comfy_entrypoint() -> MoonvalleyExtension: + return MoonvalleyExtension() From b149e2e1e302e75ce5b47e9b823b42b304d70b4b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 14:53:15 -0700 Subject: [PATCH 15/30] Better way of doing the generator for the hunyuan image noise aug. (#9834) --- comfy/model_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 4176bca25..324d89cff 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1448,7 +1448,9 @@ class HunyuanImage21Refiner(HunyuanImage21): image = self.process_latent_in(image) image = utils.resize_to_batch_size(image, noise.shape[0]) if noise_augmentation > 0: - noise = torch.randn(image.shape, generator=torch.manual_seed(kwargs.get("seed", 0) - 10), dtype=image.dtype, device="cpu").to(image.device) + generator = torch.Generator(device="cpu") + generator.manual_seed(kwargs.get("seed", 0) - 10) + noise = torch.randn(image.shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device) image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image else: image = 0.75 * image From d7f40442f91a02946cab7445c6204bf154b1e86f Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 12 Sep 2025 15:07:38 -0700 Subject: [PATCH 16/30] Enable Runtime Selection of Attention Functions (#9639) * Looking into a @wrap_attn decorator to look for 'optimized_attention_override' entry in transformer_options * Created logging code for this branch so that it can be used to track down all the code paths where transformer_options would need to be added * Fix memory usage issue with inspect * Made WAN attention receive transformer_options, test node added to wan to test out attention override later * Added **kwargs to all attention functions so transformer_options could potentially be passed through * Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention * Turn off attention logging for now, make AttentionOverrideTestNode have a dropdown with available attention (this is a test node only) * Make flux work with optimized_attention_override * Add logs to verify optimized_attention_override is passed all the way into attention function * Make Qwen work with optimized_attention_override * Made hidream work with optimized_attention_override * Made wan patches_replace work with optimized_attention_override * Made SD3 work with optimized_attention_override * Made HunyuanVideo work with optimized_attention_override * Made Mochi work with optimized_attention_override * Made LTX work with optimized_attention_override * Made StableAudio work with optimized_attention_override * Made optimized_attention_override work with ACE Step * Made Hunyuan3D work with optimized_attention_override * Make CosmosPredict2 work with optimized_attention_override * Made CosmosVideo work with optimized_attention_override * Made Omnigen 2 work with optimized_attention_override * Made StableCascade work with optimized_attention_override * Made AuraFlow work with optimized_attention_override * Made Lumina work with optimized_attention_override * Made Chroma work with optimized_attention_override * Made SVD work with optimized_attention_override * Fix WanI2VCrossAttention so that it expects to receive transformer_options * Fixed Wan2.1 Fun Camera transformer_options passthrough * Fixed WAN 2.1 VACE transformer_options passthrough * Add optimized to get_attention_function * Disable attention logs for now * Remove attention logging code * Remove _register_core_attention_functions, as we wouldn't want someone to call that, just in case * Satisfy ruff * Remove AttentionOverrideTest node, that's something to cook up for later --- comfy/ldm/ace/attention.py | 9 +- comfy/ldm/ace/model.py | 4 + comfy/ldm/audio/dit.py | 25 ++-- comfy/ldm/aura/mmdit.py | 29 ++--- comfy/ldm/cascade/common.py | 12 +- comfy/ldm/cascade/stage_b.py | 14 +-- comfy/ldm/cascade/stage_c.py | 14 +-- comfy/ldm/chroma/layers.py | 8 +- comfy/ldm/chroma/model.py | 17 ++- comfy/ldm/cosmos/blocks.py | 10 +- comfy/ldm/cosmos/model.py | 2 + comfy/ldm/cosmos/predict2.py | 17 ++- comfy/ldm/flux/layers.py | 10 +- comfy/ldm/flux/math.py | 4 +- comfy/ldm/flux/model.py | 17 ++- .../genmo/joint_model/asymm_models_joint.py | 11 +- comfy/ldm/hidream/model.py | 18 ++- comfy/ldm/hunyuan3d/model.py | 17 ++- comfy/ldm/hunyuan_video/model.py | 25 ++-- comfy/ldm/lightricks/model.py | 19 +-- comfy/ldm/lumina/model.py | 17 ++- comfy/ldm/modules/attention.py | 114 +++++++++++++----- comfy/ldm/modules/diffusionmodules/mmdit.py | 9 +- comfy/ldm/omnigen/omnigen2.py | 23 ++-- comfy/ldm/qwen_image/model.py | 12 +- comfy/ldm/wan/model.py | 38 +++--- 26 files changed, 316 insertions(+), 179 deletions(-) diff --git a/comfy/ldm/ace/attention.py b/comfy/ldm/ace/attention.py index f20a01669..670eb9783 100644 --- a/comfy/ldm/ace/attention.py +++ b/comfy/ldm/ace/attention.py @@ -133,6 +133,7 @@ class Attention(nn.Module): hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + transformer_options={}, **cross_attention_kwargs, ) -> torch.Tensor: return self.processor( @@ -140,6 +141,7 @@ class Attention(nn.Module): hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + transformer_options=transformer_options, **cross_attention_kwargs, ) @@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0: encoder_attention_mask: Optional[torch.FloatTensor] = None, rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, + transformer_options={}, *args, **kwargs, ) -> torch.Tensor: @@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0: # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = optimized_attention( - query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, + query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options, ).to(query.dtype) # linear proj @@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module): rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, temb: torch.FloatTensor = None, + transformer_options={}, ): N = hidden_states.shape[0] @@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, + transformer_options=transformer_options, ) else: attn_output, _ = self.attn( @@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=None, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=None, + transformer_options=transformer_options, ) if self.use_adaln_single: @@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, + transformer_options=transformer_options, ) hidden_states = attn_output + hidden_states diff --git a/comfy/ldm/ace/model.py b/comfy/ldm/ace/model.py index 41d85eeb5..399329853 100644 --- a/comfy/ldm/ace/model.py +++ b/comfy/ldm/ace/model.py @@ -314,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length: int = 0, block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, controlnet_scale: Union[float, torch.Tensor] = 1.0, + transformer_options={}, ): embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype)) temb = self.t_block(embedded_timestep) @@ -339,6 +340,7 @@ class ACEStepTransformer2DModel(nn.Module): rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=encoder_rotary_freqs_cis, temb=temb, + transformer_options=transformer_options, ) output = self.final_layer(hidden_states, embedded_timestep, output_length) @@ -393,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length = hidden_states.shape[-1] + transformer_options = kwargs.get("transformer_options", {}) output = self.decode( hidden_states=hidden_states, attention_mask=attention_mask, @@ -402,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length=output_length, block_controlnet_hidden_states=block_controlnet_hidden_states, controlnet_scale=controlnet_scale, + transformer_options=transformer_options, ) return output diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index d0d69bbdc..ca865189e 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -298,7 +298,8 @@ class Attention(nn.Module): mask = None, context_mask = None, rotary_pos_emb = None, - causal = None + causal = None, + transformer_options={}, ): h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None @@ -363,7 +364,7 @@ class Attention(nn.Module): heads_per_kv_head = h // kv_h k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) - out = optimized_attention(q, k, v, h, skip_reshape=True) + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) out = self.to_out(out) if mask is not None: @@ -488,7 +489,8 @@ class TransformerBlock(nn.Module): global_cond=None, mask = None, context_mask = None, - rotary_pos_emb = None + rotary_pos_emb = None, + transformer_options={} ): if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: @@ -498,12 +500,12 @@ class TransformerBlock(nn.Module): residual = x x = self.pre_norm(x) x = x * (1 + scale_self) + shift_self - x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) x = x * torch.sigmoid(1 - gate_self) x = x + residual if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -517,10 +519,10 @@ class TransformerBlock(nn.Module): x = x + residual else: - x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module): return_info = False, **kwargs ): - patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {}) + transformer_options = kwargs.get("transformer_options", {}) + patches_replace = transformer_options.get("patches_replace", {}) batch, seq, device = *x.shape[:2], x.device context = kwargs["context"] @@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"]) + out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context) + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options) # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) if return_info: diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index d7f32b5e8..66d9613b6 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -85,7 +85,7 @@ class SingleAttention(nn.Module): ) #@torch.compile() - def forward(self, c): + def forward(self, c, transformer_options={}): bsz, seqlen1, _ = c.shape @@ -95,7 +95,7 @@ class SingleAttention(nn.Module): v = v.view(bsz, seqlen1, self.n_heads, self.head_dim) q, k = self.q_norm1(q), self.k_norm1(k) - output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options) c = self.w1o(output) return c @@ -144,7 +144,7 @@ class DoubleAttention(nn.Module): #@torch.compile() - def forward(self, c, x): + def forward(self, c, x, transformer_options={}): bsz, seqlen1, _ = c.shape bsz, seqlen2, _ = x.shape @@ -168,7 +168,7 @@ class DoubleAttention(nn.Module): torch.cat([cv, xv], dim=1), ) - output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options) c, x = output.split([seqlen1, seqlen2], dim=1) c = self.w1o(c) @@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module): self.is_last = is_last #@torch.compile() - def forward(self, c, x, global_cond, **kwargs): + def forward(self, c, x, global_cond, transformer_options={}, **kwargs): cres, xres = c, x @@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module): x = modulate(self.normX1(x), xshift_msa, xscale_msa) # attention - c, x = self.attn(c, x) + c, x = self.attn(c, x, transformer_options=transformer_options) c = self.normC2(cres + cgate_msa.unsqueeze(1) * c) @@ -255,13 +255,13 @@ class DiTBlock(nn.Module): self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) #@torch.compile() - def forward(self, cx, global_cond, **kwargs): + def forward(self, cx, global_cond, transformer_options={}, **kwargs): cxres = cx shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX( global_cond ).chunk(6, dim=1) cx = modulate(self.norm1(cx), shift_msa, scale_msa) - cx = self.attn(cx) + cx = self.attn(cx, transformer_options=transformer_options) cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx) mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) cx = gate_mlp.unsqueeze(1) * mlpout @@ -473,13 +473,14 @@ class MMDiT(nn.Module): out = {} out["txt"], out["img"] = layer(args["txt"], args["img"], - args["vec"]) + args["vec"], + transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap}) c = out["txt"] x = out["img"] else: - c, x = layer(c, x, global_cond, **kwargs) + c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs) if len(self.single_layers) > 0: c_len = c.size(1) @@ -488,13 +489,13 @@ class MMDiT(nn.Module): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = layer(args["img"], args["vec"]) + out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap}) cx = out["img"] else: - cx = layer(cx, global_cond, **kwargs) + cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs) x = cx[:, c_len:] diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py index 3eaa0c821..42ef98c7a 100644 --- a/comfy/ldm/cascade/common.py +++ b/comfy/ldm/cascade/common.py @@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module): self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) - def forward(self, q, k, v): + def forward(self, q, k, v, transformer_options={}): q = self.to_q(q) k = self.to_k(k) v = self.to_v(v) - out = optimized_attention(q, k, v, self.heads) + out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options) return self.out_proj(out) @@ -47,13 +47,13 @@ class Attention2D(nn.Module): self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) - def forward(self, x, kv, self_attn=False): + def forward(self, x, kv, self_attn=False, transformer_options={}): orig_shape = x.shape x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 if self_attn: kv = torch.cat([x, kv], dim=1) # x = self.attn(x, kv, kv, need_weights=False)[0] - x = self.attn(x, kv, kv) + x = self.attn(x, kv, kv, transformer_options=transformer_options) x = x.permute(0, 2, 1).view(*orig_shape) return x @@ -114,9 +114,9 @@ class AttnBlock(nn.Module): operations.Linear(c_cond, c, dtype=dtype, device=device) ) - def forward(self, x, kv): + def forward(self, x, kv, transformer_options={}): kv = self.kv_mapper(kv) - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options) return x diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py index 773830956..428c67fdf 100644 --- a/comfy/ldm/cascade/stage_b.py +++ b/comfy/ldm/cascade/stage_b.py @@ -173,7 +173,7 @@ class StageB(nn.Module): clip = self.clip_norm(clip) return clip - def _down_encode(self, x, r_embed, clip): + def _down_encode(self, x, r_embed, clip, transformer_options={}): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: @@ -187,7 +187,7 @@ class StageB(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -199,7 +199,7 @@ class StageB(nn.Module): level_outputs.insert(0, x) return level_outputs - def _up_decode(self, level_outputs, r_embed, clip): + def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) for i, (up_block, upscaler, repmap) in enumerate(block_group): @@ -216,7 +216,7 @@ class StageB(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -228,7 +228,7 @@ class StageB(nn.Module): x = upscaler(x) return x - def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs): if pixels is None: pixels = x.new_zeros(x.size(0), 3, 8, 8) @@ -245,8 +245,8 @@ class StageB(nn.Module): nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', align_corners=True) - level_outputs = self._down_encode(x, r_embed, clip) - x = self._up_decode(level_outputs, r_embed, clip) + level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options) + x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options) return self.clf(x) def update_weights_ema(self, src_model, beta=0.999): diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py index b952d0349..ebc4434e2 100644 --- a/comfy/ldm/cascade/stage_c.py +++ b/comfy/ldm/cascade/stage_c.py @@ -182,7 +182,7 @@ class StageC(nn.Module): clip = self.clip_norm(clip) return clip - def _down_encode(self, x, r_embed, clip, cnet=None): + def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: @@ -201,7 +201,7 @@ class StageC(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -213,7 +213,7 @@ class StageC(nn.Module): level_outputs.insert(0, x) return level_outputs - def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) for i, (up_block, upscaler, repmap) in enumerate(block_group): @@ -235,7 +235,7 @@ class StageC(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -247,7 +247,7 @@ class StageC(nn.Module): x = upscaler(x) return x - def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs): + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs): # Process the conditioning embeddings r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) for c in self.t_conds: @@ -262,8 +262,8 @@ class StageC(nn.Module): # Model Blocks x = self.embedding(x) - level_outputs = self._down_encode(x, r_embed, clip, cnet) - x = self._up_decode(level_outputs, r_embed, clip, cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options) + x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options) return self.clf(x) def update_weights_ema(self, src_model, beta=0.999): diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 2a0dec606..fc7110cce 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None): + def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}): (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec # prepare image for attention @@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module): attn = attention(torch.cat((txt_q, img_q), dim=2), torch.cat((txt_k, img_k), dim=2), torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] @@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") - def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: + def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor: mod = vec x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module): q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask) + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x.addcmul_(mod.gate, output) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 5cff44dc8..4f709f87d 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -193,14 +193,16 @@ class Chroma(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": double_mod, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -209,7 +211,8 @@ class Chroma(nn.Module): txt=txt, vec=double_mod, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -229,17 +232,19 @@ class Chroma(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": single_mod, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask) + img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 5c4356a3f..afb43d469 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -176,6 +176,7 @@ class Attention(nn.Module): context=None, mask=None, rope_emb=None, + transformer_options={}, **kwargs, ): """ @@ -184,7 +185,7 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) - out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) + out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options) del q, k, v out = rearrange(out, " b n s c -> s b (n c)") return self.to_out(out) @@ -546,6 +547,7 @@ class VideoAttn(nn.Module): context: Optional[torch.Tensor] = None, crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for video attention. @@ -571,6 +573,7 @@ class VideoAttn(nn.Module): context_M_B_D, crossattn_mask, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) return x_T_H_W_B_D @@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for dynamically configured blocks with adaptive normalization. @@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module): adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), context=None, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) elif self.block_type in ["cross_attn", "ca"]: x = x + gate_1_1_1_B_D * self.block( @@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module): context=crossattn_emb, crossattn_mask=crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) else: raise ValueError(f"Unknown block type: {self.block_type}") @@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: for block in self.blocks: x = block( @@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) return x diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 53698b758..52ef7ef43 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -520,6 +520,7 @@ class GeneralDIT(nn.Module): x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + transformer_options = kwargs.get("transformer_options", {}) for _, block in self.blocks.items(): assert ( self.blocks["block0"].x_format == block.x_format @@ -534,6 +535,7 @@ class GeneralDIT(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index fcc83ba76..07a4fc79f 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -44,7 +44,7 @@ class GPT2FeedForward(nn.Module): return x -def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: """Computes multi-head attention using PyTorch's native implementation. This function provides a PyTorch backend alternative to Transformer Engine's attention operation. @@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) - return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True) + return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options) class Attention(nn.Module): @@ -180,8 +180,8 @@ class Attention(nn.Module): return q, k, v - def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - result = self.attn_op(q, k, v) # [B, S, H, D] + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: + result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D] return self.output_dropout(self.output_proj(result)) def forward( @@ -189,6 +189,7 @@ class Attention(nn.Module): x: torch.Tensor, context: Optional[torch.Tensor] = None, rope_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Args: @@ -196,7 +197,7 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) - return self.compute_attention(q, k, v) + return self.compute_attention(q, k, v, transformer_options=transformer_options) class Timesteps(nn.Module): @@ -459,6 +460,7 @@ class Block(nn.Module): rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb @@ -512,6 +514,7 @@ class Block(nn.Module): rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), None, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ), "b (t h w) d -> b t h w d", t=T, @@ -525,6 +528,7 @@ class Block(nn.Module): layer_norm_cross_attn: Callable, _scale_cross_attn_B_T_1_1_D: torch.Tensor, _shift_cross_attn_B_T_1_1_D: torch.Tensor, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: _normalized_x_B_T_H_W_D = _fn( _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D @@ -534,6 +538,7 @@ class Block(nn.Module): rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), crossattn_emb, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ), "b (t h w) d -> b t h w d", t=T, @@ -547,6 +552,7 @@ class Block(nn.Module): self.layer_norm_cross_attn, scale_cross_attn_B_T_1_1_D, shift_cross_attn_B_T_1_1_D, + transformer_options=transformer_options, ) x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D @@ -865,6 +871,7 @@ class MiniTrainDIT(nn.Module): "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), "adaln_lora_B_T_3D": adaln_lora_B_T_3D, "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + "transformer_options": kwargs.get("transformer_options", {}), } for block in self.blocks: x_B_T_H_W_D = block( diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 113eb2096..ef21b416b 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module): attn = attention(torch.cat((img_q, txt_q), dim=2), torch.cat((img_k, txt_k), dim=2), torch.cat((img_v, txt_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] else: @@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module): attn = attention(torch.cat((txt_q, img_q), dim=2), torch.cat((txt_k, img_k), dim=2), torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] @@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor: mod, _ = self.modulation(vec) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module): q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask) + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x += apply_mod(output, mod.gate, None, modulation_dims) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 3e0978176..4d743cda2 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: q_shape = q.shape k_shape = k.shape @@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) heads = q.shape[1] - x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) + x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 8ea7d4f57..14f90cea5 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -144,14 +144,16 @@ class Flux(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -160,7 +162,8 @@ class Flux(nn.Module): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -181,17 +184,19 @@ class Flux(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py index 366a8b713..5c1bb4d42 100644 --- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py +++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py @@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module): scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. crop_y, + transformer_options={}, **rope_rotation, ) -> Tuple[torch.Tensor, torch.Tensor]: rope_cos = rope_rotation.get("rope_cos") @@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module): xy = optimized_attention(q, k, - v, self.num_heads, skip_reshape=True) + v, self.num_heads, skip_reshape=True, transformer_options=transformer_options) x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1) x = self.proj_x(x) @@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module): x: torch.Tensor, c: torch.Tensor, y: torch.Tensor, + transformer_options={}, **attn_kwargs, ): """Forward pass of a block. @@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module): y, scale_x=scale_msa_x, scale_y=scale_msa_y, + transformer_options=transformer_options, **attn_kwargs, ) @@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module): args["txt"], rope_cos=args["rope_cos"], rope_sin=args["rope_sin"], - crop_y=args["num_tokens"] + crop_y=args["num_tokens"], + transformer_options=args["transformer_options"] ) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap}) y_feat = out["txt"] x = out["img"] else: @@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module): rope_cos=rope_cos, rope_sin=rope_sin, crop_y=num_tokens, + transformer_options=transformer_options, ) # (B, M, D), (B, L, D) del y_feat # Final layers don't use dense text features. diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index ae49cf945..28d81c79e 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module): return t_emb -def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): - return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}): + return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options) class HiDreamAttnProcessor_flashattn: @@ -86,6 +86,7 @@ class HiDreamAttnProcessor_flashattn: image_tokens_masks: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, + transformer_options={}, *args, **kwargs, ) -> torch.FloatTensor: @@ -133,7 +134,7 @@ class HiDreamAttnProcessor_flashattn: query = torch.cat([query_1, query_2], dim=-1) key = torch.cat([key_1, key_2], dim=-1) - hidden_states = attention(query, key, value) + hidden_states = attention(query, key, value, transformer_options=transformer_options) if not attn.single: hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) @@ -199,6 +200,7 @@ class HiDreamAttention(nn.Module): image_tokens_masks: torch.FloatTensor = None, norm_text_tokens: torch.FloatTensor = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.Tensor: return self.processor( self, @@ -206,6 +208,7 @@ class HiDreamAttention(nn.Module): image_tokens_masks = image_tokens_masks, text_tokens = norm_text_tokens, rope = rope, + transformer_options=transformer_options, ) @@ -406,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, - + transformer_options={}, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ @@ -419,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module): norm_image_tokens, image_tokens_masks, rope = rope, + transformer_options=transformer_options, ) image_tokens = gate_msa_i * attn_output_i + image_tokens @@ -483,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ @@ -500,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module): image_tokens_masks, norm_text_tokens, rope = rope, + transformer_options=transformer_options, ) image_tokens = gate_msa_i * attn_output_i + image_tokens @@ -550,6 +556,7 @@ class HiDreamImageBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: torch.FloatTensor = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.FloatTensor: return self.block( image_tokens, @@ -557,6 +564,7 @@ class HiDreamImageBlock(nn.Module): text_tokens, adaln_input, rope, + transformer_options=transformer_options, ) @@ -786,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module): text_tokens = cur_encoder_hidden_states, adaln_input = adaln_input, rope = rope, + transformer_options=transformer_options, ) initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] block_id += 1 @@ -809,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module): text_tokens=None, adaln_input=adaln_input, rope=rope, + transformer_options=transformer_options, ) hidden_states = hidden_states[:, :hidden_states_seq_len] block_id += 1 diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py index 0fa5e78c1..4991b1645 100644 --- a/comfy/ldm/hunyuan3d/model.py +++ b/comfy/ldm/hunyuan3d/model.py @@ -99,14 +99,16 @@ class Hunyuan3Dv2(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args["transformer_options"]) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -115,7 +117,8 @@ class Hunyuan3Dv2(nn.Module): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) img = torch.cat((txt, img), 1) @@ -126,17 +129,19 @@ class Hunyuan3Dv2(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args["transformer_options"]) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) img = img[:, txt.shape[1]:, ...] img = self.final_layer(img, vec) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index ca86b8bb1..5132e6c07 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -80,13 +80,13 @@ class TokenRefinerBlock(nn.Module): operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - def forward(self, x, c, mask): + def forward(self, x, c, mask, transformer_options={}): mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1) norm_x = self.norm1(x) qkv = self.self_attn.qkv(norm_x) q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4) - attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True) + attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options) x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1) x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1) @@ -117,14 +117,14 @@ class IndividualTokenRefiner(nn.Module): ] ) - def forward(self, x, c, mask): + def forward(self, x, c, mask, transformer_options={}): m = None if mask is not None: m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1) m = m + m.transpose(2, 3) for block in self.blocks: - x = block(x, c, m) + x = block(x, c, m, transformer_options=transformer_options) return x @@ -152,6 +152,7 @@ class TokenRefiner(nn.Module): x, timesteps, mask, + transformer_options={}, ): t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) # m = mask.float().unsqueeze(-1) @@ -160,7 +161,7 @@ class TokenRefiner(nn.Module): c = t + self.c_embedder(c.to(x.dtype)) x = self.input_embedder(x) - x = self.individual_token_refiner(x, c, mask) + x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options) return x @@ -328,7 +329,7 @@ class HunyuanVideo(nn.Module): if txt_mask is not None and not torch.is_floating_point(txt_mask): txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max - txt = self.txt_in(txt, timesteps, txt_mask) + txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options) if self.byt5_in is not None and txt_byt5 is not None: txt_byt5 = self.byt5_in(txt_byt5) @@ -352,14 +353,14 @@ class HunyuanVideo(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) + out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -374,13 +375,13 @@ class HunyuanVideo(nn.Module): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) + out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index aa2ea62b1..def365ba7 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -271,7 +271,7 @@ class CrossAttention(nn.Module): self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, mask=None, pe=None): + def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) @@ -285,9 +285,9 @@ class CrossAttention(nn.Module): k = apply_rotary_emb(k, pe) if mask is None: - out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out) @@ -303,12 +303,12 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) - x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa - x += self.attn2(x, context=context, mask=attention_mask) + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp x += self.ff(y) * gate_mlp @@ -479,10 +479,10 @@ class LTXVModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -490,7 +490,8 @@ class LTXVModel(torch.nn.Module): context=context, attention_mask=attention_mask, timestep=timestep, - pe=pe + pe=pe, + transformer_options=transformer_options, ) # 3. Output diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index e08ed817d..f87d98ac0 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -104,6 +104,7 @@ class JointAttention(nn.Module): x: torch.Tensor, x_mask: torch.Tensor, freqs_cis: torch.Tensor, + transformer_options={}, ) -> torch.Tensor: """ @@ -140,7 +141,7 @@ class JointAttention(nn.Module): if n_rep >= 1: xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True) + output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options) return self.out(output) @@ -268,6 +269,7 @@ class JointTransformerBlock(nn.Module): x_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor]=None, + transformer_options={}, ): """ Perform a forward pass through the TransformerBlock. @@ -290,6 +292,7 @@ class JointTransformerBlock(nn.Module): modulate(self.attention_norm1(x), scale_msa), x_mask, freqs_cis, + transformer_options=transformer_options, ) ) x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( @@ -304,6 +307,7 @@ class JointTransformerBlock(nn.Module): self.attention_norm1(x), x_mask, freqs_cis, + transformer_options=transformer_options, ) ) x = x + self.ffn_norm2( @@ -494,7 +498,7 @@ class NextDiT(nn.Module): return imgs def patchify_and_embed( - self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens + self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={} ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: bsz = len(x) pH = pW = self.patch_size @@ -554,7 +558,7 @@ class NextDiT(nn.Module): # refine context for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) # refine image flat_x = [] @@ -573,7 +577,7 @@ class NextDiT(nn.Module): padded_img_embed = self.x_embedder(padded_img_embed) padded_img_mask = padded_img_mask.unsqueeze(1) for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) + padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) if cap_mask is not None: mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) @@ -616,12 +620,13 @@ class NextDiT(nn.Module): cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + transformer_options = kwargs.get("transformer_options", {}) x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) + x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(x.device) for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input) + x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options) x = self.final_layer(x, adaln_input) x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 043df28df..bf2553c37 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -5,8 +5,9 @@ import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional +from typing import Optional, Any, Callable, Union import logging +import functools from .diffusionmodules.util import AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention @@ -17,23 +18,45 @@ if model_management.xformers_enabled(): import xformers import xformers.ops -if model_management.sage_attention_enabled(): - try: - from sageattention import sageattn - except ModuleNotFoundError as e: +SAGE_ATTENTION_IS_AVAILABLE = False +try: + from sageattention import sageattn + SAGE_ATTENTION_IS_AVAILABLE = True +except ModuleNotFoundError as e: + if model_management.sage_attention_enabled(): if e.name == "sageattention": logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") else: raise e exit(-1) -if model_management.flash_attention_enabled(): - try: - from flash_attn import flash_attn_func - except ModuleNotFoundError: +FLASH_ATTENTION_IS_AVAILABLE = False +try: + from flash_attn import flash_attn_func + FLASH_ATTENTION_IS_AVAILABLE = True +except ModuleNotFoundError: + if model_management.flash_attention_enabled(): logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) +REGISTERED_ATTENTION_FUNCTIONS = {} +def register_attention_function(name: str, func: Callable): + # avoid replacing existing functions + if name not in REGISTERED_ATTENTION_FUNCTIONS: + REGISTERED_ATTENTION_FUNCTIONS[name] = func + else: + logging.warning(f"Attention function {name} already registered, skipping registration.") + +def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]: + if name == "optimized": + return optimized_attention + elif name not in REGISTERED_ATTENTION_FUNCTIONS: + if default is ...: + raise KeyError(f"Attention function {name} not found.") + else: + return default + return REGISTERED_ATTENTION_FUNCTIONS[name] + from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -91,7 +114,27 @@ class FeedForward(nn.Module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) -def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): + +def wrap_attn(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + remove_attn_wrapper_key = False + try: + if "_inside_attn_wrapper" not in kwargs: + transformer_options = kwargs.get("transformer_options", None) + remove_attn_wrapper_key = True + kwargs["_inside_attn_wrapper"] = True + if transformer_options is not None: + if "optimized_attention_override" in transformer_options: + return transformer_options["optimized_attention_override"](func, *args, **kwargs) + return func(*args, **kwargs) + finally: + if remove_attn_wrapper_key: + del kwargs["_inside_attn_wrapper"] + return wrapper + +@wrap_attn +def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out - -def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, query.dtype) if skip_reshape: @@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states -def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -359,7 +403,8 @@ try: except: pass -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken @@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh disabled_xformers = True if disabled_xformers: - return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) + return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs) if skip_reshape: # b h k d -> b k h d @@ -427,8 +472,8 @@ else: #TODO: other GPUs ? SDP_BATCH_LIMIT = 2**31 - -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: @@ -470,8 +515,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out - -def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" @@ -501,7 +546,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.transpose(1, 2), (q, k, v), ) - return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs) if tensor_layout == "HND": if not skip_output_reshape: @@ -534,8 +579,8 @@ except AttributeError as error: dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" - -def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: @@ -597,6 +642,19 @@ else: optimized_attention_masked = optimized_attention + +# register core-supported attention functions +if SAGE_ATTENTION_IS_AVAILABLE: + register_attention_function("sage", attention_sage) +if FLASH_ATTENTION_IS_AVAILABLE: + register_attention_function("flash", attention_flash) +if model_management.xformers_enabled(): + register_attention_function("xformers", attention_xformers) +register_attention_function("pytorch", attention_pytorch) +register_attention_function("sub_quad", attention_sub_quad) +register_attention_function("split", attention_split) + + def optimized_attention_for_device(device, mask=False, small_input=False): if small_input: if model_management.pytorch_attention_enabled(): @@ -629,7 +687,7 @@ class CrossAttention(nn.Module): self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, value=None, mask=None): + def forward(self, x, context=None, value=None, mask=None, transformer_options={}): q = self.to_q(x) context = default(context, x) k = self.to_k(context) @@ -640,9 +698,9 @@ class CrossAttention(nn.Module): v = self.to_v(context) if mask is None: - out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out) @@ -746,7 +804,7 @@ class BasicTransformerBlock(nn.Module): n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) n = self.attn1.to_out(n) else: - n = self.attn1(n, context=context_attn1, value=value_attn1) + n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options) if "attn1_output_patch" in transformer_patches: patch = transformer_patches["attn1_output_patch"] @@ -786,7 +844,7 @@ class BasicTransformerBlock(nn.Module): n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) n = self.attn2.to_out(n) else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] @@ -1017,7 +1075,7 @@ class SpatialVideoTransformer(SpatialTransformer): B, S, C = x_mix.shape x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) - x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options + x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options) x_mix = rearrange( x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps ) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 4d6beba2d..42f406f1a 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs): return _block_mixing(*args, **kwargs) -def _block_mixing(context, x, context_block, x_block, c): +def _block_mixing(context, x, context_block, x_block, c, transformer_options={}): context_qkv, context_intermediates = context_block.pre_attention(context, c) if x_block.x_block_self_attn: @@ -622,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c): attn = optimized_attention( qkv[0], qkv[1], qkv[2], heads=x_block.attn.num_heads, + transformer_options=transformer_options, ) context_attn, x_attn = ( attn[:, : context_qkv[0].shape[1]], @@ -637,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c): attn2 = optimized_attention( x_qkv2[0], x_qkv2[1], x_qkv2[2], heads=x_block.attn2.num_heads, + transformer_options=transformer_options, ) x = x_block.post_attention_x(x_attn, attn2, *x_intermediates) else: @@ -958,10 +960,10 @@ class MMDiT(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"]) + out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap}) context = out["txt"] x = out["img"] else: @@ -970,6 +972,7 @@ class MMDiT(nn.Module): x, c=c_mod, use_checkpoint=self.use_checkpoint, + transformer_options=transformer_options, ) if control is not None: control_o = control.get("output") diff --git a/comfy/ldm/omnigen/omnigen2.py b/comfy/ldm/omnigen/omnigen2.py index 4884449f8..82edc92da 100644 --- a/comfy/ldm/omnigen/omnigen2.py +++ b/comfy/ldm/omnigen/omnigen2.py @@ -120,7 +120,7 @@ class Attention(nn.Module): nn.Dropout(0.0) ) - def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape query = self.to_q(hidden_states) @@ -146,7 +146,7 @@ class Attention(nn.Module): key = key.repeat_interleave(self.heads // self.kv_heads, dim=1) value = value.repeat_interleave(self.heads // self.kv_heads, dim=1) - hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True) + hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options) hidden_states = self.to_out[0](hidden_states) return hidden_states @@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module): self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor: if self.modulation: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) - attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) + attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) else: norm_hidden_states = self.norm1(hidden_states) - attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) + attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) @@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module): ref_img_sizes, img_sizes, ) - def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb): + def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}): batch_size = len(hidden_states) hidden_states = self.x_embedder(hidden_states) @@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module): shift += ref_img_len for layer in self.noise_refiner: - hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options) if ref_image_hidden_states is not None: for layer in self.ref_image_refiner: - ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb) + ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options) hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1) return hidden_states - def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs): + def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs): B, C, H, W = x.shape hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) _, _, H_padded, W_padded = hidden_states.shape @@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module): ) for layer in self.context_refiner: - text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options) img_len = hidden_states.shape[1] combined_img_hidden_states = self.img_patch_embed_and_refine( @@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module): noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, + transformer_options=transformer_options, ) hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1) attention_mask = None for layer in self.layers: - hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options) hidden_states = self.norm_out(hidden_states, temb) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 04071f31c..b9f60c2b7 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -132,6 +132,7 @@ class Attention(nn.Module): encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: seq_txt = encoder_hidden_states.shape[1] @@ -159,7 +160,7 @@ class Attention(nn.Module): joint_key = joint_key.flatten(start_dim=2) joint_value = joint_value.flatten(start_dim=2) - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -226,6 +227,7 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_params = self.img_mod(temb) txt_mod_params = self.txt_mod(temb) @@ -242,6 +244,7 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states=txt_modulated, encoder_hidden_states_mask=encoder_hidden_states_mask, image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, ) hidden_states = hidden_states + img_gate1 * img_attn_output @@ -434,9 +437,9 @@ class QwenImageTransformer2DModel(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) hidden_states = out["img"] encoder_hidden_states = out["txt"] else: @@ -446,11 +449,12 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, ) if "double_block" in patches: for p in patches["double_block"]: - out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) + out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options}) hidden_states = out["img"] encoder_hidden_states = out["txt"] diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 47857dc2b..63472ada2 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -52,7 +52,7 @@ class WanSelfAttention(nn.Module): self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, freqs): + def forward(self, x, freqs, transformer_options={}): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] @@ -75,6 +75,7 @@ class WanSelfAttention(nn.Module): k.view(b, s, n * d), v, heads=self.num_heads, + transformer_options=transformer_options, ) x = self.o(x) @@ -83,7 +84,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context, **kwargs): + def forward(self, x, context, transformer_options={}, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -95,7 +96,7 @@ class WanT2VCrossAttention(WanSelfAttention): v = self.v(context) # compute attention - x = optimized_attention(q, k, v, heads=self.num_heads) + x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options) x = self.o(x) return x @@ -116,7 +117,7 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context, context_img_len): + def forward(self, x, context, context_img_len, transformer_options={}): r""" Args: x(Tensor): Shape [B, L1, C] @@ -131,9 +132,9 @@ class WanI2VCrossAttention(WanSelfAttention): v = self.v(context) k_img = self.norm_k_img(self.k_img(context_img)) v_img = self.v_img(context_img) - img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads) + img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options) # compute attention - x = optimized_attention(q, k, v, heads=self.num_heads) + x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options) # output x = x + img_x @@ -206,6 +207,7 @@ class WanAttentionBlock(nn.Module): freqs, context, context_img_len=257, + transformer_options={}, ): r""" Args: @@ -224,12 +226,12 @@ class WanAttentionBlock(nn.Module): # self-attention y = self.self_attn( torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), - freqs) + freqs, transformer_options=transformer_options) x = torch.addcmul(x, y, repeat_e(e[2], x)) # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -559,12 +561,12 @@ class WanModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) # head x = self.head(x, e) @@ -742,17 +744,17 @@ class VaceWanModel(WanModel): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) ii = self.vace_layers_mapping.get(i, None) if ii is not None: for iii in range(len(c)): - c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) x += c_skip * vace_strength[iii] del c_skip # head @@ -841,12 +843,12 @@ class CameraWanModel(WanModel): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) # head x = self.head(x, e) From a3b04de7004cc19dee9364bd71e62bab05475810 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:46:46 -0700 Subject: [PATCH 17/30] Hunyuan refiner vae now works with tiled. (#9836) --- comfy/ldm/hunyuan_video/vae_refiner.py | 1 - comfy/sd.py | 21 +++++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index e3fff9bbe..c6f742710 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -185,7 +185,6 @@ class Encoder(nn.Module): self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() def forward(self, x): - x = x.unsqueeze(2) x = self.conv_in(x) for stage in self.down: diff --git a/comfy/sd.py b/comfy/sd.py index 02ddc7239..f8f1a89e8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -412,9 +412,12 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float32] elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] - self.downscale_ratio = 16 - self.upscale_ratio = 16 + ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] + self.latent_channels = 64 + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) self.latent_dim = 3 self.not_video = True self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -684,8 +687,11 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) - if not self.not_video and self.latent_dim == 3 and pixel_samples.ndim < 5: - pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + if self.latent_dim == 3 and pixel_samples.ndim < 5: + if not self.not_video: + pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + else: + pixel_samples = pixel_samples.unsqueeze(2) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -719,7 +725,10 @@ class VAE: dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) if dims == 3: - pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + if not self.not_video: + pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + else: + pixel_samples = pixel_samples.unsqueeze(2) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) From 2559dee49202365bc97218b98121e796f57dfcb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 13 Sep 2025 04:52:58 +0300 Subject: [PATCH 18/30] Support wav2vec base models (#9637) * Support wav2vec base models * trim trailing whitespace * Do interpolation after --- comfy/audio_encoders/audio_encoders.py | 36 ++++++++++- comfy/audio_encoders/wav2vec2.py | 87 +++++++++++++++++++------- 2 files changed, 99 insertions(+), 24 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 538c21bd5..d1ec78f69 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -11,7 +11,13 @@ class AudioEncoderModel(): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) - self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast) + model_config = dict(config) + model_config.update({ + "dtype": self.dtype, + "device": offload_device, + "operations": comfy.ops.manual_cast + }) + self.model = Wav2Vec2Model(**model_config) self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 @@ -25,7 +31,7 @@ class AudioEncoderModel(): def encode_audio(self, audio, sample_rate): comfy.model_management.load_model_gpu(self.patcher) audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) - out, all_layers = self.model(audio.to(self.load_device)) + out, all_layers = self.model(audio.to(self.load_device), sr=self.model_sample_rate) outputs = {} outputs["encoded_audio"] = out outputs["encoded_audio_all_layers"] = all_layers @@ -33,8 +39,32 @@ class AudioEncoderModel(): def load_audio_encoder_from_sd(sd, prefix=""): - audio_encoder = AudioEncoderModel(None) sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) + embed_dim = sd["encoder.layer_norm.bias"].shape[0] + if embed_dim == 1024:# large + config = { + "embed_dim": 1024, + "num_heads": 16, + "num_layers": 24, + "conv_norm": True, + "conv_bias": True, + "do_normalize": True, + "do_stable_layer_norm": True + } + elif embed_dim == 768: # base + config = { + "embed_dim": 768, + "num_heads": 12, + "num_layers": 12, + "conv_norm": False, + "conv_bias": False, + "do_normalize": False, # chinese-wav2vec2-base has this False + "do_stable_layer_norm": False + } + else: + raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) + + audio_encoder = AudioEncoderModel(config) m, u = audio_encoder.load_sd(sd) if len(m) > 0: logging.warning("missing audio encoder: {}".format(m)) diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py index de906622a..ef10dcd2a 100644 --- a/comfy/audio_encoders/wav2vec2.py +++ b/comfy/audio_encoders/wav2vec2.py @@ -13,19 +13,49 @@ class LayerNormConv(nn.Module): x = self.conv(x) return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1)) +class LayerGroupNormConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): + super().__init__() + self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) + self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return torch.nn.functional.gelu(self.layer_norm(x)) + +class ConvNoNorm(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): + super().__init__() + self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return torch.nn.functional.gelu(x) + class ConvFeatureEncoder(nn.Module): - def __init__(self, conv_dim, dtype=None, device=None, operations=None): + def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None): super().__init__() - self.conv_layers = nn.ModuleList([ - LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), - ]) + if conv_norm: + self.conv_layers = nn.ModuleList([ + LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ]) + else: + self.conv_layers = nn.ModuleList([ + LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ]) def forward(self, x): x = x.unsqueeze(1) @@ -76,6 +106,7 @@ class TransformerEncoder(nn.Module): num_heads=12, num_layers=12, mlp_ratio=4.0, + do_stable_layer_norm=True, dtype=None, device=None, operations=None ): super().__init__() @@ -86,20 +117,25 @@ class TransformerEncoder(nn.Module): embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, + do_stable_layer_norm=do_stable_layer_norm, device=device, dtype=dtype, operations=operations ) for _ in range(num_layers) ]) self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype) + self.do_stable_layer_norm = do_stable_layer_norm def forward(self, x, mask=None): x = x + self.pos_conv_embed(x) all_x = () + if not self.do_stable_layer_norm: + x = self.layer_norm(x) for layer in self.layers: all_x += (x,) x = layer(x, mask) - x = self.layer_norm(x) + if self.do_stable_layer_norm: + x = self.layer_norm(x) all_x += (x,) return x, all_x @@ -145,6 +181,7 @@ class TransformerEncoderLayer(nn.Module): embed_dim=768, num_heads=12, mlp_ratio=4.0, + do_stable_layer_norm=True, dtype=None, device=None, operations=None ): super().__init__() @@ -154,15 +191,19 @@ class TransformerEncoderLayer(nn.Module): self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations) self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + self.do_stable_layer_norm = do_stable_layer_norm def forward(self, x, mask=None): residual = x - x = self.layer_norm(x) + if self.do_stable_layer_norm: + x = self.layer_norm(x) x = self.attention(x, mask=mask) x = residual + x - - x = x + self.feed_forward(self.final_layer_norm(x)) - return x + if not self.do_stable_layer_norm: + x = self.layer_norm(x) + return self.final_layer_norm(x + self.feed_forward(x)) + else: + return x + self.feed_forward(self.final_layer_norm(x)) class Wav2Vec2Model(nn.Module): @@ -174,34 +215,38 @@ class Wav2Vec2Model(nn.Module): final_dim=256, num_heads=16, num_layers=24, + conv_norm=True, + conv_bias=True, + do_normalize=True, + do_stable_layer_norm=True, dtype=None, device=None, operations=None ): super().__init__() conv_dim = 512 - self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations) + self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations) self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations) self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype)) + self.do_normalize = do_normalize self.encoder = TransformerEncoder( embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, + do_stable_layer_norm=do_stable_layer_norm, device=device, dtype=dtype, operations=operations ) - def forward(self, x, mask_time_indices=None, return_dict=False): - + def forward(self, x, sr=16000, mask_time_indices=None, return_dict=False): x = torch.mean(x, dim=1) - x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7) + if self.do_normalize: + x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7) features = self.feature_extractor(x) features = self.feature_projection(features) - batch_size, seq_len, _ = features.shape x, all_x = self.encoder(features) - return x, all_x From 29bf807b0e2d89402d555d08bd8e9df15e636f0c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Sep 2025 18:57:04 -0700 Subject: [PATCH 19/30] Cleanup. (#9838) --- comfy/audio_encoders/audio_encoders.py | 2 +- comfy/audio_encoders/wav2vec2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index d1ec78f69..6fb5b08e9 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -31,7 +31,7 @@ class AudioEncoderModel(): def encode_audio(self, audio, sample_rate): comfy.model_management.load_model_gpu(self.patcher) audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) - out, all_layers = self.model(audio.to(self.load_device), sr=self.model_sample_rate) + out, all_layers = self.model(audio.to(self.load_device)) outputs = {} outputs["encoded_audio"] = out outputs["encoded_audio_all_layers"] = all_layers diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py index ef10dcd2a..4e34a40a7 100644 --- a/comfy/audio_encoders/wav2vec2.py +++ b/comfy/audio_encoders/wav2vec2.py @@ -238,7 +238,7 @@ class Wav2Vec2Model(nn.Module): device=device, dtype=dtype, operations=operations ) - def forward(self, x, sr=16000, mask_time_indices=None, return_dict=False): + def forward(self, x, mask_time_indices=None, return_dict=False): x = torch.mean(x, dim=1) if self.do_normalize: From e5e70636e7b7b54695220a88ab036c1607959736 Mon Sep 17 00:00:00 2001 From: Kimbing Ng <50580578+KimbingNg@users.noreply.github.com> Date: Sun, 14 Sep 2025 04:59:19 +0800 Subject: [PATCH 20/30] Remove single quote pattern to avoid wrong matches (#9842) --- comfy/text_encoders/hunyuan_image.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index be396cae7..699eddc33 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -22,17 +22,14 @@ class HunyuanImageTokenizer(QwenImageTokenizer): # ByT5 processing for HunyuanImage text_prompt_texts = [] - pattern_quote_single = r'\'(.*?)\'' pattern_quote_double = r'\"(.*?)\"' pattern_quote_chinese_single = r'‘(.*?)’' pattern_quote_chinese_double = r'“(.*?)”' - matches_quote_single = re.findall(pattern_quote_single, text) matches_quote_double = re.findall(pattern_quote_double, text) matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text) matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text) - text_prompt_texts.extend(matches_quote_single) text_prompt_texts.extend(matches_quote_double) text_prompt_texts.extend(matches_quote_chinese_single) text_prompt_texts.extend(matches_quote_chinese_double) From c1297f4eb38a63e2f99c9fa76e32e3a36c933b85 Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Sat, 13 Sep 2025 15:58:43 -0600 Subject: [PATCH 21/30] Add support for Chroma Radiance (#9682) * Initial Chroma Radiance support * Minor Chroma Radiance cleanups * Update Radiance nodes to ensure latents/images are on the intermediate device * Fix Chroma Radiance memory estimation. * Increase Chroma Radiance memory usage factor * Increase Chroma Radiance memory usage factor once again * Ensure images are multiples of 16 for Chroma Radiance Add batch dimension and fix channels when necessary in ChromaRadianceImageToLatent node * Tile Chroma Radiance NeRF to reduce memory consumption, update memory usage factor * Update Radiance to support conv nerf final head type. * Allow setting NeRF embedder dtype for Radiance Bump Radiance nerf tile size to 32 Support EasyCache/LazyCache on Radiance (maybe) * Add ChromaRadianceStubVAE node * Crop Radiance image inputs to multiples of 16 instead of erroring to be in line with existing VAE behavior * Convert Chroma Radiance nodes to V3 schema. * Add ChromaRadianceOptions node and backend support. Cleanups/refactoring to reduce code duplication with Chroma. * Fix overriding the NeRF embedder dtype for Chroma Radiance * Minor Chroma Radiance cleanups * Move Chroma Radiance to its own directory in ldm Minor code cleanups and tooltip improvements * Fix Chroma Radiance embedder dtype overriding * Remove Radiance dynamic nerf_embedder dtype override feature * Unbork Radiance NeRF embedder init * Remove Chroma Radiance image conversion and stub VAE nodes Add a chroma_radiance option to the VAELoader builtin node which uses comfy.sd.PixelspaceConversionVAE Add a PixelspaceConversionVAE to comfy.sd for converting BHWC 0..1 <-> BCHW -1..1 --- comfy/latent_formats.py | 17 ++ comfy/ldm/chroma/model.py | 10 +- comfy/ldm/chroma_radiance/layers.py | 206 ++++++++++++++++ comfy/ldm/chroma_radiance/model.py | 328 ++++++++++++++++++++++++++ comfy/model_base.py | 9 +- comfy/model_detection.py | 14 +- comfy/sd.py | 60 +++++ comfy/supported_models.py | 15 +- comfy_extras/nodes_chroma_radiance.py | 114 +++++++++ nodes.py | 6 +- 10 files changed, 770 insertions(+), 9 deletions(-) create mode 100644 comfy/ldm/chroma_radiance/layers.py create mode 100644 comfy/ldm/chroma_radiance/model.py create mode 100644 comfy_extras/nodes_chroma_radiance.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 894540879..77e642a94 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -629,3 +629,20 @@ class Hunyuan3Dv2mini(LatentFormat): class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 + +class ChromaRadiance(LatentFormat): + latent_channels = 3 + + def __init__(self): + self.latent_rgb_factors = [ + # R G B + [ 1.0, 0.0, 0.0 ], + [ 0.0, 1.0, 0.0 ], + [ 0.0, 0.0, 1.0 ] + ] + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 4f709f87d..ad1c523fe 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -151,8 +151,6 @@ class Chroma(nn.Module): attn_mask: Tensor = None, ) -> Tensor: patches_replace = transformer_options.get("patches_replace", {}) - if img.ndim != 3 or txt.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) @@ -254,8 +252,9 @@ class Chroma(nn.Module): img[:, txt.shape[1] :, ...] += add img = img[:, txt.shape[1] :, ...] - final_mod = self.get_modulations(mod_vectors, "final") - img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) + if hasattr(self, "final_layer"): + final_mod = self.get_modulations(mod_vectors, "final") + img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): @@ -271,6 +270,9 @@ class Chroma(nn.Module): img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) + if img.ndim != 3 or context.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + h_len = ((h + (self.patch_size // 2)) // self.patch_size) w_len = ((w + (self.patch_size // 2)) // self.patch_size) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py new file mode 100644 index 000000000..3c7bc9b6b --- /dev/null +++ b/comfy/ldm/chroma_radiance/layers.py @@ -0,0 +1,206 @@ +# Adapted from https://github.com/lodestone-rock/flow +from functools import lru_cache + +import torch +from torch import nn + +from comfy.ldm.flux.layers import RMSNorm + + +class NerfEmbedder(nn.Module): + """ + An embedder module that combines input features with a 2D positional + encoding that mimics the Discrete Cosine Transform (DCT). + + This module takes an input tensor of shape (B, P^2, C), where P is the + patch size, and enriches it with positional information before projecting + it to a new hidden size. + """ + def __init__( + self, + in_channels: int, + hidden_size_input: int, + max_freqs: int, + dtype=None, + device=None, + operations=None, + ): + """ + Initializes the NerfEmbedder. + + Args: + in_channels (int): The number of channels in the input tensor. + hidden_size_input (int): The desired dimension of the output embedding. + max_freqs (int): The number of frequency components to use for both + the x and y dimensions of the positional encoding. + The total number of positional features will be max_freqs^2. + """ + super().__init__() + self.dtype = dtype + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + + # A linear layer to project the concatenated input features and + # positional encodings to the final output dimension. + self.embedder = nn.Sequential( + operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device) + ) + + @lru_cache(maxsize=4) + def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """ + Generates and caches 2D DCT-like positional embeddings for a given patch size. + + The LRU cache is a performance optimization that avoids recomputing the + same positional grid on every forward pass. + + Args: + patch_size (int): The side length of the square input patch. + device: The torch device to create the tensors on. + dtype: The torch dtype for the tensors. + + Returns: + A tensor of shape (1, patch_size^2, max_freqs^2) containing the + positional embeddings. + """ + # Create normalized 1D coordinate grids from 0 to 1. + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + + # Create a 2D meshgrid of coordinates. + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + + # Reshape positions to be broadcastable with frequencies. + # Shape becomes (patch_size^2, 1, 1). + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + # Create a 1D tensor of frequency values from 0 to max_freqs-1. + freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) + + # Reshape frequencies to be broadcastable for creating 2D basis functions. + # freqs_x shape: (1, max_freqs, 1) + # freqs_y shape: (1, 1, max_freqs) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + + # A custom weighting coefficient, not part of standard DCT. + # This seems to down-weight the contribution of higher-frequency interactions. + coeffs = (1 + freqs_x * freqs_y) ** -1 + + # Calculate the 1D cosine basis functions for x and y coordinates. + # This is the core of the DCT formulation. + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + + # Combine the 1D basis functions to create 2D basis functions by element-wise + # multiplication, and apply the custom coefficients. Broadcasting handles the + # combination of all (pos_x, freqs_x) with all (pos_y, freqs_y). + # The result is flattened into a feature vector for each position. + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + + return dct + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the embedder. + + Args: + inputs (Tensor): The input tensor of shape (B, P^2, C). + + Returns: + Tensor: The output tensor of shape (B, P^2, hidden_size_input). + """ + # Get the batch size, number of pixels, and number of channels. + B, P2, C = inputs.shape + + # Infer the patch side length from the number of pixels (P^2). + patch_size = int(P2 ** 0.5) + + input_dtype = inputs.dtype + inputs = inputs.to(dtype=self.dtype) + + # Fetch the pre-computed or cached positional embeddings. + dct = self.fetch_pos(patch_size, inputs.device, self.dtype) + + # Repeat the positional embeddings for each item in the batch. + dct = dct.repeat(B, 1, 1) + + # Concatenate the original input features with the positional embeddings + # along the feature dimension. + inputs = torch.cat((inputs, dct), dim=-1) + + # Project the combined tensor to the target hidden size. + return self.embedder(inputs).to(dtype=input_dtype) + + +class NerfGLUBlock(nn.Module): + """ + A NerfBlock using a Gated Linear Unit (GLU) like MLP. + """ + def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None): + super().__init__() + # The total number of parameters for the MLP is increased to accommodate + # the gate, value, and output projection matrices. + # We now need to generate parameters for 3 matrices. + total_params = 3 * hidden_size_x**2 * mlp_ratio + self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device) + self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations) + self.mlp_ratio = mlp_ratio + + + def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + batch_size, num_x, hidden_size_x = x.shape + mlp_params = self.param_generator(s) + + # Split the generated parameters into three parts for the gate, value, and output projection. + fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1) + + # Reshape the parameters into matrices for batch matrix multiplication. + fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x) + + # Normalize the generated weight matrices as in the original implementation. + fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2) + fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2) + fc2 = torch.nn.functional.normalize(fc2, dim=-2) + + res_x = x + x = self.norm(x) + + # Apply the final output projection. + x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2) + + return x + res_x + + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So we temporarily move the channel dimension to the end for the norm operation. + return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1) + + +class NerfFinalLayerConv(nn.Module): + def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.conv = operations.Conv2d( + in_channels=hidden_size, + out_channels=out_channels, + kernel_size=3, + padding=1, + dtype=dtype, + device=device, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So we temporarily move the channel dimension to the end for the norm operation. + return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1)) diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py new file mode 100644 index 000000000..f7eb7a22e --- /dev/null +++ b/comfy/ldm/chroma_radiance/model.py @@ -0,0 +1,328 @@ +# Credits: +# Original Flux code can be found on: https://github.com/black-forest-labs/flux +# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import Tensor, nn +from einops import repeat +import comfy.ldm.common_dit + +from comfy.ldm.flux.layers import EmbedND + +from comfy.ldm.chroma.model import Chroma, ChromaParams +from comfy.ldm.chroma.layers import ( + DoubleStreamBlock, + SingleStreamBlock, + Approximator, +) +from .layers import ( + NerfEmbedder, + NerfGLUBlock, + NerfFinalLayer, + NerfFinalLayerConv, +) + + +@dataclass +class ChromaRadianceParams(ChromaParams): + patch_size: int + nerf_hidden_size: int + nerf_mlp_ratio: int + nerf_depth: int + nerf_max_freqs: int + # Setting nerf_tile_size to 0 disables tiling. + nerf_tile_size: int + # Currently one of linear (legacy) or conv. + nerf_final_head_type: str + # None means use the same dtype as the model. + nerf_embedder_dtype: Optional[torch.dtype] + + +class ChromaRadiance(Chroma): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): + if operations is None: + raise RuntimeError("Attempt to create ChromaRadiance object without setting operations") + nn.Module.__init__(self) + self.dtype = dtype + params = ChromaRadianceParams(**kwargs) + self.params = params + self.patch_size = params.patch_size + self.in_channels = params.in_channels + self.out_channels = params.out_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.in_dim = params.in_dim + self.out_dim = params.out_dim + self.hidden_dim = params.hidden_dim + self.n_layers = params.n_layers + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in_patch = operations.Conv2d( + params.in_channels, + params.hidden_size, + kernel_size=params.patch_size, + stride=params.patch_size, + bias=True, + dtype=dtype, + device=device, + ) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + # set as nn identity for now, will overwrite it later. + self.distilled_guidance_layer = Approximator( + in_dim=self.in_dim, + hidden_dim=self.hidden_dim, + out_dim=self.out_dim, + n_layers=self.n_layers, + dtype=dtype, device=device, operations=operations + ) + + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + dtype=dtype, device=device, operations=operations + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + dtype=dtype, device=device, operations=operations, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + # pixel channel concat with DCT + self.nerf_image_embedder = NerfEmbedder( + in_channels=params.in_channels, + hidden_size_input=params.nerf_hidden_size, + max_freqs=params.nerf_max_freqs, + dtype=params.nerf_embedder_dtype or dtype, + device=device, + operations=operations, + ) + + self.nerf_blocks = nn.ModuleList([ + NerfGLUBlock( + hidden_size_s=params.hidden_size, + hidden_size_x=params.nerf_hidden_size, + mlp_ratio=params.nerf_mlp_ratio, + dtype=dtype, + device=device, + operations=operations, + ) for _ in range(params.nerf_depth) + ]) + + if params.nerf_final_head_type == "linear": + self.nerf_final_layer = NerfFinalLayer( + params.nerf_hidden_size, + out_channels=params.in_channels, + dtype=dtype, + device=device, + operations=operations, + ) + elif params.nerf_final_head_type == "conv": + self.nerf_final_layer_conv = NerfFinalLayerConv( + params.nerf_hidden_size, + out_channels=params.in_channels, + dtype=dtype, + device=device, + operations=operations, + ) + else: + errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}" + raise ValueError(errstr) + + self.skip_mmdit = [] + self.skip_dit = [] + self.lite = False + + @property + def _nerf_final_layer(self) -> nn.Module: + if self.params.nerf_final_head_type == "linear": + return self.nerf_final_layer + if self.params.nerf_final_head_type == "conv": + return self.nerf_final_layer_conv + # Impossible to get here as we raise an error on unexpected types on initialization. + raise NotImplementedError + + def img_in(self, img: Tensor) -> Tensor: + img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P] + # flatten into a sequence for the transformer. + return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] + + def forward_nerf( + self, + img_orig: Tensor, + img_out: Tensor, + params: ChromaRadianceParams, + ) -> Tensor: + B, C, H, W = img_orig.shape + num_patches = img_out.shape[1] + patch_size = params.patch_size + + # Store the raw pixel values of each patch for the NeRF head later. + # unfold creates patches: [B, C * P * P, NumPatches] + nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size) + nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] + + if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size: + # Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than + # the tile size. + img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params) + else: + # Reshape for per-patch processing + nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size) + nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) + + # Get DCT-encoded pixel embeddings [pixel-dct] + img_dct = self.nerf_image_embedder(nerf_pixels) + + # Pass through the dynamic MLP blocks (the NeRF) + for block in self.nerf_blocks: + img_dct = block(img_dct, nerf_hidden) + + # Reassemble the patches into the final image. + img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P] + # Reshape to combine with batch dimension for fold + img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P] + img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches] + img_dct = nn.functional.fold( + img_dct, + output_size=(H, W), + kernel_size=patch_size, + stride=patch_size, + ) + return self._nerf_final_layer(img_dct) + + def forward_tiled_nerf( + self, + nerf_hidden: Tensor, + nerf_pixels: Tensor, + batch: int, + channels: int, + num_patches: int, + patch_size: int, + params: ChromaRadianceParams, + ) -> Tensor: + """ + Processes the NeRF head in tiles to save memory. + nerf_hidden has shape [B, L, D] + nerf_pixels has shape [B, L, C * P * P] + """ + tile_size = params.nerf_tile_size + output_tiles = [] + # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1. + for i in range(0, num_patches, tile_size): + end = min(i + tile_size, num_patches) + + # Slice the current tile from the input tensors + nerf_hidden_tile = nerf_hidden[:, i:end, :] + nerf_pixels_tile = nerf_pixels[:, i:end, :] + + # Get the actual number of patches in this tile (can be smaller for the last tile) + num_patches_tile = nerf_hidden_tile.shape[1] + + # Reshape the tile for per-patch processing + # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D] + nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size) + # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C] + nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2) + + # get DCT-encoded pixel embeddings [pixel-dct] + img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile) + + # pass through the dynamic MLP blocks (the NeRF) + for block in self.nerf_blocks: + img_dct_tile = block(img_dct_tile, nerf_hidden_tile) + + output_tiles.append(img_dct_tile) + + # Concatenate the processed tiles along the patch dimension + return torch.cat(output_tiles, dim=0) + + def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams: + params = self.params + if not overrides: + return params + params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__} + nullable_keys = frozenset(("nerf_embedder_dtype",)) + bad_keys = tuple(k for k in overrides if k not in params_dict) + if bad_keys: + e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" + raise ValueError(e) + bad_keys = tuple( + k + for k, v in overrides.items() + if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) + ) + if bad_keys: + e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" + raise ValueError(e) + # At this point it's all valid keys and values so we can merge with the existing params. + params_dict |= overrides + return params.__class__(**params_dict) + + def _forward( + self, + x: Tensor, + timestep: Tensor, + context: Tensor, + guidance: Optional[Tensor], + control: Optional[dict]=None, + transformer_options: dict={}, + **kwargs: dict, + ) -> Tensor: + bs, c, h, w = x.shape + img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + if img.ndim != 4: + raise ValueError("Input img tensor must be in [B, C, H, W] format.") + if context.ndim != 3: + raise ValueError("Input txt tensors must have 3 dimensions.") + + params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {})) + + h_len = ((h + (self.patch_size // 2)) // self.patch_size) + w_len = ((w + (self.patch_size // 2)) // self.patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + + img_out = self.forward_orig( + img, + img_ids, + context, + txt_ids, + timestep, + guidance, + control, + transformer_options, + attn_mask=kwargs.get("attention_mask", None), + ) + return self.forward_nerf(img, img_out, params) diff --git a/comfy/model_base.py b/comfy/model_base.py index 324d89cff..252dfcf69 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model +import comfy.ldm.chroma_radiance.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model @@ -1320,8 +1321,8 @@ class HiDream(BaseModel): return out class Chroma(Flux): - def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) + def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma): + super().__init__(model_config, model_type, device=device, unet_model=unet_model) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1331,6 +1332,10 @@ class Chroma(Flux): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class ChromaRadiance(Chroma): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance) + class ACEStep(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index fe983cede..03d44f65e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 @@ -204,6 +204,18 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["out_dim"] = 3072 dit_config["hidden_dim"] = 5120 dit_config["n_layers"] = 5 + if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance + dit_config["image_model"] = "chroma_radiance" + dit_config["in_channels"] = 3 + dit_config["out_channels"] = 3 + dit_config["patch_size"] = 16 + dit_config["nerf_hidden_size"] = 64 + dit_config["nerf_mlp_ratio"] = 4 + dit_config["nerf_depth"] = 4 + dit_config["nerf_max_freqs"] = 8 + dit_config["nerf_tile_size"] = 32 + dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" + dit_config["nerf_embedder_dtype"] = torch.float32 else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config diff --git a/comfy/sd.py b/comfy/sd.py index f8f1a89e8..cb92802e9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -785,6 +785,66 @@ class VAE: except: return None +# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 +# to LATENT B, C, H, W and values on the scale of -1..1. +class PixelspaceConversionVAE: + def __init__(self, size_increment: int=16): + self.intermediate_device = comfy.model_management.intermediate_device() + self.size_increment = size_increment + + def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor: + if self.size_increment == 1: + return pixels + dims = pixels.shape[1:-1] + for d in range(len(dims)): + d_adj = (dims[d] // self.size_increment) * self.size_increment + if d_adj == d: + continue + d_offset = (dims[d] % self.size_increment) // 2 + pixels = pixels.narrow(d + 1, d_offset, d_adj) + return pixels + + def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + if pixels.ndim == 3: + pixels = pixels.unsqueeze(0) + elif pixels.ndim != 4: + raise ValueError("Unexpected input image shape") + # Ensure the image has spatial dimensions that are multiples of 16. + pixels = self.vae_encode_crop_pixels(pixels) + h, w, c = pixels.shape[1:] + if h < self.size_increment or w < self.size_increment: + raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).") + pixels= pixels[..., :3] + if c == 1: + pixels = pixels.expand(-1, -1, -1, 3) + elif c != 3: + raise ValueError("Unexpected number of channels in input image") + # Rescale to -1..1 and move the channel dimension to position 1. + latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True) + latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous() + latent -= 0.5 + latent *= 2 + return latent.clamp_(-1, 1) + + def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + # Rescale to 0..1 and move the channel dimension to the end. + img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True) + img = img.clamp_(-1, 1).movedim(1, -1).contiguous() + img += 1.0 + img *= 0.5 + return img.clamp_(0, 1) + + encode_tiled = encode + decode_tiled = decode + + @classmethod + def spacial_compression_decode(cls) -> int: + # This just exists so the tiled VAE nodes don't crash. + return 1 + + spacial_compression_encode = spacial_compression_decode + temporal_compression_decode = spacial_compression_decode + class StyleModel: def __init__(self, model, device="cpu"): self.model = model diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 472ea0ae9..be36b5dfe 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1205,6 +1205,19 @@ class Chroma(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) +class ChromaRadiance(Chroma): + unet_config = { + "image_model": "chroma_radiance", + } + + latent_format = comfy.latent_formats.ChromaRadiance + + # Pixel-space model, no spatial compression for model input. + memory_usage_factor = 0.0325 + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ChromaRadiance(self, device=device) + class ACEStep(supported_models_base.BASE): unet_config = { "audio_model": "ace", @@ -1338,6 +1351,6 @@ class HunyuanImage21Refiner(HunyuanVideo): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py new file mode 100644 index 000000000..381989818 --- /dev/null +++ b/comfy_extras/nodes_chroma_radiance.py @@ -0,0 +1,114 @@ +from typing_extensions import override +from typing import Callable + +import torch + +import comfy.model_management +from comfy_api.latest import ComfyExtension, io + +import nodes + +class EmptyChromaRadianceLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyChromaRadianceLatentImage", + category="latent/chroma_radiance", + inputs=[ + io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input(id="batch_size", default=1, min=1, max=4096), + ], + outputs=[io.Latent().Output()], + ) + + @classmethod + def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput: + latent = torch.zeros((batch_size, 3, height, width), device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples":latent}) + + +class ChromaRadianceOptions(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ChromaRadianceOptions", + category="model_patches/chroma_radiance", + description="Allows setting advanced options for the Chroma Radiance model.", + inputs=[ + io.Model.Input(id="model"), + io.Boolean.Input( + id="preserve_wrapper", + default=True, + tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.", + ), + io.Float.Input( + id="start_sigma", + default=1.0, + min=0.0, + max=1.0, + tooltip="First sigma that these options will be in effect.", + ), + io.Float.Input( + id="end_sigma", + default=0.0, + min=0.0, + max=1.0, + tooltip="Last sigma that these options will be in effect.", + ), + io.Int.Input( + id="nerf_tile_size", + default=-1, + min=-1, + tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", + ), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def execute( + cls, + *, + model: io.Model.Type, + preserve_wrapper: bool, + start_sigma: float, + end_sigma: float, + nerf_tile_size: int, + ) -> io.NodeOutput: + radiance_options = {} + if nerf_tile_size >= 0: + radiance_options["nerf_tile_size"] = nerf_tile_size + + if not radiance_options: + return io.NodeOutput(model) + + old_wrapper = model.model_options.get("model_function_wrapper") + + def model_function_wrapper(apply_model: Callable, args: dict) -> torch.Tensor: + c = args["c"].copy() + sigma = args["timestep"].max().detach().cpu().item() + if end_sigma <= sigma <= start_sigma: + transformer_options = c.get("transformer_options", {}).copy() + transformer_options["chroma_radiance_options"] = radiance_options.copy() + c["transformer_options"] = transformer_options + if not (preserve_wrapper and old_wrapper): + return apply_model(args["input"], args["timestep"], **c) + return old_wrapper(apply_model, args | {"c": c}) + + model = model.clone() + model.set_model_unet_function_wrapper(model_function_wrapper) + return io.NodeOutput(model) + + +class ChromaRadianceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyChromaRadianceLatentImage, + ChromaRadianceOptions, + ] + + +async def comfy_entrypoint() -> ChromaRadianceExtension: + return ChromaRadianceExtension() diff --git a/nodes.py b/nodes.py index 2befb4b75..76b8cbac8 100644 --- a/nodes.py +++ b/nodes.py @@ -730,6 +730,7 @@ class VAELoader: vaes.append("taesd3") if f1_taesd_dec and f1_taesd_enc: vaes.append("taef1") + vaes.append("chroma_radiance") return vaes @staticmethod @@ -772,7 +773,9 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: + if vae_name == "chroma_radiance": + return (comfy.sd.PixelspaceConversionVAE(),) + elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: sd = self.load_taesd(vae_name) else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) @@ -2322,6 +2325,7 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", + "nodes_chroma_radiance.py", "nodes_model_patch.py", "nodes_easycache.py", "nodes_audio_encoder.py", From 80b7c9455bf7afba7a9e95a1eb76b172408ab56c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 13 Sep 2025 15:03:34 -0700 Subject: [PATCH 22/30] Changes to the previous radiance commit. (#9851) --- comfy/ldm/chroma_radiance/model.py | 7 +-- comfy/pixel_space_convert.py | 16 +++++++ comfy/sd.py | 69 +++++------------------------- comfy/supported_models.py | 2 +- nodes.py | 7 +-- 5 files changed, 35 insertions(+), 66 deletions(-) create mode 100644 comfy/pixel_space_convert.py diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index f7eb7a22e..47aa11b04 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -306,8 +306,9 @@ class ChromaRadiance(Chroma): params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {})) - h_len = ((h + (self.patch_size // 2)) // self.patch_size) - w_len = ((w + (self.patch_size // 2)) // self.patch_size) + h_len = (img.shape[-2] // self.patch_size) + w_len = (img.shape[-1] // self.patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) @@ -325,4 +326,4 @@ class ChromaRadiance(Chroma): transformer_options, attn_mask=kwargs.get("attention_mask", None), ) - return self.forward_nerf(img, img_out, params) + return self.forward_nerf(img, img_out, params)[:, :, :h, :w] diff --git a/comfy/pixel_space_convert.py b/comfy/pixel_space_convert.py new file mode 100644 index 000000000..049bbcfb4 --- /dev/null +++ b/comfy/pixel_space_convert.py @@ -0,0 +1,16 @@ +import torch + + +# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 +# to LATENT B, C, H, W and values on the scale of -1..1. +class PixelspaceConversionVAE(torch.nn.Module): + def __init__(self): + super().__init__() + self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0)) + + def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + return pixels + + def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + return samples + diff --git a/comfy/sd.py b/comfy/sd.py index cb92802e9..2df340739 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.hunyuan_video.vae +import comfy.pixel_space_convert import yaml import math import os @@ -516,6 +517,15 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.disable_offload = True self.extra_1d_channel = 16 + elif "pixel_space_vae" in sd: + self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE() + self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.downscale_ratio = 1 + self.upscale_ratio = 1 + self.latent_channels = 3 + self.latent_dim = 2 + self.output_channels = 3 else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -785,65 +795,6 @@ class VAE: except: return None -# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 -# to LATENT B, C, H, W and values on the scale of -1..1. -class PixelspaceConversionVAE: - def __init__(self, size_increment: int=16): - self.intermediate_device = comfy.model_management.intermediate_device() - self.size_increment = size_increment - - def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor: - if self.size_increment == 1: - return pixels - dims = pixels.shape[1:-1] - for d in range(len(dims)): - d_adj = (dims[d] // self.size_increment) * self.size_increment - if d_adj == d: - continue - d_offset = (dims[d] % self.size_increment) // 2 - pixels = pixels.narrow(d + 1, d_offset, d_adj) - return pixels - - def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: - if pixels.ndim == 3: - pixels = pixels.unsqueeze(0) - elif pixels.ndim != 4: - raise ValueError("Unexpected input image shape") - # Ensure the image has spatial dimensions that are multiples of 16. - pixels = self.vae_encode_crop_pixels(pixels) - h, w, c = pixels.shape[1:] - if h < self.size_increment or w < self.size_increment: - raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).") - pixels= pixels[..., :3] - if c == 1: - pixels = pixels.expand(-1, -1, -1, 3) - elif c != 3: - raise ValueError("Unexpected number of channels in input image") - # Rescale to -1..1 and move the channel dimension to position 1. - latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True) - latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous() - latent -= 0.5 - latent *= 2 - return latent.clamp_(-1, 1) - - def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: - # Rescale to 0..1 and move the channel dimension to the end. - img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True) - img = img.clamp_(-1, 1).movedim(1, -1).contiguous() - img += 1.0 - img *= 0.5 - return img.clamp_(0, 1) - - encode_tiled = encode - decode_tiled = decode - - @classmethod - def spacial_compression_decode(cls) -> int: - # This just exists so the tiled VAE nodes don't crash. - return 1 - - spacial_compression_encode = spacial_compression_decode - temporal_compression_decode = spacial_compression_decode class StyleModel: def __init__(self, model, device="cpu"): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index be36b5dfe..557902d11 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1213,7 +1213,7 @@ class ChromaRadiance(Chroma): latent_format = comfy.latent_formats.ChromaRadiance # Pixel-space model, no spatial compression for model input. - memory_usage_factor = 0.0325 + memory_usage_factor = 0.038 def get_model(self, state_dict, prefix="", device=None): return model_base.ChromaRadiance(self, device=device) diff --git a/nodes.py b/nodes.py index 76b8cbac8..5a5fdcb8e 100644 --- a/nodes.py +++ b/nodes.py @@ -730,7 +730,7 @@ class VAELoader: vaes.append("taesd3") if f1_taesd_dec and f1_taesd_enc: vaes.append("taef1") - vaes.append("chroma_radiance") + vaes.append("pixel_space") return vaes @staticmethod @@ -773,8 +773,9 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - if vae_name == "chroma_radiance": - return (comfy.sd.PixelspaceConversionVAE(),) + if vae_name == "pixel_space": + sd = {} + sd["pixel_space_vae"] = torch.tensor(1.0) elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: sd = self.load_taesd(vae_name) else: From f228367c5e3906de194968fa9b6fbe7aa9987bfa Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 13 Sep 2025 18:34:21 -0700 Subject: [PATCH 23/30] Make ModuleNotFoundError ImportError instead (#9850) --- comfy/ldm/modules/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index bf2553c37..9dd1a43c1 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -22,7 +22,7 @@ SAGE_ATTENTION_IS_AVAILABLE = False try: from sageattention import sageattn SAGE_ATTENTION_IS_AVAILABLE = True -except ModuleNotFoundError as e: +except ImportError as e: if model_management.sage_attention_enabled(): if e.name == "sageattention": logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") @@ -34,7 +34,7 @@ FLASH_ATTENTION_IS_AVAILABLE = False try: from flash_attn import flash_attn_func FLASH_ATTENTION_IS_AVAILABLE = True -except ModuleNotFoundError: +except ImportError: if model_management.flash_attention_enabled(): logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) From 4f1f26ac6c11b803bbc83cb347178e2f9b5e421b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 14 Sep 2025 01:05:38 -0700 Subject: [PATCH 24/30] Add that hunyuan image is supported to readme. (#9857) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 8024870c2..3f6cfc2ed 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) + - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) From 47a9cde5d3045c42f20baafb9855fb96959124f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 15 Sep 2025 15:10:55 -0700 Subject: [PATCH 25/30] Support the omnigen2 umo lora. (#9886) --- comfy/lora.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index 4a44f1318..36d26293a 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -297,6 +297,12 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")] key_map["{}".format(key_lora)] = k + if isinstance(model, comfy.model_base.Omnigen2): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + if isinstance(model, comfy.model_base.QwenImage): for k in sdk: if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format From 1a85483da159f2800407ae5a8a45eb0d88ffce2d Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Mon, 15 Sep 2025 18:05:03 -0600 Subject: [PATCH 26/30] Fix depending on asserts to raise an exception in BatchedBrownianTree and Flash attn module (#9884) Correctly handle the case where w0 is passed by kwargs in BatchedBrownianTree --- comfy/k_diffusion/sampling.py | 35 +++++++++++++++++----------------- comfy/ldm/modules/attention.py | 3 ++- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 2d7e09838..0e2cda291 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -86,24 +86,24 @@ class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" def __init__(self, x, t0, t1, seed=None, **kwargs): - self.cpu_tree = True - if "cpu" in kwargs: - self.cpu_tree = kwargs.pop("cpu") + self.cpu_tree = kwargs.pop("cpu", True) t0, t1, self.sign = self.sort(t0, t1) - w0 = kwargs.get('w0', torch.zeros_like(x)) + w0 = kwargs.pop('w0', None) + if w0 is None: + w0 = torch.zeros_like(x) + self.batched = False if seed is None: - seed = torch.randint(0, 2 ** 63 - 1, []).item() - self.batched = True - try: - assert len(seed) == x.shape[0] + seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),) + elif isinstance(seed, (tuple, list)): + if len(seed) != x.shape[0]: + raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.") + self.batched = True w0 = w0[0] - except TypeError: - seed = [seed] - self.batched = False - if self.cpu_tree: - self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] else: - self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + seed = (seed,) + if self.cpu_tree: + t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu() + self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed) @staticmethod def sort(a, b): @@ -111,11 +111,10 @@ class BatchedBrownianTree: def __call__(self, t0, t1): t0, t1, sign = self.sort(t0, t1) + device, dtype = t0.device, t0.dtype if self.cpu_tree: - w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) - else: - w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) - + t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float() + w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign) return w if self.batched else w[0] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 9dd1a43c1..7437e0567 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -600,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape mask = mask.unsqueeze(1) try: - assert mask is None + if mask is not None: + raise RuntimeError("Mask must not be set for Flash attention") out = flash_attn_wrapper( q.transpose(1, 2), k.transpose(1, 2), From a39ac59c3e3fddc8b278899814f0bd5371abb11f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 15 Sep 2025 22:19:50 -0700 Subject: [PATCH 27/30] Add encoder part of whisper large v3 as an audio encoder model. (#9894) Not useful yet but some models use it. --- comfy/audio_encoders/audio_encoders.py | 58 +++++--- comfy/audio_encoders/whisper.py | 186 +++++++++++++++++++++++++ 2 files changed, 224 insertions(+), 20 deletions(-) create mode 100755 comfy/audio_encoders/whisper.py diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 6fb5b08e9..0550b2f9b 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -1,4 +1,5 @@ from .wav2vec2 import Wav2Vec2Model +from .whisper import WhisperLargeV3 import comfy.model_management import comfy.ops import comfy.utils @@ -11,13 +12,18 @@ class AudioEncoderModel(): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + model_type = config.pop("model_type") model_config = dict(config) model_config.update({ "dtype": self.dtype, "device": offload_device, "operations": comfy.ops.manual_cast }) - self.model = Wav2Vec2Model(**model_config) + + if model_type == "wav2vec2": + self.model = Wav2Vec2Model(**model_config) + elif model_type == "whisper3": + self.model = WhisperLargeV3(**model_config) self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 @@ -40,33 +46,45 @@ class AudioEncoderModel(): def load_audio_encoder_from_sd(sd, prefix=""): sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) - embed_dim = sd["encoder.layer_norm.bias"].shape[0] - if embed_dim == 1024:# large - config = { - "embed_dim": 1024, - "num_heads": 16, - "num_layers": 24, - "conv_norm": True, - "conv_bias": True, - "do_normalize": True, - "do_stable_layer_norm": True + if "encoder.layer_norm.bias" in sd: #wav2vec2 + embed_dim = sd["encoder.layer_norm.bias"].shape[0] + if embed_dim == 1024:# large + config = { + "model_type": "wav2vec2", + "embed_dim": 1024, + "num_heads": 16, + "num_layers": 24, + "conv_norm": True, + "conv_bias": True, + "do_normalize": True, + "do_stable_layer_norm": True + } + elif embed_dim == 768: # base + config = { + "model_type": "wav2vec2", + "embed_dim": 768, + "num_heads": 12, + "num_layers": 12, + "conv_norm": False, + "conv_bias": False, + "do_normalize": False, # chinese-wav2vec2-base has this False + "do_stable_layer_norm": False } - elif embed_dim == 768: # base + else: + raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) + elif "model.encoder.embed_positions.weight" in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""}) config = { - "embed_dim": 768, - "num_heads": 12, - "num_layers": 12, - "conv_norm": False, - "conv_bias": False, - "do_normalize": False, # chinese-wav2vec2-base has this False - "do_stable_layer_norm": False + "model_type": "whisper3", } else: - raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) + raise RuntimeError("ERROR: audio encoder not supported.") audio_encoder = AudioEncoderModel(config) m, u = audio_encoder.load_sd(sd) if len(m) > 0: logging.warning("missing audio encoder: {}".format(m)) + if len(u) > 0: + logging.warning("unexpected audio encoder: {}".format(u)) return audio_encoder diff --git a/comfy/audio_encoders/whisper.py b/comfy/audio_encoders/whisper.py new file mode 100755 index 000000000..93d3782f1 --- /dev/null +++ b/comfy/audio_encoders/whisper.py @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from typing import Optional +from comfy.ldm.modules.attention import optimized_attention_masked +import comfy.ops + +class WhisperFeatureExtractor(nn.Module): + def __init__(self, n_mels=128, device=None): + super().__init__() + self.sample_rate = 16000 + self.n_fft = 400 + self.hop_length = 160 + self.n_mels = n_mels + self.chunk_length = 30 + self.n_samples = 480000 + + self.mel_spectrogram = torchaudio.transforms.MelSpectrogram( + sample_rate=self.sample_rate, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=0, + f_max=8000, + norm="slaney", + mel_scale="slaney", + ).to(device) + + def __call__(self, audio): + audio = torch.mean(audio, dim=1) + batch_size = audio.shape[0] + processed_audio = [] + + for i in range(batch_size): + aud = audio[i] + if aud.shape[0] > self.n_samples: + aud = aud[:self.n_samples] + elif aud.shape[0] < self.n_samples: + aud = F.pad(aud, (0, self.n_samples - aud.shape[0])) + processed_audio.append(aud) + + audio = torch.stack(processed_audio) + + mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device) + + log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0) + log_mel_spec = (log_mel_spec + 4.0) / 4.0 + + return log_mel_spec + + +class MultiHeadAttention(nn.Module): + def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None): + super().__init__() + assert d_model % n_heads == 0 + + self.d_model = d_model + self.n_heads = n_heads + self.d_k = d_model // n_heads + + self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device) + self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, seq_len, _ = query.shape + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class EncoderLayer(nn.Module): + def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None): + super().__init__() + + self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations) + self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device) + + self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device) + self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device) + self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device) + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn(x, x, x, attention_mask) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + x = residual + x + + return x + + +class AudioEncoder(nn.Module): + def __init__( + self, + n_mels: int = 128, + n_ctx: int = 1500, + n_state: int = 1280, + n_head: int = 20, + n_layer: int = 32, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device) + self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device) + + self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device) + + self.layers = nn.ModuleList([ + EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations) + for _ in range(n_layer) + ]) + + self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + + x = x.transpose(1, 2) + + x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x) + + all_x = () + for layer in self.layers: + all_x += (x,) + x = layer(x) + + x = self.layer_norm(x) + all_x += (x,) + return x, all_x + + +class WhisperLargeV3(nn.Module): + def __init__( + self, + n_mels: int = 128, + n_audio_ctx: int = 1500, + n_audio_state: int = 1280, + n_audio_head: int = 20, + n_audio_layer: int = 32, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device) + + self.encoder = AudioEncoder( + n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, + dtype=dtype, device=device, operations=operations + ) + + def forward(self, audio): + mel = self.feature_extractor(audio) + x, all_x = self.encoder(mel) + return x, all_x From e42682b24ef033a93001ba27cc5c5aa461a61d8d Mon Sep 17 00:00:00 2001 From: rattus128 <46076784+rattus128@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:21:14 +1000 Subject: [PATCH 28/30] Reduce Peak WAN inference VRAM usage (#9898) * flux: Do the xq and xk ropes one at a time This was doing independendent interleaved tensor math on the q and k tensors, leading to the holding of more than the minimum intermediates in VRAM. On a bad day, it would VRAM OOM on xk intermediates. Do everything q and then everything k, so torch can garbage collect all of qs intermediates before k allocates its intermediates. This reduces peak VRAM usage for some WAN2.2 inferences (at least). * wan: Optimize qkv intermediates on attention As commented. The former logic computed independent pieces of QKV in parallel which help more inference intermediates in VRAM spiking VRAM usage. Fully roping Q and garbage collecting the intermediates before touching K reduces the peak inference VRAM usage. --- comfy/ldm/flux/math.py | 11 +++++------ comfy/ldm/wan/model.py | 22 +++++++++++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 4d743cda2..fb7cd7586 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -35,11 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) +def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1] + return x_out.reshape(*x.shape).type_as(x) def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2) - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 63472ada2..67dcf8f1e 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -8,7 +8,7 @@ from einops import rearrange from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND -from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.math import apply_rope1 import comfy.ldm.common_dit import comfy.model_management import comfy.patcher_extension @@ -60,20 +60,24 @@ class WanSelfAttention(nn.Module): """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim - # query, key, value function - def qkv_fn(x): + def qkv_fn_q(x): q = self.norm_q(self.q(x)).view(b, s, n, d) - k = self.norm_k(self.k(x)).view(b, s, n, d) - v = self.v(x).view(b, s, n * d) - return q, k, v + return apply_rope1(q, freqs) - q, k, v = qkv_fn(x) - q, k = apply_rope(q, k, freqs) + def qkv_fn_k(x): + k = self.norm_k(self.k(x)).view(b, s, n, d) + return apply_rope1(k, freqs) + + #These two are VRAM hogs, so we want to do all of q computation and + #have pytorch garbage collect the intermediates on the sub function + #return before we touch k + q = qkv_fn_q(x) + k = qkv_fn_k(x) x = optimized_attention( q.view(b, s, n * d), k.view(b, s, n * d), - v, + self.v(x).view(b, s, n * d), heads=self.num_heads, transformer_options=transformer_options, ) From 9288c78fc5fae74d3fa7787736dea442e996303f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 16 Sep 2025 21:12:48 -0700 Subject: [PATCH 29/30] Support the HuMo model. (#9903) --- comfy/audio_encoders/audio_encoders.py | 1 + comfy/ldm/wan/model.py | 259 ++++++++++++++++++++++++- comfy/model_base.py | 17 ++ comfy/model_detection.py | 2 + comfy/supported_models.py | 12 +- comfy_extras/nodes_wan.py | 98 ++++++++++ 6 files changed, 383 insertions(+), 6 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 0550b2f9b..46ef21c95 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -41,6 +41,7 @@ class AudioEncoderModel(): outputs = {} outputs["encoded_audio"] = out outputs["encoded_audio_all_layers"] = all_layers + outputs["audio_samples"] = audio.shape[2] return outputs diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 67dcf8f1e..b3b7da5d5 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -34,7 +34,9 @@ class WanSelfAttention(nn.Module): num_heads, window_size=(-1, -1), qk_norm=True, - eps=1e-6, operation_settings={}): + eps=1e-6, + kv_dim=None, + operation_settings={}): assert dim % num_heads == 0 super().__init__() self.dim = dim @@ -43,11 +45,13 @@ class WanSelfAttention(nn.Module): self.window_size = window_size self.qk_norm = qk_norm self.eps = eps + if kv_dim is None: + kv_dim = dim # layers self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() @@ -402,6 +406,7 @@ class WanModel(torch.nn.Module): eps=1e-6, flf_pos_embed_token_number=None, in_dim_ref_conv=None, + wan_attn_block_class=WanAttentionBlock, image_model=None, device=None, dtype=None, @@ -479,8 +484,8 @@ class WanModel(torch.nn.Module): # blocks cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' self.blocks = nn.ModuleList([ - WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) + wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) for _ in range(num_layers) ]) @@ -1325,3 +1330,247 @@ class WanModel_S2V(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + + +class WanT2VCrossAttentionGather(WanSelfAttention): + + def forward(self, x, context, transformer_options={}, **kwargs): + r""" + Args: + x(Tensor): Shape [B, L1, C] - video tokens + context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(context)) + v = self.v(context) + + # Handle audio temporal structure (16 tokens per frame) + k = k.reshape(-1, 16, n, d).transpose(1, 2) + v = v.reshape(-1, 16, n, d).transpose(1, 2) + + # Handle video spatial structure + q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2) + + x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options) + + x = x.transpose(1, 2).view(b, -1, n, d).flatten(2) + x = self.o(x) + return x + + +class AudioCrossAttentionWrapper(nn.Module): + def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}): + super().__init__() + + self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm, kv_dim, eps, operation_settings=operation_settings) + self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x, audio, transformer_options={}): + x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options) + return x + + +class WanAttentionBlockAudio(WanAttentionBlock): + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, operation_settings={}): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings) + self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings) + + def forward( + self, + x, + e, + freqs, + context, + context_img_len=257, + audio=None, + transformer_options={}, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + # assert e.dtype == torch.float32 + + if e.ndim < 4: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) + # assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn( + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + freqs, transformer_options=transformer_options) + + x = torch.addcmul(x, y, repeat_e(e[2], x)) + + # cross-attention & ffn + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + if audio is not None: + x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) + return x + +class DummyAdapterLayer(nn.Module): + def __init__(self, layer): + super().__init__() + self.layer = layer + + def forward(self, *args, **kwargs): + return self.layer(*args, **kwargs) + + +class AudioProjModel(nn.Module): + def __init__( + self, + seq_len=5, + blocks=13, # add a new parameter blocks + channels=768, # add a new parameter channels + intermediate_dim=512, + output_dim=1536, + context_tokens=16, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels. + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.output_dim = output_dim + + # define multiple linear layers + self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device)) + self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device)) + self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device)) + + self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device)) + + def forward(self, audio_embeds): + video_length = audio_embeds.shape[1] + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds)) + audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds)) + + context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim) + + context_tokens = self.audio_proj_glob_norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + + +class HumoWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='humo', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + image_model=None, + audio_token_num=16, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations) + + self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations) + + def forward_orig( + self, + x, + t, + context, + freqs=None, + audio_embed=None, + reference_latent=None, + transformer_options={}, + **kwargs, + ): + bs, _, time, height, width = x.shape + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + if reference_latent is not None: + ref = self.patch_embedding(reference_latent.float()).to(x.dtype) + ref = ref.flatten(2).transpose(1, 2) + freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype) + x = torch.cat([x, ref], dim=1) + freqs = torch.cat([freqs, freqs_ref], dim=1) + del ref, freqs_ref + + # context + context = self.text_embedding(context) + context_img_len = None + + if audio_embed is not None: + audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2) + else: + audio = None + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 252dfcf69..cf99035da 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1213,6 +1213,23 @@ class WAN21_Camera(WAN21): out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) return out +class WAN21_HuMo(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + audio_embed = kwargs.get("audio_embed", None) + if audio_embed is not None: + out['audio_embed'] = comfy.conds.CONDRegular(audio_embed) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + return out + class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 03d44f65e..72621bed6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -402,6 +402,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "camera_2.2" elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "s2v" + elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "humo" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 557902d11..213b5b92c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1073,6 +1073,16 @@ class WAN21_Vace(WAN21_T2V): out = model_base.WAN21_Vace(self, image_to_video=False, device=device) return out +class WAN21_HuMo(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "humo", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_HuMo(self, image_to_video=False, device=device) + return out + class WAN22_S2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1351,6 +1361,6 @@ class HunyuanImage21Refiner(HunyuanVideo): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 4f73369f5..0b8b55813 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1015,6 +1015,103 @@ class WanSoundImageToVideoExtend(io.ComfyNode): return io.NodeOutput(positive, negative, out_latent) +def get_audio_emb_window(audio_emb, frame_num, frame0_idx, audio_shift=2): + zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) + zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device + iter_ = 1 + (frame_num - 1) // 4 + audio_emb_wind = [] + for lt_i in range(iter_): + if lt_i == 0: + st = frame0_idx + lt_i - 2 + ed = frame0_idx + lt_i + 3 + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) + else: + st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift + ed = frame0_idx + 1 + 4 * lt_i + audio_shift + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + audio_emb_wind.append(wind_feat) + audio_emb_wind = torch.stack(audio_emb_wind, dim=0) + + return audio_emb_wind, ed - audio_shift + + +class WanHuMoImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanHuMoImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.AudioEncoderOutput.Input("audio_encoder_output", optional=True), + io.Image.Input("ref_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None) -> io.NodeOutput: + latent_t = ((length - 1) // 4) + 1 + latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) + else: + zero_latent = torch.zeros([batch_size, 16, 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [zero_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [zero_latent]}, append=True) + + if audio_encoder_output is not None: + audio_emb = torch.stack(audio_encoder_output["encoded_audio_all_layers"], dim=2) + audio_len = audio_encoder_output["audio_samples"] // 640 + audio_emb = audio_emb[:, :audio_len * 2] + + feat0 = linear_interpolation(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25) + feat1 = linear_interpolation(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25) + feat2 = linear_interpolation(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25) + feat3 = linear_interpolation(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25) + feat4 = linear_interpolation(audio_emb[:, :, 32], 50, 25) + audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280] + audio_emb, _ = get_audio_emb_window(audio_emb, length, frame0_idx=0) + + # pad for ref latent + zero_audio_pad = torch.zeros(ref_latent.shape[2], *audio_emb.shape[1:], device=audio_emb.device, dtype=audio_emb.dtype) + audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0) + + audio_emb = audio_emb.unsqueeze(0) + audio_emb_neg = torch.zeros_like(audio_emb) + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_emb}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_emb_neg}) + else: + zero_audio = torch.zeros([batch_size, latent_t + 1, 8, 5, 1280], device=comfy.model_management.intermediate_device()) + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": zero_audio}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": zero_audio}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod def define_schema(cls): @@ -1075,6 +1172,7 @@ class WanExtension(ComfyExtension): WanPhantomSubjectToVideo, WanSoundImageToVideo, WanSoundImageToVideoExtend, + WanHuMoImageToVideo, Wan22ImageToVideoLatent, ] From dd611a7700956f45f393dee32fb8505de176dc66 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:39:24 -0700 Subject: [PATCH 30/30] Support the HuMo 17B model. (#9912) --- comfy/ldm/wan/model.py | 2 +- comfy/model_base.py | 29 ++++++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index b3b7da5d5..9cf3c171d 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1364,7 +1364,7 @@ class AudioCrossAttentionWrapper(nn.Module): def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}): super().__init__() - self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm, kv_dim, eps, operation_settings=operation_settings) + self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm=qk_norm, kv_dim=kv_dim, eps=eps, operation_settings=operation_settings) self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) def forward(self, x, audio, transformer_options={}): diff --git a/comfy/model_base.py b/comfy/model_base.py index cf99035da..70b67b7c1 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1220,14 +1220,37 @@ class WAN21_HuMo(WAN21): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) + noise = kwargs.get("noise", None) audio_embed = kwargs.get("audio_embed", None) if audio_embed is not None: out['audio_embed'] = comfy.conds.CONDRegular(audio_embed) - reference_latents = kwargs.get("reference_latents", None) - if reference_latents is not None: - out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + if "c_concat" not in out: # 1.7B model + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + else: + noise_shape = list(noise.shape) + noise_shape[1] += 4 + concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) + zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1) + zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1) + zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1) + concat_latent[:, 4:] = zero_vae_values + concat_latent[:, 4:, :1] = zero_vae_values_first + concat_latent[:, 4:, 1:2] = zero_vae_values_second + out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent) + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + ref_latent = self.process_latent_in(reference_latents[-1]) + ref_latent_shape = list(ref_latent.shape) + ref_latent_shape[1] += 4 + ref_latent_shape[1] + ref_latent_full = torch.zeros(ref_latent_shape, device=ref_latent.device, dtype=ref_latent.dtype) + ref_latent_full[:, 20:] = ref_latent + ref_latent_full[:, 16:20] = 1.0 + out['reference_latent'] = comfy.conds.CONDRegular(ref_latent_full) + return out class WAN22_S2V(WAN21):