From 9d0107ee9079f5e026419c12a61d1e70ebd0e50e Mon Sep 17 00:00:00 2001 From: silveroxides Date: Tue, 11 Nov 2025 16:32:31 +0100 Subject: [PATCH 1/5] Add process_img function from comfy/ldm/flux/model.py to enable the use of ScaleROPE --- comfy/ldm/chroma/model.py | 88 +++++++++++++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 17 deletions(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index ad1c523fe..a1604fbb7 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -257,29 +257,83 @@ class Chroma(nn.Module): img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): - return comfy.patcher_extension.WrapperExecutor.new_class_executor( - self._forward, - self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, guidance, control, transformer_options, **kwargs) - - def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): + def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): bs, c, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) - - if img.ndim != 3 or context.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") - h_len = ((h + (self.patch_size // 2)) // self.patch_size) w_len = ((w + (self.patch_size // 2)) // self.patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + h_offset = ((h_offset + (self.patch_size // 2)) // self.patch_size) + w_offset = ((w_offset + (self.patch_size // 2)) // self.patch_size) + + steps_h = h_len + steps_w = w_len + + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) + return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) + + def forward(self, x, timestep, context, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, guidance, ref_latents, control, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): + bs, c, h_orig, w_orig = x.shape + + h_len = ((h_orig + (self.patch_size // 2)) // self.patch_size) + w_len = ((w_orig + (self.patch_size // 2)) // self.patch_size) + img, img_ids = self.process_img(x, transformer_options=transformer_options) + if img.ndim != 3 or context.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + img_tokens = img.shape[1] + if ref_latents is not None: + h = 0 + w = 0 + index = 0 + ref_latents_method = kwargs.get("ref_latents_method", "offset") + for ref in ref_latents: + if ref_latents_method == "index": + index += 1 + h_offset = 0 + w_offset = 0 + elif ref_latents_method == "uxo": + index = 0 + h_offset = h_len * self.patch_size + h + w_offset = w_len * self.patch_size + w + h += ref.shape[-2] + w += ref.shape[-1] + else: + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) + + kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + img = torch.cat([img, kontext], dim=1) + img_ids = torch.cat([img_ids, kontext_ids], dim=1) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w] + out = out[:, :img_tokens] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] From 7f8d7f296867bea3eac09d118aac19f970978f07 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Fri, 14 Nov 2025 12:03:54 +0100 Subject: [PATCH 2/5] refactor flux/chroma process_img function to comfy.ldm.common_dit to be used as shared function. --- comfy/ldm/chroma/model.py | 33 ++------------------------------- comfy/ldm/common_dit.py | 30 ++++++++++++++++++++++++++++++ comfy/ldm/flux/model.py | 33 ++------------------------------- 3 files changed, 34 insertions(+), 62 deletions(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index a1604fbb7..c31524a40 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -257,35 +257,6 @@ class Chroma(nn.Module): img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img - def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): - bs, c, h, w = x.shape - x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) - - img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) - h_len = ((h + (self.patch_size // 2)) // self.patch_size) - w_len = ((w + (self.patch_size // 2)) // self.patch_size) - - h_offset = ((h_offset + (self.patch_size // 2)) // self.patch_size) - w_offset = ((w_offset + (self.patch_size // 2)) // self.patch_size) - - steps_h = h_len - steps_w = w_len - - rope_options = transformer_options.get("rope_options", None) - if rope_options is not None: - h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 - w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 - - index += rope_options.get("shift_t", 0.0) - h_offset += rope_options.get("shift_y", 0.0) - w_offset += rope_options.get("shift_x", 0.0) - - img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) - return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) - def forward(self, x, timestep, context, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, @@ -298,7 +269,7 @@ class Chroma(nn.Module): h_len = ((h_orig + (self.patch_size // 2)) // self.patch_size) w_len = ((w_orig + (self.patch_size // 2)) // self.patch_size) - img, img_ids = self.process_img(x, transformer_options=transformer_options) + img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=self.patch_size, transformer_options=transformer_options) if img.ndim != 3 or context.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") img_tokens = img.shape[1] @@ -329,7 +300,7 @@ class Chroma(nn.Module): h = max(h, ref.shape[-2] + h_offset) w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=self.patch_size) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index f7f56b72c..a31c9f571 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,4 +1,5 @@ import torch +from einops import rearrange, repeat import comfy.rmsnorm @@ -14,3 +15,32 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): rms_norm = comfy.rmsnorm.rms_norm + +def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=(2, 2), transformer_options={}): + bs, c, h, w = x.shape + x = pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + + h_offset = ((h_offset + (patch_size // 2)) // patch_size) + w_offset = ((w_offset + (patch_size // 2)) // patch_size) + + steps_h = h_len + steps_w = w_len + + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) + return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index b9d36f202..1333b45e7 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -210,35 +210,6 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): - bs, c, h, w = x.shape - patch_size = self.patch_size - x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) - - img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - h_len = ((h + (patch_size // 2)) // patch_size) - w_len = ((w + (patch_size // 2)) // patch_size) - - h_offset = ((h_offset + (patch_size // 2)) // patch_size) - w_offset = ((w_offset + (patch_size // 2)) // patch_size) - - steps_h = h_len - steps_w = w_len - - rope_options = transformer_options.get("rope_options", None) - if rope_options is not None: - h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 - w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 - - index += rope_options.get("shift_t", 0.0) - h_offset += rope_options.get("shift_y", 0.0) - w_offset += rope_options.get("shift_x", 0.0) - - img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) - return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -253,7 +224,7 @@ class Flux(nn.Module): h_len = ((h_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size) - img, img_ids = self.process_img(x, transformer_options=transformer_options) + img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options) img_tokens = img.shape[1] if ref_latents is not None: h = 0 @@ -282,7 +253,7 @@ class Flux(nn.Module): h = max(h, ref.shape[-2] + h_offset) w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) From c4ebef588207b586b2a053ea651d5c128e3a1147 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Fri, 14 Nov 2025 12:05:54 +0100 Subject: [PATCH 3/5] fix the ruff check error with unused import --- comfy/ldm/common_dit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index a31c9f571..3f6f38917 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,5 +1,5 @@ import torch -from einops import rearrange, repeat +from einops import rearrange import comfy.rmsnorm From b9be0b76b23eb386491e3f841a5c7dcc449692c4 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Fri, 14 Nov 2025 12:09:08 +0100 Subject: [PATCH 4/5] ACTUALLY fix the ruff check error with unused import --- comfy/ldm/common_dit.py | 2 +- comfy/ldm/flux/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index 3f6f38917..a31c9f571 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,5 +1,5 @@ import torch -from einops import rearrange +from einops import rearrange, repeat import comfy.rmsnorm diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 1333b45e7..fb38d1a73 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import torch from torch import Tensor, nn -from einops import rearrange, repeat +from einops import rearrange import comfy.ldm.common_dit import comfy.patcher_extension From 1afbaa097fb942a3bab3e11fda75107c43a5ed88 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Fri, 14 Nov 2025 12:10:25 +0100 Subject: [PATCH 5/5] Final fix for the ruff check error with unused import --- comfy/ldm/chroma/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index c31524a40..d56dd12d0 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import torch from torch import Tensor, nn -from einops import rearrange, repeat +from einops import rearrange import comfy.patcher_extension import comfy.ldm.common_dit