From 7f8d7f296867bea3eac09d118aac19f970978f07 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Fri, 14 Nov 2025 12:03:54 +0100 Subject: [PATCH] refactor flux/chroma process_img function to comfy.ldm.common_dit to be used as shared function. --- comfy/ldm/chroma/model.py | 33 ++------------------------------- comfy/ldm/common_dit.py | 30 ++++++++++++++++++++++++++++++ comfy/ldm/flux/model.py | 33 ++------------------------------- 3 files changed, 34 insertions(+), 62 deletions(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index a1604fbb7..c31524a40 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -257,35 +257,6 @@ class Chroma(nn.Module): img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img - def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): - bs, c, h, w = x.shape - x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) - - img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) - h_len = ((h + (self.patch_size // 2)) // self.patch_size) - w_len = ((w + (self.patch_size // 2)) // self.patch_size) - - h_offset = ((h_offset + (self.patch_size // 2)) // self.patch_size) - w_offset = ((w_offset + (self.patch_size // 2)) // self.patch_size) - - steps_h = h_len - steps_w = w_len - - rope_options = transformer_options.get("rope_options", None) - if rope_options is not None: - h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 - w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 - - index += rope_options.get("shift_t", 0.0) - h_offset += rope_options.get("shift_y", 0.0) - w_offset += rope_options.get("shift_x", 0.0) - - img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) - return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) - def forward(self, x, timestep, context, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, @@ -298,7 +269,7 @@ class Chroma(nn.Module): h_len = ((h_orig + (self.patch_size // 2)) // self.patch_size) w_len = ((w_orig + (self.patch_size // 2)) // self.patch_size) - img, img_ids = self.process_img(x, transformer_options=transformer_options) + img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=self.patch_size, transformer_options=transformer_options) if img.ndim != 3 or context.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") img_tokens = img.shape[1] @@ -329,7 +300,7 @@ class Chroma(nn.Module): h = max(h, ref.shape[-2] + h_offset) w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=self.patch_size) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index f7f56b72c..a31c9f571 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,4 +1,5 @@ import torch +from einops import rearrange, repeat import comfy.rmsnorm @@ -14,3 +15,32 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): rms_norm = comfy.rmsnorm.rms_norm + +def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=(2, 2), transformer_options={}): + bs, c, h, w = x.shape + x = pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + + h_offset = ((h_offset + (patch_size // 2)) // patch_size) + w_offset = ((w_offset + (patch_size // 2)) // patch_size) + + steps_h = h_len + steps_w = w_len + + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) + return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index b9d36f202..1333b45e7 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -210,35 +210,6 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): - bs, c, h, w = x.shape - patch_size = self.patch_size - x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) - - img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - h_len = ((h + (patch_size // 2)) // patch_size) - w_len = ((w + (patch_size // 2)) // patch_size) - - h_offset = ((h_offset + (patch_size // 2)) // patch_size) - w_offset = ((w_offset + (patch_size // 2)) // patch_size) - - steps_h = h_len - steps_w = w_len - - rope_options = transformer_options.get("rope_options", None) - if rope_options is not None: - h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 - w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 - - index += rope_options.get("shift_t", 0.0) - h_offset += rope_options.get("shift_y", 0.0) - w_offset += rope_options.get("shift_x", 0.0) - - img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) - return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -253,7 +224,7 @@ class Flux(nn.Module): h_len = ((h_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size) - img, img_ids = self.process_img(x, transformer_options=transformer_options) + img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options) img_tokens = img.shape[1] if ref_latents is not None: h = 0 @@ -282,7 +253,7 @@ class Flux(nn.Module): h = max(h, ref.shape[-2] + h_offset) w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)