diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 242672462..754733a86 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -5,6 +5,8 @@ import copy from typing import List, Tuple from dataclasses import dataclass +import comfy.model_management + FLIPFLOP_REGISTRY = {} def register(name): @@ -105,15 +107,21 @@ class FlipFlopContext: class FlipFlopHolder: def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"): - self.inference_device = torch.device(inference_device) - self.offloading_device = torch.device(offloading_device) + self.load_device = torch.device(inference_device) + self.offload_device = torch.device(offloading_device) self.transformer_blocks = transformer_blocks - self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=self.inference_device) - self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=self.inference_device) + self.block_module_size = 0 + if len(self.transformer_blocks) > 0: + self.block_module_size = comfy.model_management.module_size(self.transformer_blocks[0]) - self.compute_stream = cuda.default_stream(self.inference_device) - self.cpy_stream = cuda.Stream(self.inference_device) + self.flip: torch.nn.Module = None + self.flop: torch.nn.Module = None + # TODO: make initialization happen in model management code/model patcher, not here + self.initialize_flipflop_blocks(self.load_device) + + self.compute_stream = cuda.default_stream(self.load_device) + self.cpy_stream = cuda.Stream(self.load_device) self.event_flip = torch.cuda.Event(enable_timing=False) self.event_flop = torch.cuda.Event(enable_timing=False) @@ -134,6 +142,10 @@ class FlipFlopHolder: def context(self): return FlipFlopContext(self) + def initialize_flipflop_blocks(self, load_device: torch.device): + self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=load_device) + self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=load_device) + class FlipFlopTransformer: def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"): self.transformer_blocks = transformer_blocks diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index a338a6805..d1833ffe6 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -336,7 +336,7 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) - self.flipflop_holders: dict[str, FlipFlopHolder] = {} + self.flipflop: dict[str, FlipFlopHolder] = {} if final_layer: self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) @@ -344,8 +344,8 @@ class QwenImageTransformer2DModel(nn.Module): def setup_flipflop_holders(self, block_percentage: float): # We hackily move any flipflopped blocks into holder so that our model management system does not see them. - num_blocks = int(len(self.transformer_blocks) * block_percentage) - self.flipflop_holders["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:]) + num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage)) + self.flipflop["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:]) self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks]) def process_img(self, x, index=0, h_offset=0, w_offset=0): @@ -412,8 +412,8 @@ class QwenImageTransformer2DModel(nn.Module): def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options): for i, block in enumerate(self.transformer_blocks): encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) - if "blocks_fwd" in self.flipflop_holders: - holder = self.flipflop_holders["blocks_fwd"] + if "blocks_fwd" in self.flipflop: + holder = self.flipflop["blocks_fwd"] with holder.context() as ctx: for i, block in enumerate(holder.transformer_blocks): encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1fd03d9d1..b9a768203 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -604,6 +604,9 @@ class ModelPatcher: else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + def supports_flipflop(self): + return hasattr(self.model.diffusion_model, "flipflop") + def _load_list(self): loading = [] for n, m in self.model.named_modules():