mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 02:23:06 +08:00
Fix percentage logic, begin adding elements to ModelPatcher to track flip flop compatibility
This commit is contained in:
parent
ff789c8beb
commit
8a8162e8da
@ -5,6 +5,8 @@ import copy
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
FLIPFLOP_REGISTRY = {}
|
||||
|
||||
def register(name):
|
||||
@ -105,15 +107,21 @@ class FlipFlopContext:
|
||||
|
||||
class FlipFlopHolder:
|
||||
def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"):
|
||||
self.inference_device = torch.device(inference_device)
|
||||
self.offloading_device = torch.device(offloading_device)
|
||||
self.load_device = torch.device(inference_device)
|
||||
self.offload_device = torch.device(offloading_device)
|
||||
self.transformer_blocks = transformer_blocks
|
||||
|
||||
self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=self.inference_device)
|
||||
self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=self.inference_device)
|
||||
self.block_module_size = 0
|
||||
if len(self.transformer_blocks) > 0:
|
||||
self.block_module_size = comfy.model_management.module_size(self.transformer_blocks[0])
|
||||
|
||||
self.compute_stream = cuda.default_stream(self.inference_device)
|
||||
self.cpy_stream = cuda.Stream(self.inference_device)
|
||||
self.flip: torch.nn.Module = None
|
||||
self.flop: torch.nn.Module = None
|
||||
# TODO: make initialization happen in model management code/model patcher, not here
|
||||
self.initialize_flipflop_blocks(self.load_device)
|
||||
|
||||
self.compute_stream = cuda.default_stream(self.load_device)
|
||||
self.cpy_stream = cuda.Stream(self.load_device)
|
||||
|
||||
self.event_flip = torch.cuda.Event(enable_timing=False)
|
||||
self.event_flop = torch.cuda.Event(enable_timing=False)
|
||||
@ -134,6 +142,10 @@ class FlipFlopHolder:
|
||||
def context(self):
|
||||
return FlipFlopContext(self)
|
||||
|
||||
def initialize_flipflop_blocks(self, load_device: torch.device):
|
||||
self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=load_device)
|
||||
self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=load_device)
|
||||
|
||||
class FlipFlopTransformer:
|
||||
def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"):
|
||||
self.transformer_blocks = transformer_blocks
|
||||
|
||||
@ -336,7 +336,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.flipflop_holders: dict[str, FlipFlopHolder] = {}
|
||||
self.flipflop: dict[str, FlipFlopHolder] = {}
|
||||
|
||||
if final_layer:
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
@ -344,8 +344,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
|
||||
def setup_flipflop_holders(self, block_percentage: float):
|
||||
# We hackily move any flipflopped blocks into holder so that our model management system does not see them.
|
||||
num_blocks = int(len(self.transformer_blocks) * block_percentage)
|
||||
self.flipflop_holders["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:])
|
||||
num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage))
|
||||
self.flipflop["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:])
|
||||
self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks])
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
@ -412,8 +412,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||
if "blocks_fwd" in self.flipflop_holders:
|
||||
holder = self.flipflop_holders["blocks_fwd"]
|
||||
if "blocks_fwd" in self.flipflop:
|
||||
holder = self.flipflop["blocks_fwd"]
|
||||
with holder.context() as ctx:
|
||||
for i, block in enumerate(holder.transformer_blocks):
|
||||
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||
|
||||
@ -604,6 +604,9 @@ class ModelPatcher:
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def supports_flipflop(self):
|
||||
return hasattr(self.model.diffusion_model, "flipflop")
|
||||
|
||||
def _load_list(self):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user