mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 18:13:01 +08:00
Simplified flipflop setup by adding FlipFlopModule.execute_blocks helper
This commit is contained in:
parent
c4420b6a41
commit
6d3ec9fcf3
@ -48,6 +48,25 @@ class FlipFlopModule(torch.nn.Module):
|
|||||||
def get_block_module_size(self, block_type: str) -> int:
|
def get_block_module_size(self, block_type: str) -> int:
|
||||||
return comfy.model_management.module_size(getattr(self, block_type)[0])
|
return comfy.model_management.module_size(getattr(self, block_type)[0])
|
||||||
|
|
||||||
|
def execute_blocks(self, block_type: str, func, out: torch.Tensor | tuple[torch.Tensor,...], *args, **kwargs):
|
||||||
|
# execute blocks, supporting both single and double (or higher) block types
|
||||||
|
if isinstance(out, torch.Tensor):
|
||||||
|
out = (out,)
|
||||||
|
for i, block in enumerate(self.get_blocks(block_type)):
|
||||||
|
out = func(i, block, *out, *args, **kwargs)
|
||||||
|
if isinstance(out, torch.Tensor):
|
||||||
|
out = (out,)
|
||||||
|
if "transformer_blocks" in self.flipflop:
|
||||||
|
holder = self.flipflop["transformer_blocks"]
|
||||||
|
with holder.context() as ctx:
|
||||||
|
for i, block in enumerate(holder.blocks):
|
||||||
|
out = ctx(func, i, block, *out, *args, **kwargs)
|
||||||
|
if isinstance(out, torch.Tensor):
|
||||||
|
out = (out,)
|
||||||
|
if len(out) == 1:
|
||||||
|
out = out[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class FlipFlopContext:
|
class FlipFlopContext:
|
||||||
def __init__(self, holder: FlipFlopHolder):
|
def __init__(self, holder: FlipFlopHolder):
|
||||||
|
|||||||
@ -400,7 +400,7 @@ class QwenImageTransformer2DModel(FlipFlopModule):
|
|||||||
if add is not None:
|
if add is not None:
|
||||||
hidden_states[:, :add.shape[1]] += add
|
hidden_states[:, :add.shape[1]] += add
|
||||||
|
|
||||||
return encoder_hidden_states, hidden_states
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
@ -469,14 +469,8 @@ class QwenImageTransformer2DModel(FlipFlopModule):
|
|||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
for i, block in enumerate(self.get_blocks("transformer_blocks")):
|
out = (hidden_states, encoder_hidden_states)
|
||||||
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
hidden_states, encoder_hidden_states = self.execute_blocks("transformer_blocks", self.indiv_block_fwd, out, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||||
if "transformer_blocks" in self.flipflop:
|
|
||||||
holder = self.flipflop["transformer_blocks"]
|
|
||||||
with holder.context() as ctx:
|
|
||||||
for i, block in enumerate(holder.blocks):
|
|
||||||
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
|
||||||
|
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user