model loaded and able to run however vector direction still wrong tho

This commit is contained in:
lodestone-rock 2026-02-28 21:06:14 +07:00
parent 54099360bd
commit 033a534f7c
4 changed files with 73 additions and 73 deletions

View File

@ -991,14 +991,12 @@ class SimpleMLPAdaLN(nn.Module):
out_channels: int,
z_channels: int,
num_res_blocks: int,
patch_size: int,
max_freqs: int = 8,
):
super().__init__()
self.patch_size = patch_size
# Project backbone hidden state → per-position conditioning
self.cond_embed = nn.Linear(z_channels, patch_size ** 2 * model_channels)
# Project backbone hidden state → per-patch conditioning
self.cond_embed = nn.Linear(z_channels, model_channels)
nn.init.xavier_uniform_(self.cond_embed.weight)
nn.init.constant_(self.cond_embed.bias, 0)
@ -1018,12 +1016,15 @@ class SimpleMLPAdaLN(nn.Module):
self.final_layer = DCTFinalLayer(model_channels, out_channels)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# x: [B*N, P^2, C], c: [B*N, dim]
x = self.input_embedder(x) # [B*N, P^2, model_channels]
y = self.cond_embed(c).reshape(c.shape[0], self.patch_size ** 2, -1) # [B*N, P^2, model_channels]
# x: [B*N, 1, P^2*C], c: [B*N, dim]
original_dtype = x.dtype
weight_dtype = self.cond_embed.weight.dtype
x = self.input_embedder(x) # [B*N, 1, model_channels]
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
x = x.to(weight_dtype)
for block in self.res_blocks:
x = block(x, y)
return self.final_layer(x) # [B*N, P^2, C]
return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C]
#############################################################################
@ -1052,6 +1053,7 @@ class NextDiTPixelSpace(NextDiT):
decoder_hidden_size: int = 3840,
decoder_num_res_blocks: int = 4,
decoder_max_freqs: int = 8,
decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels)
use_x0: bool = False,
# all NextDiT args forwarded unchanged
**kwargs,
@ -1065,13 +1067,15 @@ class NextDiTPixelSpace(NextDiT):
in_channels = kwargs.get("in_channels", 4)
dim = kwargs.get("dim", 4096)
# decoder_in_channels is the full flattened patch: patch_size^2 * in_channels
dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels
self.dec_net = SimpleMLPAdaLN(
in_channels=in_channels,
in_channels=dec_in_ch,
model_channels=decoder_hidden_size,
out_channels=in_channels,
out_channels=dec_in_ch,
z_channels=dim,
num_res_blocks=decoder_num_res_blocks,
patch_size=patch_size,
max_freqs=decoder_max_freqs,
)
@ -1079,99 +1083,99 @@ class NextDiTPixelSpace(NextDiT):
self.register_buffer("__x0__", torch.tensor([]))
# ------------------------------------------------------------------
# Override patchify_and_embed to also return the raw pixel patches
# Forward — mirrors NextDiT._forward exactly, replacing final_layer
# with the pixel-space dec_net decoder.
# ------------------------------------------------------------------
def patchify_and_embed(self, x, cap_feats, cap_mask, t, num_tokens, transformer_options={}):
# Run the parent implementation unchanged; we capture pixel values
# separately in _forward before calling this.
return super().patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options)
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
omni = len(ref_latents) > 0
if omni:
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
t = self.t_embedder(t * self.time_scale, dtype=x.dtype)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats)
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
# ---- capture raw pixel patches before backbone embedding ----
# ---- capture raw pixel patches before patchify_and_embed embeds them ----
pH = pW = self.patch_size
B, C, H, W = x.shape
# [B, N, P*P*C] (same layout as what x_embedder receives)
pixel_patches = (
x.view(B, C, H // pH, pH, W // pW, pW)
.permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C]
.flatten(3) # [B, Ht, Wt, pH*pW*C]
.flatten(1, 2) # [B, N, pH*pW*C]
)
# reshape to [B*N, P^2, C] for the decoder
N = pixel_patches.shape[1]
pixel_values = pixel_patches.reshape(B * N, pH * pW, C)
# decoder sees one token per patch: [B*N, 1, P^2*C]
pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C)
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(
x, cap_feats, cap_mask, adaln_input, num_tokens,
ref_latents=ref_latents, ref_contexts=ref_contexts,
siglip_feats=siglip_feats, transformer_options=transformer_options
)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers):
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
# ---- pixel-space decoder ----
# img: [B, txt_len+N, dim] → extract image tokens → [B, N, dim]
img_hidden = img[:, cap_size[0]:, :] # [B, N, dim]
# per-patch conditioning: [B*N, dim]
decoder_cond = img_hidden.reshape(B * N, self.dim)
# ---- pixel-space decoder (replaces final_layer + unpatchify) ----
# img may have padding tokens beyond N; only the first N are real image patches
img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim]
decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim]
# decode: [B*N, P^2, C]
output = self.dec_net(pixel_values, decoder_cond)
output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C]
output = output.reshape(B, N, -1) # [B, N, P^2*C]
# reshape back: [B*N, P^2, C] → [B, N, P^2*C]
output = output.reshape(B, N, -1)
# unpatchify expects [B, txt_len+N, P^2*C] with cap tokens prepended
# re-prepend a zero placeholder for the cap positions so unpatchify works
# prepend zero cap placeholder so unpatchify indexing works unchanged
cap_placeholder = torch.zeros(
B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype
)
img_out = torch.cat([cap_placeholder, output], dim=1)
img_out = self.unpatchify(img_out, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
img_out = self.unpatchify(
torch.cat([cap_placeholder, output], dim=1),
img_size, cap_size, return_tensor=x_is_tensor
)[:, :, :h, :w]
return -img_out
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
# _forward returns -x0 (negated decoder output, matching the latent-space convention).
# Reference x0→v conversion: v = (noisy - out) / t, where out = -x0
# → v = (noisy - (-x0)) / t = (noisy + x0) / t
# Since neg_x0 = -x0: v = (x - neg_x0) / t
# _forward returns -x0 (negated decoder output, same convention as NextDiT).
#
# ComfyUI uses CONST sampling: calculate_denoised = x - model_output * sigma
# We need calculate_denoised to return x0, so model_output must equal (x - x0) / sigma.
#
# _forward returns neg_x0 = -x0, so:
# model_output = (x + neg_x0) / sigma → (x + (-x0)) / sigma = (x - x0) / sigma ✓
# Then: x - model_output * sigma = x - (x - x0) = x0 ✓
neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
return (x - neg_x0) / timesteps.view(-1, 1, 1, 1)
return (x + neg_x0) / timesteps.view(-1, 1, 1, 1)

View File

@ -1217,7 +1217,8 @@ class Lumina2(BaseModel):
class ZImagePixelSpace(Lumina2):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)

