diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 90a6632f0..c433c8834 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -991,14 +991,12 @@ class SimpleMLPAdaLN(nn.Module): out_channels: int, z_channels: int, num_res_blocks: int, - patch_size: int, max_freqs: int = 8, ): super().__init__() - self.patch_size = patch_size - # Project backbone hidden state → per-position conditioning - self.cond_embed = nn.Linear(z_channels, patch_size ** 2 * model_channels) + # Project backbone hidden state → per-patch conditioning + self.cond_embed = nn.Linear(z_channels, model_channels) nn.init.xavier_uniform_(self.cond_embed.weight) nn.init.constant_(self.cond_embed.bias, 0) @@ -1018,12 +1016,15 @@ class SimpleMLPAdaLN(nn.Module): self.final_layer = DCTFinalLayer(model_channels, out_channels) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: - # x: [B*N, P^2, C], c: [B*N, dim] - x = self.input_embedder(x) # [B*N, P^2, model_channels] - y = self.cond_embed(c).reshape(c.shape[0], self.patch_size ** 2, -1) # [B*N, P^2, model_channels] + # x: [B*N, 1, P^2*C], c: [B*N, dim] + original_dtype = x.dtype + weight_dtype = self.cond_embed.weight.dtype + x = self.input_embedder(x) # [B*N, 1, model_channels] + y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels] + x = x.to(weight_dtype) for block in self.res_blocks: x = block(x, y) - return self.final_layer(x) # [B*N, P^2, C] + return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C] ############################################################################# @@ -1052,6 +1053,7 @@ class NextDiTPixelSpace(NextDiT): decoder_hidden_size: int = 3840, decoder_num_res_blocks: int = 4, decoder_max_freqs: int = 8, + decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels) use_x0: bool = False, # all NextDiT args forwarded unchanged **kwargs, @@ -1065,13 +1067,15 @@ class NextDiTPixelSpace(NextDiT): in_channels = kwargs.get("in_channels", 4) dim = kwargs.get("dim", 4096) + # decoder_in_channels is the full flattened patch: patch_size^2 * in_channels + dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels + self.dec_net = SimpleMLPAdaLN( - in_channels=in_channels, + in_channels=dec_in_ch, model_channels=decoder_hidden_size, - out_channels=in_channels, + out_channels=dec_in_ch, z_channels=dim, num_res_blocks=decoder_num_res_blocks, - patch_size=patch_size, max_freqs=decoder_max_freqs, ) @@ -1079,99 +1083,99 @@ class NextDiTPixelSpace(NextDiT): self.register_buffer("__x0__", torch.tensor([])) # ------------------------------------------------------------------ - # Override patchify_and_embed to also return the raw pixel patches + # Forward — mirrors NextDiT._forward exactly, replacing final_layer + # with the pixel-space dec_net decoder. # ------------------------------------------------------------------ - def patchify_and_embed(self, x, cap_feats, cap_mask, t, num_tokens, transformer_options={}): - # Run the parent implementation unchanged; we capture pixel values - # separately in _forward before calling this. - return super().patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options) + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) - # ------------------------------------------------------------------ - # Forward - # ------------------------------------------------------------------ - def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask bs, c, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) - t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) adaln_input = t - cap_feats = self.cap_embedder(cap_feats) - if self.clip_text_pooled_proj is not None: pooled = kwargs.get("clip_text_pooled", None) if pooled is not None: pooled = self.clip_text_pooled_proj(pooled) else: - pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) - # ---- capture raw pixel patches before backbone embedding ---- + # ---- capture raw pixel patches before patchify_and_embed embeds them ---- pH = pW = self.patch_size B, C, H, W = x.shape - # [B, N, P*P*C] (same layout as what x_embedder receives) pixel_patches = ( x.view(B, C, H // pH, pH, W // pW, pW) .permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C] .flatten(3) # [B, Ht, Wt, pH*pW*C] .flatten(1, 2) # [B, N, pH*pW*C] ) - # reshape to [B*N, P^2, C] for the decoder N = pixel_patches.shape[1] - pixel_values = pixel_patches.reshape(B * N, pH * pW, C) + # decoder sees one token per patch: [B*N, 1, P^2*C] + pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C) patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) - img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed( - x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed( + x, cap_feats, cap_mask, adaln_input, num_tokens, + ref_latents=ref_latents, ref_contexts=ref_contexts, + siglip_feats=siglip_feats, transformer_options=transformer_options ) freqs_cis = freqs_cis.to(img.device) + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" + img_input = img for i, layer in enumerate(self.layers): - img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + transformer_options["block_index"] = i + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: - out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) if "img" in out: img[:, cap_size[0]:] = out["img"] if "txt" in out: img[:, :cap_size[0]] = out["txt"] - # ---- pixel-space decoder ---- - # img: [B, txt_len+N, dim] → extract image tokens → [B, N, dim] - img_hidden = img[:, cap_size[0]:, :] # [B, N, dim] - # per-patch conditioning: [B*N, dim] - decoder_cond = img_hidden.reshape(B * N, self.dim) + # ---- pixel-space decoder (replaces final_layer + unpatchify) ---- + # img may have padding tokens beyond N; only the first N are real image patches + img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim] + decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim] - # decode: [B*N, P^2, C] - output = self.dec_net(pixel_values, decoder_cond) + output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C] + output = output.reshape(B, N, -1) # [B, N, P^2*C] - # reshape back: [B*N, P^2, C] → [B, N, P^2*C] - output = output.reshape(B, N, -1) - - # unpatchify expects [B, txt_len+N, P^2*C] with cap tokens prepended - # re-prepend a zero placeholder for the cap positions so unpatchify works + # prepend zero cap placeholder so unpatchify indexing works unchanged cap_placeholder = torch.zeros( B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype ) - img_out = torch.cat([cap_placeholder, output], dim=1) - - img_out = self.unpatchify(img_out, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] + img_out = self.unpatchify( + torch.cat([cap_placeholder, output], dim=1), + img_size, cap_size, return_tensor=x_is_tensor + )[:, :, :h, :w] return -img_out def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): - # _forward returns -x0 (negated decoder output, matching the latent-space convention). - # Reference x0→v conversion: v = (noisy - out) / t, where out = -x0 - # → v = (noisy - (-x0)) / t = (noisy + x0) / t - # Since neg_x0 = -x0: v = (x - neg_x0) / t + # _forward returns -x0 (negated decoder output, same convention as NextDiT). + # + # ComfyUI uses CONST sampling: calculate_denoised = x - model_output * sigma + # We need calculate_denoised to return x0, so model_output must equal (x - x0) / sigma. + # + # _forward returns neg_x0 = -x0, so: + # model_output = (x + neg_x0) / sigma → (x + (-x0)) / sigma = (x - x0) / sigma ✓ + # Then: x - model_output * sigma = x - (x - x0) = x0 ✓ neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) - return (x - neg_x0) / timesteps.view(-1, 1, 1, 1) + return (x + neg_x0) / timesteps.view(-1, 1, 1, 1) diff --git a/comfy/model_base.py b/comfy/model_base.py index 1842b3b08..a06d8afd9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1217,7 +1217,8 @@ class Lumina2(BaseModel): class ZImagePixelSpace(Lumina2): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) + BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 2f3e185f7..92e6f1bde 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -467,28 +467,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix) if dec_cond_key in state_dict_keys: # pixel-space variant dit_config["image_model"] = "zimage_pixel" - w = state_dict[dec_cond_key] # [patch_size^2 * decoder_hidden_size, dim] - dit_config["decoder_hidden_size"] = w.shape[0] // (dit_config["patch_size"] ** 2) + # patch_size and in_channels are derived from x_embedder: + # x_embedder: Linear(patch_size * patch_size * in_channels, dim) + # The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim. + x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] + dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0] + # patch_size: infer from decoder final layer output matching x_embedder input + # in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2) + embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)] + dec_in_ch = dec_out # decoder in == decoder out (same pixel space) + dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3) + dit_config["in_channels"] = 3 + dit_config["decoder_in_channels"] = dec_in_ch + dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0] dit_config["decoder_num_res_blocks"] = count_blocks( state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.' ) - dit_config["decoder_max_freqs"] = 8 # fixed in NerfEmbedder - dit_config["in_channels"] = w.shape[1] // dit_config["dim"] if False else \ - state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] // (dit_config["patch_size"] ** 2) - if '{}__x0__'.format(key_prefix) in state_dict_keys: - dit_config["use_x0"] = True - - dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix) - if dec_cond_key in state_dict_keys: # pixel-space variant - dit_config["image_model"] = "zimage_pixel" - w = state_dict[dec_cond_key] # [patch_size^2 * decoder_hidden_size, dim] - dit_config["decoder_hidden_size"] = w.shape[0] // (dit_config["patch_size"] ** 2) - dit_config["decoder_num_res_blocks"] = count_blocks( - state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.' - ) - dit_config["decoder_max_freqs"] = 8 # fixed in NerfEmbedder - dit_config["in_channels"] = w.shape[1] // dit_config["dim"] if False else \ - state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] // (dit_config["patch_size"] ** 2) + dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5) if '{}__x0__'.format(key_prefix) in state_dict_keys: dit_config["use_x0"] = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 5259e143f..c0d3f387f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1734,6 +1734,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid]