diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index ad1c523fe..a1604fbb7 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -257,29 +257,83 @@ class Chroma(nn.Module): img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): - return comfy.patcher_extension.WrapperExecutor.new_class_executor( - self._forward, - self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, guidance, control, transformer_options, **kwargs) - - def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): + def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): bs, c, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) - - if img.ndim != 3 or context.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") - h_len = ((h + (self.patch_size // 2)) // self.patch_size) w_len = ((w + (self.patch_size // 2)) // self.patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + h_offset = ((h_offset + (self.patch_size // 2)) // self.patch_size) + w_offset = ((w_offset + (self.patch_size // 2)) // self.patch_size) + + steps_h = h_len + steps_w = w_len + + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) + return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) + + def forward(self, x, timestep, context, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, guidance, ref_latents, control, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): + bs, c, h_orig, w_orig = x.shape + + h_len = ((h_orig + (self.patch_size // 2)) // self.patch_size) + w_len = ((w_orig + (self.patch_size // 2)) // self.patch_size) + img, img_ids = self.process_img(x, transformer_options=transformer_options) + if img.ndim != 3 or context.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + img_tokens = img.shape[1] + if ref_latents is not None: + h = 0 + w = 0 + index = 0 + ref_latents_method = kwargs.get("ref_latents_method", "offset") + for ref in ref_latents: + if ref_latents_method == "index": + index += 1 + h_offset = 0 + w_offset = 0 + elif ref_latents_method == "uxo": + index = 0 + h_offset = h_len * self.patch_size + h + w_offset = w_len * self.patch_size + w + h += ref.shape[-2] + w += ref.shape[-1] + else: + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) + + kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + img = torch.cat([img, kontext], dim=1) + img_ids = torch.cat([img_ids, kontext_ids], dim=1) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w] + out = out[:, :img_tokens] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]