Fix percentage logic, begin adding elements to ModelPatcher to track flip flop compatibility

This commit is contained in:
Jedrzej Kosinski 2025-09-29 22:49:12 -07:00
parent ff789c8beb
commit 8a8162e8da
3 changed files with 26 additions and 11 deletions

View File

@ -5,6 +5,8 @@ import copy
from typing import List, Tuple
from dataclasses import dataclass
import comfy.model_management
FLIPFLOP_REGISTRY = {}
def register(name):
@ -105,15 +107,21 @@ class FlipFlopContext:
class FlipFlopHolder:
def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"):
self.inference_device = torch.device(inference_device)
self.offloading_device = torch.device(offloading_device)
self.load_device = torch.device(inference_device)
self.offload_device = torch.device(offloading_device)
self.transformer_blocks = transformer_blocks
self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=self.inference_device)
self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=self.inference_device)
self.block_module_size = 0
if len(self.transformer_blocks) > 0:
self.block_module_size = comfy.model_management.module_size(self.transformer_blocks[0])
self.compute_stream = cuda.default_stream(self.inference_device)
self.cpy_stream = cuda.Stream(self.inference_device)
self.flip: torch.nn.Module = None
self.flop: torch.nn.Module = None
# TODO: make initialization happen in model management code/model patcher, not here
self.initialize_flipflop_blocks(self.load_device)
self.compute_stream = cuda.default_stream(self.load_device)
self.cpy_stream = cuda.Stream(self.load_device)
self.event_flip = torch.cuda.Event(enable_timing=False)
self.event_flop = torch.cuda.Event(enable_timing=False)
@ -134,6 +142,10 @@ class FlipFlopHolder:
def context(self):
return FlipFlopContext(self)
def initialize_flipflop_blocks(self, load_device: torch.device):
self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=load_device)
self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=load_device)
class FlipFlopTransformer:
def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"):
self.transformer_blocks = transformer_blocks

View File

@ -336,7 +336,7 @@ class QwenImageTransformer2DModel(nn.Module):
for _ in range(num_layers)
])
self.flipflop_holders: dict[str, FlipFlopHolder] = {}
self.flipflop: dict[str, FlipFlopHolder] = {}
if final_layer:
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
@ -344,8 +344,8 @@ class QwenImageTransformer2DModel(nn.Module):
def setup_flipflop_holders(self, block_percentage: float):
# We hackily move any flipflopped blocks into holder so that our model management system does not see them.
num_blocks = int(len(self.transformer_blocks) * block_percentage)
self.flipflop_holders["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:])
num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage))
self.flipflop["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:])
self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks])
def process_img(self, x, index=0, h_offset=0, w_offset=0):
@ -412,8 +412,8 @@ class QwenImageTransformer2DModel(nn.Module):
def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
for i, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
if "blocks_fwd" in self.flipflop_holders:
holder = self.flipflop_holders["blocks_fwd"]
if "blocks_fwd" in self.flipflop:
holder = self.flipflop["blocks_fwd"]
with holder.context() as ctx:
for i, block in enumerate(holder.transformer_blocks):
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)

View File

@ -604,6 +604,9 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def supports_flipflop(self):
return hasattr(self.model.diffusion_model, "flipflop")
def _load_list(self):
loading = []
for n, m in self.model.named_modules():