Refactored old flip flop into a new implementation that allows for controlling the percentage of blocks getting flip flopped, converted nodes to v3 schema

This commit is contained in:
Jedrzej Kosinski 2025-09-25 22:41:41 -07:00
parent f9fbf902d5
commit 6b240b0bce
3 changed files with 177 additions and 21 deletions

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import torch import torch
import torch.cuda as cuda import torch.cuda as cuda
import copy import copy
@ -37,6 +38,102 @@ def patch_model_from_config(model, config: FlipFlopConfig):
setattr(model, config.overwrite_forward, flip_flop_transformer.__call__) setattr(model, config.overwrite_forward, flip_flop_transformer.__call__)
class FlipFlopContext:
def __init__(self, holder: FlipFlopHolder):
self.holder = holder
self.reset()
def reset(self):
self.num_blocks = len(self.holder.transformer_blocks)
self.first_flip = True
self.first_flop = True
self.last_flip = False
self.last_flop = False
def __enter__(self):
self.reset()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.holder.compute_stream.record_event(self.holder.cpy_end_event)
def do_flip(self, func, i: int, _, *args, **kwargs):
# flip
self.holder.compute_stream.wait_event(self.holder.cpy_end_event)
with torch.cuda.stream(self.holder.compute_stream):
out = func(i, self.holder.flip, *args, **kwargs)
self.holder.event_flip.record(self.holder.compute_stream)
# while flip executes, queue flop to copy to its next block
next_flop_i = i + 1
if next_flop_i >= self.num_blocks:
next_flop_i = next_flop_i - self.num_blocks
self.last_flip = True
if not self.first_flip:
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
if self.last_flip:
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
self.first_flip = False
return out
def do_flop(self, func, i: int, _, *args, **kwargs):
# flop
if not self.first_flop:
self.holder.compute_stream.wait_event(self.holder.cpy_end_event)
with torch.cuda.stream(self.holder.compute_stream):
out = func(i, self.holder.flop, *args, **kwargs)
self.holder.event_flop.record(self.holder.compute_stream)
# while flop executes, queue flip to copy to its next block
next_flip_i = i + 1
if next_flip_i >= self.num_blocks:
next_flip_i = next_flip_i - self.num_blocks
self.last_flop = True
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
if self.last_flop:
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
self.first_flop = False
return out
@torch.no_grad()
def __call__(self, func, i: int, block: torch.nn.Module, *args, **kwargs):
# flips are even indexes, flops are odd indexes
if i % 2 == 0:
return self.do_flip(func, i, block, *args, **kwargs)
else:
return self.do_flop(func, i, block, *args, **kwargs)
class FlipFlopHolder:
def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"):
self.inference_device = torch.device(inference_device)
self.offloading_device = torch.device(offloading_device)
self.transformer_blocks = transformer_blocks
self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=self.inference_device)
self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=self.inference_device)
self.compute_stream = cuda.default_stream(self.inference_device)
self.cpy_stream = cuda.Stream(self.inference_device)
self.event_flip = torch.cuda.Event(enable_timing=False)
self.event_flop = torch.cuda.Event(enable_timing=False)
self.cpy_end_event = torch.cuda.Event(enable_timing=False)
# INIT - is this actually needed?
self.compute_stream.record_event(self.cpy_end_event)
def _copy_state_dict(self, dst, src, cpy_start_event: torch.cuda.Event=None, cpy_end_event: torch.cuda.Event=None):
if cpy_start_event:
self.cpy_stream.wait_event(cpy_start_event)
with torch.cuda.stream(self.cpy_stream):
for k, v in src.items():
dst[k].copy_(v, non_blocking=True)
if cpy_end_event:
cpy_end_event.record(self.cpy_stream)
def context(self):
return FlipFlopContext(self)
class FlipFlopTransformer: class FlipFlopTransformer:
def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"): def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"):
self.transformer_blocks = transformer_blocks self.transformer_blocks = transformer_blocks
@ -114,6 +211,7 @@ class FlipFlopTransformer:
Flip accounts for even blocks (0 is first block), flop accounts for odd blocks. Flip accounts for even blocks (0 is first block), flop accounts for odd blocks.
''' '''
# separated flip flop refactor # separated flip flop refactor
num_blocks = len(self.transformer_blocks)
first_flip = True first_flip = True
first_flop = True first_flop = True
last_flip = False last_flip = False
@ -128,8 +226,8 @@ class FlipFlopTransformer:
self.event_flip.record(self.compute_stream) self.event_flip.record(self.compute_stream)
# while flip executes, queue flop to copy to its next block # while flip executes, queue flop to copy to its next block
next_flop_i = i + 1 next_flop_i = i + 1
if next_flop_i >= self.num_blocks: if next_flop_i >= num_blocks:
next_flop_i = next_flop_i - self.num_blocks next_flop_i = next_flop_i - num_blocks
last_flip = True last_flip = True
if not first_flip: if not first_flip:
self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event) self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event)
@ -145,8 +243,8 @@ class FlipFlopTransformer:
self.event_flop.record(self.compute_stream) self.event_flop.record(self.compute_stream)
# while flop executes, queue flip to copy to its next block # while flop executes, queue flip to copy to its next block
next_flip_i = i + 1 next_flip_i = i + 1
if next_flip_i >= self.num_blocks: if next_flip_i >= num_blocks:
next_flip_i = next_flip_i - self.num_blocks next_flip_i = next_flip_i - num_blocks
last_flop = True last_flop = True
self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event) self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event)
if last_flop: if last_flop:

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
from typing import Optional, Tuple from typing import Optional, Tuple
from einops import repeat from einops import repeat
from comfy.ldm.flipflop_transformer import FlipFlopHolder
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
@ -335,10 +336,18 @@ class QwenImageTransformer2DModel(nn.Module):
for _ in range(num_layers) for _ in range(num_layers)
]) ])
self.flipflop_holders: dict[str, FlipFlopHolder] = {}
if final_layer: if final_layer:
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
def setup_flipflop_holders(self, block_percentage: float):
# We hackily move any flipflopped blocks into holder so that our model management system does not see them.
num_blocks = int(len(self.transformer_blocks) * block_percentage)
self.flipflop_holders["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:])
self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks])
def process_img(self, x, index=0, h_offset=0, w_offset=0): def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
patch_size = self.patch_size patch_size = self.patch_size
@ -403,6 +412,12 @@ class QwenImageTransformer2DModel(nn.Module):
def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options): def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
if "blocks_fwd" in self.flipflop_holders:
holder = self.flipflop_holders["blocks_fwd"]
with holder.context() as ctx:
for i, block in enumerate(holder.transformer_blocks):
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
def _forward( def _forward(

View File

@ -1,28 +1,71 @@
from __future__ import annotations
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy.ldm.flipflop_transformer import FLIPFLOP_REGISTRY from comfy.ldm.flipflop_transformer import FLIPFLOP_REGISTRY
class FlipFlop: class FlipFlopOld(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
return {"required": return io.Schema(
{"model": ("MODEL",), }, node_id="FlipFlop",
} display_name="FlipFlop (Old)",
category="_for_testing",
inputs=[
io.Model.Input(id="model")
],
outputs=[
io.Model.Output()
],
description="Apply FlipFlop transformation to model using registry-based patching"
)
RETURN_TYPES = ("MODEL",) @classmethod
FUNCTION = "patch" def execute(cls, model) -> io.NodeOutput:
OUTPUT_NODE = False
CATEGORY = "_for_testing"
def patch(self, model):
patch_cls = FLIPFLOP_REGISTRY.get(model.model.diffusion_model.__class__.__name__, None) patch_cls = FLIPFLOP_REGISTRY.get(model.model.diffusion_model.__class__.__name__, None)
if patch_cls is None: if patch_cls is None:
raise ValueError(f"Model {model.model.diffusion_model.__class__.__name__} not supported") raise ValueError(f"Model {model.model.diffusion_model.__class__.__name__} not supported")
model.model.diffusion_model = patch_cls.patch(model.model.diffusion_model) model.model.diffusion_model = patch_cls.patch(model.model.diffusion_model)
return (model,) return io.NodeOutput(model)
NODE_CLASS_MAPPINGS = { class FlipFlop(io.ComfyNode):
"FlipFlop": FlipFlop @classmethod
} def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="FlipFlopNew",
display_name="FlipFlop (New)",
category="_for_testing",
inputs=[
io.Model.Input(id="model"),
io.Float.Input(id="block_percentage", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output()
],
description="Apply FlipFlop transformation to model using setup_flipflop_holders method"
)
@classmethod
def execute(cls, model: io.Model.Type, block_percentage: float) -> io.NodeOutput:
# NOTE: this is just a hacky prototype still, this would not be exposed as a node.
# At the moment, this modifies the underlying model with no way to 'unpatch' it.
model = model.clone()
if not hasattr(model.model.diffusion_model, "setup_flipflop_holders"):
raise ValueError("Model does not have flipflop holders; FlipFlop not supported")
model.model.diffusion_model.setup_flipflop_holders(block_percentage)
return io.NodeOutput(model)
class FlipFlopExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
FlipFlopOld,
FlipFlop,
]
async def comfy_entrypoint() -> FlipFlopExtension:
return FlipFlopExtension()