From 5329180fce39baac993d6ba65fea2fc7814c2961 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 3 Oct 2025 16:21:01 -0700 Subject: [PATCH] Made flipflop consider partial_unload, partial_offload, and add flip+flop to mem counters --- comfy/ldm/flipflop_transformer.py | 29 ++++++++++++++++++----- comfy/model_patcher.py | 38 ++++++++++++++++++++----------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 9e9c28468..0379d14ff 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -11,21 +11,31 @@ class FlipFlopModule(torch.nn.Module): self.block_types = block_types self.enable_flipflop = enable_flipflop self.flipflop: dict[str, FlipFlopHolder] = {} + self.block_info: dict[str, tuple[int, int]] = {} + self.flipflop_prefixes: list[str] = [] - def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device): + def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], flipflop_prefixes: list[str], load_device: torch.device, offload_device: torch.device): for block_type, (flipflop_blocks, total_blocks) in block_info.items(): if block_type in self.flipflop: continue self.flipflop[block_type] = FlipFlopHolder(getattr(self, block_type)[total_blocks-flipflop_blocks:], flipflop_blocks, total_blocks, load_device, offload_device) + self.block_info[block_type] = (flipflop_blocks, total_blocks) + self.flipflop_prefixes = flipflop_prefixes.copy() - def init_flipflop_block_copies(self, device: torch.device): + def init_flipflop_block_copies(self, device: torch.device) -> int: + memory_freed = 0 for holder in self.flipflop.values(): - holder.init_flipflop_block_copies(device) + memory_freed += holder.init_flipflop_block_copies(device) + return memory_freed def clean_flipflop_holders(self): + memory_freed = 0 for block_type in list(self.flipflop.keys()): - self.flipflop[block_type].clean_flipflop_blocks() + memory_freed += self.flipflop[block_type].clean_flipflop_blocks() del self.flipflop[block_type] + self.block_info = {} + self.flipflop_prefixes = [] + return memory_freed def get_all_blocks(self, block_type: str) -> list[torch.nn.Module]: return getattr(self, block_type) @@ -71,6 +81,8 @@ class FlipFlopModule(torch.nn.Module): class FlipFlopContext: def __init__(self, holder: FlipFlopHolder): + # NOTE: there is a bug when there are an odd number of blocks to flipflop. + # Worked around right now by always making sure it will be even, but need to resolve. self.holder = holder self.reset() @@ -172,12 +184,17 @@ class FlipFlopHolder: def context(self): return FlipFlopContext(self) - def init_flipflop_block_copies(self, load_device: torch.device): + def init_flipflop_block_copies(self, load_device: torch.device) -> int: self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device) self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device) + return comfy.model_management.module_size(self.flip) + comfy.model_management.module_size(self.flop) - def clean_flipflop_blocks(self): + def clean_flipflop_blocks(self) -> int: + memory_freed = 0 + memory_freed += comfy.model_management.module_size(self.flip) + memory_freed += comfy.model_management.module_size(self.flop) del self.flip del self.flop self.flip = None self.flop = None + return memory_freed diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 08055b65c..dfb38d6e8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -620,21 +620,26 @@ class ModelPatcher: return False return True - def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]]): + def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]], flipflop_prefixes: list[str]): if not self.supports_flipflop(): return logging.info(f"setting up flipflop with {flipflop_blocks_per_type}") - self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, self.load_device, self.offload_device) + self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, flipflop_prefixes, self.load_device, self.offload_device) - def init_flipflop_block_copies(self): + def init_flipflop_block_copies(self) -> int: if not self.supports_flipflop(): - return - self.model.diffusion_model.init_flipflop_block_copies(self.load_device) + return 0 + return self.model.diffusion_model.init_flipflop_block_copies(self.load_device) - def clean_flipflop(self): + def clean_flipflop(self) -> int: if not self.supports_flipflop(): - return - self.model.diffusion_model.clean_flipflop_holders() + return 0 + return self.model.diffusion_model.clean_flipflop_holders() + + def _get_existing_flipflop_prefixes(self): + if self.supports_flipflop(): + return self.model.diffusion_model.flipflop_prefixes + return [] def _calc_flipflop_prefixes(self, lowvram_model_memory=0, prepare_flipflop=False): flipflop_prefixes = [] @@ -678,12 +683,15 @@ class ModelPatcher: for i in range(total_blocks-flipflop_blocks, total_blocks): flipflop_prefixes.append(f"diffusion_model.{block_type}.{i}") if prepare_flipflop and len(flipflop_blocks_per_type) > 0: - self.setup_flipflop(flipflop_blocks_per_type) + self.setup_flipflop(flipflop_blocks_per_type, flipflop_prefixes) return flipflop_prefixes - def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False): + def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False, get_existing_flipflop=False): loading = [] - flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop) + if get_existing_flipflop: + flipflop_prefixes = self._get_existing_flipflop_prefixes() + else: + flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop) for n, m in self.model.named_modules(): params = [] skip = False @@ -817,7 +825,7 @@ class ModelPatcher: end_time = time.perf_counter() logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds") start_time = time.perf_counter() - self.init_flipflop_block_copies() + mem_counter += self.init_flipflop_block_copies() end_time = time.perf_counter() logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds") @@ -905,8 +913,9 @@ class ModelPatcher: with self.use_ejected(): hooks_unpatched = False memory_freed = 0 + memory_freed += self.clean_flipflop() patch_counter = 0 - unload_list = self._load_list() + unload_list = self._load_list(get_existing_flipflop=True) unload_list.sort() for unload in unload_list: if memory_to_free < memory_freed: @@ -915,7 +924,10 @@ class ModelPatcher: n = unload[1] m = unload[2] params = unload[3] + flipflop: bool = unload[4] + if flipflop: + continue lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: move_weight = True