ComfyUI/comfy/ldm/flipflop_transformer.py

296 lines
12 KiB
Python

import torch
import torch.cuda as cuda
import copy
from typing import List, Tuple
from dataclasses import dataclass
FLIPFLOP_REGISTRY = {}
def register(name):
def decorator(cls):
FLIPFLOP_REGISTRY[name] = cls
return cls
return decorator
@dataclass
class FlipFlopConfig:
block_name: str
block_wrap_fn: callable
out_names: Tuple[str]
overwrite_forward: str
pinned_staging: bool = False
inference_device: str = "cuda"
offloading_device: str = "cpu"
def patch_model_from_config(model, config: FlipFlopConfig):
block_list = getattr(model, config.block_name)
flip_flop_transformer = FlipFlopTransformer(block_list,
block_wrap_fn=config.block_wrap_fn,
out_names=config.out_names,
offloading_device=config.offloading_device,
inference_device=config.inference_device,
pinned_staging=config.pinned_staging)
delattr(model, config.block_name)
setattr(model, config.block_name, flip_flop_transformer)
setattr(model, config.overwrite_forward, flip_flop_transformer.__call__)
class FlipFlopTransformer:
def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"):
self.transformer_blocks = transformer_blocks
self.offloading_device = torch.device(offloading_device)
self.inference_device = torch.device(inference_device)
self.staging = pinned_staging
self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=self.inference_device)
self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=self.inference_device)
self._cpy_fn = self._copy_state_dict
if self.staging:
self.staging_buffer = self._pin_module(self.transformer_blocks[0]).state_dict()
self._cpy_fn = self._copy_state_dict_with_staging
self.compute_stream = cuda.default_stream(self.inference_device)
self.cpy_stream = cuda.Stream(self.inference_device)
self.event_flip = torch.cuda.Event(enable_timing=False)
self.event_flop = torch.cuda.Event(enable_timing=False)
self.cpy_end_event = torch.cuda.Event(enable_timing=False)
self.block_wrap_fn = block_wrap_fn
self.out_names = out_names
self.num_blocks = len(self.transformer_blocks)
self.extra_run = self.num_blocks % 2
# INIT
self.compute_stream.record_event(self.cpy_end_event)
def _copy_state_dict(self, dst, src, cpy_start_event=None, cpy_end_event=None):
if cpy_start_event:
self.cpy_stream.wait_event(cpy_start_event)
with torch.cuda.stream(self.cpy_stream):
for k, v in src.items():
dst[k].copy_(v, non_blocking=True)
if cpy_end_event:
cpy_end_event.record(self.cpy_stream)
def _copy_state_dict_with_staging(self, dst, src, cpy_start_event=None, cpy_end_event=None):
if cpy_start_event:
self.cpy_stream.wait_event(cpy_start_event)
with torch.cuda.stream(self.cpy_stream):
for k, v in src.items():
self.staging_buffer[k].copy_(v, non_blocking=True)
dst[k].copy_(self.staging_buffer[k], non_blocking=True)
if cpy_end_event:
cpy_end_event.record(self.cpy_stream)
def _pin_module(self, module):
pinned_module = copy.deepcopy(module)
for param in pinned_module.parameters():
param.data = param.data.pin_memory()
# Pin all buffers (if any)
for buffer in pinned_module.buffers():
buffer.data = buffer.data.pin_memory()
return pinned_module
def _reset(self):
if self.extra_run:
self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop)
self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip)
else:
self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip)
self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop)
self.compute_stream.record_event(self.cpy_end_event)
@torch.no_grad()
def __call__(self, **feed_dict):
'''
Flip accounts for even blocks (0 is first block), flop accounts for odd blocks.
'''
# separated flip flop refactor
first_flip = True
first_flop = True
last_flip = False
last_flop = False
for i, block in enumerate(self.transformer_blocks):
is_flip = i % 2 == 0
if is_flip:
# flip
self.compute_stream.wait_event(self.cpy_end_event)
with torch.cuda.stream(self.compute_stream):
feed_dict = self.block_wrap_fn(self.flip, **feed_dict)
self.event_flip.record(self.compute_stream)
# while flip executes, queue flop to copy to its next block
next_flop_i = i + 1
if next_flop_i >= self.num_blocks:
next_flop_i = next_flop_i - self.num_blocks
last_flip = True
if not first_flip:
self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event)
if last_flip:
self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip)
first_flip = False
else:
# flop
if not first_flop:
self.compute_stream.wait_event(self.cpy_end_event)
with torch.cuda.stream(self.compute_stream):
feed_dict = self.block_wrap_fn(self.flop, **feed_dict)
self.event_flop.record(self.compute_stream)
# while flop executes, queue flip to copy to its next block
next_flip_i = i + 1
if next_flip_i >= self.num_blocks:
next_flip_i = next_flip_i - self.num_blocks
last_flop = True
self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event)
if last_flop:
self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop)
first_flop = False
self.compute_stream.record_event(self.cpy_end_event)
outputs = [feed_dict[name] for name in self.out_names]
if len(outputs) == 1:
return outputs[0]
return tuple(outputs)
@torch.no_grad()
def __call__old(self, **feed_dict):
# contentis' prototype flip flop
# Wait for reset
self.compute_stream.wait_event(self.cpy_end_event)
with torch.cuda.stream(self.compute_stream):
feed_dict = self.block_wrap_fn(self.flip, **feed_dict)
self.event_flip.record(self.compute_stream)
for i in range(self.num_blocks // 2 - 1):
with torch.cuda.stream(self.compute_stream):
feed_dict = self.block_wrap_fn(self.flop, **feed_dict)
self.event_flop.record(self.compute_stream)
self._cpy_fn(self.flip.state_dict(), self.transformer_blocks[(i + 1) * 2].state_dict(), self.event_flip,
self.cpy_end_event)
self.compute_stream.wait_event(self.cpy_end_event)
with torch.cuda.stream(self.compute_stream):
feed_dict = self.block_wrap_fn(self.flip, **feed_dict)
self.event_flip.record(self.compute_stream)
self._cpy_fn(self.flop.state_dict(), self.transformer_blocks[(i + 1) * 2 + 1].state_dict(), self.event_flop,
self.cpy_end_event)
self.compute_stream.wait_event(self.cpy_end_event)
with torch.cuda.stream(self.compute_stream):
feed_dict = self.block_wrap_fn(self.flop, **feed_dict)
self.event_flop.record(self.compute_stream)
if self.extra_run:
self._cpy_fn(self.flip.state_dict(), self.transformer_blocks[-1].state_dict(), self.event_flip,
self.cpy_end_event)
self.compute_stream.wait_event(self.cpy_end_event)
with torch.cuda.stream(self.compute_stream):
feed_dict = self.block_wrap_fn(self.flip, **feed_dict)
self.event_flip.record(self.compute_stream)
self._reset()
outputs = [feed_dict[name] for name in self.out_names]
if len(outputs) == 1:
return outputs[0]
return tuple(outputs)
# @register("Flux")
# class Flux:
# @staticmethod
# def double_block_wrap(block, **kwargs):
# kwargs["img"], kwargs["txt"] = block(img=kwargs["img"],
# txt=kwargs["txt"],
# vec=kwargs["vec"],
# pe=kwargs["pe"],
# attn_mask=kwargs.get("attn_mask"))
# return kwargs
# @staticmethod
# def single_block_wrap(block, **kwargs):
# kwargs["img"] = block(kwargs["img"],
# vec=kwargs["vec"],
# pe=kwargs["pe"],
# attn_mask=kwargs.get("attn_mask"))
# return kwargs
# double_config = FlipFlopConfig(block_name="double_blocks",
# block_wrap_fn=double_block_wrap,
# out_names=("img", "txt"),
# overwrite_forward="double_transformer_fwd",
# pinned_staging=False)
# single_config = FlipFlopConfig(block_name="single_blocks",
# block_wrap_fn=single_block_wrap,
# out_names=("img",),
# overwrite_forward="single_transformer_fwd",
# pinned_staging=False)
# @staticmethod
# def patch(model):
# patch_model_from_config(model, Flux.double_config)
# patch_model_from_config(model, Flux.single_config)
# return model
# @register("WanModel")
# class Wan:
# @staticmethod
# def wan_blocks_wrap(block, **kwargs):
# kwargs["x"] = block(x=kwargs["x"],
# context=kwargs["context"],
# e=kwargs["e"],
# freqs=kwargs["freqs"],
# context_img_len=kwargs.get("context_img_len"))
# return kwargs
# blocks_config = FlipFlopConfig(block_name="blocks",
# block_wrap_fn=wan_blocks_wrap,
# out_names=("x",),
# overwrite_forward="block_fwd",
# pinned_staging=False)
# @staticmethod
# def patch(model):
# patch_model_from_config(model, Wan.blocks_config)
# return model
@register("QwenImageTransformer2DModel")
class QwenImage:
@staticmethod
def qwen_blocks_wrap(block, **kwargs):
kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"],
encoder_hidden_states=kwargs["encoder_hidden_states"],
encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
temb=kwargs["temb"],
image_rotary_emb=kwargs["image_rotary_emb"])
return kwargs
blocks_config = FlipFlopConfig(block_name="transformer_blocks",
block_wrap_fn=qwen_blocks_wrap,
out_names=("encoder_hidden_states", "hidden_states"),
overwrite_forward="block_fwd",
pinned_staging=False)
@staticmethod
def patch(model):
patch_model_from_config(model, QwenImage.blocks_config)
return model