View File

@ -467,28 +467,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
if dec_cond_key in state_dict_keys: # pixel-space variant
dit_config["image_model"] = "zimage_pixel"
w = state_dict[dec_cond_key] # [patch_size^2 * decoder_hidden_size, dim]
dit_config["decoder_hidden_size"] = w.shape[0] // (dit_config["patch_size"] ** 2)
# patch_size and in_channels are derived from x_embedder:
# x_embedder: Linear(patch_size * patch_size * in_channels, dim)
# The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim.
x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1]
dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0]
# patch_size: infer from decoder final layer output matching x_embedder input
# in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2)
embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)]
dec_in_ch = dec_out # decoder in == decoder out (same pixel space)
dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3)
dit_config["in_channels"] = 3
dit_config["decoder_in_channels"] = dec_in_ch
dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0]
dit_config["decoder_num_res_blocks"] = count_blocks(
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
)
dit_config["decoder_max_freqs"] = 8 # fixed in NerfEmbedder
dit_config["in_channels"] = w.shape[1] // dit_config["dim"] if False else \
state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] // (dit_config["patch_size"] ** 2)
if '{}__x0__'.format(key_prefix) in state_dict_keys:
dit_config["use_x0"] = True
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
if dec_cond_key in state_dict_keys: # pixel-space variant
dit_config["image_model"] = "zimage_pixel"
w = state_dict[dec_cond_key] # [patch_size^2 * decoder_hidden_size, dim]
dit_config["decoder_hidden_size"] = w.shape[0] // (dit_config["patch_size"] ** 2)
dit_config["decoder_num_res_blocks"] = count_blocks(
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
)
dit_config["decoder_max_freqs"] = 8 # fixed in NerfEmbedder
dit_config["in_channels"] = w.shape[1] // dit_config["dim"] if False else \
state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] // (dit_config["patch_size"] ** 2)
dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5)
if '{}__x0__'.format(key_prefix) in state_dict_keys:
dit_config["use_x0"] = True

View File

@ -1734,6 +1734,6 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]