From f9bc8cc781744686fdfa5f9b6fc14a6987c02f15 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Wed, 13 May 2026 19:27:57 +0200 Subject: [PATCH] resolve hardcoded axes_dim --- comfy/ldm/common_dit.py | 4 ++-- comfy/ldm/flux/model.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index 83f830037..911fb7183 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -16,7 +16,7 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): rms_norm = comfy.rmsnorm.rms_norm -def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_options={}): +def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_options={}, num_axes=3): bs, c, h, w = x.shape x = pad_to_patch_size(x, (patch_size, patch_size)) @@ -39,7 +39,7 @@ def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_op h_offset += rope_options.get("shift_y", 0.0) w_offset += rope_options.get("shift_x", 0.0) - img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) + img_ids = torch.zeros((steps_h, steps_w, num_axes), device=x.device, dtype=x.dtype) img_ids[:, :, 0] = img_ids[:, :, 1] + index img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 8ff1ac627..ca8806fb4 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -324,7 +324,8 @@ class Flux(nn.Module): h_len = ((h_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size) - img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options) + num_axes = len(self.params.axes_dim) + img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options, num_axes=num_axes) img_tokens = img.shape[1] timestep_zero_index = None if ref_latents is not None: @@ -356,7 +357,7 @@ class Flux(nn.Module): h = max(h, ref.shape[-2] + h_offset) w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size, transformer_options=transformer_options) + kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size, transformer_options=transformer_options, num_axes=num_axes) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) ref_num_tokens.append(kontext.shape[1])