Simplified flipflop setup by adding FlipFlopModule.execute_blocks helper

This commit is contained in:
Jedrzej Kosinski 2025-10-02 16:46:37 -07:00
parent c4420b6a41
commit 6d3ec9fcf3
2 changed files with 22 additions and 9 deletions

View File

@ -48,6 +48,25 @@ class FlipFlopModule(torch.nn.Module):
def get_block_module_size(self, block_type: str) -> int:
return comfy.model_management.module_size(getattr(self, block_type)[0])
def execute_blocks(self, block_type: str, func, out: torch.Tensor | tuple[torch.Tensor,...], *args, **kwargs):
# execute blocks, supporting both single and double (or higher) block types
if isinstance(out, torch.Tensor):
out = (out,)
for i, block in enumerate(self.get_blocks(block_type)):
out = func(i, block, *out, *args, **kwargs)
if isinstance(out, torch.Tensor):
out = (out,)
if "transformer_blocks" in self.flipflop:
holder = self.flipflop["transformer_blocks"]
with holder.context() as ctx:
for i, block in enumerate(holder.blocks):
out = ctx(func, i, block, *out, *args, **kwargs)
if isinstance(out, torch.Tensor):
out = (out,)
if len(out) == 1:
out = out[0]
return out
class FlipFlopContext:
def __init__(self, holder: FlipFlopHolder):

View File

@ -400,7 +400,7 @@ class QwenImageTransformer2DModel(FlipFlopModule):
if add is not None:
hidden_states[:, :add.shape[1]] += add
return encoder_hidden_states, hidden_states
return hidden_states, encoder_hidden_states
def _forward(
self,
@ -469,14 +469,8 @@ class QwenImageTransformer2DModel(FlipFlopModule):
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.get_blocks("transformer_blocks")):
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
if "transformer_blocks" in self.flipflop:
holder = self.flipflop["transformer_blocks"]
with holder.context() as ctx:
for i, block in enumerate(holder.blocks):
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
out = (hidden_states, encoder_hidden_states)
hidden_states, encoder_hidden_states = self.execute_blocks("transformer_blocks", self.indiv_block_fwd, out, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)