In-progress commit on making flipflop async weight streaming native, made loaded partially/loaded completely log messages have labels because having to memorize their meaning for dev work is annoying

This commit is contained in:
Jedrzej Kosinski 2025-09-30 23:08:08 -07:00
parent d0bd221495
commit 01f4512bf8
4 changed files with 145 additions and 119 deletions

View File

@ -3,42 +3,9 @@ import torch
import torch.cuda as cuda
import copy
from typing import List, Tuple
from dataclasses import dataclass
import comfy.model_management
FLIPFLOP_REGISTRY = {}
def register(name):
def decorator(cls):
FLIPFLOP_REGISTRY[name] = cls
return cls
return decorator
@dataclass
class FlipFlopConfig:
block_name: str
block_wrap_fn: callable
out_names: Tuple[str]
overwrite_forward: str
pinned_staging: bool = False
inference_device: str = "cuda"
offloading_device: str = "cpu"
def patch_model_from_config(model, config: FlipFlopConfig):
block_list = getattr(model, config.block_name)
flip_flop_transformer = FlipFlopTransformer(block_list,
block_wrap_fn=config.block_wrap_fn,
out_names=config.out_names,
offloading_device=config.offloading_device,
inference_device=config.inference_device,
pinned_staging=config.pinned_staging)
delattr(model, config.block_name)
setattr(model, config.block_name, flip_flop_transformer)
setattr(model, config.overwrite_forward, flip_flop_transformer.__call__)
class FlipFlopContext:
def __init__(self, holder: FlipFlopHolder):
@ -46,11 +13,12 @@ class FlipFlopContext:
self.reset()
def reset(self):
self.num_blocks = len(self.holder.transformer_blocks)
self.num_blocks = len(self.holder.blocks)
self.first_flip = True
self.first_flop = True
self.last_flip = False
self.last_flop = False
# TODO: the 'i' that's passed into func needs to be properly offset to do patches correctly
def __enter__(self):
self.reset()
@ -71,9 +39,9 @@ class FlipFlopContext:
next_flop_i = next_flop_i - self.num_blocks
self.last_flip = True
if not self.first_flip:
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
if self.last_flip:
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
self.first_flip = False
return out
@ -89,9 +57,9 @@ class FlipFlopContext:
if next_flip_i >= self.num_blocks:
next_flip_i = next_flip_i - self.num_blocks
self.last_flop = True
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
if self.last_flop:
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
self.first_flop = False
return out
@ -106,19 +74,20 @@ class FlipFlopContext:
class FlipFlopHolder:
def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"):
self.load_device = torch.device(inference_device)
self.offload_device = torch.device(offloading_device)
self.transformer_blocks = transformer_blocks
def __init__(self, blocks: List[torch.nn.Module], flip_amount: int, load_device="cuda", offload_device="cpu"):
self.load_device = torch.device(load_device)
self.offload_device = torch.device(offload_device)
self.blocks = blocks
self.flip_amount = flip_amount
self.block_module_size = 0
if len(self.transformer_blocks) > 0:
self.block_module_size = comfy.model_management.module_size(self.transformer_blocks[0])
if len(self.blocks) > 0:
self.block_module_size = comfy.model_management.module_size(self.blocks[0])
self.flip: torch.nn.Module = None
self.flop: torch.nn.Module = None
# TODO: make initialization happen in model management code/model patcher, not here
self.initialize_flipflop_blocks(self.load_device)
self.init_flipflop_blocks(self.load_device)
self.compute_stream = cuda.default_stream(self.load_device)
self.cpy_stream = cuda.Stream(self.load_device)
@ -142,10 +111,57 @@ class FlipFlopHolder:
def context(self):
return FlipFlopContext(self)
def initialize_flipflop_blocks(self, load_device: torch.device):
self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=load_device)
self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=load_device)
def init_flipflop_blocks(self, load_device: torch.device):
self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device)
self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device)
def clean_flipflop_blocks(self):
del self.flip
del self.flop
self.flip = None
self.flop = None
class FlopFlopModule(torch.nn.Module):
def __init__(self, block_types: tuple[str, ...]):
super().__init__()
self.block_types = block_types
self.flipflop: dict[str, FlipFlopHolder] = {}
def setup_flipflop_holders(self, block_percentage: float):
for block_type in self.block_types:
if block_type in self.flipflop:
continue
num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage))
self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks)
def clean_flipflop_holders(self):
for block_type in self.flipflop.keys():
self.flipflop[block_type].clean_flipflop_blocks()
del self.flipflop[block_type]
def get_blocks(self, block_type: str) -> torch.nn.ModuleList:
if block_type not in self.block_types:
raise ValueError(f"Block type {block_type} not found in {self.block_types}")
if block_type in self.flipflop:
return getattr(self, block_type)[:self.flipflop[block_type].flip_amount]
return getattr(self, block_type)
def get_all_block_module_sizes(self, sort_by_size: bool = False) -> list[tuple[str, int]]:
'''
Returns a list of (block_type, size).
If sort_by_size is True, the list is sorted by size.
'''
sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types]
if sort_by_size:
sizes.sort(key=lambda x: x[1])
return sizes
def get_block_module_size(self, block_type: str) -> int:
return comfy.model_management.module_size(getattr(self, block_type)[0])
# Below is the implementation from contentis' prototype flip flop
class FlipFlopTransformer:
def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"):
self.transformer_blocks = transformer_blocks
@ -379,28 +395,26 @@ class FlipFlopTransformer:
# patch_model_from_config(model, Wan.blocks_config)
# return model
# @register("QwenImageTransformer2DModel")
# class QwenImage:
# @staticmethod
# def qwen_blocks_wrap(block, **kwargs):
# kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"],
# encoder_hidden_states=kwargs["encoder_hidden_states"],
# encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
# temb=kwargs["temb"],
# image_rotary_emb=kwargs["image_rotary_emb"],
# transformer_options=kwargs["transformer_options"])
# return kwargs
@register("QwenImageTransformer2DModel")
class QwenImage:
@staticmethod
def qwen_blocks_wrap(block, **kwargs):
kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"],
encoder_hidden_states=kwargs["encoder_hidden_states"],
encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
temb=kwargs["temb"],
image_rotary_emb=kwargs["image_rotary_emb"],
transformer_options=kwargs["transformer_options"])
return kwargs
blocks_config = FlipFlopConfig(block_name="transformer_blocks",
block_wrap_fn=qwen_blocks_wrap,
out_names=("encoder_hidden_states", "hidden_states"),
overwrite_forward="blocks_fwd",
pinned_staging=False)
# blocks_config = FlipFlopConfig(block_name="transformer_blocks",
# block_wrap_fn=qwen_blocks_wrap,
# out_names=("encoder_hidden_states", "hidden_states"),
# overwrite_forward="blocks_fwd",
# pinned_staging=False)
@staticmethod
def patch(model):
patch_model_from_config(model, QwenImage.blocks_config)
return model
# @staticmethod
# def patch(model):
# patch_model_from_config(model, QwenImage.blocks_config)
# return model

