mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
In-progress commit on making flipflop async weight streaming native, made loaded partially/loaded completely log messages have labels because having to memorize their meaning for dev work is annoying
This commit is contained in:
parent
d0bd221495
commit
01f4512bf8
@ -3,42 +3,9 @@ import torch
|
||||
import torch.cuda as cuda
|
||||
import copy
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
FLIPFLOP_REGISTRY = {}
|
||||
|
||||
def register(name):
|
||||
def decorator(cls):
|
||||
FLIPFLOP_REGISTRY[name] = cls
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlipFlopConfig:
|
||||
block_name: str
|
||||
block_wrap_fn: callable
|
||||
out_names: Tuple[str]
|
||||
overwrite_forward: str
|
||||
pinned_staging: bool = False
|
||||
inference_device: str = "cuda"
|
||||
offloading_device: str = "cpu"
|
||||
|
||||
|
||||
def patch_model_from_config(model, config: FlipFlopConfig):
|
||||
block_list = getattr(model, config.block_name)
|
||||
flip_flop_transformer = FlipFlopTransformer(block_list,
|
||||
block_wrap_fn=config.block_wrap_fn,
|
||||
out_names=config.out_names,
|
||||
offloading_device=config.offloading_device,
|
||||
inference_device=config.inference_device,
|
||||
pinned_staging=config.pinned_staging)
|
||||
delattr(model, config.block_name)
|
||||
setattr(model, config.block_name, flip_flop_transformer)
|
||||
setattr(model, config.overwrite_forward, flip_flop_transformer.__call__)
|
||||
|
||||
|
||||
class FlipFlopContext:
|
||||
def __init__(self, holder: FlipFlopHolder):
|
||||
@ -46,11 +13,12 @@ class FlipFlopContext:
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.num_blocks = len(self.holder.transformer_blocks)
|
||||
self.num_blocks = len(self.holder.blocks)
|
||||
self.first_flip = True
|
||||
self.first_flop = True
|
||||
self.last_flip = False
|
||||
self.last_flop = False
|
||||
# TODO: the 'i' that's passed into func needs to be properly offset to do patches correctly
|
||||
|
||||
def __enter__(self):
|
||||
self.reset()
|
||||
@ -71,9 +39,9 @@ class FlipFlopContext:
|
||||
next_flop_i = next_flop_i - self.num_blocks
|
||||
self.last_flip = True
|
||||
if not self.first_flip:
|
||||
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
|
||||
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
|
||||
if self.last_flip:
|
||||
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
|
||||
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
|
||||
self.first_flip = False
|
||||
return out
|
||||
|
||||
@ -89,9 +57,9 @@ class FlipFlopContext:
|
||||
if next_flip_i >= self.num_blocks:
|
||||
next_flip_i = next_flip_i - self.num_blocks
|
||||
self.last_flop = True
|
||||
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
|
||||
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
|
||||
if self.last_flop:
|
||||
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
|
||||
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
|
||||
self.first_flop = False
|
||||
return out
|
||||
|
||||
@ -106,19 +74,20 @@ class FlipFlopContext:
|
||||
|
||||
|
||||
class FlipFlopHolder:
|
||||
def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"):
|
||||
self.load_device = torch.device(inference_device)
|
||||
self.offload_device = torch.device(offloading_device)
|
||||
self.transformer_blocks = transformer_blocks
|
||||
def __init__(self, blocks: List[torch.nn.Module], flip_amount: int, load_device="cuda", offload_device="cpu"):
|
||||
self.load_device = torch.device(load_device)
|
||||
self.offload_device = torch.device(offload_device)
|
||||
self.blocks = blocks
|
||||
self.flip_amount = flip_amount
|
||||
|
||||
self.block_module_size = 0
|
||||
if len(self.transformer_blocks) > 0:
|
||||
self.block_module_size = comfy.model_management.module_size(self.transformer_blocks[0])
|
||||
if len(self.blocks) > 0:
|
||||
self.block_module_size = comfy.model_management.module_size(self.blocks[0])
|
||||
|
||||
self.flip: torch.nn.Module = None
|
||||
self.flop: torch.nn.Module = None
|
||||
# TODO: make initialization happen in model management code/model patcher, not here
|
||||
self.initialize_flipflop_blocks(self.load_device)
|
||||
self.init_flipflop_blocks(self.load_device)
|
||||
|
||||
self.compute_stream = cuda.default_stream(self.load_device)
|
||||
self.cpy_stream = cuda.Stream(self.load_device)
|
||||
@ -142,10 +111,57 @@ class FlipFlopHolder:
|
||||
def context(self):
|
||||
return FlipFlopContext(self)
|
||||
|
||||
def initialize_flipflop_blocks(self, load_device: torch.device):
|
||||
self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=load_device)
|
||||
self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=load_device)
|
||||
def init_flipflop_blocks(self, load_device: torch.device):
|
||||
self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device)
|
||||
self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device)
|
||||
|
||||
def clean_flipflop_blocks(self):
|
||||
del self.flip
|
||||
del self.flop
|
||||
self.flip = None
|
||||
self.flop = None
|
||||
|
||||
|
||||
class FlopFlopModule(torch.nn.Module):
|
||||
def __init__(self, block_types: tuple[str, ...]):
|
||||
super().__init__()
|
||||
self.block_types = block_types
|
||||
self.flipflop: dict[str, FlipFlopHolder] = {}
|
||||
|
||||
def setup_flipflop_holders(self, block_percentage: float):
|
||||
for block_type in self.block_types:
|
||||
if block_type in self.flipflop:
|
||||
continue
|
||||
num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage))
|
||||
self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks)
|
||||
|
||||
def clean_flipflop_holders(self):
|
||||
for block_type in self.flipflop.keys():
|
||||
self.flipflop[block_type].clean_flipflop_blocks()
|
||||
del self.flipflop[block_type]
|
||||
|
||||
def get_blocks(self, block_type: str) -> torch.nn.ModuleList:
|
||||
if block_type not in self.block_types:
|
||||
raise ValueError(f"Block type {block_type} not found in {self.block_types}")
|
||||
if block_type in self.flipflop:
|
||||
return getattr(self, block_type)[:self.flipflop[block_type].flip_amount]
|
||||
return getattr(self, block_type)
|
||||
|
||||
def get_all_block_module_sizes(self, sort_by_size: bool = False) -> list[tuple[str, int]]:
|
||||
'''
|
||||
Returns a list of (block_type, size).
|
||||
If sort_by_size is True, the list is sorted by size.
|
||||
'''
|
||||
sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types]
|
||||
if sort_by_size:
|
||||
sizes.sort(key=lambda x: x[1])
|
||||
return sizes
|
||||
|
||||
def get_block_module_size(self, block_type: str) -> int:
|
||||
return comfy.model_management.module_size(getattr(self, block_type)[0])
|
||||
|
||||
|
||||
# Below is the implementation from contentis' prototype flip flop
|
||||
class FlipFlopTransformer:
|
||||
def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"):
|
||||
self.transformer_blocks = transformer_blocks
|
||||
@ -379,28 +395,26 @@ class FlipFlopTransformer:
|
||||
# patch_model_from_config(model, Wan.blocks_config)
|
||||
# return model
|
||||
|
||||
# @register("QwenImageTransformer2DModel")
|
||||
# class QwenImage:
|
||||
# @staticmethod
|
||||
# def qwen_blocks_wrap(block, **kwargs):
|
||||
# kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"],
|
||||
# encoder_hidden_states=kwargs["encoder_hidden_states"],
|
||||
# encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
|
||||
# temb=kwargs["temb"],
|
||||
# image_rotary_emb=kwargs["image_rotary_emb"],
|
||||
# transformer_options=kwargs["transformer_options"])
|
||||
# return kwargs
|
||||
|
||||
@register("QwenImageTransformer2DModel")
|
||||
class QwenImage:
|
||||
@staticmethod
|
||||
def qwen_blocks_wrap(block, **kwargs):
|
||||
kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"],
|
||||
encoder_hidden_states=kwargs["encoder_hidden_states"],
|
||||
encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
|
||||
temb=kwargs["temb"],
|
||||
image_rotary_emb=kwargs["image_rotary_emb"],
|
||||
transformer_options=kwargs["transformer_options"])
|
||||
return kwargs
|
||||
|
||||
blocks_config = FlipFlopConfig(block_name="transformer_blocks",
|
||||
block_wrap_fn=qwen_blocks_wrap,
|
||||
out_names=("encoder_hidden_states", "hidden_states"),
|
||||
overwrite_forward="blocks_fwd",
|
||||
pinned_staging=False)
|
||||
# blocks_config = FlipFlopConfig(block_name="transformer_blocks",
|
||||
# block_wrap_fn=qwen_blocks_wrap,
|
||||
# out_names=("encoder_hidden_states", "hidden_states"),
|
||||
# overwrite_forward="blocks_fwd",
|
||||
# pinned_staging=False)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def patch(model):
|
||||
patch_model_from_config(model, QwenImage.blocks_config)
|
||||
return model
|
||||
|
||||
# @staticmethod
|
||||
# def patch(model):
|
||||
# patch_model_from_config(model, QwenImage.blocks_config)
|
||||
# return model
|
||||
|
||||
@ -343,11 +343,36 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def setup_flipflop_holders(self, block_percentage: float):
|
||||
if "transformer_blocks" in self.flipflop:
|
||||
return
|
||||
import comfy.model_management
|
||||
# We hackily move any flipflopped blocks into holder so that our model management system does not see them.
|
||||
num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage))
|
||||
self.flipflop["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:])
|
||||
loading = []
|
||||
for n, m in self.named_modules():
|
||||
params = []
|
||||
skip = False
|
||||
for name, param in m.named_parameters(recurse=False):
|
||||
params.append(name)
|
||||
for name, param in m.named_parameters(recurse=True):
|
||||
if name not in params:
|
||||
skip = True # skip random weights in non leaf modules
|
||||
break
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||
self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks)
|
||||
self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks])
|
||||
|
||||
def clean_flipflop_holders(self):
|
||||
if "transformer_blocks" in self.flipflop:
|
||||
self.flipflop["transformer_blocks"].clean_flipflop_blocks()
|
||||
del self.flipflop["transformer_blocks"]
|
||||
|
||||
def get_transformer_blocks(self):
|
||||
if "transformer_blocks" in self.flipflop:
|
||||
return self.transformer_blocks[:self.flipflop["transformer_blocks"].flip_amount]
|
||||
return self.transformer_blocks
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
@ -409,17 +434,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||
if "blocks_fwd" in self.flipflop:
|
||||
holder = self.flipflop["blocks_fwd"]
|
||||
with holder.context() as ctx:
|
||||
for i, block in enumerate(holder.transformer_blocks):
|
||||
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
@ -487,12 +501,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
encoder_hidden_states, hidden_states = self.blocks_fwd(hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb, image_rotary_emb=image_rotary_emb,
|
||||
patches=patches, control=control, blocks_replace=blocks_replace, x=x,
|
||||
transformer_options=transformer_options)
|
||||
for i, block in enumerate(self.get_transformer_blocks()):
|
||||
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||
if "transformer_blocks" in self.flipflop:
|
||||
holder = self.flipflop["transformer_blocks"]
|
||||
with holder.context() as ctx:
|
||||
for i, block in enumerate(holder.blocks):
|
||||
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
@ -605,7 +605,27 @@ class ModelPatcher:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def supports_flipflop(self):
|
||||
return hasattr(self.model.diffusion_model, "flipflop")
|
||||
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
|
||||
if not hasattr(self.model, "diffusion_model"):
|
||||
return False
|
||||
if not hasattr(self.model.diffusion_model, "flipflop"):
|
||||
return False
|
||||
if not comfy.model_management.is_nvidia():
|
||||
return False
|
||||
if comfy.model_management.vram_state in (comfy.model_management.VRAMState.HIGH_VRAM, comfy.model_management.VRAMState.SHARED):
|
||||
return False
|
||||
return True
|
||||
|
||||
def init_flipflop(self):
|
||||
if not self.supports_flipflop():
|
||||
return
|
||||
# figure out how many b
|
||||
self.model.diffusion_model.setup_flipflop_holders(self.model_options["flipflop_block_percentage"])
|
||||
|
||||
def clean_flipflop(self):
|
||||
if not self.supports_flipflop():
|
||||
return
|
||||
self.model.diffusion_model.clean_flipflop_holders()
|
||||
|
||||
def _load_list(self):
|
||||
loading = []
|
||||
@ -628,6 +648,9 @@ class ModelPatcher:
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
lowvram_mem_counter = 0
|
||||
if self.supports_flipflop():
|
||||
...
|
||||
loading = self._load_list()
|
||||
|
||||
load_completely = []
|
||||
@ -647,6 +670,7 @@ class ModelPatcher:
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
lowvram_mem_counter += module_mem
|
||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||
continue
|
||||
|
||||
@ -709,10 +733,10 @@ class ModelPatcher:
|
||||
x[2].to(device_to)
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}")
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}")
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
|
||||
@ -3,33 +3,6 @@ from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
from comfy.ldm.flipflop_transformer import FLIPFLOP_REGISTRY
|
||||
|
||||
class FlipFlopOld(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="FlipFlop",
|
||||
display_name="FlipFlop (Old)",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.Model.Input(id="model")
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output()
|
||||
],
|
||||
description="Apply FlipFlop transformation to model using registry-based patching"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model) -> io.NodeOutput:
|
||||
patch_cls = FLIPFLOP_REGISTRY.get(model.model.diffusion_model.__class__.__name__, None)
|
||||
if patch_cls is None:
|
||||
raise ValueError(f"Model {model.model.diffusion_model.__class__.__name__} not supported")
|
||||
|
||||
model.model.diffusion_model = patch_cls.patch(model.model.diffusion_model)
|
||||
|
||||
return io.NodeOutput(model)
|
||||
|
||||
class FlipFlop(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -62,7 +35,6 @@ class FlipFlopExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
FlipFlopOld,
|
||||
FlipFlop,
|
||||
]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user