diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 0e980ecda..4018307db 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -48,6 +48,25 @@ class FlipFlopModule(torch.nn.Module): def get_block_module_size(self, block_type: str) -> int: return comfy.model_management.module_size(getattr(self, block_type)[0]) + def execute_blocks(self, block_type: str, func, out: torch.Tensor | tuple[torch.Tensor,...], *args, **kwargs): + # execute blocks, supporting both single and double (or higher) block types + if isinstance(out, torch.Tensor): + out = (out,) + for i, block in enumerate(self.get_blocks(block_type)): + out = func(i, block, *out, *args, **kwargs) + if isinstance(out, torch.Tensor): + out = (out,) + if "transformer_blocks" in self.flipflop: + holder = self.flipflop["transformer_blocks"] + with holder.context() as ctx: + for i, block in enumerate(holder.blocks): + out = ctx(func, i, block, *out, *args, **kwargs) + if isinstance(out, torch.Tensor): + out = (out,) + if len(out) == 1: + out = out[0] + return out + class FlipFlopContext: def __init__(self, holder: FlipFlopHolder): diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 2af4968ac..5d24a0029 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -400,7 +400,7 @@ class QwenImageTransformer2DModel(FlipFlopModule): if add is not None: hidden_states[:, :add.shape[1]] += add - return encoder_hidden_states, hidden_states + return hidden_states, encoder_hidden_states def _forward( self, @@ -469,14 +469,8 @@ class QwenImageTransformer2DModel(FlipFlopModule): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) - for i, block in enumerate(self.get_blocks("transformer_blocks")): - encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) - if "transformer_blocks" in self.flipflop: - holder = self.flipflop["transformer_blocks"] - with holder.context() as ctx: - for i, block in enumerate(holder.blocks): - encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) - + out = (hidden_states, encoder_hidden_states) + hidden_states, encoder_hidden_states = self.execute_blocks("transformer_blocks", self.indiv_block_fwd, out, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states)