Made flipflop consider partial_unload, partial_offload, and add flip+flop to mem counters

This commit is contained in:
Jedrzej Kosinski 2025-10-03 16:21:01 -07:00
parent 0fdd327c2f
commit 5329180fce
2 changed files with 48 additions and 19 deletions

View File

@ -11,21 +11,31 @@ class FlipFlopModule(torch.nn.Module):
self.block_types = block_types
self.enable_flipflop = enable_flipflop
self.flipflop: dict[str, FlipFlopHolder] = {}
self.block_info: dict[str, tuple[int, int]] = {}
self.flipflop_prefixes: list[str] = []
def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device):
def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], flipflop_prefixes: list[str], load_device: torch.device, offload_device: torch.device):
for block_type, (flipflop_blocks, total_blocks) in block_info.items():
if block_type in self.flipflop:
continue
self.flipflop[block_type] = FlipFlopHolder(getattr(self, block_type)[total_blocks-flipflop_blocks:], flipflop_blocks, total_blocks, load_device, offload_device)
self.block_info[block_type] = (flipflop_blocks, total_blocks)
self.flipflop_prefixes = flipflop_prefixes.copy()
def init_flipflop_block_copies(self, device: torch.device):
def init_flipflop_block_copies(self, device: torch.device) -> int:
memory_freed = 0
for holder in self.flipflop.values():
holder.init_flipflop_block_copies(device)
memory_freed += holder.init_flipflop_block_copies(device)
return memory_freed
def clean_flipflop_holders(self):
memory_freed = 0
for block_type in list(self.flipflop.keys()):
self.flipflop[block_type].clean_flipflop_blocks()
memory_freed += self.flipflop[block_type].clean_flipflop_blocks()
del self.flipflop[block_type]
self.block_info = {}
self.flipflop_prefixes = []
return memory_freed
def get_all_blocks(self, block_type: str) -> list[torch.nn.Module]:
return getattr(self, block_type)
@ -71,6 +81,8 @@ class FlipFlopModule(torch.nn.Module):
class FlipFlopContext:
def __init__(self, holder: FlipFlopHolder):
# NOTE: there is a bug when there are an odd number of blocks to flipflop.
# Worked around right now by always making sure it will be even, but need to resolve.
self.holder = holder
self.reset()
@ -172,12 +184,17 @@ class FlipFlopHolder:
def context(self):
return FlipFlopContext(self)
def init_flipflop_block_copies(self, load_device: torch.device):
def init_flipflop_block_copies(self, load_device: torch.device) -> int:
self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device)
self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device)
return comfy.model_management.module_size(self.flip) + comfy.model_management.module_size(self.flop)
def clean_flipflop_blocks(self):
def clean_flipflop_blocks(self) -> int:
memory_freed = 0
memory_freed += comfy.model_management.module_size(self.flip)
memory_freed += comfy.model_management.module_size(self.flop)
del self.flip
del self.flop
self.flip = None
self.flop = None
return memory_freed

View File

@ -620,21 +620,26 @@ class ModelPatcher:
return False
return True
def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]]):
def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]], flipflop_prefixes: list[str]):
if not self.supports_flipflop():
return
logging.info(f"setting up flipflop with {flipflop_blocks_per_type}")
self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, self.load_device, self.offload_device)
self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, flipflop_prefixes, self.load_device, self.offload_device)
def init_flipflop_block_copies(self):
def init_flipflop_block_copies(self) -> int:
if not self.supports_flipflop():
return
self.model.diffusion_model.init_flipflop_block_copies(self.load_device)
return 0
return self.model.diffusion_model.init_flipflop_block_copies(self.load_device)
def clean_flipflop(self):
def clean_flipflop(self) -> int:
if not self.supports_flipflop():
return
self.model.diffusion_model.clean_flipflop_holders()
return 0
return self.model.diffusion_model.clean_flipflop_holders()
def _get_existing_flipflop_prefixes(self):
if self.supports_flipflop():
return self.model.diffusion_model.flipflop_prefixes
return []
def _calc_flipflop_prefixes(self, lowvram_model_memory=0, prepare_flipflop=False):
flipflop_prefixes = []
@ -678,12 +683,15 @@ class ModelPatcher:
for i in range(total_blocks-flipflop_blocks, total_blocks):
flipflop_prefixes.append(f"diffusion_model.{block_type}.{i}")
if prepare_flipflop and len(flipflop_blocks_per_type) > 0:
self.setup_flipflop(flipflop_blocks_per_type)
self.setup_flipflop(flipflop_blocks_per_type, flipflop_prefixes)
return flipflop_prefixes
def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False):
def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False, get_existing_flipflop=False):
loading = []
flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop)
if get_existing_flipflop:
flipflop_prefixes = self._get_existing_flipflop_prefixes()
else:
flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop)
for n, m in self.model.named_modules():
params = []
skip = False
@ -817,7 +825,7 @@ class ModelPatcher:
end_time = time.perf_counter()
logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds")
start_time = time.perf_counter()
self.init_flipflop_block_copies()
mem_counter += self.init_flipflop_block_copies()
end_time = time.perf_counter()
logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds")
@ -905,8 +913,9 @@ class ModelPatcher:
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
memory_freed += self.clean_flipflop()
patch_counter = 0
unload_list = self._load_list()
unload_list = self._load_list(get_existing_flipflop=True)
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
@ -915,7 +924,10 @@ class ModelPatcher:
n = unload[1]
m = unload[2]
params = unload[3]
flipflop: bool = unload[4]
if flipflop:
continue
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
move_weight = True