View File

@ -343,11 +343,36 @@ class QwenImageTransformer2DModel(nn.Module):
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
def setup_flipflop_holders(self, block_percentage: float):
if "transformer_blocks" in self.flipflop:
return
import comfy.model_management
# We hackily move any flipflopped blocks into holder so that our model management system does not see them.
num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage))
self.flipflop["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:])
loading = []
for n, m in self.named_modules():
params = []
skip = False
for name, param in m.named_parameters(recurse=False):
params.append(name)
for name, param in m.named_parameters(recurse=True):
if name not in params:
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params))
self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks)
self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks])
def clean_flipflop_holders(self):
if "transformer_blocks" in self.flipflop:
self.flipflop["transformer_blocks"].clean_flipflop_blocks()
del self.flipflop["transformer_blocks"]
def get_transformer_blocks(self):
if "transformer_blocks" in self.flipflop:
return self.transformer_blocks[:self.flipflop["transformer_blocks"].flip_amount]
return self.transformer_blocks
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
@ -409,17 +434,6 @@ class QwenImageTransformer2DModel(nn.Module):
return encoder_hidden_states, hidden_states
def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
for i, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
if "blocks_fwd" in self.flipflop:
holder = self.flipflop["blocks_fwd"]
with holder.context() as ctx:
for i, block in enumerate(holder.transformer_blocks):
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
return encoder_hidden_states, hidden_states
def _forward(
self,
x,
@ -487,12 +501,14 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
encoder_hidden_states, hidden_states = self.blocks_fwd(hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb, image_rotary_emb=image_rotary_emb,
patches=patches, control=control, blocks_replace=blocks_replace, x=x,
transformer_options=transformer_options)
for i, block in enumerate(self.get_transformer_blocks()):
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
if "transformer_blocks" in self.flipflop:
holder = self.flipflop["transformer_blocks"]
with holder.context() as ctx:
for i, block in enumerate(holder.blocks):
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

View File

@ -605,7 +605,27 @@ class ModelPatcher:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def supports_flipflop(self):
return hasattr(self.model.diffusion_model, "flipflop")
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
if not hasattr(self.model, "diffusion_model"):
return False
if not hasattr(self.model.diffusion_model, "flipflop"):
return False
if not comfy.model_management.is_nvidia():
return False
if comfy.model_management.vram_state in (comfy.model_management.VRAMState.HIGH_VRAM, comfy.model_management.VRAMState.SHARED):
return False
return True
def init_flipflop(self):
if not self.supports_flipflop():
return
# figure out how many b
self.model.diffusion_model.setup_flipflop_holders(self.model_options["flipflop_block_percentage"])
def clean_flipflop(self):
if not self.supports_flipflop():
return
self.model.diffusion_model.clean_flipflop_holders()
def _load_list(self):
loading = []
@ -628,6 +648,9 @@ class ModelPatcher:
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
lowvram_mem_counter = 0
if self.supports_flipflop():
...
loading = self._load_list()
load_completely = []
@ -647,6 +670,7 @@ class ModelPatcher:
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
@ -709,10 +733,10 @@ class ModelPatcher:
x[2].to(device_to)
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}")
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}")
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)

View File

@ -3,33 +3,6 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy.ldm.flipflop_transformer import FLIPFLOP_REGISTRY
class FlipFlopOld(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="FlipFlop",
display_name="FlipFlop (Old)",
category="_for_testing",
inputs=[
io.Model.Input(id="model")
],
outputs=[
io.Model.Output()
],
description="Apply FlipFlop transformation to model using registry-based patching"
)
@classmethod
def execute(cls, model) -> io.NodeOutput:
patch_cls = FLIPFLOP_REGISTRY.get(model.model.diffusion_model.__class__.__name__, None)
if patch_cls is None:
raise ValueError(f"Model {model.model.diffusion_model.__class__.__name__} not supported")
model.model.diffusion_model = patch_cls.patch(model.model.diffusion_model)
return io.NodeOutput(model)
class FlipFlop(io.ComfyNode):
@classmethod
@ -62,7 +35,6 @@ class FlipFlopExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
FlipFlopOld,
FlipFlop,
]