From 84e73f2aa5e75f42ceba40a5b124a53aa3dad899 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 25 Sep 2025 16:15:46 -0700 Subject: [PATCH 01/17] Brought over flip flop prototype from contentis' fork, limiting it to only Qwen to ease the process of adapting it to be a native feature --- comfy/ldm/flipflop_transformer.py | 243 ++++++++++++++++++++++++++++++ comfy/ldm/qwen_image/model.py | 69 +++++---- comfy_extras/nodes_flipflop.py | 28 ++++ nodes.py | 1 + 4 files changed, 310 insertions(+), 31 deletions(-) create mode 100644 comfy/ldm/flipflop_transformer.py create mode 100644 comfy_extras/nodes_flipflop.py diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py new file mode 100644 index 000000000..8d8f6565d --- /dev/null +++ b/comfy/ldm/flipflop_transformer.py @@ -0,0 +1,243 @@ +import torch +import torch.cuda as cuda +import copy +from typing import List, Tuple +from dataclasses import dataclass + +FLIPFLOP_REGISTRY = {} + +def register(name): + def decorator(cls): + FLIPFLOP_REGISTRY[name] = cls + return cls + return decorator + + +@dataclass +class FlipFlopConfig: + block_name: str + block_wrap_fn: callable + out_names: Tuple[str] + overwrite_forward: str + pinned_staging: bool = False + inference_device: str = "cuda" + offloading_device: str = "cpu" + + +def patch_model_from_config(model, config: FlipFlopConfig): + block_list = getattr(model, config.block_name) + flip_flop_transformer = FlipFlopTransformer(block_list, + block_wrap_fn=config.block_wrap_fn, + out_names=config.out_names, + offloading_device=config.offloading_device, + inference_device=config.inference_device, + pinned_staging=config.pinned_staging) + delattr(model, config.block_name) + setattr(model, config.block_name, flip_flop_transformer) + setattr(model, config.overwrite_forward, flip_flop_transformer.__call__) + + +class FlipFlopTransformer: + def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"): + self.transformer_blocks = transformer_blocks + self.offloading_device = torch.device(offloading_device) + self.inference_device = torch.device(inference_device) + self.staging = pinned_staging + + self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=self.inference_device) + self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=self.inference_device) + + self._cpy_fn = self._copy_state_dict + if self.staging: + self.staging_buffer = self._pin_module(self.transformer_blocks[0]).state_dict() + self._cpy_fn = self._copy_state_dict_with_staging + + self.compute_stream = cuda.default_stream(self.inference_device) + self.cpy_stream = cuda.Stream(self.inference_device) + + self.event_flip = torch.cuda.Event(enable_timing=False) + self.event_flop = torch.cuda.Event(enable_timing=False) + self.cpy_end_event = torch.cuda.Event(enable_timing=False) + + self.block_wrap_fn = block_wrap_fn + self.out_names = out_names + + self.num_blocks = len(self.transformer_blocks) + self.extra_run = self.num_blocks % 2 + + # INIT + self.compute_stream.record_event(self.cpy_end_event) + + def _copy_state_dict(self, dst, src, cpy_start_event=None, cpy_end_event=None): + if cpy_start_event: + self.cpy_stream.wait_event(cpy_start_event) + + with torch.cuda.stream(self.cpy_stream): + for k, v in src.items(): + dst[k].copy_(v, non_blocking=True) + if cpy_end_event: + cpy_end_event.record(self.cpy_stream) + + def _copy_state_dict_with_staging(self, dst, src, cpy_start_event=None, cpy_end_event=None): + if cpy_start_event: + self.cpy_stream.wait_event(cpy_start_event) + + with torch.cuda.stream(self.cpy_stream): + for k, v in src.items(): + self.staging_buffer[k].copy_(v, non_blocking=True) + dst[k].copy_(self.staging_buffer[k], non_blocking=True) + if cpy_end_event: + cpy_end_event.record(self.cpy_stream) + + def _pin_module(self, module): + pinned_module = copy.deepcopy(module) + for param in pinned_module.parameters(): + param.data = param.data.pin_memory() + # Pin all buffers (if any) + for buffer in pinned_module.buffers(): + buffer.data = buffer.data.pin_memory() + return pinned_module + + def _reset(self): + if self.extra_run: + self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop) + self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip) + else: + self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip) + self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop) + + self.compute_stream.record_event(self.cpy_end_event) + + @torch.no_grad() + def __call__(self, **feed_dict): + # contentis' prototype flip flop + # Wait for reset + self.compute_stream.wait_event(self.cpy_end_event) + + with torch.cuda.stream(self.compute_stream): + feed_dict = self.block_wrap_fn(self.flip, **feed_dict) + self.event_flip.record(self.compute_stream) + + for i in range(self.num_blocks // 2 - 1): + + with torch.cuda.stream(self.compute_stream): + feed_dict = self.block_wrap_fn(self.flop, **feed_dict) + self.event_flop.record(self.compute_stream) + + self._cpy_fn(self.flip.state_dict(), self.transformer_blocks[(i + 1) * 2].state_dict(), self.event_flip, + self.cpy_end_event) + + self.compute_stream.wait_event(self.cpy_end_event) + + with torch.cuda.stream(self.compute_stream): + feed_dict = self.block_wrap_fn(self.flip, **feed_dict) + self.event_flip.record(self.compute_stream) + + self._cpy_fn(self.flop.state_dict(), self.transformer_blocks[(i + 1) * 2 + 1].state_dict(), self.event_flop, + self.cpy_end_event) + self.compute_stream.wait_event(self.cpy_end_event) + + + with torch.cuda.stream(self.compute_stream): + feed_dict = self.block_wrap_fn(self.flop, **feed_dict) + self.event_flop.record(self.compute_stream) + + if self.extra_run: + self._cpy_fn(self.flip.state_dict(), self.transformer_blocks[-1].state_dict(), self.event_flip, + self.cpy_end_event) + self.compute_stream.wait_event(self.cpy_end_event) + with torch.cuda.stream(self.compute_stream): + feed_dict = self.block_wrap_fn(self.flip, **feed_dict) + self.event_flip.record(self.compute_stream) + + self._reset() + + outputs = [feed_dict[name] for name in self.out_names] + if len(outputs) == 1: + return outputs[0] + return tuple(outputs) + +# @register("Flux") +# class Flux: +# @staticmethod +# def double_block_wrap(block, **kwargs): +# kwargs["img"], kwargs["txt"] = block(img=kwargs["img"], +# txt=kwargs["txt"], +# vec=kwargs["vec"], +# pe=kwargs["pe"], +# attn_mask=kwargs.get("attn_mask")) +# return kwargs + +# @staticmethod +# def single_block_wrap(block, **kwargs): +# kwargs["img"] = block(kwargs["img"], +# vec=kwargs["vec"], +# pe=kwargs["pe"], +# attn_mask=kwargs.get("attn_mask")) +# return kwargs + +# double_config = FlipFlopConfig(block_name="double_blocks", +# block_wrap_fn=double_block_wrap, +# out_names=("img", "txt"), +# overwrite_forward="double_transformer_fwd", +# pinned_staging=False) + +# single_config = FlipFlopConfig(block_name="single_blocks", +# block_wrap_fn=single_block_wrap, +# out_names=("img",), +# overwrite_forward="single_transformer_fwd", +# pinned_staging=False) +# @staticmethod +# def patch(model): +# patch_model_from_config(model, Flux.double_config) +# patch_model_from_config(model, Flux.single_config) +# return model + + +# @register("WanModel") +# class Wan: +# @staticmethod +# def wan_blocks_wrap(block, **kwargs): +# kwargs["x"] = block(x=kwargs["x"], +# context=kwargs["context"], +# e=kwargs["e"], +# freqs=kwargs["freqs"], +# context_img_len=kwargs.get("context_img_len")) +# return kwargs + +# blocks_config = FlipFlopConfig(block_name="blocks", +# block_wrap_fn=wan_blocks_wrap, +# out_names=("x",), +# overwrite_forward="block_fwd", +# pinned_staging=False) + + +# @staticmethod +# def patch(model): +# patch_model_from_config(model, Wan.blocks_config) +# return model + + +@register("QwenImageTransformer2DModel") +class QwenImage: + @staticmethod + def qwen_blocks_wrap(block, **kwargs): + kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"], + encoder_hidden_states=kwargs["encoder_hidden_states"], + encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"], + temb=kwargs["temb"], + image_rotary_emb=kwargs["image_rotary_emb"]) + return kwargs + + blocks_config = FlipFlopConfig(block_name="transformer_blocks", + block_wrap_fn=qwen_blocks_wrap, + out_names=("encoder_hidden_states", "hidden_states"), + overwrite_forward="block_fwd", + pinned_staging=False) + + + @staticmethod + def patch(model): + patch_model_from_config(model, QwenImage.blocks_config) + return model + diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index b9f60c2b7..e2ac461a6 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -366,6 +366,39 @@ class QwenImageTransformer2DModel(nn.Module): comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) + def block_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace): + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) + return out + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + hidden_states[:, :add.shape[1]] += add + return encoder_hidden_states, hidden_states + def _forward( self, x, @@ -433,37 +466,11 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) - for i, block in enumerate(self.transformer_blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) - return out - out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) - hidden_states = out["img"] - encoder_hidden_states = out["txt"] - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, - temb=temb, - image_rotary_emb=image_rotary_emb, - transformer_options=transformer_options, - ) - - if "double_block" in patches: - for p in patches["double_block"]: - out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options}) - hidden_states = out["img"] - encoder_hidden_states = out["txt"] - - if control is not None: # Controlnet - control_i = control.get("input") - if i < len(control_i): - add = control_i[i] - if add is not None: - hidden_states[:, :add.shape[1]] += add + encoder_hidden_states, hidden_states = self.block_fwd(hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, image_rotary_emb=image_rotary_emb, + patches=patches, control=control, blocks_replace=blocks_replace) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy_extras/nodes_flipflop.py b/comfy_extras/nodes_flipflop.py new file mode 100644 index 000000000..0406d2441 --- /dev/null +++ b/comfy_extras/nodes_flipflop.py @@ -0,0 +1,28 @@ +from comfy.ldm.flipflop_transformer import FLIPFLOP_REGISTRY + +class FlipFlop: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), }, + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + OUTPUT_NODE = False + + CATEGORY = "_for_testing" + + def patch(self, model): + patch_cls = FLIPFLOP_REGISTRY.get(model.model.diffusion_model.__class__.__name__, None) + if patch_cls is None: + raise ValueError(f"Model {model.model.diffusion_model.__class__.__name__} not supported") + + model.model.diffusion_model = patch_cls.patch(model.model.diffusion_model) + + return (model,) + +NODE_CLASS_MAPPINGS = { + "FlipFlop": FlipFlop +} diff --git a/nodes.py b/nodes.py index 1a6784b68..4fff3eb10 100644 --- a/nodes.py +++ b/nodes.py @@ -2330,6 +2330,7 @@ async def init_builtin_extra_nodes(): "nodes_model_patch.py", "nodes_easycache.py", "nodes_audio_encoder.py", + "nodes_flipflop.py", ] import_failed = [] From f083720eb4b8de9ad56f42545a5d549ae03ca12c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 25 Sep 2025 16:16:51 -0700 Subject: [PATCH 02/17] Refactored FlipFlopTransformer.__call__ to fully separate out actions between flip and flop --- comfy/ldm/flipflop_transformer.py | 52 +++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 8d8f6565d..d8059eafe 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -110,6 +110,58 @@ class FlipFlopTransformer: @torch.no_grad() def __call__(self, **feed_dict): + ''' + Flip accounts for even blocks (0 is first block), flop accounts for odd blocks. + ''' + # separated flip flop refactor + first_flip = True + first_flop = True + last_flip = False + last_flop = False + for i, block in enumerate(self.transformer_blocks): + is_flip = i % 2 == 0 + if is_flip: + # flip + self.compute_stream.wait_event(self.cpy_end_event) + with torch.cuda.stream(self.compute_stream): + feed_dict = self.block_wrap_fn(self.flip, **feed_dict) + self.event_flip.record(self.compute_stream) + # while flip executes, queue flop to copy to its next block + next_flop_i = i + 1 + if next_flop_i >= self.num_blocks: + next_flop_i = next_flop_i - self.num_blocks + last_flip = True + if not first_flip: + self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event) + if last_flip: + self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip) + first_flip = False + else: + # flop + if not first_flop: + self.compute_stream.wait_event(self.cpy_end_event) + with torch.cuda.stream(self.compute_stream): + feed_dict = self.block_wrap_fn(self.flop, **feed_dict) + self.event_flop.record(self.compute_stream) + # while flop executes, queue flip to copy to its next block + next_flip_i = i + 1 + if next_flip_i >= self.num_blocks: + next_flip_i = next_flip_i - self.num_blocks + last_flop = True + self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event) + if last_flop: + self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop) + first_flop = False + + self.compute_stream.record_event(self.cpy_end_event) + + outputs = [feed_dict[name] for name in self.out_names] + if len(outputs) == 1: + return outputs[0] + return tuple(outputs) + + @torch.no_grad() + def __call__old(self, **feed_dict): # contentis' prototype flip flop # Wait for reset self.compute_stream.wait_event(self.cpy_end_event) From f9fbf902d5f3894c054b0f6e06d3994a2dae84d2 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 25 Sep 2025 17:49:39 -0700 Subject: [PATCH 03/17] Added missing Qwen block params, further subdivided blocks function --- comfy/ldm/flipflop_transformer.py | 5 ++- comfy/ldm/qwen_image/model.py | 65 +++++++++++++++++-------------- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index d8059eafe..edb4b3f75 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -278,13 +278,14 @@ class QwenImage: encoder_hidden_states=kwargs["encoder_hidden_states"], encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"], temb=kwargs["temb"], - image_rotary_emb=kwargs["image_rotary_emb"]) + image_rotary_emb=kwargs["image_rotary_emb"], + transformer_options=kwargs["transformer_options"]) return kwargs blocks_config = FlipFlopConfig(block_name="transformer_blocks", block_wrap_fn=qwen_blocks_wrap, out_names=("encoder_hidden_states", "hidden_states"), - overwrite_forward="block_fwd", + overwrite_forward="blocks_fwd", pinned_staging=False) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index e2ac461a6..fad6440eb 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -366,37 +366,43 @@ class QwenImageTransformer2DModel(nn.Module): comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) - def block_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace): - for i, block in enumerate(self.transformer_blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) - return out - out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) + def indiv_block_fwd(self, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options}) hidden_states = out["img"] encoder_hidden_states = out["txt"] - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) - if "double_block" in patches: - for p in patches["double_block"]: - out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) - hidden_states = out["img"] - encoder_hidden_states = out["txt"] + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + hidden_states[:, :add.shape[1]] += add - if control is not None: # Controlnet - control_i = control.get("input") - if i < len(control_i): - add = control_i[i] - if add is not None: - hidden_states[:, :add.shape[1]] += add + return encoder_hidden_states, hidden_states + + def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options): + for i, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) return encoder_hidden_states, hidden_states def _forward( @@ -466,11 +472,12 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) - encoder_hidden_states, hidden_states = self.block_fwd(hidden_states=hidden_states, + encoder_hidden_states, hidden_states = self.blocks_fwd(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, - patches=patches, control=control, blocks_replace=blocks_replace) + patches=patches, control=control, blocks_replace=blocks_replace, x=x, + transformer_options=transformer_options) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From 6b240b0bceac06f8ed46e150c89a90888562a2ad Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 25 Sep 2025 22:41:41 -0700 Subject: [PATCH 04/17] Refactored old flip flop into a new implementation that allows for controlling the percentage of blocks getting flip flopped, converted nodes to v3 schema --- comfy/ldm/flipflop_transformer.py | 106 ++++++++++++++++++++++++++++-- comfy/ldm/qwen_image/model.py | 15 +++++ comfy_extras/nodes_flipflop.py | 77 +++++++++++++++++----- 3 files changed, 177 insertions(+), 21 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index edb4b3f75..242672462 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -1,3 +1,4 @@ +from __future__ import annotations import torch import torch.cuda as cuda import copy @@ -37,6 +38,102 @@ def patch_model_from_config(model, config: FlipFlopConfig): setattr(model, config.overwrite_forward, flip_flop_transformer.__call__) +class FlipFlopContext: + def __init__(self, holder: FlipFlopHolder): + self.holder = holder + self.reset() + + def reset(self): + self.num_blocks = len(self.holder.transformer_blocks) + self.first_flip = True + self.first_flop = True + self.last_flip = False + self.last_flop = False + + def __enter__(self): + self.reset() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.holder.compute_stream.record_event(self.holder.cpy_end_event) + + def do_flip(self, func, i: int, _, *args, **kwargs): + # flip + self.holder.compute_stream.wait_event(self.holder.cpy_end_event) + with torch.cuda.stream(self.holder.compute_stream): + out = func(i, self.holder.flip, *args, **kwargs) + self.holder.event_flip.record(self.holder.compute_stream) + # while flip executes, queue flop to copy to its next block + next_flop_i = i + 1 + if next_flop_i >= self.num_blocks: + next_flop_i = next_flop_i - self.num_blocks + self.last_flip = True + if not self.first_flip: + self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event) + if self.last_flip: + self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[0].state_dict(), cpy_start_event=self.holder.event_flip) + self.first_flip = False + return out + + def do_flop(self, func, i: int, _, *args, **kwargs): + # flop + if not self.first_flop: + self.holder.compute_stream.wait_event(self.holder.cpy_end_event) + with torch.cuda.stream(self.holder.compute_stream): + out = func(i, self.holder.flop, *args, **kwargs) + self.holder.event_flop.record(self.holder.compute_stream) + # while flop executes, queue flip to copy to its next block + next_flip_i = i + 1 + if next_flip_i >= self.num_blocks: + next_flip_i = next_flip_i - self.num_blocks + self.last_flop = True + self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event) + if self.last_flop: + self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[1].state_dict(), cpy_start_event=self.holder.event_flop) + self.first_flop = False + return out + + @torch.no_grad() + def __call__(self, func, i: int, block: torch.nn.Module, *args, **kwargs): + # flips are even indexes, flops are odd indexes + if i % 2 == 0: + return self.do_flip(func, i, block, *args, **kwargs) + else: + return self.do_flop(func, i, block, *args, **kwargs) + + + +class FlipFlopHolder: + def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"): + self.inference_device = torch.device(inference_device) + self.offloading_device = torch.device(offloading_device) + self.transformer_blocks = transformer_blocks + + self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=self.inference_device) + self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=self.inference_device) + + self.compute_stream = cuda.default_stream(self.inference_device) + self.cpy_stream = cuda.Stream(self.inference_device) + + self.event_flip = torch.cuda.Event(enable_timing=False) + self.event_flop = torch.cuda.Event(enable_timing=False) + self.cpy_end_event = torch.cuda.Event(enable_timing=False) + # INIT - is this actually needed? + self.compute_stream.record_event(self.cpy_end_event) + + def _copy_state_dict(self, dst, src, cpy_start_event: torch.cuda.Event=None, cpy_end_event: torch.cuda.Event=None): + if cpy_start_event: + self.cpy_stream.wait_event(cpy_start_event) + + with torch.cuda.stream(self.cpy_stream): + for k, v in src.items(): + dst[k].copy_(v, non_blocking=True) + if cpy_end_event: + cpy_end_event.record(self.cpy_stream) + + def context(self): + return FlipFlopContext(self) + class FlipFlopTransformer: def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"): self.transformer_blocks = transformer_blocks @@ -114,6 +211,7 @@ class FlipFlopTransformer: Flip accounts for even blocks (0 is first block), flop accounts for odd blocks. ''' # separated flip flop refactor + num_blocks = len(self.transformer_blocks) first_flip = True first_flop = True last_flip = False @@ -128,8 +226,8 @@ class FlipFlopTransformer: self.event_flip.record(self.compute_stream) # while flip executes, queue flop to copy to its next block next_flop_i = i + 1 - if next_flop_i >= self.num_blocks: - next_flop_i = next_flop_i - self.num_blocks + if next_flop_i >= num_blocks: + next_flop_i = next_flop_i - num_blocks last_flip = True if not first_flip: self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event) @@ -145,8 +243,8 @@ class FlipFlopTransformer: self.event_flop.record(self.compute_stream) # while flop executes, queue flip to copy to its next block next_flip_i = i + 1 - if next_flip_i >= self.num_blocks: - next_flip_i = next_flip_i - self.num_blocks + if next_flip_i >= num_blocks: + next_flip_i = next_flip_i - num_blocks last_flop = True self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event) if last_flop: diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index fad6440eb..a338a6805 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from typing import Optional, Tuple from einops import repeat +from comfy.ldm.flipflop_transformer import FlipFlopHolder from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND @@ -335,10 +336,18 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) + self.flipflop_holders: dict[str, FlipFlopHolder] = {} + if final_layer: self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) + def setup_flipflop_holders(self, block_percentage: float): + # We hackily move any flipflopped blocks into holder so that our model management system does not see them. + num_blocks = int(len(self.transformer_blocks) * block_percentage) + self.flipflop_holders["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:]) + self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks]) + def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape patch_size = self.patch_size @@ -403,6 +412,12 @@ class QwenImageTransformer2DModel(nn.Module): def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options): for i, block in enumerate(self.transformer_blocks): encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) + if "blocks_fwd" in self.flipflop_holders: + holder = self.flipflop_holders["blocks_fwd"] + with holder.context() as ctx: + for i, block in enumerate(holder.transformer_blocks): + encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) + return encoder_hidden_states, hidden_states def _forward( diff --git a/comfy_extras/nodes_flipflop.py b/comfy_extras/nodes_flipflop.py index 0406d2441..90ea1f6d5 100644 --- a/comfy_extras/nodes_flipflop.py +++ b/comfy_extras/nodes_flipflop.py @@ -1,28 +1,71 @@ +from __future__ import annotations +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + from comfy.ldm.flipflop_transformer import FLIPFLOP_REGISTRY -class FlipFlop: +class FlipFlopOld(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), }, - } + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="FlipFlop", + display_name="FlipFlop (Old)", + category="_for_testing", + inputs=[ + io.Model.Input(id="model") + ], + outputs=[ + io.Model.Output() + ], + description="Apply FlipFlop transformation to model using registry-based patching" + ) - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - OUTPUT_NODE = False - - CATEGORY = "_for_testing" - - def patch(self, model): + @classmethod + def execute(cls, model) -> io.NodeOutput: patch_cls = FLIPFLOP_REGISTRY.get(model.model.diffusion_model.__class__.__name__, None) if patch_cls is None: raise ValueError(f"Model {model.model.diffusion_model.__class__.__name__} not supported") model.model.diffusion_model = patch_cls.patch(model.model.diffusion_model) - return (model,) + return io.NodeOutput(model) -NODE_CLASS_MAPPINGS = { - "FlipFlop": FlipFlop -} +class FlipFlop(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="FlipFlopNew", + display_name="FlipFlop (New)", + category="_for_testing", + inputs=[ + io.Model.Input(id="model"), + io.Float.Input(id="block_percentage", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Model.Output() + ], + description="Apply FlipFlop transformation to model using setup_flipflop_holders method" + ) + + @classmethod + def execute(cls, model: io.Model.Type, block_percentage: float) -> io.NodeOutput: + # NOTE: this is just a hacky prototype still, this would not be exposed as a node. + # At the moment, this modifies the underlying model with no way to 'unpatch' it. + model = model.clone() + if not hasattr(model.model.diffusion_model, "setup_flipflop_holders"): + raise ValueError("Model does not have flipflop holders; FlipFlop not supported") + model.model.diffusion_model.setup_flipflop_holders(block_percentage) + return io.NodeOutput(model) + +class FlipFlopExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FlipFlopOld, + FlipFlop, + ] + + +async def comfy_entrypoint() -> FlipFlopExtension: + return FlipFlopExtension() From 8a8162e8da860f66f32dc1bfc56d87f6d56ef098 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 29 Sep 2025 22:49:12 -0700 Subject: [PATCH 05/17] Fix percentage logic, begin adding elements to ModelPatcher to track flip flop compatibility --- comfy/ldm/flipflop_transformer.py | 24 ++++++++++++++++++------ comfy/ldm/qwen_image/model.py | 10 +++++----- comfy/model_patcher.py | 3 +++ 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 242672462..754733a86 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -5,6 +5,8 @@ import copy from typing import List, Tuple from dataclasses import dataclass +import comfy.model_management + FLIPFLOP_REGISTRY = {} def register(name): @@ -105,15 +107,21 @@ class FlipFlopContext: class FlipFlopHolder: def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"): - self.inference_device = torch.device(inference_device) - self.offloading_device = torch.device(offloading_device) + self.load_device = torch.device(inference_device) + self.offload_device = torch.device(offloading_device) self.transformer_blocks = transformer_blocks - self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=self.inference_device) - self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=self.inference_device) + self.block_module_size = 0 + if len(self.transformer_blocks) > 0: + self.block_module_size = comfy.model_management.module_size(self.transformer_blocks[0]) - self.compute_stream = cuda.default_stream(self.inference_device) - self.cpy_stream = cuda.Stream(self.inference_device) + self.flip: torch.nn.Module = None + self.flop: torch.nn.Module = None + # TODO: make initialization happen in model management code/model patcher, not here + self.initialize_flipflop_blocks(self.load_device) + + self.compute_stream = cuda.default_stream(self.load_device) + self.cpy_stream = cuda.Stream(self.load_device) self.event_flip = torch.cuda.Event(enable_timing=False) self.event_flop = torch.cuda.Event(enable_timing=False) @@ -134,6 +142,10 @@ class FlipFlopHolder: def context(self): return FlipFlopContext(self) + def initialize_flipflop_blocks(self, load_device: torch.device): + self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=load_device) + self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=load_device) + class FlipFlopTransformer: def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"): self.transformer_blocks = transformer_blocks diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index a338a6805..d1833ffe6 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -336,7 +336,7 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) - self.flipflop_holders: dict[str, FlipFlopHolder] = {} + self.flipflop: dict[str, FlipFlopHolder] = {} if final_layer: self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) @@ -344,8 +344,8 @@ class QwenImageTransformer2DModel(nn.Module): def setup_flipflop_holders(self, block_percentage: float): # We hackily move any flipflopped blocks into holder so that our model management system does not see them. - num_blocks = int(len(self.transformer_blocks) * block_percentage) - self.flipflop_holders["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:]) + num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage)) + self.flipflop["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:]) self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks]) def process_img(self, x, index=0, h_offset=0, w_offset=0): @@ -412,8 +412,8 @@ class QwenImageTransformer2DModel(nn.Module): def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options): for i, block in enumerate(self.transformer_blocks): encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) - if "blocks_fwd" in self.flipflop_holders: - holder = self.flipflop_holders["blocks_fwd"] + if "blocks_fwd" in self.flipflop: + holder = self.flipflop["blocks_fwd"] with holder.context() as ctx: for i, block in enumerate(holder.transformer_blocks): encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1fd03d9d1..b9a768203 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -604,6 +604,9 @@ class ModelPatcher: else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + def supports_flipflop(self): + return hasattr(self.model.diffusion_model, "flipflop") + def _load_list(self): loading = [] for n, m in self.model.named_modules(): From 01f4512bf804c9ff2489a6c1558340e58efc2d72 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 30 Sep 2025 23:08:08 -0700 Subject: [PATCH 06/17] In-progress commit on making flipflop async weight streaming native, made loaded partially/loaded completely log messages have labels because having to memorize their meaning for dev work is annoying --- comfy/ldm/flipflop_transformer.py | 154 ++++++++++++++++-------------- comfy/ldm/qwen_image/model.py | 52 ++++++---- comfy/model_patcher.py | 30 +++++- comfy_extras/nodes_flipflop.py | 28 ------ 4 files changed, 145 insertions(+), 119 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 754733a86..03e256c7a 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -3,42 +3,9 @@ import torch import torch.cuda as cuda import copy from typing import List, Tuple -from dataclasses import dataclass import comfy.model_management -FLIPFLOP_REGISTRY = {} - -def register(name): - def decorator(cls): - FLIPFLOP_REGISTRY[name] = cls - return cls - return decorator - - -@dataclass -class FlipFlopConfig: - block_name: str - block_wrap_fn: callable - out_names: Tuple[str] - overwrite_forward: str - pinned_staging: bool = False - inference_device: str = "cuda" - offloading_device: str = "cpu" - - -def patch_model_from_config(model, config: FlipFlopConfig): - block_list = getattr(model, config.block_name) - flip_flop_transformer = FlipFlopTransformer(block_list, - block_wrap_fn=config.block_wrap_fn, - out_names=config.out_names, - offloading_device=config.offloading_device, - inference_device=config.inference_device, - pinned_staging=config.pinned_staging) - delattr(model, config.block_name) - setattr(model, config.block_name, flip_flop_transformer) - setattr(model, config.overwrite_forward, flip_flop_transformer.__call__) - class FlipFlopContext: def __init__(self, holder: FlipFlopHolder): @@ -46,11 +13,12 @@ class FlipFlopContext: self.reset() def reset(self): - self.num_blocks = len(self.holder.transformer_blocks) + self.num_blocks = len(self.holder.blocks) self.first_flip = True self.first_flop = True self.last_flip = False self.last_flop = False + # TODO: the 'i' that's passed into func needs to be properly offset to do patches correctly def __enter__(self): self.reset() @@ -71,9 +39,9 @@ class FlipFlopContext: next_flop_i = next_flop_i - self.num_blocks self.last_flip = True if not self.first_flip: - self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event) + self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event) if self.last_flip: - self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[0].state_dict(), cpy_start_event=self.holder.event_flip) + self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[0].state_dict(), cpy_start_event=self.holder.event_flip) self.first_flip = False return out @@ -89,9 +57,9 @@ class FlipFlopContext: if next_flip_i >= self.num_blocks: next_flip_i = next_flip_i - self.num_blocks self.last_flop = True - self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event) + self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event) if self.last_flop: - self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[1].state_dict(), cpy_start_event=self.holder.event_flop) + self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[1].state_dict(), cpy_start_event=self.holder.event_flop) self.first_flop = False return out @@ -106,19 +74,20 @@ class FlipFlopContext: class FlipFlopHolder: - def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"): - self.load_device = torch.device(inference_device) - self.offload_device = torch.device(offloading_device) - self.transformer_blocks = transformer_blocks + def __init__(self, blocks: List[torch.nn.Module], flip_amount: int, load_device="cuda", offload_device="cpu"): + self.load_device = torch.device(load_device) + self.offload_device = torch.device(offload_device) + self.blocks = blocks + self.flip_amount = flip_amount self.block_module_size = 0 - if len(self.transformer_blocks) > 0: - self.block_module_size = comfy.model_management.module_size(self.transformer_blocks[0]) + if len(self.blocks) > 0: + self.block_module_size = comfy.model_management.module_size(self.blocks[0]) self.flip: torch.nn.Module = None self.flop: torch.nn.Module = None # TODO: make initialization happen in model management code/model patcher, not here - self.initialize_flipflop_blocks(self.load_device) + self.init_flipflop_blocks(self.load_device) self.compute_stream = cuda.default_stream(self.load_device) self.cpy_stream = cuda.Stream(self.load_device) @@ -142,10 +111,57 @@ class FlipFlopHolder: def context(self): return FlipFlopContext(self) - def initialize_flipflop_blocks(self, load_device: torch.device): - self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=load_device) - self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=load_device) + def init_flipflop_blocks(self, load_device: torch.device): + self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device) + self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device) + def clean_flipflop_blocks(self): + del self.flip + del self.flop + self.flip = None + self.flop = None + + +class FlopFlopModule(torch.nn.Module): + def __init__(self, block_types: tuple[str, ...]): + super().__init__() + self.block_types = block_types + self.flipflop: dict[str, FlipFlopHolder] = {} + + def setup_flipflop_holders(self, block_percentage: float): + for block_type in self.block_types: + if block_type in self.flipflop: + continue + num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage)) + self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks) + + def clean_flipflop_holders(self): + for block_type in self.flipflop.keys(): + self.flipflop[block_type].clean_flipflop_blocks() + del self.flipflop[block_type] + + def get_blocks(self, block_type: str) -> torch.nn.ModuleList: + if block_type not in self.block_types: + raise ValueError(f"Block type {block_type} not found in {self.block_types}") + if block_type in self.flipflop: + return getattr(self, block_type)[:self.flipflop[block_type].flip_amount] + return getattr(self, block_type) + + def get_all_block_module_sizes(self, sort_by_size: bool = False) -> list[tuple[str, int]]: + ''' + Returns a list of (block_type, size). + If sort_by_size is True, the list is sorted by size. + ''' + sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types] + if sort_by_size: + sizes.sort(key=lambda x: x[1]) + return sizes + + def get_block_module_size(self, block_type: str) -> int: + return comfy.model_management.module_size(getattr(self, block_type)[0]) + + +# Below is the implementation from contentis' prototype flip flop class FlipFlopTransformer: def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"): self.transformer_blocks = transformer_blocks @@ -379,28 +395,26 @@ class FlipFlopTransformer: # patch_model_from_config(model, Wan.blocks_config) # return model +# @register("QwenImageTransformer2DModel") +# class QwenImage: +# @staticmethod +# def qwen_blocks_wrap(block, **kwargs): +# kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"], +# encoder_hidden_states=kwargs["encoder_hidden_states"], +# encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"], +# temb=kwargs["temb"], +# image_rotary_emb=kwargs["image_rotary_emb"], +# transformer_options=kwargs["transformer_options"]) +# return kwargs -@register("QwenImageTransformer2DModel") -class QwenImage: - @staticmethod - def qwen_blocks_wrap(block, **kwargs): - kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"], - encoder_hidden_states=kwargs["encoder_hidden_states"], - encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"], - temb=kwargs["temb"], - image_rotary_emb=kwargs["image_rotary_emb"], - transformer_options=kwargs["transformer_options"]) - return kwargs - - blocks_config = FlipFlopConfig(block_name="transformer_blocks", - block_wrap_fn=qwen_blocks_wrap, - out_names=("encoder_hidden_states", "hidden_states"), - overwrite_forward="blocks_fwd", - pinned_staging=False) +# blocks_config = FlipFlopConfig(block_name="transformer_blocks", +# block_wrap_fn=qwen_blocks_wrap, +# out_names=("encoder_hidden_states", "hidden_states"), +# overwrite_forward="blocks_fwd", +# pinned_staging=False) - @staticmethod - def patch(model): - patch_model_from_config(model, QwenImage.blocks_config) - return model - +# @staticmethod +# def patch(model): +# patch_model_from_config(model, QwenImage.blocks_config) +# return model diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index d1833ffe6..474d831c4 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -343,11 +343,36 @@ class QwenImageTransformer2DModel(nn.Module): self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) def setup_flipflop_holders(self, block_percentage: float): + if "transformer_blocks" in self.flipflop: + return + import comfy.model_management # We hackily move any flipflopped blocks into holder so that our model management system does not see them. num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage)) - self.flipflop["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:]) + loading = [] + for n, m in self.named_modules(): + params = [] + skip = False + for name, param in m.named_parameters(recurse=False): + params.append(name) + for name, param in m.named_parameters(recurse=True): + if name not in params: + skip = True # skip random weights in non leaf modules + break + if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): + loading.append((comfy.model_management.module_size(m), n, m, params)) + self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks) self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks]) + def clean_flipflop_holders(self): + if "transformer_blocks" in self.flipflop: + self.flipflop["transformer_blocks"].clean_flipflop_blocks() + del self.flipflop["transformer_blocks"] + + def get_transformer_blocks(self): + if "transformer_blocks" in self.flipflop: + return self.transformer_blocks[:self.flipflop["transformer_blocks"].flip_amount] + return self.transformer_blocks + def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape patch_size = self.patch_size @@ -409,17 +434,6 @@ class QwenImageTransformer2DModel(nn.Module): return encoder_hidden_states, hidden_states - def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options): - for i, block in enumerate(self.transformer_blocks): - encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) - if "blocks_fwd" in self.flipflop: - holder = self.flipflop["blocks_fwd"] - with holder.context() as ctx: - for i, block in enumerate(holder.transformer_blocks): - encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) - - return encoder_hidden_states, hidden_states - def _forward( self, x, @@ -487,12 +501,14 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) - encoder_hidden_states, hidden_states = self.blocks_fwd(hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, - temb=temb, image_rotary_emb=image_rotary_emb, - patches=patches, control=control, blocks_replace=blocks_replace, x=x, - transformer_options=transformer_options) + for i, block in enumerate(self.get_transformer_blocks()): + encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) + if "transformer_blocks" in self.flipflop: + holder = self.flipflop["transformer_blocks"] + with holder.context() as ctx: + for i, block in enumerate(holder.blocks): + encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b9a768203..7e7b3a4ed 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -605,7 +605,27 @@ class ModelPatcher: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) def supports_flipflop(self): - return hasattr(self.model.diffusion_model, "flipflop") + # flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM + if not hasattr(self.model, "diffusion_model"): + return False + if not hasattr(self.model.diffusion_model, "flipflop"): + return False + if not comfy.model_management.is_nvidia(): + return False + if comfy.model_management.vram_state in (comfy.model_management.VRAMState.HIGH_VRAM, comfy.model_management.VRAMState.SHARED): + return False + return True + + def init_flipflop(self): + if not self.supports_flipflop(): + return + # figure out how many b + self.model.diffusion_model.setup_flipflop_holders(self.model_options["flipflop_block_percentage"]) + + def clean_flipflop(self): + if not self.supports_flipflop(): + return + self.model.diffusion_model.clean_flipflop_holders() def _load_list(self): loading = [] @@ -628,6 +648,9 @@ class ModelPatcher: mem_counter = 0 patch_counter = 0 lowvram_counter = 0 + lowvram_mem_counter = 0 + if self.supports_flipflop(): + ... loading = self._load_list() load_completely = [] @@ -647,6 +670,7 @@ class ModelPatcher: if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True lowvram_counter += 1 + lowvram_mem_counter += module_mem if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue @@ -709,10 +733,10 @@ class ModelPatcher: x[2].to(device_to) if lowvram_counter > 0: - logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}") self.model.model_lowvram = True else: - logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}") self.model.model_lowvram = False if full_load: self.model.to(device_to) diff --git a/comfy_extras/nodes_flipflop.py b/comfy_extras/nodes_flipflop.py index 90ea1f6d5..4ddd5c479 100644 --- a/comfy_extras/nodes_flipflop.py +++ b/comfy_extras/nodes_flipflop.py @@ -3,33 +3,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, io -from comfy.ldm.flipflop_transformer import FLIPFLOP_REGISTRY - -class FlipFlopOld(io.ComfyNode): - @classmethod - def define_schema(cls) -> io.Schema: - return io.Schema( - node_id="FlipFlop", - display_name="FlipFlop (Old)", - category="_for_testing", - inputs=[ - io.Model.Input(id="model") - ], - outputs=[ - io.Model.Output() - ], - description="Apply FlipFlop transformation to model using registry-based patching" - ) - - @classmethod - def execute(cls, model) -> io.NodeOutput: - patch_cls = FLIPFLOP_REGISTRY.get(model.model.diffusion_model.__class__.__name__, None) - if patch_cls is None: - raise ValueError(f"Model {model.model.diffusion_model.__class__.__name__} not supported") - - model.model.diffusion_model = patch_cls.patch(model.model.diffusion_model) - - return io.NodeOutput(model) class FlipFlop(io.ComfyNode): @classmethod @@ -62,7 +35,6 @@ class FlipFlopExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - FlipFlopOld, FlipFlop, ] From 7c896c55674fbd6d6f8f5e6de91a3c508bb59cff Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 1 Oct 2025 20:13:50 -0700 Subject: [PATCH 07/17] Initial automatic support for flipflop within ModelPatcher - only Qwen Image diffusion_model uses FlipFlopModule currently --- comfy/ldm/flipflop_transformer.py | 367 +++++------------------------- comfy/ldm/qwen_image/model.py | 44 +--- comfy/model_patcher.py | 101 ++++++-- 3 files changed, 149 insertions(+), 363 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 03e256c7a..a88613c01 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -1,12 +1,54 @@ from __future__ import annotations import torch -import torch.cuda as cuda import copy -from typing import List, Tuple import comfy.model_management +class FlipFlopModule(torch.nn.Module): + def __init__(self, block_types: tuple[str, ...]): + super().__init__() + self.block_types = block_types + self.flipflop: dict[str, FlipFlopHolder] = {} + + def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device): + for block_type, (flipflop_blocks, total_blocks) in block_info.items(): + if block_type in self.flipflop: + continue + self.flipflop[block_type] = FlipFlopHolder(getattr(self, block_type)[total_blocks-flipflop_blocks:], flipflop_blocks, total_blocks, load_device, offload_device) + + def init_flipflop_block_copies(self, device: torch.device): + for holder in self.flipflop.values(): + holder.init_flipflop_block_copies(device) + + def clean_flipflop_holders(self): + for block_type in list(self.flipflop.keys()): + self.flipflop[block_type].clean_flipflop_blocks() + del self.flipflop[block_type] + + def get_all_blocks(self, block_type: str) -> list[torch.nn.Module]: + return getattr(self, block_type) + + def get_blocks(self, block_type: str) -> torch.nn.ModuleList: + if block_type not in self.block_types: + raise ValueError(f"Block type {block_type} not found in {self.block_types}") + if block_type in self.flipflop: + return getattr(self, block_type)[:self.flipflop[block_type].flip_amount] + return getattr(self, block_type) + + def get_all_block_module_sizes(self, reverse_sort_by_size: bool = False) -> list[tuple[str, int]]: + ''' + Returns a list of (block_type, size) sorted by size. + If reverse_sort_by_size is True, the list is sorted by size in reverse order. + ''' + sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types] + sizes.sort(key=lambda x: x[1], reverse=reverse_sort_by_size) + return sizes + + def get_block_module_size(self, block_type: str) -> int: + return comfy.model_management.module_size(getattr(self, block_type)[0]) + + class FlipFlopContext: def __init__(self, holder: FlipFlopHolder): self.holder = holder @@ -18,7 +60,6 @@ class FlipFlopContext: self.first_flop = True self.last_flip = False self.last_flop = False - # TODO: the 'i' that's passed into func needs to be properly offset to do patches correctly def __enter__(self): self.reset() @@ -31,7 +72,7 @@ class FlipFlopContext: # flip self.holder.compute_stream.wait_event(self.holder.cpy_end_event) with torch.cuda.stream(self.holder.compute_stream): - out = func(i, self.holder.flip, *args, **kwargs) + out = func(i+self.holder.i_offset, self.holder.flip, *args, **kwargs) self.holder.event_flip.record(self.holder.compute_stream) # while flip executes, queue flop to copy to its next block next_flop_i = i + 1 @@ -50,7 +91,7 @@ class FlipFlopContext: if not self.first_flop: self.holder.compute_stream.wait_event(self.holder.cpy_end_event) with torch.cuda.stream(self.holder.compute_stream): - out = func(i, self.holder.flop, *args, **kwargs) + out = func(i+self.holder.i_offset, self.holder.flop, *args, **kwargs) self.holder.event_flop.record(self.holder.compute_stream) # while flop executes, queue flip to copy to its next block next_flip_i = i + 1 @@ -72,13 +113,15 @@ class FlipFlopContext: return self.do_flop(func, i, block, *args, **kwargs) - class FlipFlopHolder: - def __init__(self, blocks: List[torch.nn.Module], flip_amount: int, load_device="cuda", offload_device="cpu"): - self.load_device = torch.device(load_device) - self.offload_device = torch.device(offload_device) + def __init__(self, blocks: list[torch.nn.Module], flip_amount: int, total_amount: int, load_device: torch.device, offload_device: torch.device): + self.load_device = load_device + self.offload_device = offload_device self.blocks = blocks self.flip_amount = flip_amount + self.total_amount = total_amount + # NOTE: used to make sure block indexes passed into block functions match expected patch indexes + self.i_offset = total_amount - flip_amount self.block_module_size = 0 if len(self.blocks) > 0: @@ -86,11 +129,9 @@ class FlipFlopHolder: self.flip: torch.nn.Module = None self.flop: torch.nn.Module = None - # TODO: make initialization happen in model management code/model patcher, not here - self.init_flipflop_blocks(self.load_device) - self.compute_stream = cuda.default_stream(self.load_device) - self.cpy_stream = cuda.Stream(self.load_device) + self.compute_stream = torch.cuda.default_stream(self.load_device) + self.cpy_stream = torch.cuda.Stream(self.load_device) self.event_flip = torch.cuda.Event(enable_timing=False) self.event_flop = torch.cuda.Event(enable_timing=False) @@ -111,7 +152,7 @@ class FlipFlopHolder: def context(self): return FlipFlopContext(self) - def init_flipflop_blocks(self, load_device: torch.device): + def init_flipflop_block_copies(self, load_device: torch.device): self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device) self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device) @@ -120,301 +161,3 @@ class FlipFlopHolder: del self.flop self.flip = None self.flop = None - - -class FlopFlopModule(torch.nn.Module): - def __init__(self, block_types: tuple[str, ...]): - super().__init__() - self.block_types = block_types - self.flipflop: dict[str, FlipFlopHolder] = {} - - def setup_flipflop_holders(self, block_percentage: float): - for block_type in self.block_types: - if block_type in self.flipflop: - continue - num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage)) - self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks) - - def clean_flipflop_holders(self): - for block_type in self.flipflop.keys(): - self.flipflop[block_type].clean_flipflop_blocks() - del self.flipflop[block_type] - - def get_blocks(self, block_type: str) -> torch.nn.ModuleList: - if block_type not in self.block_types: - raise ValueError(f"Block type {block_type} not found in {self.block_types}") - if block_type in self.flipflop: - return getattr(self, block_type)[:self.flipflop[block_type].flip_amount] - return getattr(self, block_type) - - def get_all_block_module_sizes(self, sort_by_size: bool = False) -> list[tuple[str, int]]: - ''' - Returns a list of (block_type, size). - If sort_by_size is True, the list is sorted by size. - ''' - sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types] - if sort_by_size: - sizes.sort(key=lambda x: x[1]) - return sizes - - def get_block_module_size(self, block_type: str) -> int: - return comfy.model_management.module_size(getattr(self, block_type)[0]) - - -# Below is the implementation from contentis' prototype flip flop -class FlipFlopTransformer: - def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"): - self.transformer_blocks = transformer_blocks - self.offloading_device = torch.device(offloading_device) - self.inference_device = torch.device(inference_device) - self.staging = pinned_staging - - self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=self.inference_device) - self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=self.inference_device) - - self._cpy_fn = self._copy_state_dict - if self.staging: - self.staging_buffer = self._pin_module(self.transformer_blocks[0]).state_dict() - self._cpy_fn = self._copy_state_dict_with_staging - - self.compute_stream = cuda.default_stream(self.inference_device) - self.cpy_stream = cuda.Stream(self.inference_device) - - self.event_flip = torch.cuda.Event(enable_timing=False) - self.event_flop = torch.cuda.Event(enable_timing=False) - self.cpy_end_event = torch.cuda.Event(enable_timing=False) - - self.block_wrap_fn = block_wrap_fn - self.out_names = out_names - - self.num_blocks = len(self.transformer_blocks) - self.extra_run = self.num_blocks % 2 - - # INIT - self.compute_stream.record_event(self.cpy_end_event) - - def _copy_state_dict(self, dst, src, cpy_start_event=None, cpy_end_event=None): - if cpy_start_event: - self.cpy_stream.wait_event(cpy_start_event) - - with torch.cuda.stream(self.cpy_stream): - for k, v in src.items(): - dst[k].copy_(v, non_blocking=True) - if cpy_end_event: - cpy_end_event.record(self.cpy_stream) - - def _copy_state_dict_with_staging(self, dst, src, cpy_start_event=None, cpy_end_event=None): - if cpy_start_event: - self.cpy_stream.wait_event(cpy_start_event) - - with torch.cuda.stream(self.cpy_stream): - for k, v in src.items(): - self.staging_buffer[k].copy_(v, non_blocking=True) - dst[k].copy_(self.staging_buffer[k], non_blocking=True) - if cpy_end_event: - cpy_end_event.record(self.cpy_stream) - - def _pin_module(self, module): - pinned_module = copy.deepcopy(module) - for param in pinned_module.parameters(): - param.data = param.data.pin_memory() - # Pin all buffers (if any) - for buffer in pinned_module.buffers(): - buffer.data = buffer.data.pin_memory() - return pinned_module - - def _reset(self): - if self.extra_run: - self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop) - self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip) - else: - self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip) - self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop) - - self.compute_stream.record_event(self.cpy_end_event) - - @torch.no_grad() - def __call__(self, **feed_dict): - ''' - Flip accounts for even blocks (0 is first block), flop accounts for odd blocks. - ''' - # separated flip flop refactor - num_blocks = len(self.transformer_blocks) - first_flip = True - first_flop = True - last_flip = False - last_flop = False - for i, block in enumerate(self.transformer_blocks): - is_flip = i % 2 == 0 - if is_flip: - # flip - self.compute_stream.wait_event(self.cpy_end_event) - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flip, **feed_dict) - self.event_flip.record(self.compute_stream) - # while flip executes, queue flop to copy to its next block - next_flop_i = i + 1 - if next_flop_i >= num_blocks: - next_flop_i = next_flop_i - num_blocks - last_flip = True - if not first_flip: - self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event) - if last_flip: - self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip) - first_flip = False - else: - # flop - if not first_flop: - self.compute_stream.wait_event(self.cpy_end_event) - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flop, **feed_dict) - self.event_flop.record(self.compute_stream) - # while flop executes, queue flip to copy to its next block - next_flip_i = i + 1 - if next_flip_i >= num_blocks: - next_flip_i = next_flip_i - num_blocks - last_flop = True - self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event) - if last_flop: - self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop) - first_flop = False - - self.compute_stream.record_event(self.cpy_end_event) - - outputs = [feed_dict[name] for name in self.out_names] - if len(outputs) == 1: - return outputs[0] - return tuple(outputs) - - @torch.no_grad() - def __call__old(self, **feed_dict): - # contentis' prototype flip flop - # Wait for reset - self.compute_stream.wait_event(self.cpy_end_event) - - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flip, **feed_dict) - self.event_flip.record(self.compute_stream) - - for i in range(self.num_blocks // 2 - 1): - - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flop, **feed_dict) - self.event_flop.record(self.compute_stream) - - self._cpy_fn(self.flip.state_dict(), self.transformer_blocks[(i + 1) * 2].state_dict(), self.event_flip, - self.cpy_end_event) - - self.compute_stream.wait_event(self.cpy_end_event) - - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flip, **feed_dict) - self.event_flip.record(self.compute_stream) - - self._cpy_fn(self.flop.state_dict(), self.transformer_blocks[(i + 1) * 2 + 1].state_dict(), self.event_flop, - self.cpy_end_event) - self.compute_stream.wait_event(self.cpy_end_event) - - - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flop, **feed_dict) - self.event_flop.record(self.compute_stream) - - if self.extra_run: - self._cpy_fn(self.flip.state_dict(), self.transformer_blocks[-1].state_dict(), self.event_flip, - self.cpy_end_event) - self.compute_stream.wait_event(self.cpy_end_event) - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flip, **feed_dict) - self.event_flip.record(self.compute_stream) - - self._reset() - - outputs = [feed_dict[name] for name in self.out_names] - if len(outputs) == 1: - return outputs[0] - return tuple(outputs) - -# @register("Flux") -# class Flux: -# @staticmethod -# def double_block_wrap(block, **kwargs): -# kwargs["img"], kwargs["txt"] = block(img=kwargs["img"], -# txt=kwargs["txt"], -# vec=kwargs["vec"], -# pe=kwargs["pe"], -# attn_mask=kwargs.get("attn_mask")) -# return kwargs - -# @staticmethod -# def single_block_wrap(block, **kwargs): -# kwargs["img"] = block(kwargs["img"], -# vec=kwargs["vec"], -# pe=kwargs["pe"], -# attn_mask=kwargs.get("attn_mask")) -# return kwargs - -# double_config = FlipFlopConfig(block_name="double_blocks", -# block_wrap_fn=double_block_wrap, -# out_names=("img", "txt"), -# overwrite_forward="double_transformer_fwd", -# pinned_staging=False) - -# single_config = FlipFlopConfig(block_name="single_blocks", -# block_wrap_fn=single_block_wrap, -# out_names=("img",), -# overwrite_forward="single_transformer_fwd", -# pinned_staging=False) -# @staticmethod -# def patch(model): -# patch_model_from_config(model, Flux.double_config) -# patch_model_from_config(model, Flux.single_config) -# return model - - -# @register("WanModel") -# class Wan: -# @staticmethod -# def wan_blocks_wrap(block, **kwargs): -# kwargs["x"] = block(x=kwargs["x"], -# context=kwargs["context"], -# e=kwargs["e"], -# freqs=kwargs["freqs"], -# context_img_len=kwargs.get("context_img_len")) -# return kwargs - -# blocks_config = FlipFlopConfig(block_name="blocks", -# block_wrap_fn=wan_blocks_wrap, -# out_names=("x",), -# overwrite_forward="block_fwd", -# pinned_staging=False) - - -# @staticmethod -# def patch(model): -# patch_model_from_config(model, Wan.blocks_config) -# return model - -# @register("QwenImageTransformer2DModel") -# class QwenImage: -# @staticmethod -# def qwen_blocks_wrap(block, **kwargs): -# kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"], -# encoder_hidden_states=kwargs["encoder_hidden_states"], -# encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"], -# temb=kwargs["temb"], -# image_rotary_emb=kwargs["image_rotary_emb"], -# transformer_options=kwargs["transformer_options"]) -# return kwargs - -# blocks_config = FlipFlopConfig(block_name="transformer_blocks", -# block_wrap_fn=qwen_blocks_wrap, -# out_names=("encoder_hidden_states", "hidden_states"), -# overwrite_forward="blocks_fwd", -# pinned_staging=False) - - -# @staticmethod -# def patch(model): -# patch_model_from_config(model, QwenImage.blocks_config) -# return model diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 474d831c4..2af4968ac 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -5,12 +5,13 @@ import torch.nn.functional as F from typing import Optional, Tuple from einops import repeat -from comfy.ldm.flipflop_transformer import FlipFlopHolder +from comfy.ldm.flipflop_transformer import FlipFlopModule from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND import comfy.ldm.common_dit import comfy.patcher_extension +import comfy.ops class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): @@ -284,7 +285,7 @@ class LastLayer(nn.Module): return x -class QwenImageTransformer2DModel(nn.Module): +class QwenImageTransformer2DModel(FlipFlopModule): def __init__( self, patch_size: int = 2, @@ -301,9 +302,9 @@ class QwenImageTransformer2DModel(nn.Module): final_layer=True, dtype=None, device=None, - operations=None, + operations: comfy.ops.disable_weight_init=None, ): - super().__init__() + super().__init__(block_types=("transformer_blocks",)) self.dtype = dtype self.patch_size = patch_size self.in_channels = in_channels @@ -336,43 +337,10 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) - self.flipflop: dict[str, FlipFlopHolder] = {} - if final_layer: self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) - def setup_flipflop_holders(self, block_percentage: float): - if "transformer_blocks" in self.flipflop: - return - import comfy.model_management - # We hackily move any flipflopped blocks into holder so that our model management system does not see them. - num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage)) - loading = [] - for n, m in self.named_modules(): - params = [] - skip = False - for name, param in m.named_parameters(recurse=False): - params.append(name) - for name, param in m.named_parameters(recurse=True): - if name not in params: - skip = True # skip random weights in non leaf modules - break - if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) - self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks) - self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks]) - - def clean_flipflop_holders(self): - if "transformer_blocks" in self.flipflop: - self.flipflop["transformer_blocks"].clean_flipflop_blocks() - del self.flipflop["transformer_blocks"] - - def get_transformer_blocks(self): - if "transformer_blocks" in self.flipflop: - return self.transformer_blocks[:self.flipflop["transformer_blocks"].flip_amount] - return self.transformer_blocks - def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape patch_size = self.patch_size @@ -501,7 +469,7 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) - for i, block in enumerate(self.get_transformer_blocks()): + for i, block in enumerate(self.get_blocks("transformer_blocks")): encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) if "transformer_blocks" in self.flipflop: holder = self.flipflop["transformer_blocks"] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7e7b3a4ed..67d1cb233 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -616,19 +616,62 @@ class ModelPatcher: return False return True - def init_flipflop(self): + def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]]): if not self.supports_flipflop(): return - # figure out how many b - self.model.diffusion_model.setup_flipflop_holders(self.model_options["flipflop_block_percentage"]) + self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, self.load_device, self.offload_device) + + def init_flipflop_block_copies(self): + if not self.supports_flipflop(): + return + self.model.diffusion_model.init_flipflop_block_copies(self.load_device) def clean_flipflop(self): if not self.supports_flipflop(): return self.model.diffusion_model.clean_flipflop_holders() - def _load_list(self): + def _calc_flipflop_prefixes(self, lowvram_model_memory=0, prepare_flipflop=False): + flipflop_prefixes = [] + flipflop_blocks_per_type: dict[str, tuple[int, int]] = {} + if lowvram_model_memory > 0 and self.supports_flipflop(): + block_buffer = 3 + valid_block_types = [] + # for each block type, check if have enough room to flipflop + for block_info in self.model.diffusion_model.get_all_block_module_sizes(): + block_size: int = block_info[1] + if block_size * block_buffer < lowvram_model_memory: + valid_block_types.append(block_info) + # if have candidates for flipping, see how many of each type we have can flipflop + if len(valid_block_types) > 0: + leftover_memory = lowvram_model_memory + for block_info in valid_block_types: + block_type: str = block_info[0] + block_size: int = block_info[1] + total_blocks = len(self.model.diffusion_model.get_all_blocks(block_type)) + n_fit_in_memory = int(leftover_memory // block_size) + # if all (or more) of this block type would fit in memory, no need to flipflop with it + if n_fit_in_memory >= total_blocks: + continue + # if the amount of this block that would fit in memory is less than buffer, skip this block type + if n_fit_in_memory < block_buffer: + continue + # 2 blocks worth of VRAM may be needed for flipflop, so make sure to account for them. + flipflop_blocks = min((total_blocks - n_fit_in_memory) + 2, total_blocks) + flipflop_blocks_per_type[block_type] = (flipflop_blocks, total_blocks) + leftover_memory -= (total_blocks - flipflop_blocks + 2) * block_size + # if there are blocks to flipflop, need to mark their keys + for block_type, (flipflop_blocks, total_blocks) in flipflop_blocks_per_type.items(): + # blocks to flipflop are at the end + for i in range(total_blocks-flipflop_blocks, total_blocks): + flipflop_prefixes.append(f"diffusion_model.{block_type}.{i}") + if prepare_flipflop and len(flipflop_blocks_per_type) > 0: + self.setup_flipflop(flipflop_blocks_per_type) + return flipflop_prefixes + + def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False): loading = [] + flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop) for n, m in self.model.named_modules(): params = [] skip = False @@ -639,7 +682,12 @@ class ModelPatcher: skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) + flipflop = False + for prefix in flipflop_prefixes: + if n.startswith(prefix): + flipflop = True + break + loading.append((comfy.model_management.module_size(m), n, m, params, flipflop)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -649,16 +697,18 @@ class ModelPatcher: patch_counter = 0 lowvram_counter = 0 lowvram_mem_counter = 0 - if self.supports_flipflop(): - ... - loading = self._load_list() + flipflop_counter = 0 + flipflop_mem_counter = 0 + loading = self._load_list(lowvram_model_memory, prepare_flipflop=True) load_completely = [] + load_flipflop = [] loading.sort(reverse=True) for x in loading: n = x[1] m = x[2] params = x[3] + flipflop: bool = x[4] module_mem = x[0] lowvram_weight = False @@ -666,7 +716,7 @@ class ModelPatcher: weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) - if not full_load and hasattr(m, "comfy_cast_weights"): + if not full_load and hasattr(m, "comfy_cast_weights") and not flipflop: if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True lowvram_counter += 1 @@ -698,7 +748,11 @@ class ModelPatcher: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) - if full_load or mem_counter + module_mem < lowvram_model_memory: + if flipflop: + flipflop_counter += 1 + flipflop_mem_counter += module_mem + load_flipflop.append((module_mem, n, m, params)) + elif full_load or mem_counter + module_mem < lowvram_model_memory: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) @@ -714,6 +768,7 @@ class ModelPatcher: mem_counter += move_weight_functions(m, device_to) + # handle load completely load_completely.sort(reverse=True) for x in load_completely: n = x[1] @@ -732,11 +787,30 @@ class ModelPatcher: for x in load_completely: x[2].to(device_to) - if lowvram_counter > 0: - logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}") + # handle flipflop + if len(load_flipflop) > 0: + load_flipflop.sort(reverse=True) + for x in load_flipflop: + n = x[1] + m = x[2] + params = x[3] + if hasattr(m, "comfy_patched_weights"): + if m.comfy_patched_weights == True: + continue + for param in params: + self.patch_weight_to_device("{}.{}".format(n, param), device_to=self.offload_device) + + logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m)) + self.init_flipflop_block_copies() + + if lowvram_counter > 0 or flipflop_counter > 0: + if flipflop_counter > 0: + logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {flipflop_mem_counter / (1024 * 1024):.2f} MB to flipflop, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}") + else: + logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}") self.model.model_lowvram = True else: - logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}") + logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}") self.model.model_lowvram = False if full_load: self.model.to(device_to) @@ -773,6 +847,7 @@ class ModelPatcher: self.eject_model() if unpatch_weights: self.unpatch_hooks() + self.clean_flipflop() if self.model.model_lowvram: for m in self.model.modules(): move_weight_functions(m, device_to) From 0df61b5032e6b4d8225783d9742ed5e1cbbb2f11 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 1 Oct 2025 21:21:36 -0700 Subject: [PATCH 08/17] Fix improper index slicing for flipflop get blocks, add extra log message --- comfy/ldm/flipflop_transformer.py | 2 +- comfy/model_patcher.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index a88613c01..0e980ecda 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -33,7 +33,7 @@ class FlipFlopModule(torch.nn.Module): if block_type not in self.block_types: raise ValueError(f"Block type {block_type} not found in {self.block_types}") if block_type in self.flipflop: - return getattr(self, block_type)[:self.flipflop[block_type].flip_amount] + return getattr(self, block_type)[:self.flipflop[block_type].i_offset] return getattr(self, block_type) def get_all_block_module_sizes(self, reverse_sort_by_size: bool = False) -> list[tuple[str, int]]: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 67d1cb233..f47e5c157 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -619,6 +619,7 @@ class ModelPatcher: def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]]): if not self.supports_flipflop(): return + logging.info(f"setting up flipflop with {flipflop_blocks_per_type}") self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, self.load_device, self.offload_device) def init_flipflop_block_copies(self): From c4420b6a41470f81795f074d49c86634050dbd0f Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 2 Oct 2025 15:34:35 -0700 Subject: [PATCH 09/17] Change log string slightly --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f47e5c157..aa7fd6cd7 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -806,7 +806,7 @@ class ModelPatcher: if lowvram_counter > 0 or flipflop_counter > 0: if flipflop_counter > 0: - logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {flipflop_mem_counter / (1024 * 1024):.2f} MB to flipflop, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}") + logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {flipflop_mem_counter / (1024 * 1024):.2f} MB flipflop, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}") else: logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}") self.model.model_lowvram = True From 6d3ec9fcf3417b73f1d47ef2af0d1c1292a837dc Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 2 Oct 2025 16:46:37 -0700 Subject: [PATCH 10/17] Simplified flipflop setup by adding FlipFlopModule.execute_blocks helper --- comfy/ldm/flipflop_transformer.py | 19 +++++++++++++++++++ comfy/ldm/qwen_image/model.py | 12 +++--------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 0e980ecda..4018307db 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -48,6 +48,25 @@ class FlipFlopModule(torch.nn.Module): def get_block_module_size(self, block_type: str) -> int: return comfy.model_management.module_size(getattr(self, block_type)[0]) + def execute_blocks(self, block_type: str, func, out: torch.Tensor | tuple[torch.Tensor,...], *args, **kwargs): + # execute blocks, supporting both single and double (or higher) block types + if isinstance(out, torch.Tensor): + out = (out,) + for i, block in enumerate(self.get_blocks(block_type)): + out = func(i, block, *out, *args, **kwargs) + if isinstance(out, torch.Tensor): + out = (out,) + if "transformer_blocks" in self.flipflop: + holder = self.flipflop["transformer_blocks"] + with holder.context() as ctx: + for i, block in enumerate(holder.blocks): + out = ctx(func, i, block, *out, *args, **kwargs) + if isinstance(out, torch.Tensor): + out = (out,) + if len(out) == 1: + out = out[0] + return out + class FlipFlopContext: def __init__(self, holder: FlipFlopHolder): diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 2af4968ac..5d24a0029 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -400,7 +400,7 @@ class QwenImageTransformer2DModel(FlipFlopModule): if add is not None: hidden_states[:, :add.shape[1]] += add - return encoder_hidden_states, hidden_states + return hidden_states, encoder_hidden_states def _forward( self, @@ -469,14 +469,8 @@ class QwenImageTransformer2DModel(FlipFlopModule): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) - for i, block in enumerate(self.get_blocks("transformer_blocks")): - encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) - if "transformer_blocks" in self.flipflop: - holder = self.flipflop["transformer_blocks"] - with holder.context() as ctx: - for i, block in enumerate(holder.blocks): - encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) - + out = (hidden_states, encoder_hidden_states) + hidden_states, encoder_hidden_states = self.execute_blocks("transformer_blocks", self.indiv_block_fwd, out, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From 8d7b22b72030e384e11f6813a37ec695c958cbdd Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 2 Oct 2025 17:49:43 -0700 Subject: [PATCH 11/17] Fixed FlipFlipModule.execute_blocks having hardcoded strings from Qwen --- comfy/ldm/flipflop_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 4018307db..7bbb08208 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -56,8 +56,8 @@ class FlipFlopModule(torch.nn.Module): out = func(i, block, *out, *args, **kwargs) if isinstance(out, torch.Tensor): out = (out,) - if "transformer_blocks" in self.flipflop: - holder = self.flipflop["transformer_blocks"] + if block_type in self.flipflop: + holder = self.flipflop[block_type] with holder.context() as ctx: for i, block in enumerate(holder.blocks): out = ctx(func, i, block, *out, *args, **kwargs) From d5001ed90ef81d5095a665a15fd513b920f94832 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 2 Oct 2025 17:53:22 -0700 Subject: [PATCH 12/17] Make flux support flipflop --- comfy/ldm/flux/model.py | 137 +++++++++++++++++++++------------------- comfy/model_patcher.py | 3 +- 2 files changed, 75 insertions(+), 65 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 14f90cea5..5b014fc1b 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -7,6 +7,7 @@ from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit import comfy.patcher_extension +from comfy.ldm.flipflop_transformer import FlipFlopModule from .layers import ( DoubleStreamBlock, @@ -35,13 +36,13 @@ class FluxParams: guidance_embed: bool -class Flux(nn.Module): +class Flux(FlipFlopModule): """ Transformer model for flow matching on sequences. """ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): - super().__init__() + super().__init__(("double_blocks", "single_blocks")) self.dtype = dtype params = FluxParams(**kwargs) self.params = params @@ -89,6 +90,72 @@ class Flux(nn.Module): if final_layer: self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + def indiv_double_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask, + transformer_options=transformer_options) + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + img[:, :add.shape[1]] += add + return img, txt + + def indiv_single_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options): + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) + + if control is not None: # Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add + return img + def forward_orig( self, img: Tensor, @@ -136,74 +203,16 @@ class Flux(nn.Module): pe = None blocks_replace = patches_replace.get("dit", {}) - for i, block in enumerate(self.double_blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"], out["txt"] = block(img=args["img"], - txt=args["txt"], - vec=args["vec"], - pe=args["pe"], - attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) - return out - - out = blocks_replace[("double_block", i)]({"img": img, - "txt": txt, - "vec": vec, - "pe": pe, - "attn_mask": attn_mask, - "transformer_options": transformer_options}, - {"original_block": block_wrap}) - txt = out["txt"] - img = out["img"] - else: - img, txt = block(img=img, - txt=txt, - vec=vec, - pe=pe, - attn_mask=attn_mask, - transformer_options=transformer_options) - - if control is not None: # Controlnet - control_i = control.get("input") - if i < len(control_i): - add = control_i[i] - if add is not None: - img[:, :add.shape[1]] += add + # execute double blocks + img, txt = self.execute_blocks("double_blocks", self.indiv_double_block_fwd, (img, txt), vec, pe, attn_mask, control, blocks_replace, transformer_options) if img.dtype == torch.float16: img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) img = torch.cat((txt, img), 1) - for i, block in enumerate(self.single_blocks): - if ("single_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], - vec=args["vec"], - pe=args["pe"], - attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) - return out - - out = blocks_replace[("single_block", i)]({"img": img, - "vec": vec, - "pe": pe, - "attn_mask": attn_mask, - "transformer_options": transformer_options}, - {"original_block": block_wrap}) - img = out["img"] - else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) - - if control is not None: # Controlnet - control_o = control.get("output") - if i < len(control_o): - add = control_o[i] - if add is not None: - img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add + # execute single blocks + img = self.execute_blocks("single_blocks", self.indiv_single_block_fwd, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options) img = img[:, txt.shape[1] :, ...] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index aa7fd6cd7..8ffed3efd 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -639,7 +639,7 @@ class ModelPatcher: block_buffer = 3 valid_block_types = [] # for each block type, check if have enough room to flipflop - for block_info in self.model.diffusion_model.get_all_block_module_sizes(): + for block_info in self.model.diffusion_model.get_all_block_module_sizes(reverse_sort_by_size=False): block_size: int = block_info[1] if block_size * block_buffer < lowvram_model_memory: valid_block_types.append(block_info) @@ -653,6 +653,7 @@ class ModelPatcher: n_fit_in_memory = int(leftover_memory // block_size) # if all (or more) of this block type would fit in memory, no need to flipflop with it if n_fit_in_memory >= total_blocks: + leftover_memory -= total_blocks * block_size continue # if the amount of this block that would fit in memory is less than buffer, skip this block type if n_fit_in_memory < block_buffer: From 0d8e8abd90e3357782084c3a4a3d2913bda3d280 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 2 Oct 2025 18:00:21 -0700 Subject: [PATCH 13/17] Default ro smaller blocks getting flipflopped first --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 8ffed3efd..02572b8ce 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -639,7 +639,7 @@ class ModelPatcher: block_buffer = 3 valid_block_types = [] # for each block type, check if have enough room to flipflop - for block_info in self.model.diffusion_model.get_all_block_module_sizes(reverse_sort_by_size=False): + for block_info in self.model.diffusion_model.get_all_block_module_sizes(reverse_sort_by_size=True): block_size: int = block_info[1] if block_size * block_buffer < lowvram_model_memory: valid_block_types.append(block_info) From 831c3cf05e77ac4c7afb4e9b640801625cb52414 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 2 Oct 2025 20:29:11 -0700 Subject: [PATCH 14/17] Add a temporary workaround for odd amount of blocks not producing expected results --- comfy/model_patcher.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 02572b8ce..142b810b3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -660,6 +660,12 @@ class ModelPatcher: continue # 2 blocks worth of VRAM may be needed for flipflop, so make sure to account for them. flipflop_blocks = min((total_blocks - n_fit_in_memory) + 2, total_blocks) + # for now, work around odd number issue by making it even + if flipflop_blocks % 2 != 0: + if flipflop_blocks == total_blocks: + flipflop_blocks -= 1 + else: + flipflop_blocks += 1 flipflop_blocks_per_type[block_type] = (flipflop_blocks, total_blocks) leftover_memory -= (total_blocks - flipflop_blocks + 2) * block_size # if there are blocks to flipflop, need to mark their keys From ee01002e6377531506058dfae9915ac051694feb Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 2 Oct 2025 22:02:50 -0700 Subject: [PATCH 15/17] Add flipflop support to (base) WAN, fix issue with applying loras to flipflop weights being done on CPU instead of GPU, left some timing functions as the lora application time could use some reduction --- comfy/ldm/flipflop_transformer.py | 3 ++- comfy/ldm/wan/model.py | 38 ++++++++++++++++++------------- comfy/ldm/wan/model_animate.py | 2 +- comfy/model_patcher.py | 18 +++++++++++---- 4 files changed, 39 insertions(+), 22 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 7bbb08208..9e9c28468 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -6,9 +6,10 @@ import comfy.model_management class FlipFlopModule(torch.nn.Module): - def __init__(self, block_types: tuple[str, ...]): + def __init__(self, block_types: tuple[str, ...], enable_flipflop: bool = True): super().__init__() self.block_types = block_types + self.enable_flipflop = enable_flipflop self.flipflop: dict[str, FlipFlopHolder] = {} def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device): diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 0dc650ced..7f8ca65d7 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -7,6 +7,7 @@ import torch.nn as nn from einops import rearrange from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flipflop_transformer import FlipFlopModule from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope1 import comfy.ldm.common_dit @@ -384,7 +385,7 @@ class MLPProj(torch.nn.Module): return clip_extra_context_tokens -class WanModel(torch.nn.Module): +class WanModel(FlipFlopModule): r""" Wan diffusion backbone supporting both text-to-video and image-to-video. """ @@ -412,6 +413,7 @@ class WanModel(torch.nn.Module): device=None, dtype=None, operations=None, + enable_flipflop=True, ): r""" Initialize the diffusion model backbone. @@ -449,7 +451,7 @@ class WanModel(torch.nn.Module): Epsilon value for normalization layers """ - super().__init__() + super().__init__(block_types=("blocks",), enable_flipflop=enable_flipflop) self.dtype = dtype operation_settings = {"operations": operations, "device": device, "dtype": dtype} @@ -506,6 +508,18 @@ class WanModel(torch.nn.Module): else: self.ref_conv = None + def indiv_block_fwd(self, i, block, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + return x + def forward_orig( self, x, @@ -567,16 +581,8 @@ class WanModel(torch.nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) - for i, block in enumerate(self.blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) - return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) - x = out["img"] - else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + # execute blocks + x = self.execute_blocks("blocks", self.indiv_block_fwd, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options) # head x = self.head(x, e) @@ -688,7 +694,7 @@ class VaceWanModel(WanModel): operations=None, ): - super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) operation_settings = {"operations": operations, "device": device, "dtype": dtype} # Vace @@ -808,7 +814,7 @@ class CameraWanModel(WanModel): else: model_type = 't2v' - super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) operation_settings = {"operations": operations, "device": device, "dtype": dtype} self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) @@ -1211,7 +1217,7 @@ class WanModel_S2V(WanModel): operations=None, ): - super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype) @@ -1511,7 +1517,7 @@ class HumoWanModel(WanModel): operations=None, ): - super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations) diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py index 7c87835d4..2c09420c5 100644 --- a/comfy/ldm/wan/model_animate.py +++ b/comfy/ldm/wan/model_animate.py @@ -426,7 +426,7 @@ class AnimateWanModel(WanModel): operations=None, ): - super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False) self.pose_patch_embedding = operations.Conv3d( 16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 142b810b3..08055b65c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -25,7 +25,7 @@ import logging import math import uuid from typing import Callable, Optional - +import time # TODO remove import torch import comfy.float @@ -577,7 +577,7 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, device_final=None): if key not in self.patches: return @@ -597,18 +597,22 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + if device_final is not None: + out_weight = out_weight.to(device_final) if inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: + if device_final is not None: + out_weight = out_weight.to(device_final) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) def supports_flipflop(self): # flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM if not hasattr(self.model, "diffusion_model"): return False - if not hasattr(self.model.diffusion_model, "flipflop"): + if not getattr(self.model.diffusion_model, "enable_flipflop", False): return False if not comfy.model_management.is_nvidia(): return False @@ -797,6 +801,7 @@ class ModelPatcher: # handle flipflop if len(load_flipflop) > 0: + start_time = time.perf_counter() load_flipflop.sort(reverse=True) for x in load_flipflop: n = x[1] @@ -806,10 +811,15 @@ class ModelPatcher: if m.comfy_patched_weights == True: continue for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=self.offload_device) + self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to, device_final=self.offload_device) logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m)) + end_time = time.perf_counter() + logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds") + start_time = time.perf_counter() self.init_flipflop_block_copies() + end_time = time.perf_counter() + logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds") if lowvram_counter > 0 or flipflop_counter > 0: if flipflop_counter > 0: From 5329180fce39baac993d6ba65fea2fc7814c2961 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 3 Oct 2025 16:21:01 -0700 Subject: [PATCH 16/17] Made flipflop consider partial_unload, partial_offload, and add flip+flop to mem counters --- comfy/ldm/flipflop_transformer.py | 29 ++++++++++++++++++----- comfy/model_patcher.py | 38 ++++++++++++++++++++----------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 9e9c28468..0379d14ff 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -11,21 +11,31 @@ class FlipFlopModule(torch.nn.Module): self.block_types = block_types self.enable_flipflop = enable_flipflop self.flipflop: dict[str, FlipFlopHolder] = {} + self.block_info: dict[str, tuple[int, int]] = {} + self.flipflop_prefixes: list[str] = [] - def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device): + def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], flipflop_prefixes: list[str], load_device: torch.device, offload_device: torch.device): for block_type, (flipflop_blocks, total_blocks) in block_info.items(): if block_type in self.flipflop: continue self.flipflop[block_type] = FlipFlopHolder(getattr(self, block_type)[total_blocks-flipflop_blocks:], flipflop_blocks, total_blocks, load_device, offload_device) + self.block_info[block_type] = (flipflop_blocks, total_blocks) + self.flipflop_prefixes = flipflop_prefixes.copy() - def init_flipflop_block_copies(self, device: torch.device): + def init_flipflop_block_copies(self, device: torch.device) -> int: + memory_freed = 0 for holder in self.flipflop.values(): - holder.init_flipflop_block_copies(device) + memory_freed += holder.init_flipflop_block_copies(device) + return memory_freed def clean_flipflop_holders(self): + memory_freed = 0 for block_type in list(self.flipflop.keys()): - self.flipflop[block_type].clean_flipflop_blocks() + memory_freed += self.flipflop[block_type].clean_flipflop_blocks() del self.flipflop[block_type] + self.block_info = {} + self.flipflop_prefixes = [] + return memory_freed def get_all_blocks(self, block_type: str) -> list[torch.nn.Module]: return getattr(self, block_type) @@ -71,6 +81,8 @@ class FlipFlopModule(torch.nn.Module): class FlipFlopContext: def __init__(self, holder: FlipFlopHolder): + # NOTE: there is a bug when there are an odd number of blocks to flipflop. + # Worked around right now by always making sure it will be even, but need to resolve. self.holder = holder self.reset() @@ -172,12 +184,17 @@ class FlipFlopHolder: def context(self): return FlipFlopContext(self) - def init_flipflop_block_copies(self, load_device: torch.device): + def init_flipflop_block_copies(self, load_device: torch.device) -> int: self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device) self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device) + return comfy.model_management.module_size(self.flip) + comfy.model_management.module_size(self.flop) - def clean_flipflop_blocks(self): + def clean_flipflop_blocks(self) -> int: + memory_freed = 0 + memory_freed += comfy.model_management.module_size(self.flip) + memory_freed += comfy.model_management.module_size(self.flop) del self.flip del self.flop self.flip = None self.flop = None + return memory_freed diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 08055b65c..dfb38d6e8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -620,21 +620,26 @@ class ModelPatcher: return False return True - def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]]): + def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]], flipflop_prefixes: list[str]): if not self.supports_flipflop(): return logging.info(f"setting up flipflop with {flipflop_blocks_per_type}") - self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, self.load_device, self.offload_device) + self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, flipflop_prefixes, self.load_device, self.offload_device) - def init_flipflop_block_copies(self): + def init_flipflop_block_copies(self) -> int: if not self.supports_flipflop(): - return - self.model.diffusion_model.init_flipflop_block_copies(self.load_device) + return 0 + return self.model.diffusion_model.init_flipflop_block_copies(self.load_device) - def clean_flipflop(self): + def clean_flipflop(self) -> int: if not self.supports_flipflop(): - return - self.model.diffusion_model.clean_flipflop_holders() + return 0 + return self.model.diffusion_model.clean_flipflop_holders() + + def _get_existing_flipflop_prefixes(self): + if self.supports_flipflop(): + return self.model.diffusion_model.flipflop_prefixes + return [] def _calc_flipflop_prefixes(self, lowvram_model_memory=0, prepare_flipflop=False): flipflop_prefixes = [] @@ -678,12 +683,15 @@ class ModelPatcher: for i in range(total_blocks-flipflop_blocks, total_blocks): flipflop_prefixes.append(f"diffusion_model.{block_type}.{i}") if prepare_flipflop and len(flipflop_blocks_per_type) > 0: - self.setup_flipflop(flipflop_blocks_per_type) + self.setup_flipflop(flipflop_blocks_per_type, flipflop_prefixes) return flipflop_prefixes - def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False): + def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False, get_existing_flipflop=False): loading = [] - flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop) + if get_existing_flipflop: + flipflop_prefixes = self._get_existing_flipflop_prefixes() + else: + flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop) for n, m in self.model.named_modules(): params = [] skip = False @@ -817,7 +825,7 @@ class ModelPatcher: end_time = time.perf_counter() logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds") start_time = time.perf_counter() - self.init_flipflop_block_copies() + mem_counter += self.init_flipflop_block_copies() end_time = time.perf_counter() logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds") @@ -905,8 +913,9 @@ class ModelPatcher: with self.use_ejected(): hooks_unpatched = False memory_freed = 0 + memory_freed += self.clean_flipflop() patch_counter = 0 - unload_list = self._load_list() + unload_list = self._load_list(get_existing_flipflop=True) unload_list.sort() for unload in unload_list: if memory_to_free < memory_freed: @@ -915,7 +924,10 @@ class ModelPatcher: n = unload[1] m = unload[2] params = unload[3] + flipflop: bool = unload[4] + if flipflop: + continue lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: move_weight = True From 61133af77292df630e1e7fceda80ee7ed8a271c6 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 13 Oct 2025 21:10:44 -0700 Subject: [PATCH 17/17] Add '--flipflop-offload' startup argument --- comfy/cli_args.py | 2 ++ comfy/model_management.py | 2 ++ comfy/model_patcher.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc1f12482..1eddeebd4 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -132,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") +parser.add_argument("--flipflop-offload", action="store_true", help="Use async flipflop weight offloading for supported DiT models.") + parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 709ebc40b..c0edb251d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1001,6 +1001,8 @@ def force_channels_last(): #TODO return False +def flipflop_enabled(): + return args.flipflop_offload STREAMS = {} NUM_STREAMS = 1 diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d810e5b76..18e6f2e23 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -624,6 +624,8 @@ class ModelPatcher: def supports_flipflop(self): # flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM + if not comfy.model_management.flipflop_enabled(): + return False if not hasattr(self.model, "diffusion_model"): return False if not getattr(self.model.diffusion_model, "enable_flipflop", False):