resolve hardcoded axes_dim

This commit is contained in:
silveroxides 2026-05-13 19:27:57 +02:00
parent 1715859d65
commit f9bc8cc781
2 changed files with 5 additions and 4 deletions

View File

@ -16,7 +16,7 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
rms_norm = comfy.rmsnorm.rms_norm
def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_options={}):
def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_options={}, num_axes=3):
bs, c, h, w = x.shape
x = pad_to_patch_size(x, (patch_size, patch_size))
@ -39,7 +39,7 @@ def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_op
h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
img_ids = torch.zeros((steps_h, steps_w, num_axes), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)

View File

@ -324,7 +324,8 @@ class Flux(nn.Module):
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options)
num_axes = len(self.params.axes_dim)
img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options, num_axes=num_axes)
img_tokens = img.shape[1]
timestep_zero_index = None
if ref_latents is not None:
@ -356,7 +357,7 @@ class Flux(nn.Module):
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size, transformer_options=transformer_options)
kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size, transformer_options=transformer_options, num_axes=num_axes)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])