From 6b240b0bceac06f8ed46e150c89a90888562a2ad Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 25 Sep 2025 22:41:41 -0700 Subject: [PATCH] Refactored old flip flop into a new implementation that allows for controlling the percentage of blocks getting flip flopped, converted nodes to v3 schema --- comfy/ldm/flipflop_transformer.py | 106 ++++++++++++++++++++++++++++-- comfy/ldm/qwen_image/model.py | 15 +++++ comfy_extras/nodes_flipflop.py | 77 +++++++++++++++++----- 3 files changed, 177 insertions(+), 21 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index edb4b3f75..242672462 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -1,3 +1,4 @@ +from __future__ import annotations import torch import torch.cuda as cuda import copy @@ -37,6 +38,102 @@ def patch_model_from_config(model, config: FlipFlopConfig): setattr(model, config.overwrite_forward, flip_flop_transformer.__call__) +class FlipFlopContext: + def __init__(self, holder: FlipFlopHolder): + self.holder = holder + self.reset() + + def reset(self): + self.num_blocks = len(self.holder.transformer_blocks) + self.first_flip = True + self.first_flop = True + self.last_flip = False + self.last_flop = False + + def __enter__(self): + self.reset() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.holder.compute_stream.record_event(self.holder.cpy_end_event) + + def do_flip(self, func, i: int, _, *args, **kwargs): + # flip + self.holder.compute_stream.wait_event(self.holder.cpy_end_event) + with torch.cuda.stream(self.holder.compute_stream): + out = func(i, self.holder.flip, *args, **kwargs) + self.holder.event_flip.record(self.holder.compute_stream) + # while flip executes, queue flop to copy to its next block + next_flop_i = i + 1 + if next_flop_i >= self.num_blocks: + next_flop_i = next_flop_i - self.num_blocks + self.last_flip = True + if not self.first_flip: + self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event) + if self.last_flip: + self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[0].state_dict(), cpy_start_event=self.holder.event_flip) + self.first_flip = False + return out + + def do_flop(self, func, i: int, _, *args, **kwargs): + # flop + if not self.first_flop: + self.holder.compute_stream.wait_event(self.holder.cpy_end_event) + with torch.cuda.stream(self.holder.compute_stream): + out = func(i, self.holder.flop, *args, **kwargs) + self.holder.event_flop.record(self.holder.compute_stream) + # while flop executes, queue flip to copy to its next block + next_flip_i = i + 1 + if next_flip_i >= self.num_blocks: + next_flip_i = next_flip_i - self.num_blocks + self.last_flop = True + self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event) + if self.last_flop: + self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[1].state_dict(), cpy_start_event=self.holder.event_flop) + self.first_flop = False + return out + + @torch.no_grad() + def __call__(self, func, i: int, block: torch.nn.Module, *args, **kwargs): + # flips are even indexes, flops are odd indexes + if i % 2 == 0: + return self.do_flip(func, i, block, *args, **kwargs) + else: + return self.do_flop(func, i, block, *args, **kwargs) + + + +class FlipFlopHolder: + def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"): + self.inference_device = torch.device(inference_device) + self.offloading_device = torch.device(offloading_device) + self.transformer_blocks = transformer_blocks + + self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=self.inference_device) + self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=self.inference_device) + + self.compute_stream = cuda.default_stream(self.inference_device) + self.cpy_stream = cuda.Stream(self.inference_device) + + self.event_flip = torch.cuda.Event(enable_timing=False) + self.event_flop = torch.cuda.Event(enable_timing=False) + self.cpy_end_event = torch.cuda.Event(enable_timing=False) + # INIT - is this actually needed? + self.compute_stream.record_event(self.cpy_end_event) + + def _copy_state_dict(self, dst, src, cpy_start_event: torch.cuda.Event=None, cpy_end_event: torch.cuda.Event=None): + if cpy_start_event: + self.cpy_stream.wait_event(cpy_start_event) + + with torch.cuda.stream(self.cpy_stream): + for k, v in src.items(): + dst[k].copy_(v, non_blocking=True) + if cpy_end_event: + cpy_end_event.record(self.cpy_stream) + + def context(self): + return FlipFlopContext(self) + class FlipFlopTransformer: def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"): self.transformer_blocks = transformer_blocks @@ -114,6 +211,7 @@ class FlipFlopTransformer: Flip accounts for even blocks (0 is first block), flop accounts for odd blocks. ''' # separated flip flop refactor + num_blocks = len(self.transformer_blocks) first_flip = True first_flop = True last_flip = False @@ -128,8 +226,8 @@ class FlipFlopTransformer: self.event_flip.record(self.compute_stream) # while flip executes, queue flop to copy to its next block next_flop_i = i + 1 - if next_flop_i >= self.num_blocks: - next_flop_i = next_flop_i - self.num_blocks + if next_flop_i >= num_blocks: + next_flop_i = next_flop_i - num_blocks last_flip = True if not first_flip: self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event) @@ -145,8 +243,8 @@ class FlipFlopTransformer: self.event_flop.record(self.compute_stream) # while flop executes, queue flip to copy to its next block next_flip_i = i + 1 - if next_flip_i >= self.num_blocks: - next_flip_i = next_flip_i - self.num_blocks + if next_flip_i >= num_blocks: + next_flip_i = next_flip_i - num_blocks last_flop = True self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event) if last_flop: diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index fad6440eb..a338a6805 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from typing import Optional, Tuple from einops import repeat +from comfy.ldm.flipflop_transformer import FlipFlopHolder from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND @@ -335,10 +336,18 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) + self.flipflop_holders: dict[str, FlipFlopHolder] = {} + if final_layer: self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) + def setup_flipflop_holders(self, block_percentage: float): + # We hackily move any flipflopped blocks into holder so that our model management system does not see them. + num_blocks = int(len(self.transformer_blocks) * block_percentage) + self.flipflop_holders["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:]) + self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks]) + def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape patch_size = self.patch_size @@ -403,6 +412,12 @@ class QwenImageTransformer2DModel(nn.Module): def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options): for i, block in enumerate(self.transformer_blocks): encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) + if "blocks_fwd" in self.flipflop_holders: + holder = self.flipflop_holders["blocks_fwd"] + with holder.context() as ctx: + for i, block in enumerate(holder.transformer_blocks): + encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) + return encoder_hidden_states, hidden_states def _forward( diff --git a/comfy_extras/nodes_flipflop.py b/comfy_extras/nodes_flipflop.py index 0406d2441..90ea1f6d5 100644 --- a/comfy_extras/nodes_flipflop.py +++ b/comfy_extras/nodes_flipflop.py @@ -1,28 +1,71 @@ +from __future__ import annotations +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + from comfy.ldm.flipflop_transformer import FLIPFLOP_REGISTRY -class FlipFlop: +class FlipFlopOld(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), }, - } + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="FlipFlop", + display_name="FlipFlop (Old)", + category="_for_testing", + inputs=[ + io.Model.Input(id="model") + ], + outputs=[ + io.Model.Output() + ], + description="Apply FlipFlop transformation to model using registry-based patching" + ) - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - OUTPUT_NODE = False - - CATEGORY = "_for_testing" - - def patch(self, model): + @classmethod + def execute(cls, model) -> io.NodeOutput: patch_cls = FLIPFLOP_REGISTRY.get(model.model.diffusion_model.__class__.__name__, None) if patch_cls is None: raise ValueError(f"Model {model.model.diffusion_model.__class__.__name__} not supported") model.model.diffusion_model = patch_cls.patch(model.model.diffusion_model) - return (model,) + return io.NodeOutput(model) -NODE_CLASS_MAPPINGS = { - "FlipFlop": FlipFlop -} +class FlipFlop(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="FlipFlopNew", + display_name="FlipFlop (New)", + category="_for_testing", + inputs=[ + io.Model.Input(id="model"), + io.Float.Input(id="block_percentage", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Model.Output() + ], + description="Apply FlipFlop transformation to model using setup_flipflop_holders method" + ) + + @classmethod + def execute(cls, model: io.Model.Type, block_percentage: float) -> io.NodeOutput: + # NOTE: this is just a hacky prototype still, this would not be exposed as a node. + # At the moment, this modifies the underlying model with no way to 'unpatch' it. + model = model.clone() + if not hasattr(model.model.diffusion_model, "setup_flipflop_holders"): + raise ValueError("Model does not have flipflop holders; FlipFlop not supported") + model.model.diffusion_model.setup_flipflop_holders(block_percentage) + return io.NodeOutput(model) + +class FlipFlopExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FlipFlopOld, + FlipFlop, + ] + + +async def comfy_entrypoint() -> FlipFlopExtension: + return FlipFlopExtension()