diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index b47aac70e..4e85dfd9a 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -336,6 +336,7 @@ class Kandinsky5(nn.Module): return freqs def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs): + patches_replace = transformer_options.get("patches_replace", {}) context = self.text_embeddings(context) time_embed = self.time_embeddings(timestep).to(x.dtype) + self.pooled_text_embeddings(y) @@ -345,8 +346,18 @@ class Kandinsky5(nn.Module): visual_embed = self.visual_embeddings(x) visual_shape = visual_embed.shape[:-1] - for block in self.visual_transformer_blocks: - visual_embed = block(visual_embed.flatten(1, -2), context, time_embed, freqs=freqs, transformer_options=transformer_options) + visual_embed = visual_embed.flatten(1, -2) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.visual_transformer_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.visual_transformer_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) + visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"] + else: + visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options) visual_embed = visual_embed.reshape(*visual_shape, -1) return self.out_layer(visual_embed, time_embed) diff --git a/comfy/model_base.py b/comfy/model_base.py index ff82929fb..43d7fb281 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1691,5 +1691,5 @@ class Kandinsky5_image(Kandinsky5): def concat_cond(self, **kwargs): return None - def process_latent_out(self, latent): + def process_latent_out(self, latent): # input is still 5D, return single frame to decode with Flux VAE return self.latent_format.process_out(latent)[:, :, 0]