From 7c896c55674fbd6d6f8f5e6de91a3c508bb59cff Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 1 Oct 2025 20:13:50 -0700 Subject: [PATCH] Initial automatic support for flipflop within ModelPatcher - only Qwen Image diffusion_model uses FlipFlopModule currently --- comfy/ldm/flipflop_transformer.py | 367 +++++------------------------- comfy/ldm/qwen_image/model.py | 44 +--- comfy/model_patcher.py | 101 ++++++-- 3 files changed, 149 insertions(+), 363 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 03e256c7a..a88613c01 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -1,12 +1,54 @@ from __future__ import annotations import torch -import torch.cuda as cuda import copy -from typing import List, Tuple import comfy.model_management +class FlipFlopModule(torch.nn.Module): + def __init__(self, block_types: tuple[str, ...]): + super().__init__() + self.block_types = block_types + self.flipflop: dict[str, FlipFlopHolder] = {} + + def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device): + for block_type, (flipflop_blocks, total_blocks) in block_info.items(): + if block_type in self.flipflop: + continue + self.flipflop[block_type] = FlipFlopHolder(getattr(self, block_type)[total_blocks-flipflop_blocks:], flipflop_blocks, total_blocks, load_device, offload_device) + + def init_flipflop_block_copies(self, device: torch.device): + for holder in self.flipflop.values(): + holder.init_flipflop_block_copies(device) + + def clean_flipflop_holders(self): + for block_type in list(self.flipflop.keys()): + self.flipflop[block_type].clean_flipflop_blocks() + del self.flipflop[block_type] + + def get_all_blocks(self, block_type: str) -> list[torch.nn.Module]: + return getattr(self, block_type) + + def get_blocks(self, block_type: str) -> torch.nn.ModuleList: + if block_type not in self.block_types: + raise ValueError(f"Block type {block_type} not found in {self.block_types}") + if block_type in self.flipflop: + return getattr(self, block_type)[:self.flipflop[block_type].flip_amount] + return getattr(self, block_type) + + def get_all_block_module_sizes(self, reverse_sort_by_size: bool = False) -> list[tuple[str, int]]: + ''' + Returns a list of (block_type, size) sorted by size. + If reverse_sort_by_size is True, the list is sorted by size in reverse order. + ''' + sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types] + sizes.sort(key=lambda x: x[1], reverse=reverse_sort_by_size) + return sizes + + def get_block_module_size(self, block_type: str) -> int: + return comfy.model_management.module_size(getattr(self, block_type)[0]) + + class FlipFlopContext: def __init__(self, holder: FlipFlopHolder): self.holder = holder @@ -18,7 +60,6 @@ class FlipFlopContext: self.first_flop = True self.last_flip = False self.last_flop = False - # TODO: the 'i' that's passed into func needs to be properly offset to do patches correctly def __enter__(self): self.reset() @@ -31,7 +72,7 @@ class FlipFlopContext: # flip self.holder.compute_stream.wait_event(self.holder.cpy_end_event) with torch.cuda.stream(self.holder.compute_stream): - out = func(i, self.holder.flip, *args, **kwargs) + out = func(i+self.holder.i_offset, self.holder.flip, *args, **kwargs) self.holder.event_flip.record(self.holder.compute_stream) # while flip executes, queue flop to copy to its next block next_flop_i = i + 1 @@ -50,7 +91,7 @@ class FlipFlopContext: if not self.first_flop: self.holder.compute_stream.wait_event(self.holder.cpy_end_event) with torch.cuda.stream(self.holder.compute_stream): - out = func(i, self.holder.flop, *args, **kwargs) + out = func(i+self.holder.i_offset, self.holder.flop, *args, **kwargs) self.holder.event_flop.record(self.holder.compute_stream) # while flop executes, queue flip to copy to its next block next_flip_i = i + 1 @@ -72,13 +113,15 @@ class FlipFlopContext: return self.do_flop(func, i, block, *args, **kwargs) - class FlipFlopHolder: - def __init__(self, blocks: List[torch.nn.Module], flip_amount: int, load_device="cuda", offload_device="cpu"): - self.load_device = torch.device(load_device) - self.offload_device = torch.device(offload_device) + def __init__(self, blocks: list[torch.nn.Module], flip_amount: int, total_amount: int, load_device: torch.device, offload_device: torch.device): + self.load_device = load_device + self.offload_device = offload_device self.blocks = blocks self.flip_amount = flip_amount + self.total_amount = total_amount + # NOTE: used to make sure block indexes passed into block functions match expected patch indexes + self.i_offset = total_amount - flip_amount self.block_module_size = 0 if len(self.blocks) > 0: @@ -86,11 +129,9 @@ class FlipFlopHolder: self.flip: torch.nn.Module = None self.flop: torch.nn.Module = None - # TODO: make initialization happen in model management code/model patcher, not here - self.init_flipflop_blocks(self.load_device) - self.compute_stream = cuda.default_stream(self.load_device) - self.cpy_stream = cuda.Stream(self.load_device) + self.compute_stream = torch.cuda.default_stream(self.load_device) + self.cpy_stream = torch.cuda.Stream(self.load_device) self.event_flip = torch.cuda.Event(enable_timing=False) self.event_flop = torch.cuda.Event(enable_timing=False) @@ -111,7 +152,7 @@ class FlipFlopHolder: def context(self): return FlipFlopContext(self) - def init_flipflop_blocks(self, load_device: torch.device): + def init_flipflop_block_copies(self, load_device: torch.device): self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device) self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device) @@ -120,301 +161,3 @@ class FlipFlopHolder: del self.flop self.flip = None self.flop = None - - -class FlopFlopModule(torch.nn.Module): - def __init__(self, block_types: tuple[str, ...]): - super().__init__() - self.block_types = block_types - self.flipflop: dict[str, FlipFlopHolder] = {} - - def setup_flipflop_holders(self, block_percentage: float): - for block_type in self.block_types: - if block_type in self.flipflop: - continue - num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage)) - self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks) - - def clean_flipflop_holders(self): - for block_type in self.flipflop.keys(): - self.flipflop[block_type].clean_flipflop_blocks() - del self.flipflop[block_type] - - def get_blocks(self, block_type: str) -> torch.nn.ModuleList: - if block_type not in self.block_types: - raise ValueError(f"Block type {block_type} not found in {self.block_types}") - if block_type in self.flipflop: - return getattr(self, block_type)[:self.flipflop[block_type].flip_amount] - return getattr(self, block_type) - - def get_all_block_module_sizes(self, sort_by_size: bool = False) -> list[tuple[str, int]]: - ''' - Returns a list of (block_type, size). - If sort_by_size is True, the list is sorted by size. - ''' - sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types] - if sort_by_size: - sizes.sort(key=lambda x: x[1]) - return sizes - - def get_block_module_size(self, block_type: str) -> int: - return comfy.model_management.module_size(getattr(self, block_type)[0]) - - -# Below is the implementation from contentis' prototype flip flop -class FlipFlopTransformer: - def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"): - self.transformer_blocks = transformer_blocks - self.offloading_device = torch.device(offloading_device) - self.inference_device = torch.device(inference_device) - self.staging = pinned_staging - - self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=self.inference_device) - self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=self.inference_device) - - self._cpy_fn = self._copy_state_dict - if self.staging: - self.staging_buffer = self._pin_module(self.transformer_blocks[0]).state_dict() - self._cpy_fn = self._copy_state_dict_with_staging - - self.compute_stream = cuda.default_stream(self.inference_device) - self.cpy_stream = cuda.Stream(self.inference_device) - - self.event_flip = torch.cuda.Event(enable_timing=False) - self.event_flop = torch.cuda.Event(enable_timing=False) - self.cpy_end_event = torch.cuda.Event(enable_timing=False) - - self.block_wrap_fn = block_wrap_fn - self.out_names = out_names - - self.num_blocks = len(self.transformer_blocks) - self.extra_run = self.num_blocks % 2 - - # INIT - self.compute_stream.record_event(self.cpy_end_event) - - def _copy_state_dict(self, dst, src, cpy_start_event=None, cpy_end_event=None): - if cpy_start_event: - self.cpy_stream.wait_event(cpy_start_event) - - with torch.cuda.stream(self.cpy_stream): - for k, v in src.items(): - dst[k].copy_(v, non_blocking=True) - if cpy_end_event: - cpy_end_event.record(self.cpy_stream) - - def _copy_state_dict_with_staging(self, dst, src, cpy_start_event=None, cpy_end_event=None): - if cpy_start_event: - self.cpy_stream.wait_event(cpy_start_event) - - with torch.cuda.stream(self.cpy_stream): - for k, v in src.items(): - self.staging_buffer[k].copy_(v, non_blocking=True) - dst[k].copy_(self.staging_buffer[k], non_blocking=True) - if cpy_end_event: - cpy_end_event.record(self.cpy_stream) - - def _pin_module(self, module): - pinned_module = copy.deepcopy(module) - for param in pinned_module.parameters(): - param.data = param.data.pin_memory() - # Pin all buffers (if any) - for buffer in pinned_module.buffers(): - buffer.data = buffer.data.pin_memory() - return pinned_module - - def _reset(self): - if self.extra_run: - self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop) - self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip) - else: - self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip) - self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop) - - self.compute_stream.record_event(self.cpy_end_event) - - @torch.no_grad() - def __call__(self, **feed_dict): - ''' - Flip accounts for even blocks (0 is first block), flop accounts for odd blocks. - ''' - # separated flip flop refactor - num_blocks = len(self.transformer_blocks) - first_flip = True - first_flop = True - last_flip = False - last_flop = False - for i, block in enumerate(self.transformer_blocks): - is_flip = i % 2 == 0 - if is_flip: - # flip - self.compute_stream.wait_event(self.cpy_end_event) - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flip, **feed_dict) - self.event_flip.record(self.compute_stream) - # while flip executes, queue flop to copy to its next block - next_flop_i = i + 1 - if next_flop_i >= num_blocks: - next_flop_i = next_flop_i - num_blocks - last_flip = True - if not first_flip: - self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event) - if last_flip: - self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip) - first_flip = False - else: - # flop - if not first_flop: - self.compute_stream.wait_event(self.cpy_end_event) - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flop, **feed_dict) - self.event_flop.record(self.compute_stream) - # while flop executes, queue flip to copy to its next block - next_flip_i = i + 1 - if next_flip_i >= num_blocks: - next_flip_i = next_flip_i - num_blocks - last_flop = True - self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event) - if last_flop: - self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop) - first_flop = False - - self.compute_stream.record_event(self.cpy_end_event) - - outputs = [feed_dict[name] for name in self.out_names] - if len(outputs) == 1: - return outputs[0] - return tuple(outputs) - - @torch.no_grad() - def __call__old(self, **feed_dict): - # contentis' prototype flip flop - # Wait for reset - self.compute_stream.wait_event(self.cpy_end_event) - - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flip, **feed_dict) - self.event_flip.record(self.compute_stream) - - for i in range(self.num_blocks // 2 - 1): - - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flop, **feed_dict) - self.event_flop.record(self.compute_stream) - - self._cpy_fn(self.flip.state_dict(), self.transformer_blocks[(i + 1) * 2].state_dict(), self.event_flip, - self.cpy_end_event) - - self.compute_stream.wait_event(self.cpy_end_event) - - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flip, **feed_dict) - self.event_flip.record(self.compute_stream) - - self._cpy_fn(self.flop.state_dict(), self.transformer_blocks[(i + 1) * 2 + 1].state_dict(), self.event_flop, - self.cpy_end_event) - self.compute_stream.wait_event(self.cpy_end_event) - - - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flop, **feed_dict) - self.event_flop.record(self.compute_stream) - - if self.extra_run: - self._cpy_fn(self.flip.state_dict(), self.transformer_blocks[-1].state_dict(), self.event_flip, - self.cpy_end_event) - self.compute_stream.wait_event(self.cpy_end_event) - with torch.cuda.stream(self.compute_stream): - feed_dict = self.block_wrap_fn(self.flip, **feed_dict) - self.event_flip.record(self.compute_stream) - - self._reset() - - outputs = [feed_dict[name] for name in self.out_names] - if len(outputs) == 1: - return outputs[0] - return tuple(outputs) - -# @register("Flux") -# class Flux: -# @staticmethod -# def double_block_wrap(block, **kwargs): -# kwargs["img"], kwargs["txt"] = block(img=kwargs["img"], -# txt=kwargs["txt"], -# vec=kwargs["vec"], -# pe=kwargs["pe"], -# attn_mask=kwargs.get("attn_mask")) -# return kwargs - -# @staticmethod -# def single_block_wrap(block, **kwargs): -# kwargs["img"] = block(kwargs["img"], -# vec=kwargs["vec"], -# pe=kwargs["pe"], -# attn_mask=kwargs.get("attn_mask")) -# return kwargs - -# double_config = FlipFlopConfig(block_name="double_blocks", -# block_wrap_fn=double_block_wrap, -# out_names=("img", "txt"), -# overwrite_forward="double_transformer_fwd", -# pinned_staging=False) - -# single_config = FlipFlopConfig(block_name="single_blocks", -# block_wrap_fn=single_block_wrap, -# out_names=("img",), -# overwrite_forward="single_transformer_fwd", -# pinned_staging=False) -# @staticmethod -# def patch(model): -# patch_model_from_config(model, Flux.double_config) -# patch_model_from_config(model, Flux.single_config) -# return model - - -# @register("WanModel") -# class Wan: -# @staticmethod -# def wan_blocks_wrap(block, **kwargs): -# kwargs["x"] = block(x=kwargs["x"], -# context=kwargs["context"], -# e=kwargs["e"], -# freqs=kwargs["freqs"], -# context_img_len=kwargs.get("context_img_len")) -# return kwargs - -# blocks_config = FlipFlopConfig(block_name="blocks", -# block_wrap_fn=wan_blocks_wrap, -# out_names=("x",), -# overwrite_forward="block_fwd", -# pinned_staging=False) - - -# @staticmethod -# def patch(model): -# patch_model_from_config(model, Wan.blocks_config) -# return model - -# @register("QwenImageTransformer2DModel") -# class QwenImage: -# @staticmethod -# def qwen_blocks_wrap(block, **kwargs): -# kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"], -# encoder_hidden_states=kwargs["encoder_hidden_states"], -# encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"], -# temb=kwargs["temb"], -# image_rotary_emb=kwargs["image_rotary_emb"], -# transformer_options=kwargs["transformer_options"]) -# return kwargs - -# blocks_config = FlipFlopConfig(block_name="transformer_blocks", -# block_wrap_fn=qwen_blocks_wrap, -# out_names=("encoder_hidden_states", "hidden_states"), -# overwrite_forward="blocks_fwd", -# pinned_staging=False) - - -# @staticmethod -# def patch(model): -# patch_model_from_config(model, QwenImage.blocks_config) -# return model diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 474d831c4..2af4968ac 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -5,12 +5,13 @@ import torch.nn.functional as F from typing import Optional, Tuple from einops import repeat -from comfy.ldm.flipflop_transformer import FlipFlopHolder +from comfy.ldm.flipflop_transformer import FlipFlopModule from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND import comfy.ldm.common_dit import comfy.patcher_extension +import comfy.ops class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): @@ -284,7 +285,7 @@ class LastLayer(nn.Module): return x -class QwenImageTransformer2DModel(nn.Module): +class QwenImageTransformer2DModel(FlipFlopModule): def __init__( self, patch_size: int = 2, @@ -301,9 +302,9 @@ class QwenImageTransformer2DModel(nn.Module): final_layer=True, dtype=None, device=None, - operations=None, + operations: comfy.ops.disable_weight_init=None, ): - super().__init__() + super().__init__(block_types=("transformer_blocks",)) self.dtype = dtype self.patch_size = patch_size self.in_channels = in_channels @@ -336,43 +337,10 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) - self.flipflop: dict[str, FlipFlopHolder] = {} - if final_layer: self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) - def setup_flipflop_holders(self, block_percentage: float): - if "transformer_blocks" in self.flipflop: - return - import comfy.model_management - # We hackily move any flipflopped blocks into holder so that our model management system does not see them. - num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage)) - loading = [] - for n, m in self.named_modules(): - params = [] - skip = False - for name, param in m.named_parameters(recurse=False): - params.append(name) - for name, param in m.named_parameters(recurse=True): - if name not in params: - skip = True # skip random weights in non leaf modules - break - if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) - self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks) - self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks]) - - def clean_flipflop_holders(self): - if "transformer_blocks" in self.flipflop: - self.flipflop["transformer_blocks"].clean_flipflop_blocks() - del self.flipflop["transformer_blocks"] - - def get_transformer_blocks(self): - if "transformer_blocks" in self.flipflop: - return self.transformer_blocks[:self.flipflop["transformer_blocks"].flip_amount] - return self.transformer_blocks - def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape patch_size = self.patch_size @@ -501,7 +469,7 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) - for i, block in enumerate(self.get_transformer_blocks()): + for i, block in enumerate(self.get_blocks("transformer_blocks")): encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) if "transformer_blocks" in self.flipflop: holder = self.flipflop["transformer_blocks"] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7e7b3a4ed..67d1cb233 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -616,19 +616,62 @@ class ModelPatcher: return False return True - def init_flipflop(self): + def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]]): if not self.supports_flipflop(): return - # figure out how many b - self.model.diffusion_model.setup_flipflop_holders(self.model_options["flipflop_block_percentage"]) + self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, self.load_device, self.offload_device) + + def init_flipflop_block_copies(self): + if not self.supports_flipflop(): + return + self.model.diffusion_model.init_flipflop_block_copies(self.load_device) def clean_flipflop(self): if not self.supports_flipflop(): return self.model.diffusion_model.clean_flipflop_holders() - def _load_list(self): + def _calc_flipflop_prefixes(self, lowvram_model_memory=0, prepare_flipflop=False): + flipflop_prefixes = [] + flipflop_blocks_per_type: dict[str, tuple[int, int]] = {} + if lowvram_model_memory > 0 and self.supports_flipflop(): + block_buffer = 3 + valid_block_types = [] + # for each block type, check if have enough room to flipflop + for block_info in self.model.diffusion_model.get_all_block_module_sizes(): + block_size: int = block_info[1] + if block_size * block_buffer < lowvram_model_memory: + valid_block_types.append(block_info) + # if have candidates for flipping, see how many of each type we have can flipflop + if len(valid_block_types) > 0: + leftover_memory = lowvram_model_memory + for block_info in valid_block_types: + block_type: str = block_info[0] + block_size: int = block_info[1] + total_blocks = len(self.model.diffusion_model.get_all_blocks(block_type)) + n_fit_in_memory = int(leftover_memory // block_size) + # if all (or more) of this block type would fit in memory, no need to flipflop with it + if n_fit_in_memory >= total_blocks: + continue + # if the amount of this block that would fit in memory is less than buffer, skip this block type + if n_fit_in_memory < block_buffer: + continue + # 2 blocks worth of VRAM may be needed for flipflop, so make sure to account for them. + flipflop_blocks = min((total_blocks - n_fit_in_memory) + 2, total_blocks) + flipflop_blocks_per_type[block_type] = (flipflop_blocks, total_blocks) + leftover_memory -= (total_blocks - flipflop_blocks + 2) * block_size + # if there are blocks to flipflop, need to mark their keys + for block_type, (flipflop_blocks, total_blocks) in flipflop_blocks_per_type.items(): + # blocks to flipflop are at the end + for i in range(total_blocks-flipflop_blocks, total_blocks): + flipflop_prefixes.append(f"diffusion_model.{block_type}.{i}") + if prepare_flipflop and len(flipflop_blocks_per_type) > 0: + self.setup_flipflop(flipflop_blocks_per_type) + return flipflop_prefixes + + def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False): loading = [] + flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop) for n, m in self.model.named_modules(): params = [] skip = False @@ -639,7 +682,12 @@ class ModelPatcher: skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) + flipflop = False + for prefix in flipflop_prefixes: + if n.startswith(prefix): + flipflop = True + break + loading.append((comfy.model_management.module_size(m), n, m, params, flipflop)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -649,16 +697,18 @@ class ModelPatcher: patch_counter = 0 lowvram_counter = 0 lowvram_mem_counter = 0 - if self.supports_flipflop(): - ... - loading = self._load_list() + flipflop_counter = 0 + flipflop_mem_counter = 0 + loading = self._load_list(lowvram_model_memory, prepare_flipflop=True) load_completely = [] + load_flipflop = [] loading.sort(reverse=True) for x in loading: n = x[1] m = x[2] params = x[3] + flipflop: bool = x[4] module_mem = x[0] lowvram_weight = False @@ -666,7 +716,7 @@ class ModelPatcher: weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) - if not full_load and hasattr(m, "comfy_cast_weights"): + if not full_load and hasattr(m, "comfy_cast_weights") and not flipflop: if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True lowvram_counter += 1 @@ -698,7 +748,11 @@ class ModelPatcher: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) - if full_load or mem_counter + module_mem < lowvram_model_memory: + if flipflop: + flipflop_counter += 1 + flipflop_mem_counter += module_mem + load_flipflop.append((module_mem, n, m, params)) + elif full_load or mem_counter + module_mem < lowvram_model_memory: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) @@ -714,6 +768,7 @@ class ModelPatcher: mem_counter += move_weight_functions(m, device_to) + # handle load completely load_completely.sort(reverse=True) for x in load_completely: n = x[1] @@ -732,11 +787,30 @@ class ModelPatcher: for x in load_completely: x[2].to(device_to) - if lowvram_counter > 0: - logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}") + # handle flipflop + if len(load_flipflop) > 0: + load_flipflop.sort(reverse=True) + for x in load_flipflop: + n = x[1] + m = x[2] + params = x[3] + if hasattr(m, "comfy_patched_weights"): + if m.comfy_patched_weights == True: + continue + for param in params: + self.patch_weight_to_device("{}.{}".format(n, param), device_to=self.offload_device) + + logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m)) + self.init_flipflop_block_copies() + + if lowvram_counter > 0 or flipflop_counter > 0: + if flipflop_counter > 0: + logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {flipflop_mem_counter / (1024 * 1024):.2f} MB to flipflop, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}") + else: + logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}") self.model.model_lowvram = True else: - logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}") + logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}") self.model.model_lowvram = False if full_load: self.model.to(device_to) @@ -773,6 +847,7 @@ class ModelPatcher: self.eject_model() if unpatch_weights: self.unpatch_hooks() + self.clean_flipflop() if self.model.model_lowvram: for m in self.model.modules(): move_weight_functions(m, device_to)