diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 754733a86..03e256c7a 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -3,42 +3,9 @@ import torch import torch.cuda as cuda import copy from typing import List, Tuple -from dataclasses import dataclass import comfy.model_management -FLIPFLOP_REGISTRY = {} - -def register(name): - def decorator(cls): - FLIPFLOP_REGISTRY[name] = cls - return cls - return decorator - - -@dataclass -class FlipFlopConfig: - block_name: str - block_wrap_fn: callable - out_names: Tuple[str] - overwrite_forward: str - pinned_staging: bool = False - inference_device: str = "cuda" - offloading_device: str = "cpu" - - -def patch_model_from_config(model, config: FlipFlopConfig): - block_list = getattr(model, config.block_name) - flip_flop_transformer = FlipFlopTransformer(block_list, - block_wrap_fn=config.block_wrap_fn, - out_names=config.out_names, - offloading_device=config.offloading_device, - inference_device=config.inference_device, - pinned_staging=config.pinned_staging) - delattr(model, config.block_name) - setattr(model, config.block_name, flip_flop_transformer) - setattr(model, config.overwrite_forward, flip_flop_transformer.__call__) - class FlipFlopContext: def __init__(self, holder: FlipFlopHolder): @@ -46,11 +13,12 @@ class FlipFlopContext: self.reset() def reset(self): - self.num_blocks = len(self.holder.transformer_blocks) + self.num_blocks = len(self.holder.blocks) self.first_flip = True self.first_flop = True self.last_flip = False self.last_flop = False + # TODO: the 'i' that's passed into func needs to be properly offset to do patches correctly def __enter__(self): self.reset() @@ -71,9 +39,9 @@ class FlipFlopContext: next_flop_i = next_flop_i - self.num_blocks self.last_flip = True if not self.first_flip: - self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event) + self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event) if self.last_flip: - self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[0].state_dict(), cpy_start_event=self.holder.event_flip) + self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[0].state_dict(), cpy_start_event=self.holder.event_flip) self.first_flip = False return out @@ -89,9 +57,9 @@ class FlipFlopContext: if next_flip_i >= self.num_blocks: next_flip_i = next_flip_i - self.num_blocks self.last_flop = True - self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event) + self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event) if self.last_flop: - self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[1].state_dict(), cpy_start_event=self.holder.event_flop) + self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[1].state_dict(), cpy_start_event=self.holder.event_flop) self.first_flop = False return out @@ -106,19 +74,20 @@ class FlipFlopContext: class FlipFlopHolder: - def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"): - self.load_device = torch.device(inference_device) - self.offload_device = torch.device(offloading_device) - self.transformer_blocks = transformer_blocks + def __init__(self, blocks: List[torch.nn.Module], flip_amount: int, load_device="cuda", offload_device="cpu"): + self.load_device = torch.device(load_device) + self.offload_device = torch.device(offload_device) + self.blocks = blocks + self.flip_amount = flip_amount self.block_module_size = 0 - if len(self.transformer_blocks) > 0: - self.block_module_size = comfy.model_management.module_size(self.transformer_blocks[0]) + if len(self.blocks) > 0: + self.block_module_size = comfy.model_management.module_size(self.blocks[0]) self.flip: torch.nn.Module = None self.flop: torch.nn.Module = None # TODO: make initialization happen in model management code/model patcher, not here - self.initialize_flipflop_blocks(self.load_device) + self.init_flipflop_blocks(self.load_device) self.compute_stream = cuda.default_stream(self.load_device) self.cpy_stream = cuda.Stream(self.load_device) @@ -142,10 +111,57 @@ class FlipFlopHolder: def context(self): return FlipFlopContext(self) - def initialize_flipflop_blocks(self, load_device: torch.device): - self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=load_device) - self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=load_device) + def init_flipflop_blocks(self, load_device: torch.device): + self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device) + self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device) + def clean_flipflop_blocks(self): + del self.flip + del self.flop + self.flip = None + self.flop = None + + +class FlopFlopModule(torch.nn.Module): + def __init__(self, block_types: tuple[str, ...]): + super().__init__() + self.block_types = block_types + self.flipflop: dict[str, FlipFlopHolder] = {} + + def setup_flipflop_holders(self, block_percentage: float): + for block_type in self.block_types: + if block_type in self.flipflop: + continue + num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage)) + self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks) + + def clean_flipflop_holders(self): + for block_type in self.flipflop.keys(): + self.flipflop[block_type].clean_flipflop_blocks() + del self.flipflop[block_type] + + def get_blocks(self, block_type: str) -> torch.nn.ModuleList: + if block_type not in self.block_types: + raise ValueError(f"Block type {block_type} not found in {self.block_types}") + if block_type in self.flipflop: + return getattr(self, block_type)[:self.flipflop[block_type].flip_amount] + return getattr(self, block_type) + + def get_all_block_module_sizes(self, sort_by_size: bool = False) -> list[tuple[str, int]]: + ''' + Returns a list of (block_type, size). + If sort_by_size is True, the list is sorted by size. + ''' + sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types] + if sort_by_size: + sizes.sort(key=lambda x: x[1]) + return sizes + + def get_block_module_size(self, block_type: str) -> int: + return comfy.model_management.module_size(getattr(self, block_type)[0]) + + +# Below is the implementation from contentis' prototype flip flop class FlipFlopTransformer: def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"): self.transformer_blocks = transformer_blocks @@ -379,28 +395,26 @@ class FlipFlopTransformer: # patch_model_from_config(model, Wan.blocks_config) # return model +# @register("QwenImageTransformer2DModel") +# class QwenImage: +# @staticmethod +# def qwen_blocks_wrap(block, **kwargs): +# kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"], +# encoder_hidden_states=kwargs["encoder_hidden_states"], +# encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"], +# temb=kwargs["temb"], +# image_rotary_emb=kwargs["image_rotary_emb"], +# transformer_options=kwargs["transformer_options"]) +# return kwargs -@register("QwenImageTransformer2DModel") -class QwenImage: - @staticmethod - def qwen_blocks_wrap(block, **kwargs): - kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"], - encoder_hidden_states=kwargs["encoder_hidden_states"], - encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"], - temb=kwargs["temb"], - image_rotary_emb=kwargs["image_rotary_emb"], - transformer_options=kwargs["transformer_options"]) - return kwargs - - blocks_config = FlipFlopConfig(block_name="transformer_blocks", - block_wrap_fn=qwen_blocks_wrap, - out_names=("encoder_hidden_states", "hidden_states"), - overwrite_forward="blocks_fwd", - pinned_staging=False) +# blocks_config = FlipFlopConfig(block_name="transformer_blocks", +# block_wrap_fn=qwen_blocks_wrap, +# out_names=("encoder_hidden_states", "hidden_states"), +# overwrite_forward="blocks_fwd", +# pinned_staging=False) - @staticmethod - def patch(model): - patch_model_from_config(model, QwenImage.blocks_config) - return model - +# @staticmethod +# def patch(model): +# patch_model_from_config(model, QwenImage.blocks_config) +# return model diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index d1833ffe6..474d831c4 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -343,11 +343,36 @@ class QwenImageTransformer2DModel(nn.Module): self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) def setup_flipflop_holders(self, block_percentage: float): + if "transformer_blocks" in self.flipflop: + return + import comfy.model_management # We hackily move any flipflopped blocks into holder so that our model management system does not see them. num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage)) - self.flipflop["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:]) + loading = [] + for n, m in self.named_modules(): + params = [] + skip = False + for name, param in m.named_parameters(recurse=False): + params.append(name) + for name, param in m.named_parameters(recurse=True): + if name not in params: + skip = True # skip random weights in non leaf modules + break + if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): + loading.append((comfy.model_management.module_size(m), n, m, params)) + self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks) self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks]) + def clean_flipflop_holders(self): + if "transformer_blocks" in self.flipflop: + self.flipflop["transformer_blocks"].clean_flipflop_blocks() + del self.flipflop["transformer_blocks"] + + def get_transformer_blocks(self): + if "transformer_blocks" in self.flipflop: + return self.transformer_blocks[:self.flipflop["transformer_blocks"].flip_amount] + return self.transformer_blocks + def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape patch_size = self.patch_size @@ -409,17 +434,6 @@ class QwenImageTransformer2DModel(nn.Module): return encoder_hidden_states, hidden_states - def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options): - for i, block in enumerate(self.transformer_blocks): - encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) - if "blocks_fwd" in self.flipflop: - holder = self.flipflop["blocks_fwd"] - with holder.context() as ctx: - for i, block in enumerate(holder.transformer_blocks): - encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) - - return encoder_hidden_states, hidden_states - def _forward( self, x, @@ -487,12 +501,14 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) - encoder_hidden_states, hidden_states = self.blocks_fwd(hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, - temb=temb, image_rotary_emb=image_rotary_emb, - patches=patches, control=control, blocks_replace=blocks_replace, x=x, - transformer_options=transformer_options) + for i, block in enumerate(self.get_transformer_blocks()): + encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) + if "transformer_blocks" in self.flipflop: + holder = self.flipflop["transformer_blocks"] + with holder.context() as ctx: + for i, block in enumerate(holder.blocks): + encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b9a768203..7e7b3a4ed 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -605,7 +605,27 @@ class ModelPatcher: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) def supports_flipflop(self): - return hasattr(self.model.diffusion_model, "flipflop") + # flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM + if not hasattr(self.model, "diffusion_model"): + return False + if not hasattr(self.model.diffusion_model, "flipflop"): + return False + if not comfy.model_management.is_nvidia(): + return False + if comfy.model_management.vram_state in (comfy.model_management.VRAMState.HIGH_VRAM, comfy.model_management.VRAMState.SHARED): + return False + return True + + def init_flipflop(self): + if not self.supports_flipflop(): + return + # figure out how many b + self.model.diffusion_model.setup_flipflop_holders(self.model_options["flipflop_block_percentage"]) + + def clean_flipflop(self): + if not self.supports_flipflop(): + return + self.model.diffusion_model.clean_flipflop_holders() def _load_list(self): loading = [] @@ -628,6 +648,9 @@ class ModelPatcher: mem_counter = 0 patch_counter = 0 lowvram_counter = 0 + lowvram_mem_counter = 0 + if self.supports_flipflop(): + ... loading = self._load_list() load_completely = [] @@ -647,6 +670,7 @@ class ModelPatcher: if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True lowvram_counter += 1 + lowvram_mem_counter += module_mem if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue @@ -709,10 +733,10 @@ class ModelPatcher: x[2].to(device_to) if lowvram_counter > 0: - logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}") self.model.model_lowvram = True else: - logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}") self.model.model_lowvram = False if full_load: self.model.to(device_to) diff --git a/comfy_extras/nodes_flipflop.py b/comfy_extras/nodes_flipflop.py index 90ea1f6d5..4ddd5c479 100644 --- a/comfy_extras/nodes_flipflop.py +++ b/comfy_extras/nodes_flipflop.py @@ -3,33 +3,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, io -from comfy.ldm.flipflop_transformer import FLIPFLOP_REGISTRY - -class FlipFlopOld(io.ComfyNode): - @classmethod - def define_schema(cls) -> io.Schema: - return io.Schema( - node_id="FlipFlop", - display_name="FlipFlop (Old)", - category="_for_testing", - inputs=[ - io.Model.Input(id="model") - ], - outputs=[ - io.Model.Output() - ], - description="Apply FlipFlop transformation to model using registry-based patching" - ) - - @classmethod - def execute(cls, model) -> io.NodeOutput: - patch_cls = FLIPFLOP_REGISTRY.get(model.model.diffusion_model.__class__.__name__, None) - if patch_cls is None: - raise ValueError(f"Model {model.model.diffusion_model.__class__.__name__} not supported") - - model.model.diffusion_model = patch_cls.patch(model.model.diffusion_model) - - return io.NodeOutput(model) class FlipFlop(io.ComfyNode): @classmethod @@ -62,7 +35,6 @@ class FlipFlopExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - FlipFlopOld, FlipFlop, ]