mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Simplified flipflop setup by adding FlipFlopModule.execute_blocks helper
This commit is contained in:
parent
c4420b6a41
commit
6d3ec9fcf3
@ -48,6 +48,25 @@ class FlipFlopModule(torch.nn.Module):
|
||||
def get_block_module_size(self, block_type: str) -> int:
|
||||
return comfy.model_management.module_size(getattr(self, block_type)[0])
|
||||
|
||||
def execute_blocks(self, block_type: str, func, out: torch.Tensor | tuple[torch.Tensor,...], *args, **kwargs):
|
||||
# execute blocks, supporting both single and double (or higher) block types
|
||||
if isinstance(out, torch.Tensor):
|
||||
out = (out,)
|
||||
for i, block in enumerate(self.get_blocks(block_type)):
|
||||
out = func(i, block, *out, *args, **kwargs)
|
||||
if isinstance(out, torch.Tensor):
|
||||
out = (out,)
|
||||
if "transformer_blocks" in self.flipflop:
|
||||
holder = self.flipflop["transformer_blocks"]
|
||||
with holder.context() as ctx:
|
||||
for i, block in enumerate(holder.blocks):
|
||||
out = ctx(func, i, block, *out, *args, **kwargs)
|
||||
if isinstance(out, torch.Tensor):
|
||||
out = (out,)
|
||||
if len(out) == 1:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
|
||||
class FlipFlopContext:
|
||||
def __init__(self, holder: FlipFlopHolder):
|
||||
|
||||
@ -400,7 +400,7 @@ class QwenImageTransformer2DModel(FlipFlopModule):
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
@ -469,14 +469,8 @@ class QwenImageTransformer2DModel(FlipFlopModule):
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.get_blocks("transformer_blocks")):
|
||||
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||
if "transformer_blocks" in self.flipflop:
|
||||
holder = self.flipflop["transformer_blocks"]
|
||||
with holder.context() as ctx:
|
||||
for i, block in enumerate(holder.blocks):
|
||||
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||
|
||||
out = (hidden_states, encoder_hidden_states)
|
||||
hidden_states, encoder_hidden_states = self.execute_blocks("transformer_blocks", self.indiv_block_fwd, out, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user