mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-18 10:20:16 +08:00
1129 lines
42 KiB
Python
1129 lines
42 KiB
Python
import os
|
|
import math
|
|
import time
|
|
import torch
|
|
import psutil
|
|
import asyncio
|
|
import threading
|
|
import torch.nn as nn
|
|
from pathlib import Path
|
|
import concurrent.futures
|
|
from einops import rearrange
|
|
import torch.nn.functional as F
|
|
from collections import OrderedDict
|
|
from safetensors import safe_open
|
|
from transformers.cache_utils import StaticCache
|
|
from typing import Optional, Tuple, Any, List, Dict
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
|
|
|
|
INIT_MOE = torch.cuda.device_count() != 1
|
|
|
|
if not INIT_MOE:
|
|
MOE_LAYER_SIZE = (1024**3) * 2.65 # approx
|
|
|
|
torch.cuda.set_device(0)
|
|
props = torch.cuda.get_device_properties(0)
|
|
|
|
LAYERS_IN_CPU = math.floor((int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES'))
|
|
- psutil.Process(os.getpid()).memory_info().rss
|
|
- (2*1024**3)) * 0.50) / MOE_LAYER_SIZE)
|
|
|
|
class HunyuanStaticCache(StaticCache):
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
layer_idx: int,
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
cache_position = cache_kwargs.get("cache_position")
|
|
if hasattr(self, "key_cache") and hasattr(self, "value_cache"):
|
|
if self.key_cache[layer_idx].device != key_states.device:
|
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
|
k_out = self.key_cache[layer_idx]
|
|
v_out = self.value_cache[layer_idx]
|
|
key_states = key_states.to(k_out.dtype)
|
|
value_states = value_states.to(v_out.dtype)
|
|
else:
|
|
if self.layers[layer_idx].keys is None:
|
|
self.layers[layer_idx].lazy_initialization(key_states)
|
|
k_out = self.layers[layer_idx].keys
|
|
v_out = self.layers[layer_idx].values
|
|
|
|
if cache_position is None:
|
|
k_out.copy_(key_states)
|
|
v_out.copy_(value_states)
|
|
else:
|
|
if cache_position.dim() == 1:
|
|
k_out.index_copy_(2, cache_position, key_states)
|
|
v_out.index_copy_(2, cache_position, value_states)
|
|
|
|
else:
|
|
assert cache_position.dim() == 2, f"multiple batch dims not yet {cache_position.shape=}"
|
|
batch_size, _ = cache_position.shape
|
|
for i in range(batch_size):
|
|
unbatched_dim = 1
|
|
k_out[i].index_copy_(unbatched_dim, cache_position[i], key_states[i])
|
|
v_out[i].index_copy_(unbatched_dim, cache_position[i], value_states[i])
|
|
|
|
return k_out, v_out
|
|
|
|
def real_batched_index_select(t, dim, idx):
|
|
return torch.stack([torch.index_select(t[i], dim - 1, idx[i]) for i in range(len(t))])
|
|
|
|
def timestep_embedding(t, dim, max_period=10000):
|
|
half = dim // 2
|
|
freqs = torch.exp(
|
|
-math.log(max_period)
|
|
* torch.arange(start=0, end=half, dtype=torch.float32)
|
|
/ half
|
|
).to(device=t.device)
|
|
args = t[:, None].float() * freqs[None]
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat(
|
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
|
)
|
|
return embedding
|
|
|
|
class TimestepEmbedder(nn.Module):
|
|
def __init__(self,
|
|
hidden_size,
|
|
act_layer=nn.GELU,
|
|
frequency_embedding_size=256,
|
|
max_period=10000,
|
|
out_size=None,
|
|
dtype=None,
|
|
device=None
|
|
):
|
|
factory_kwargs = {'dtype': dtype, 'device': device}
|
|
super().__init__()
|
|
self.frequency_embedding_size = frequency_embedding_size
|
|
self.max_period = max_period
|
|
if out_size is None:
|
|
out_size = hidden_size
|
|
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
|
act_layer(),
|
|
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
|
)
|
|
|
|
def forward(self, t):
|
|
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
|
t_emb = self.mlp(t_freq)
|
|
return t_emb
|
|
|
|
|
|
def _to_tuple(x, dim=2):
|
|
if isinstance(x, int):
|
|
return (x,) * dim
|
|
return x
|
|
|
|
|
|
def get_meshgrid_nd(start, *args, dim=2):
|
|
|
|
start = _to_tuple(start, dim=dim)
|
|
stop = _to_tuple(args[0], dim=dim)
|
|
num = [stop[i] - start[i] for i in range(dim)]
|
|
num_int = [int(x) for x in num]
|
|
num = num_int
|
|
|
|
axis_grid = []
|
|
for i in range(dim):
|
|
a, b, n = start[i], stop[i], num[i]
|
|
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
|
axis_grid.append(g)
|
|
grid = torch.meshgrid(*axis_grid, indexing="ij")
|
|
grid = torch.stack(grid, dim=0)
|
|
|
|
return grid
|
|
|
|
def build_2d_rope(
|
|
seq_len: int, n_elem: int, image_infos: Optional[List[Tuple[slice, Tuple[int, int]]]] = None,
|
|
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
|
|
return_all_pos: bool = False,
|
|
):
|
|
|
|
assert n_elem % 4 == 0, f"n_elem must be divisible by 4, but got {n_elem}."
|
|
|
|
# theta
|
|
if base_rescale_factor != 1.0:
|
|
base *= base_rescale_factor ** (n_elem / (n_elem - 2))
|
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
|
|
theta = theta.reshape(1, n_elem // 4, 2) # [1, half_d, 2]
|
|
|
|
# position indices
|
|
if image_infos is None:
|
|
image_infos = []
|
|
|
|
image_infos_list = [image_infos]
|
|
sample_seq_lens = [seq_len]
|
|
|
|
# Prepare position indices for each sample
|
|
x_sections = []
|
|
y_sections = []
|
|
for sample_id, sample_image_infos in enumerate(image_infos_list):
|
|
last_pos = 0
|
|
for sec_slice, (h, w) in sample_image_infos:
|
|
L = sec_slice.start # start from 0, so image_slice.start is just L
|
|
# previous text
|
|
if last_pos < L:
|
|
y_sections.append(torch.arange(last_pos, L))
|
|
x_sections.append(torch.arange(last_pos, L))
|
|
elif h is None:
|
|
# Interleave data has overlapped positions for <boi> <size> <ratio> <timestep> <eoi> tokens.
|
|
y_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
|
x_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
|
continue
|
|
else:
|
|
# Interleave data has overlapped positions for noised image and the successive clean image,
|
|
# leading to last_pos (= last text end L + noise w * h) > L (last text end L).
|
|
pass
|
|
# current image
|
|
beta_y = L + (w * h - h) / 2
|
|
beta_x = L + (w * h - w) / 2
|
|
grid = get_meshgrid_nd((beta_y, beta_x), (beta_y + h, beta_x + w)) # [2, h, w]
|
|
grid = grid.reshape(2, -1) # (y, x)
|
|
y_sections.append(grid[0])
|
|
x_sections.append(grid[1])
|
|
# step
|
|
last_pos = L + w * h
|
|
# final text
|
|
y_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id]))
|
|
x_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id]))
|
|
|
|
x_pos = torch.cat(x_sections).long()
|
|
y_pos = torch.cat(y_sections).long()
|
|
# If there are overlap positions, we need to remove them.
|
|
x_pos = x_pos[:seq_len]
|
|
y_pos = y_pos[:seq_len]
|
|
all_pos = torch.stack((y_pos, x_pos), dim=1).unsqueeze(1).to(device) # [seq_len, 1, 2]
|
|
|
|
# calc rope
|
|
idx_theta = (all_pos * theta).reshape(all_pos.shape[0], n_elem // 2).repeat(1, 2)
|
|
|
|
cos = torch.cos(idx_theta)
|
|
sin = torch.sin(idx_theta)
|
|
|
|
if return_all_pos:
|
|
return cos, sin, all_pos
|
|
|
|
return cos, sin
|
|
|
|
|
|
def build_batch_2d_rope(
|
|
seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None,
|
|
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
|
|
return_all_pos: bool = False,
|
|
):
|
|
cos_list, sin_list, all_pos_list = [], [], []
|
|
if image_infos is None:
|
|
image_infos = [None]
|
|
for i, image_info in enumerate(image_infos):
|
|
res = build_2d_rope(
|
|
seq_len, n_elem, image_infos=image_info, device=device,
|
|
base=base, base_rescale_factor=base_rescale_factor,
|
|
return_all_pos=return_all_pos,
|
|
)
|
|
if return_all_pos:
|
|
cos, sin, all_pos = res
|
|
else:
|
|
cos, sin = res
|
|
all_pos = None
|
|
cos_list.append(cos)
|
|
sin_list.append(sin)
|
|
all_pos_list.append(all_pos)
|
|
|
|
stacked_cos = torch.stack(cos_list, dim=0)
|
|
stacked_sin = torch.stack(sin_list, dim=0)
|
|
|
|
if return_all_pos:
|
|
return stacked_cos, stacked_sin, all_pos_list
|
|
|
|
return stacked_cos, stacked_sin
|
|
|
|
|
|
def rotate_half(x):
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2:]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
if position_ids is not None:
|
|
cos = cos[position_ids]
|
|
sin = sin[position_ids]
|
|
|
|
cos = cos.unsqueeze(unsqueeze_dim)
|
|
sin = sin.unsqueeze(unsqueeze_dim)
|
|
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed, k_embed
|
|
|
|
def default(val, d):
|
|
return val if val is not None else d
|
|
|
|
def conv_nd(dims, *args, **kwargs):
|
|
if dims == 1:
|
|
return nn.Conv1d(*args, **kwargs)
|
|
elif dims == 2:
|
|
return nn.Conv2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return nn.Conv3d(*args, **kwargs)
|
|
|
|
def normalization(channels, **kwargs):
|
|
return nn.GroupNorm(32, channels, **kwargs)
|
|
|
|
def topkgating(logits: torch.Tensor, topk: int):
|
|
logits = logits.float()
|
|
gates = F.softmax(logits, dim=1)
|
|
|
|
num_experts = int(gates.shape[1])
|
|
|
|
_, expert_index = torch.topk(gates, topk)
|
|
expert_mask = F.one_hot(expert_index, num_experts)
|
|
|
|
expert_index_flat = expert_index.flatten()
|
|
tokens_per_expert = torch.bincount(expert_index_flat, minlength=num_experts)
|
|
expert_capacity = torch.max(tokens_per_expert).item()
|
|
|
|
gates_s = torch.clamp(
|
|
torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps
|
|
)
|
|
router_probs = gates / gates_s
|
|
|
|
expert_index = torch.transpose(expert_index, 0, 1)
|
|
expert_index = expert_index.reshape(-1)
|
|
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
|
|
|
|
token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1
|
|
token_priority = token_priority.reshape((topk, -1, num_experts))
|
|
token_priority = torch.transpose(token_priority, 0, 1)
|
|
|
|
token_priority = torch.max(token_priority, dim=1)[0]
|
|
|
|
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
|
|
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
|
|
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
|
|
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity)
|
|
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
|
|
|
|
combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
|
|
|
|
return combine_weights, dispatch_mask
|
|
|
|
class HunyuanRMSNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-6):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
class UNetDown(nn.Module):
|
|
def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
|
|
dropout=0.0, device=None, dtype=None):
|
|
factory_kwargs = {'dtype': dtype, 'device': device}
|
|
super().__init__()
|
|
|
|
self.patch_size = patch_size
|
|
assert self.patch_size in [1, 2, 4, 8]
|
|
|
|
self.model = nn.ModuleList(
|
|
[conv_nd(
|
|
2,
|
|
in_channels=in_channels,
|
|
out_channels=hidden_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
**factory_kwargs
|
|
)]
|
|
)
|
|
|
|
if self.patch_size == 1:
|
|
self.model.append(ResBlock(
|
|
channels=hidden_channels,
|
|
emb_channels=emb_channels,
|
|
out_channels=out_channels,
|
|
use_scale_shift_norm = True,
|
|
dropout=dropout,
|
|
**factory_kwargs
|
|
))
|
|
else:
|
|
for i in range(self.patch_size // 2):
|
|
self.model.append(ResBlock(
|
|
channels=hidden_channels,
|
|
emb_channels=emb_channels,
|
|
use_scale_shift_norm = True,
|
|
out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
|
|
dropout=dropout,
|
|
down=True,
|
|
**factory_kwargs
|
|
))
|
|
|
|
def forward(self, x, t):
|
|
assert x.shape[2] % self.patch_size == 0 and x.shape[3] % self.patch_size == 0
|
|
for module in self.model:
|
|
if isinstance(module, ResBlock):
|
|
x = module(x, t)
|
|
else:
|
|
x = module(x)
|
|
_, _, token_h, token_w = x.shape
|
|
x = rearrange(x, 'b c h w -> b (h w) c')
|
|
return x, token_h, token_w
|
|
|
|
|
|
class UNetUp(nn.Module):
|
|
|
|
def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
|
|
dropout=0.0, device=None, dtype=None, operations = None, out_norm=False):
|
|
operations = operations or nn
|
|
factory_kwargs = {'dtype': dtype, 'device': device}
|
|
super().__init__()
|
|
|
|
self.patch_size = patch_size
|
|
assert self.patch_size in [1, 2, 4, 8]
|
|
|
|
self.model = nn.ModuleList()
|
|
|
|
if self.patch_size == 1:
|
|
self.model.append(ResBlock(
|
|
channels=in_channels,
|
|
emb_channels=emb_channels,
|
|
out_channels=hidden_channels,
|
|
use_scale_shift_norm = True,
|
|
dropout=dropout,
|
|
**factory_kwargs
|
|
))
|
|
else:
|
|
for i in range(self.patch_size // 2):
|
|
self.model.append(ResBlock(
|
|
channels=in_channels if i == 0 else hidden_channels,
|
|
emb_channels=emb_channels,
|
|
out_channels=hidden_channels,
|
|
use_scale_shift_norm = True,
|
|
dropout=dropout,
|
|
up=True,
|
|
**factory_kwargs
|
|
))
|
|
|
|
if out_norm:
|
|
self.model.append(nn.Sequential(
|
|
normalization(hidden_channels, **factory_kwargs),
|
|
nn.SiLU(),
|
|
nn.Conv2d(
|
|
in_channels=hidden_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
**factory_kwargs
|
|
),
|
|
))
|
|
else:
|
|
self.model.append(nn.Conv2d(
|
|
in_channels=hidden_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
**factory_kwargs
|
|
))
|
|
|
|
# batch_size, seq_len, model_dim
|
|
def forward(self, x, t, token_h, token_w):
|
|
x = rearrange(x, 'b (h w) c -> b c h w', h=token_h, w=token_w)
|
|
for module in self.model:
|
|
if isinstance(module, ResBlock):
|
|
x = module(x, t)
|
|
else:
|
|
x = module(x)
|
|
return x
|
|
|
|
class HunyuanTopKGate(nn.Module):
|
|
def __init__(self, config, layer_idx: Optional[int] = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.moe_topk = 8
|
|
self.min_capacity = 8
|
|
num_experts = 64
|
|
self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32)
|
|
|
|
def forward(self, hidden_states):
|
|
bsz, seq_len, hidden_size = hidden_states.shape
|
|
hidden_states = hidden_states.reshape(-1, hidden_size)
|
|
if self.wg.weight.dtype == torch.float32:
|
|
hidden_states = hidden_states.float()
|
|
logits = self.wg(hidden_states)
|
|
gate_output = topkgating(logits, self.moe_topk)
|
|
|
|
return gate_output
|
|
|
|
class HunyuanMLP(nn.Module):
|
|
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.hidden_size = config["hidden_size"]
|
|
|
|
self.intermediate_size = 3072
|
|
|
|
self.act_fn = torch.nn.functional.silu
|
|
self.intermediate_size *= 2 # SwiGLU
|
|
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device)
|
|
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device)
|
|
def forward(self, x):
|
|
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
|
if x.ndim == 2:
|
|
x = x.unsqueeze(0)
|
|
gate_and_up_proj = self.gate_and_up_proj(x)
|
|
x1, x2 = gate_and_up_proj.chunk(2, dim=2)
|
|
down_proj = self.down_proj(x1 * self.act_fn(x2))
|
|
return down_proj
|
|
|
|
class MoELRUCache(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gpu_cache = OrderedDict()
|
|
self.cpu_cache = OrderedDict()
|
|
self.offload_stream = torch.cuda.Stream()
|
|
self.load_stream = torch.cuda.Stream()
|
|
|
|
self.last_offload_event = None
|
|
self._loop = asyncio.new_event_loop()
|
|
threading.Thread(target=self._loop.run_forever, daemon=True).start()
|
|
|
|
async def _async_offload_to_cpu(self, layer_idx):
|
|
# async offload from gpu (removed)
|
|
|
|
num_experts = 64
|
|
moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i])
|
|
for i in range(num_experts)
|
|
if (layer_idx * num_experts + i) in self.gpu_cache]
|
|
event = torch.cuda.Event()
|
|
|
|
with torch.cuda.stream(self.offload_stream):
|
|
for index, moe in moe_group:
|
|
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
|
|
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
|
|
if p_gpu.device.type == "meta":
|
|
continue
|
|
with torch.no_grad():
|
|
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
|
|
p_cpu.copy_(p_gpu, non_blocking=True)
|
|
|
|
self.cpu_cache[index] = moe_cpu
|
|
|
|
self.offload_stream.record_event(event)
|
|
|
|
self.last_offload_event = event
|
|
|
|
def finalize_offload_layer():
|
|
event.synchronize()
|
|
for index, moe in moe_group:
|
|
moe.to("meta")
|
|
self.gpu_cache.pop(index, None)
|
|
del moe
|
|
torch.cuda.empty_cache()
|
|
|
|
threading.Thread(target=finalize_offload_layer, daemon=True).start()
|
|
|
|
async def _async_load_to_gpu(self, index, moe):
|
|
|
|
# if enough memory load, otherwise wait for offload
|
|
while True:
|
|
free_bytes, _ = torch.cuda.mem_get_info()
|
|
if free_bytes > 2 * MOE_LAYER_SIZE:
|
|
break
|
|
|
|
self.last_offload_event.synchronize()
|
|
torch.cuda.empty_cache()
|
|
await asyncio.sleep(0.01)
|
|
|
|
# async loading from cpu -> gpu
|
|
with torch.cuda.stream(self.load_stream):
|
|
moe_gpu = HunyuanMLP(moe.config).to("cuda", non_blocking=True)
|
|
for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()):
|
|
with torch.no_grad():
|
|
p_gpu.data = torch.empty_like(p_cpu, device="cuda")
|
|
p_gpu.copy_(p_cpu, non_blocking=True)
|
|
|
|
def finalize_load():
|
|
self.gpu_cache[index] = moe_gpu
|
|
self.cpu_cache.pop(index, None)
|
|
|
|
threading.Thread(target=finalize_load, daemon=True).start()
|
|
|
|
def add_cpu(self, moe, index):
|
|
moe_cpu = moe.to("cpu")
|
|
|
|
for _, p in moe_cpu.named_parameters():
|
|
if not p.is_pinned():
|
|
if p.device.type == "cpu":
|
|
p.data = p.data.pin_memory()
|
|
else:
|
|
return
|
|
|
|
self.cpu_cache[index] = moe_cpu
|
|
self.cpu_cache.move_to_end(index)
|
|
|
|
class LazyMoELoader(nn.Module):
|
|
def __init__(self, cache, config):
|
|
super().__init__()
|
|
self.cache = cache
|
|
self.config = config
|
|
self._loop = cache._loop
|
|
|
|
def get_checkpoint(self):
|
|
comfyui_dir = Path.home() / "ComfyUI"
|
|
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
|
|
checkpoint = checkpoint.resolve()
|
|
if not os.path.exists(checkpoint):
|
|
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
|
|
return checkpoint
|
|
|
|
def lazy_init(self, layer_idx, expert_idx):
|
|
checkpoint = self.get_checkpoint()
|
|
prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}."
|
|
additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight"
|
|
sd = {}
|
|
|
|
with safe_open(checkpoint, framework="pt", device="cpu") as f:
|
|
for k in f.keys():
|
|
if k.startswith(prefix) or k.startswith(additional_prefix):
|
|
new_k = k.split(f"experts.{expert_idx}.", 1)[1]
|
|
sd[new_k] = f.get_tensor(k)
|
|
|
|
model = HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True, device="meta")
|
|
model.to_empty(device = "cpu")
|
|
|
|
model.load_state_dict(sd)
|
|
return model
|
|
|
|
async def lazy_load_from_disk(self, layer_idx, expert_idx):
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
|
|
|
|
def _schedule_disk_load(self, layer_idx, expert_idx):
|
|
|
|
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
|
|
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
|
|
def _on_disk_loaded(fut):
|
|
moe_cpu = fut.result()
|
|
def _add_cpu_in_main_thread():
|
|
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
|
|
|
|
asyncio.run_coroutine_threadsafe(
|
|
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
|
|
self.cache._loop
|
|
)
|
|
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
|
|
|
|
future.add_done_callback(_on_disk_loaded)
|
|
return future
|
|
|
|
def enough_vram(required_bytes):
|
|
free, total = torch.cuda.mem_get_info()
|
|
return free > required_bytes
|
|
|
|
class HunyuanMoE(nn.Module):
|
|
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.moe_topk = 8
|
|
self.num_experts = 64
|
|
self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
|
|
self.gate = HunyuanTopKGate(config, layer_idx=layer_idx)
|
|
if INIT_MOE:
|
|
self.experts = nn.ModuleList(
|
|
[HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)]
|
|
)
|
|
else:
|
|
self.experts = []
|
|
self.moe_lru = moe_lru
|
|
|
|
def forward(self, hidden_states):
|
|
if not INIT_MOE:
|
|
torch.cuda.set_device(0)
|
|
else:
|
|
torch.cuda.set_device(hidden_states.device.index)
|
|
bsz, seq_len, hidden_size = hidden_states.shape
|
|
|
|
hidden_states_mlp = self.shared_mlp(hidden_states)
|
|
|
|
reshaped_input = hidden_states.reshape(-1, hidden_size)
|
|
|
|
with torch.cuda.nvtx.range("MoE"):
|
|
combine_weights, dispatch_mask = self.gate(hidden_states)
|
|
|
|
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(reshaped_input), reshaped_input)
|
|
device = hidden_states.device
|
|
dtype = reshaped_input.dtype
|
|
|
|
used_mask = (dispatch_mask.sum(dim=(0, 2)) > 0)
|
|
used_indices = used_mask.nonzero(as_tuple=False).squeeze(1).tolist()
|
|
|
|
combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype)
|
|
|
|
if len(used_indices) == 0:
|
|
pass
|
|
else:
|
|
tokens_padded = dispatched_input[used_indices]
|
|
|
|
l1_layers, l2_layers = [], []
|
|
for i in used_indices:
|
|
expert = self.experts[i]
|
|
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
|
expert = expert.result()
|
|
expert = expert.to(device)
|
|
l1_layers.append(expert.gate_and_up_proj)
|
|
l2_layers.append(expert.down_proj)
|
|
|
|
l1, l2 = [], []
|
|
for i in used_indices:
|
|
expert = self.experts[i]
|
|
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
|
expert = expert.result()
|
|
expert = expert.to(device)
|
|
l1.append(expert.gate_and_up_proj)
|
|
l2.append(expert.down_proj)
|
|
|
|
compute_device = hidden_states.device
|
|
l1 = [m.to(compute_device) for m in l1]
|
|
l2 = [m.to(compute_device) for m in l2]
|
|
|
|
W1 = torch.stack([m.weight for m in l1], dim=0)
|
|
W2 = torch.stack([m.weight for m in l2], dim=0)
|
|
|
|
W1_T = W1.transpose(1, 2)
|
|
W2_T = W2.transpose(1, 2)
|
|
|
|
while not enough_vram(3*(1024 ** 3)):
|
|
event = self.moe_lru.last_offload_event
|
|
if event is not None and not event.query():
|
|
time.sleep(0.001)
|
|
|
|
x = torch.bmm(tokens_padded, W1_T)
|
|
x1, x2 = x.chunk(2, dim=2)
|
|
gated = x1 * F.silu(x2)
|
|
out_padded = torch.bmm(gated, W2_T)
|
|
|
|
combine_weights_used = combine_weights[:, used_indices, :]
|
|
|
|
combined_output = torch.einsum("suc,ucm->sm",
|
|
combine_weights_used.type_as(out_padded),
|
|
out_padded
|
|
)
|
|
|
|
del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded
|
|
|
|
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
|
|
|
output = hidden_states_mlp + combined_output
|
|
|
|
return output
|
|
|
|
class HunyuanImage3Attention(nn.Module):
|
|
def __init__(self, config, layer_idx: int):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.attention_type = 'self'
|
|
|
|
self.hidden_size = config["hidden_size"]
|
|
self.num_heads = config["num_attention_heads"]
|
|
self.head_dim = config["attention_head_dim"]
|
|
self.num_key_value_heads = 8
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.max_position_embeddings = config["max_position_embeddings"]
|
|
self.rope_theta = 10000.0
|
|
self.is_causal = True
|
|
self.hidden_size_q = self.head_dim * self.num_heads
|
|
self.hidden_size_kv = self.head_dim * self.num_key_value_heads
|
|
|
|
# define layers
|
|
self.qkv_proj = nn.Linear(
|
|
self.hidden_size,
|
|
self.hidden_size_q + 2 * self.hidden_size_kv,
|
|
bias=False
|
|
)
|
|
self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=False)
|
|
|
|
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"])
|
|
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"])
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
past_key_value,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
|
**kwargs,
|
|
):
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
qkv_states = self.qkv_proj(hidden_states)
|
|
qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2,
|
|
self.head_dim)
|
|
query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3)
|
|
|
|
query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
cos, sin = custom_pos_emb
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
query_states = self.query_layernorm(query_states)
|
|
key_states = self.key_layernorm(key_states)
|
|
|
|
query_states = query_states.to(value_states.dtype)
|
|
key_states = key_states.to(value_states.dtype)
|
|
|
|
if past_key_value is not None:
|
|
cache_kwargs = {"cache_position": position_ids}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
query_states = query_states.to(key_states.dtype)
|
|
|
|
key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups)
|
|
value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups)
|
|
|
|
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
query_states = query_states.contiguous()
|
|
key_states = key_states.contiguous()
|
|
value_states = value_states.contiguous()
|
|
|
|
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, mask = attention_mask, skip_reshape=True)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, past_key_value
|
|
|
|
class HunyuanImage3DecoderLayer(nn.Module):
|
|
def __init__(self, config, layer_idx: int, moe_lru=None):
|
|
super().__init__()
|
|
self.hidden_size = config["hidden_size"]
|
|
self.layer_idx = layer_idx
|
|
self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx)
|
|
|
|
self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru)
|
|
|
|
self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
|
self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps'])
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
use_cache: Optional[bool] = False,
|
|
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
|
**kwargs,
|
|
) -> Tuple[torch.FloatTensor | Any]:
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, past_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
use_cache=use_cache,
|
|
custom_pos_emb=custom_pos_emb,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states, past_key_value)
|
|
|
|
return outputs
|
|
|
|
class HunyuanImage3Model(nn.Module):
|
|
def __init__(self, config, moe_lru=None):
|
|
super().__init__()
|
|
self.padding_idx = 128009
|
|
self.vocab_size = 133120
|
|
self.config = config
|
|
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
|
|
self.layers = nn.ModuleList(
|
|
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
|
|
)
|
|
|
|
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
|
|
|
self.shared_tensor = None
|
|
self.moe_lru = moe_lru
|
|
self.self.additional_layers_set = False
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache = True,
|
|
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
|
mode: str = "gen_image",
|
|
first_step: Optional[bool] = None,
|
|
gen_timestep_scatter_index: Optional[torch.Tensor] = None,
|
|
):
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.wte(input_ids)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
next_decoder_cache = None
|
|
next_layers = 0
|
|
additional_layers = torch.cuda.mem_get_info()[0] // (MOE_LAYER_SIZE * 2)
|
|
sparse_interval = max(1, len(self.layers) // additional_layers)
|
|
|
|
if len(self.layers[0].mlp.experts) == 0:
|
|
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
|
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
|
|
|
|
for layer_idx, decoder_layer in enumerate(self.layers):
|
|
|
|
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
|
|
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
|
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
|
|
|
|
if not self.additional_layers_set:
|
|
if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval:
|
|
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
|
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
|
|
next_layers += 1
|
|
|
|
with torch.no_grad():
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
use_cache=use_cache,
|
|
custom_pos_emb=custom_pos_emb,
|
|
mode=mode,
|
|
first_step=first_step,
|
|
gen_timestep_scatter_index=gen_timestep_scatter_index,
|
|
)
|
|
|
|
if layer_idx >= 0:
|
|
if self.additional_layers_set and layer_idx <= self.additional_layers_set:
|
|
pass
|
|
else:
|
|
torch.cuda.synchronize()
|
|
asyncio.run_coroutine_threadsafe(
|
|
self.moe_lru._async_offload_to_cpu(layer_idx),
|
|
self.moe_lru._loop
|
|
)
|
|
self.layers[layer_idx].mlp.experts = []
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache = layer_outputs[1]
|
|
|
|
next_cache = None
|
|
if use_cache:
|
|
next_cache = next_decoder_cache
|
|
self.additional_layers_set = True
|
|
return tuple(v for v in [hidden_states, next_cache] if v is not None)
|
|
|
|
|
|
class HunyuanImage3ForCausalMM(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
|
|
self.patch_embed = UNetDown(
|
|
patch_size=1,
|
|
emb_channels=config["hidden_size"],
|
|
in_channels=32,
|
|
hidden_channels=1024,
|
|
out_channels=config["hidden_size"],
|
|
)
|
|
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"])
|
|
|
|
self.final_layer = UNetUp(
|
|
patch_size=1,
|
|
emb_channels=config["hidden_size"],
|
|
in_channels=config["hidden_size"],
|
|
hidden_channels=1024,
|
|
out_channels=32,
|
|
out_norm=True,
|
|
)
|
|
self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"])
|
|
|
|
self.moe_lru = None
|
|
if not INIT_MOE:
|
|
self.moe_lru = MoELRUCache()
|
|
|
|
self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru)
|
|
|
|
self.pad_id = 128009
|
|
self.vocab_size = 133120
|
|
|
|
self.lm_head = nn.Linear(config["hidden_size"], 133120, bias=False)
|
|
self.first_step = True
|
|
|
|
self.kv_cache = None
|
|
self.token_dims = ()
|
|
|
|
@staticmethod
|
|
def get_pos_emb(custom_pos_emb, position_ids):
|
|
cos, sin = custom_pos_emb
|
|
cos = real_batched_index_select(cos, dim=1, idx=position_ids)
|
|
sin = real_batched_index_select(sin, dim=1, idx=position_ids)
|
|
return cos, sin
|
|
|
|
def ragged_final_layer(self, x, image_mask, timestep, token_h, token_w, first_step):
|
|
bsz, seq_len, n_embd = x.shape
|
|
if first_step:
|
|
image_output = x.masked_select(image_mask.unsqueeze(-1).bool()).reshape(bsz, -1, n_embd)
|
|
else:
|
|
image_output = x[:, 1:, :]
|
|
timestep_emb = self.time_embed_2(timestep)
|
|
pred = self.final_layer(image_output, timestep_emb, token_h, token_w)
|
|
return pred
|
|
|
|
def forward(self, x, condition, timestep, **kwargs):
|
|
|
|
joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
|
|
|
|
gen_timestep_scatter_index = 4
|
|
|
|
with torch.no_grad():
|
|
joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio
|
|
|
|
if self.first_step:
|
|
token_height, token_width = x[:, -2:, 0].tolist()[0]
|
|
self.token_dims = (int(token_height), int(token_width))
|
|
x = x[:, :-2, :]
|
|
else:
|
|
token_height, token_width = self.token_dims
|
|
|
|
img_slices = []
|
|
|
|
for i in range(x.size(0)):
|
|
vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0]
|
|
vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1
|
|
|
|
vit_start = vae_end + 1
|
|
vit_end = joint_image.size(1) - 1
|
|
|
|
joint_slices_i = [
|
|
slice(vae_start, vae_end),
|
|
slice(vit_start, vit_end),
|
|
]
|
|
gen_slices_i = [slice(3 + vit_end, x[i].size(0) - 1 + vit_end)]
|
|
img_slices.append(joint_slices_i + gen_slices_i)
|
|
|
|
img_s = img_slices[0]
|
|
rope_image_info = [[(img_s[0], (384 // 16, 384 // 16)), (img_s[1], (256 // 16, 256 // 16)), (img_s[2], (token_height, token_width))]]
|
|
|
|
cond_timestep = torch.zeros(inputs_embeds.size(0))
|
|
t_emb = self.time_embed(cond_timestep)
|
|
|
|
bsz, seq_len, n_embd = inputs_embeds.shape
|
|
|
|
if self.first_step:
|
|
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
|
else:
|
|
t_emb = self.time_embed(timestep)
|
|
x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb)
|
|
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
|
|
x = torch.cat([timestep_emb, x], dim=1)
|
|
|
|
#/////////////
|
|
# cond_vae_images
|
|
|
|
# cond_timestep_scatter_index
|
|
with torch.no_grad():
|
|
joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
|
|
|
inputs_embeds = torch.cat([inputs_embeds, joint_image, x], dim = 1)
|
|
|
|
attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
|
|
for i in range(bsz):
|
|
for _, image_slice in enumerate(img_slices[i]):
|
|
attention_mask[i, image_slice, image_slice] = True
|
|
attention_mask = attention_mask.unsqueeze(1)
|
|
|
|
# pos embed
|
|
position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
|
cos, sin = build_batch_2d_rope(
|
|
image_infos=rope_image_info,
|
|
seq_len=inputs_embeds.shape[1],
|
|
n_elem=self.config["hidden_size"] // self.config["num_attention_heads"],
|
|
base=10000.0,
|
|
)
|
|
custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device))
|
|
custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids)
|
|
|
|
if self.kv_cache is None:
|
|
# TODO: should change when higgsv2 gets merged
|
|
self.kv_cache = HunyuanStaticCache(
|
|
config=self.config,
|
|
batch_size=x.size(0) * 2,
|
|
max_cache_len = inputs_embeds.shape[1],
|
|
dtype=x.dtype,
|
|
)
|
|
|
|
outputs = self.model(
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=self.kv_cache,
|
|
inputs_embeds=inputs_embeds,
|
|
custom_pos_emb=custom_pos_emb,
|
|
first_step=self.first_step,
|
|
gen_timestep_scatter_index=gen_timestep_scatter_index,
|
|
)
|
|
hidden_states = outputs[0]
|
|
|
|
# safety no-op
|
|
past_key_value = outputs[1]
|
|
if past_key_value is not None:
|
|
self.kv_cache = past_key_value
|
|
|
|
hidden_states = hidden_states.to(inputs_embeds.device)
|
|
img_mask = torch.zeros(hidden_states.size(1))
|
|
img_mask[-x.size(1)+4:] = 1; img_mask[-1] = 0
|
|
|
|
diffusion_prediction = self.ragged_final_layer(
|
|
hidden_states, img_mask, timestep, int(token_height), int(token_width), self.first_step)
|
|
|
|
if self.first_step:
|
|
self.first_step = False
|
|
|
|
return diffusion_prediction
|