mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
resolve hardcoded axes_dim
This commit is contained in:
parent
1715859d65
commit
f9bc8cc781
@ -16,7 +16,7 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
|
||||
rms_norm = comfy.rmsnorm.rms_norm
|
||||
|
||||
def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_options={}):
|
||||
def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_options={}, num_axes=3):
|
||||
bs, c, h, w = x.shape
|
||||
x = pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
@ -39,7 +39,7 @@ def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_op
|
||||
h_offset += rope_options.get("shift_y", 0.0)
|
||||
w_offset += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids = torch.zeros((steps_h, steps_w, num_axes), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
|
||||
@ -324,7 +324,8 @@ class Flux(nn.Module):
|
||||
|
||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options)
|
||||
num_axes = len(self.params.axes_dim)
|
||||
img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options, num_axes=num_axes)
|
||||
img_tokens = img.shape[1]
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
@ -356,7 +357,7 @@ class Flux(nn.Module):
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size, transformer_options=transformer_options)
|
||||
kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size, transformer_options=transformer_options, num_axes=num_axes)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user