mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 12:00:49 +08:00
Made flipflop consider partial_unload, partial_offload, and add flip+flop to mem counters
This commit is contained in:
parent
0fdd327c2f
commit
5329180fce
@ -11,21 +11,31 @@ class FlipFlopModule(torch.nn.Module):
|
|||||||
self.block_types = block_types
|
self.block_types = block_types
|
||||||
self.enable_flipflop = enable_flipflop
|
self.enable_flipflop = enable_flipflop
|
||||||
self.flipflop: dict[str, FlipFlopHolder] = {}
|
self.flipflop: dict[str, FlipFlopHolder] = {}
|
||||||
|
self.block_info: dict[str, tuple[int, int]] = {}
|
||||||
|
self.flipflop_prefixes: list[str] = []
|
||||||
|
|
||||||
def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], load_device: torch.device, offload_device: torch.device):
|
def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], flipflop_prefixes: list[str], load_device: torch.device, offload_device: torch.device):
|
||||||
for block_type, (flipflop_blocks, total_blocks) in block_info.items():
|
for block_type, (flipflop_blocks, total_blocks) in block_info.items():
|
||||||
if block_type in self.flipflop:
|
if block_type in self.flipflop:
|
||||||
continue
|
continue
|
||||||
self.flipflop[block_type] = FlipFlopHolder(getattr(self, block_type)[total_blocks-flipflop_blocks:], flipflop_blocks, total_blocks, load_device, offload_device)
|
self.flipflop[block_type] = FlipFlopHolder(getattr(self, block_type)[total_blocks-flipflop_blocks:], flipflop_blocks, total_blocks, load_device, offload_device)
|
||||||
|
self.block_info[block_type] = (flipflop_blocks, total_blocks)
|
||||||
|
self.flipflop_prefixes = flipflop_prefixes.copy()
|
||||||
|
|
||||||
def init_flipflop_block_copies(self, device: torch.device):
|
def init_flipflop_block_copies(self, device: torch.device) -> int:
|
||||||
|
memory_freed = 0
|
||||||
for holder in self.flipflop.values():
|
for holder in self.flipflop.values():
|
||||||
holder.init_flipflop_block_copies(device)
|
memory_freed += holder.init_flipflop_block_copies(device)
|
||||||
|
return memory_freed
|
||||||
|
|
||||||
def clean_flipflop_holders(self):
|
def clean_flipflop_holders(self):
|
||||||
|
memory_freed = 0
|
||||||
for block_type in list(self.flipflop.keys()):
|
for block_type in list(self.flipflop.keys()):
|
||||||
self.flipflop[block_type].clean_flipflop_blocks()
|
memory_freed += self.flipflop[block_type].clean_flipflop_blocks()
|
||||||
del self.flipflop[block_type]
|
del self.flipflop[block_type]
|
||||||
|
self.block_info = {}
|
||||||
|
self.flipflop_prefixes = []
|
||||||
|
return memory_freed
|
||||||
|
|
||||||
def get_all_blocks(self, block_type: str) -> list[torch.nn.Module]:
|
def get_all_blocks(self, block_type: str) -> list[torch.nn.Module]:
|
||||||
return getattr(self, block_type)
|
return getattr(self, block_type)
|
||||||
@ -71,6 +81,8 @@ class FlipFlopModule(torch.nn.Module):
|
|||||||
|
|
||||||
class FlipFlopContext:
|
class FlipFlopContext:
|
||||||
def __init__(self, holder: FlipFlopHolder):
|
def __init__(self, holder: FlipFlopHolder):
|
||||||
|
# NOTE: there is a bug when there are an odd number of blocks to flipflop.
|
||||||
|
# Worked around right now by always making sure it will be even, but need to resolve.
|
||||||
self.holder = holder
|
self.holder = holder
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@ -172,12 +184,17 @@ class FlipFlopHolder:
|
|||||||
def context(self):
|
def context(self):
|
||||||
return FlipFlopContext(self)
|
return FlipFlopContext(self)
|
||||||
|
|
||||||
def init_flipflop_block_copies(self, load_device: torch.device):
|
def init_flipflop_block_copies(self, load_device: torch.device) -> int:
|
||||||
self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device)
|
self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device)
|
||||||
self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device)
|
self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device)
|
||||||
|
return comfy.model_management.module_size(self.flip) + comfy.model_management.module_size(self.flop)
|
||||||
|
|
||||||
def clean_flipflop_blocks(self):
|
def clean_flipflop_blocks(self) -> int:
|
||||||
|
memory_freed = 0
|
||||||
|
memory_freed += comfy.model_management.module_size(self.flip)
|
||||||
|
memory_freed += comfy.model_management.module_size(self.flop)
|
||||||
del self.flip
|
del self.flip
|
||||||
del self.flop
|
del self.flop
|
||||||
self.flip = None
|
self.flip = None
|
||||||
self.flop = None
|
self.flop = None
|
||||||
|
return memory_freed
|
||||||
|
|||||||
@ -620,21 +620,26 @@ class ModelPatcher:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]]):
|
def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]], flipflop_prefixes: list[str]):
|
||||||
if not self.supports_flipflop():
|
if not self.supports_flipflop():
|
||||||
return
|
return
|
||||||
logging.info(f"setting up flipflop with {flipflop_blocks_per_type}")
|
logging.info(f"setting up flipflop with {flipflop_blocks_per_type}")
|
||||||
self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, self.load_device, self.offload_device)
|
self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, flipflop_prefixes, self.load_device, self.offload_device)
|
||||||
|
|
||||||
def init_flipflop_block_copies(self):
|
def init_flipflop_block_copies(self) -> int:
|
||||||
if not self.supports_flipflop():
|
if not self.supports_flipflop():
|
||||||
return
|
return 0
|
||||||
self.model.diffusion_model.init_flipflop_block_copies(self.load_device)
|
return self.model.diffusion_model.init_flipflop_block_copies(self.load_device)
|
||||||
|
|
||||||
def clean_flipflop(self):
|
def clean_flipflop(self) -> int:
|
||||||
if not self.supports_flipflop():
|
if not self.supports_flipflop():
|
||||||
return
|
return 0
|
||||||
self.model.diffusion_model.clean_flipflop_holders()
|
return self.model.diffusion_model.clean_flipflop_holders()
|
||||||
|
|
||||||
|
def _get_existing_flipflop_prefixes(self):
|
||||||
|
if self.supports_flipflop():
|
||||||
|
return self.model.diffusion_model.flipflop_prefixes
|
||||||
|
return []
|
||||||
|
|
||||||
def _calc_flipflop_prefixes(self, lowvram_model_memory=0, prepare_flipflop=False):
|
def _calc_flipflop_prefixes(self, lowvram_model_memory=0, prepare_flipflop=False):
|
||||||
flipflop_prefixes = []
|
flipflop_prefixes = []
|
||||||
@ -678,12 +683,15 @@ class ModelPatcher:
|
|||||||
for i in range(total_blocks-flipflop_blocks, total_blocks):
|
for i in range(total_blocks-flipflop_blocks, total_blocks):
|
||||||
flipflop_prefixes.append(f"diffusion_model.{block_type}.{i}")
|
flipflop_prefixes.append(f"diffusion_model.{block_type}.{i}")
|
||||||
if prepare_flipflop and len(flipflop_blocks_per_type) > 0:
|
if prepare_flipflop and len(flipflop_blocks_per_type) > 0:
|
||||||
self.setup_flipflop(flipflop_blocks_per_type)
|
self.setup_flipflop(flipflop_blocks_per_type, flipflop_prefixes)
|
||||||
return flipflop_prefixes
|
return flipflop_prefixes
|
||||||
|
|
||||||
def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False):
|
def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False, get_existing_flipflop=False):
|
||||||
loading = []
|
loading = []
|
||||||
flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop)
|
if get_existing_flipflop:
|
||||||
|
flipflop_prefixes = self._get_existing_flipflop_prefixes()
|
||||||
|
else:
|
||||||
|
flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop)
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
params = []
|
params = []
|
||||||
skip = False
|
skip = False
|
||||||
@ -817,7 +825,7 @@ class ModelPatcher:
|
|||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds")
|
logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds")
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
self.init_flipflop_block_copies()
|
mem_counter += self.init_flipflop_block_copies()
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds")
|
logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds")
|
||||||
|
|
||||||
@ -905,8 +913,9 @@ class ModelPatcher:
|
|||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
hooks_unpatched = False
|
hooks_unpatched = False
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
|
memory_freed += self.clean_flipflop()
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
unload_list = self._load_list()
|
unload_list = self._load_list(get_existing_flipflop=True)
|
||||||
unload_list.sort()
|
unload_list.sort()
|
||||||
for unload in unload_list:
|
for unload in unload_list:
|
||||||
if memory_to_free < memory_freed:
|
if memory_to_free < memory_freed:
|
||||||
@ -915,7 +924,10 @@ class ModelPatcher:
|
|||||||
n = unload[1]
|
n = unload[1]
|
||||||
m = unload[2]
|
m = unload[2]
|
||||||
params = unload[3]
|
params = unload[3]
|
||||||
|
flipflop: bool = unload[4]
|
||||||
|
|
||||||
|
if flipflop:
|
||||||
|
continue
|
||||||
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||||
move_weight = True
|
move_weight = True
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user