mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-15 03:27:24 +08:00
resolve hardcoded axes_dim
This commit is contained in:
parent
1715859d65
commit
f9bc8cc781
@ -16,7 +16,7 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|||||||
|
|
||||||
rms_norm = comfy.rmsnorm.rms_norm
|
rms_norm = comfy.rmsnorm.rms_norm
|
||||||
|
|
||||||
def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_options={}):
|
def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_options={}, num_axes=3):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
x = pad_to_patch_size(x, (patch_size, patch_size))
|
x = pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=2, transformer_op
|
|||||||
h_offset += rope_options.get("shift_y", 0.0)
|
h_offset += rope_options.get("shift_y", 0.0)
|
||||||
w_offset += rope_options.get("shift_x", 0.0)
|
w_offset += rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((steps_h, steps_w, num_axes), device=x.device, dtype=x.dtype)
|
||||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
|
|||||||
@ -324,7 +324,8 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||||
img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options)
|
num_axes = len(self.params.axes_dim)
|
||||||
|
img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options, num_axes=num_axes)
|
||||||
img_tokens = img.shape[1]
|
img_tokens = img.shape[1]
|
||||||
timestep_zero_index = None
|
timestep_zero_index = None
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
@ -356,7 +357,7 @@ class Flux(nn.Module):
|
|||||||
h = max(h, ref.shape[-2] + h_offset)
|
h = max(h, ref.shape[-2] + h_offset)
|
||||||
w = max(w, ref.shape[-1] + w_offset)
|
w = max(w, ref.shape[-1] + w_offset)
|
||||||
|
|
||||||
kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size, transformer_options=transformer_options)
|
kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size, transformer_options=transformer_options, num_axes=num_axes)
|
||||||
img = torch.cat([img, kontext], dim=1)
|
img = torch.cat([img, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
ref_num_tokens.append(kontext.shape[1])
|
ref_num_tokens.append(kontext.shape[1])
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user