diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 902af30ed..00c597535 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis): class QwenTimestepProjEmbeddings(nn.Module): - def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): + def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.timestep_embedder = TimestepEmbedding( @@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module): operations=operations ) - def forward(self, timestep, hidden_states): + self.use_additional_t_cond = use_additional_t_cond + if self.use_additional_t_cond: + self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype) + + def forward(self, timestep, hidden_states, addition_t_cond=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + + if self.use_additional_t_cond: + if addition_t_cond is None: + addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long) + timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype) + return timesteps_emb @@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module): num_attention_heads: int = 24, joint_attention_dim: int = 3584, pooled_projection_dim: int = 768, - guidance_embeds: bool = False, axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), default_ref_method="index", image_model=None, final_layer=True, + use_additional_t_cond=False, dtype=None, device=None, operations=None, @@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module): self.time_text_embed = QwenTimestepProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim, + use_additional_t_cond=use_additional_t_cond, dtype=dtype, device=device, operations=operations @@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module): patch_size = self.patch_size hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) orig_shape = hidden_states.shape - hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) - hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) - hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6) + hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + t_len = t h_len = ((h + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size) h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device) - img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) - return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device) - def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): + if t_len > 1: + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1) + else: + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index + + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2) + return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape + + def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) + ).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs) def _forward( self, @@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module): timesteps, context, attention_mask=None, - guidance: torch.Tensor = None, ref_latents=None, + additional_t_cond=None, transformer_options={}, control=None, **kwargs @@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module): index = 0 ref_method = kwargs.get("ref_latents_method", self.default_ref_method) index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") + negative_ref_method = ref_method == "negative_index" timestep_zero = ref_method == "index_timestep_zero" for ref in ref_latents: if index_ref_method: index += 1 h_offset = 0 w_offset = 0 + elif negative_ref_method: + index -= 1 + h_offset = 0 + w_offset = 0 else: index = 1 h_offset = 0 @@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - if guidance is not None: - guidance = guidance * 1000 - - temb = ( - self.time_text_embed(timestep, hidden_states) - if guidance is None - else self.time_text_embed(timestep, guidance, hidden_states) - ) + temb = self.time_text_embed(timestep, hidden_states, additional_t_cond) patches_replace = transformer_options.get("patches_replace", {}) patches = transformer_options.get("patches", {}) @@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) - hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) + hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6) return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index ccbb25822..08315f1a8 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -227,6 +227,7 @@ class Encoder3d(nn.Module): def __init__(self, dim=128, z_dim=4, + input_channels=3, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], @@ -245,7 +246,7 @@ class Encoder3d(nn.Module): scale = 1.0 # init block - self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1) # downsample blocks downsamples = [] @@ -331,6 +332,7 @@ class Decoder3d(nn.Module): def __init__(self, dim=128, z_dim=4, + output_channels=3, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], @@ -378,7 +380,7 @@ class Decoder3d(nn.Module): # output blocks self.head = nn.Sequential( RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, 3, 3, padding=1)) + CausalConv3d(out_dim, output_channels, 3, padding=1)) def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 @@ -449,6 +451,7 @@ class WanVAE(nn.Module): num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], + image_channels=3, dropout=0.0): super().__init__() self.dim = dim @@ -460,11 +463,11 @@ class WanVAE(nn.Module): self.temperal_upsample = temperal_downsample[::-1] # modules - self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def encode(self, x): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7148c77fd..84fd409fd 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 dit_config["default_ref_method"] = "index_timestep_zero" + if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered + dit_config["use_additional_t_cond"] = True + dit_config["default_ref_method"] = "negative_index" return dit_config if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 diff --git a/comfy/sd.py b/comfy/sd.py index 1cad98aef..c2a9728f3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -321,6 +321,7 @@ class VAE: self.latent_channels = 4 self.latent_dim = 2 self.output_channels = 3 + self.pad_channel_value = None self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] @@ -435,6 +436,7 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) self.latent_channels = 64 self.output_channels = 2 + self.pad_channel_value = "replicate" self.upscale_ratio = 2048 self.downscale_ratio = 2048 self.latent_dim = 1 @@ -546,7 +548,9 @@ class VAE: self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = 16 - ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + self.output_channels = sd["encoder.conv1.weight"].shape[1] + self.pad_channel_value = 1.0 + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) @@ -582,6 +586,7 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) self.latent_channels = 8 self.output_channels = 2 + self.pad_channel_value = "replicate" self.upscale_ratio = 4096 self.downscale_ratio = 4096 self.latent_dim = 2 @@ -690,17 +695,28 @@ class VAE: raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") def vae_encode_crop_pixels(self, pixels): - if not self.crop_input: - return pixels + if self.crop_input: + downscale_ratio = self.spacial_compression_encode() - downscale_ratio = self.spacial_compression_encode() + dims = pixels.shape[1:-1] + for d in range(len(dims)): + x = (dims[d] // downscale_ratio) * downscale_ratio + x_offset = (dims[d] % downscale_ratio) // 2 + if x != dims[d]: + pixels = pixels.narrow(d + 1, x_offset, x) - dims = pixels.shape[1:-1] - for d in range(len(dims)): - x = (dims[d] // downscale_ratio) * downscale_ratio - x_offset = (dims[d] % downscale_ratio) // 2 - if x != dims[d]: - pixels = pixels.narrow(d + 1, x_offset, x) + if pixels.shape[-1] > self.output_channels: + pixels = pixels[..., :self.output_channels] + elif pixels.shape[-1] < self.output_channels: + if self.pad_channel_value is not None: + if isinstance(self.pad_channel_value, str): + mode = self.pad_channel_value + value = None + else: + mode = "constant" + value = self.pad_channel_value + + pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value) return pixels def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index e439b18ef..2815c5ffc 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -5,6 +5,7 @@ import nodes from typing_extensions import override from comfy_api.latest import ComfyExtension, io import logging +import math def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: @@ -207,6 +208,47 @@ class LatentCut(io.ComfyNode): samples_out["samples"] = torch.narrow(s1, dim, index, amount) return io.NodeOutput(samples_out) +class LatentCutToBatch(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentCutToBatch", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Combo.Input("dim", options=["t", "x", "y"]), + io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples, dim, slice_size) -> io.NodeOutput: + samples_out = samples.copy() + + s1 = samples["samples"] + + if "x" in dim: + dim = s1.ndim - 1 + elif "y" in dim: + dim = s1.ndim - 2 + elif "t" in dim: + dim = s1.ndim - 3 + + if dim < 2: + return io.NodeOutput(samples) + + s = s1.movedim(dim, 1) + if s.shape[1] < slice_size: + slice_size = s.shape[1] + elif s.shape[1] % slice_size != 0: + s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size] + new_shape = [-1, slice_size] + list(s.shape[2:]) + samples_out["samples"] = s.reshape(new_shape).movedim(1, dim) + return io.NodeOutput(samples_out) + class LatentBatch(io.ComfyNode): @classmethod def define_schema(cls): @@ -435,6 +477,7 @@ class LatentExtension(ComfyExtension): LatentInterpolate, LatentConcat, LatentCut, + LatentCutToBatch, LatentBatch, LatentBatchSeedBehavior, LatentApplyOperation, diff --git a/nodes.py b/nodes.py index 3fa543294..b13ceb578 100644 --- a/nodes.py +++ b/nodes.py @@ -343,7 +343,7 @@ class VAEEncode: CATEGORY = "latent" def encode(self, vae, pixels): - t = vae.encode(pixels[:,:,:,:3]) + t = vae.encode(pixels) return ({"samples":t}, ) class VAEEncodeTiled: @@ -361,7 +361,7 @@ class VAEEncodeTiled: CATEGORY = "_for_testing" def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): - t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) + t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) return ({"samples": t}, ) class VAEEncodeForInpaint: diff --git a/requirements.txt b/requirements.txt index 9b9e61683..54696395f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.34.9 -comfyui-workflow-templates==0.7.59 +comfyui-workflow-templates==0.7.60 comfyui-embedded-docs==0.3.1 torch torchsde