mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-18 18:30:19 +08:00
1187 lines
47 KiB
Python
1187 lines
47 KiB
Python
import os
|
|
import math
|
|
import torch
|
|
import asyncio
|
|
import threading
|
|
import torch.nn as nn
|
|
from pathlib import Path
|
|
import concurrent.futures
|
|
from einops import rearrange
|
|
import torch.nn.functional as F
|
|
from collections import OrderedDict
|
|
from safetensors import safe_open
|
|
from transformers.cache_utils import StaticCache
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Optional, Tuple, Any, List, Dict
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
from comfy_extras.nodes_hunyuan_image import COMPUTED_RESO_GROUPS
|
|
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
|
|
|
|
INIT_MOE = torch.cuda.device_count() != 1
|
|
MOE_LAYER_SIZE = (1024**3) * 5.15 # approx
|
|
|
|
class HunyuanStaticCache(StaticCache):
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
layer_idx: int,
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
cache_position = cache_kwargs.get("cache_position")
|
|
if hasattr(self, "key_cache") and hasattr(self, "value_cache"):
|
|
if self.key_cache[layer_idx].device != key_states.device:
|
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
|
k_out = self.key_cache[layer_idx]
|
|
v_out = self.value_cache[layer_idx]
|
|
key_states = key_states.to(k_out.dtype)
|
|
value_states = value_states.to(v_out.dtype)
|
|
else:
|
|
if self.layers[layer_idx].keys is None:
|
|
self.layers[layer_idx].lazy_initialization(key_states)
|
|
k_out = self.layers[layer_idx].keys
|
|
v_out = self.layers[layer_idx].values
|
|
|
|
if cache_position is None:
|
|
k_out.copy_(key_states)
|
|
v_out.copy_(value_states)
|
|
else:
|
|
if cache_position.dim() == 1:
|
|
k_out.index_copy_(2, cache_position, key_states)
|
|
v_out.index_copy_(2, cache_position, value_states)
|
|
|
|
else:
|
|
assert cache_position.dim() == 2, f"multiple batch dims not yet {cache_position.shape=}"
|
|
batch_size, _ = cache_position.shape
|
|
for i in range(batch_size):
|
|
unbatched_dim = 1
|
|
k_out[i].index_copy_(unbatched_dim, cache_position[i], key_states[i])
|
|
v_out[i].index_copy_(unbatched_dim, cache_position[i], value_states[i])
|
|
|
|
return k_out, v_out
|
|
|
|
def real_batched_index_select(t, dim, idx):
|
|
return torch.stack([torch.index_select(t[i], dim - 1, idx[i]) for i in range(len(t))])
|
|
|
|
def timestep_embedding(t, dim, max_period=10000):
|
|
half = dim // 2
|
|
freqs = torch.exp(
|
|
-math.log(max_period)
|
|
* torch.arange(start=0, end=half, dtype=torch.float32)
|
|
/ half
|
|
).to(device=t.device)
|
|
args = t[:, None].float() * freqs[None]
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat(
|
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
|
)
|
|
return embedding
|
|
|
|
class TimestepEmbedder(nn.Module):
|
|
def __init__(self,
|
|
hidden_size,
|
|
act_layer=nn.GELU,
|
|
frequency_embedding_size=256,
|
|
max_period=10000,
|
|
out_size=None,
|
|
dtype=None,
|
|
device=None,
|
|
operations = None
|
|
):
|
|
factory_kwargs = {'dtype': dtype, 'device': device}
|
|
super().__init__()
|
|
self.frequency_embedding_size = frequency_embedding_size
|
|
self.max_period = max_period
|
|
if out_size is None:
|
|
out_size = hidden_size
|
|
|
|
self.mlp = nn.Sequential(
|
|
operations.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
|
act_layer(),
|
|
operations.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
|
)
|
|
|
|
def forward(self, t):
|
|
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
|
t_emb = self.mlp(t_freq)
|
|
return t_emb
|
|
|
|
|
|
def _to_tuple(x, dim=2):
|
|
if isinstance(x, int):
|
|
return (x,) * dim
|
|
return x
|
|
|
|
|
|
def get_meshgrid_nd(start, *args, dim=2):
|
|
|
|
start = _to_tuple(start, dim=dim)
|
|
stop = _to_tuple(args[0], dim=dim)
|
|
num = [stop[i] - start[i] for i in range(dim)]
|
|
num_int = [int(x) for x in num]
|
|
num = num_int
|
|
|
|
axis_grid = []
|
|
for i in range(dim):
|
|
a, b, n = start[i], stop[i], num[i]
|
|
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
|
axis_grid.append(g)
|
|
grid = torch.meshgrid(*axis_grid, indexing="ij")
|
|
grid = torch.stack(grid, dim=0)
|
|
|
|
return grid
|
|
|
|
def build_2d_rope(
|
|
seq_len: int, n_elem: int, image_infos: Optional[List[Tuple[slice, Tuple[int, int]]]] = None,
|
|
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
|
|
return_all_pos: bool = False,
|
|
):
|
|
|
|
assert n_elem % 4 == 0, f"n_elem must be divisible by 4, but got {n_elem}."
|
|
|
|
# theta
|
|
if base_rescale_factor != 1.0:
|
|
base *= base_rescale_factor ** (n_elem / (n_elem - 2))
|
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
|
|
theta = theta.reshape(1, n_elem // 4, 2) # [1, half_d, 2]
|
|
|
|
# position indices
|
|
if image_infos is None:
|
|
image_infos = []
|
|
|
|
image_infos_list = [image_infos]
|
|
sample_seq_lens = [seq_len]
|
|
|
|
x_sections = []
|
|
y_sections = []
|
|
for sample_id, sample_image_infos in enumerate(image_infos_list):
|
|
last_pos = 0
|
|
for sec_slice, (h, w) in sample_image_infos:
|
|
L = sec_slice.start # start from 0, so image_slice.start is just L
|
|
# previous text
|
|
if last_pos < L:
|
|
y_sections.append(torch.arange(last_pos, L))
|
|
x_sections.append(torch.arange(last_pos, L))
|
|
elif h is None:
|
|
y_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
|
x_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
|
|
continue
|
|
else:
|
|
pass
|
|
# current image
|
|
beta_y = L + (w * h - h) / 2
|
|
beta_x = L + (w * h - w) / 2
|
|
grid = get_meshgrid_nd((beta_y, beta_x), (beta_y + h, beta_x + w)) # [2, h, w]
|
|
grid = grid.reshape(2, -1) # (y, x)
|
|
y_sections.append(grid[0])
|
|
x_sections.append(grid[1])
|
|
# step
|
|
last_pos = L + w * h
|
|
# final text
|
|
y_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id]))
|
|
x_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id]))
|
|
|
|
x_pos = torch.cat(x_sections).long()
|
|
y_pos = torch.cat(y_sections).long()
|
|
# If there are overlap positions, we need to remove them.
|
|
x_pos = x_pos[:seq_len]
|
|
y_pos = y_pos[:seq_len]
|
|
all_pos = torch.stack((y_pos, x_pos), dim=1).unsqueeze(1).to(device) # [seq_len, 1, 2]
|
|
|
|
# calc rope
|
|
idx_theta = (all_pos * theta).reshape(all_pos.shape[0], n_elem // 2).repeat(1, 2)
|
|
|
|
cos = torch.cos(idx_theta)
|
|
sin = torch.sin(idx_theta)
|
|
|
|
if return_all_pos:
|
|
return cos, sin, all_pos
|
|
|
|
return cos, sin
|
|
|
|
|
|
def build_batch_2d_rope(
|
|
seq_len: int, n_elem: int, image_infos = None,
|
|
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
|
|
return_all_pos: bool = False,
|
|
):
|
|
cos_list, sin_list, all_pos_list = [], [], []
|
|
if image_infos is None:
|
|
image_infos = [None]
|
|
for i, image_info in enumerate(image_infos):
|
|
res = build_2d_rope(
|
|
seq_len, n_elem, image_infos=image_info, device=device,
|
|
base=base, base_rescale_factor=base_rescale_factor,
|
|
return_all_pos=return_all_pos,
|
|
)
|
|
if return_all_pos:
|
|
cos, sin, all_pos = res
|
|
else:
|
|
cos, sin = res
|
|
all_pos = None
|
|
cos_list.append(cos)
|
|
sin_list.append(sin)
|
|
all_pos_list.append(all_pos)
|
|
|
|
stacked_cos = torch.stack(cos_list, dim=0)
|
|
stacked_sin = torch.stack(sin_list, dim=0)
|
|
|
|
if return_all_pos:
|
|
return stacked_cos, stacked_sin, all_pos_list
|
|
|
|
return stacked_cos, stacked_sin
|
|
|
|
|
|
def rotate_half(x):
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2:]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
if position_ids is not None:
|
|
cos = cos[position_ids]
|
|
sin = sin[position_ids]
|
|
|
|
cos = cos.unsqueeze(unsqueeze_dim)
|
|
sin = sin.unsqueeze(unsqueeze_dim)
|
|
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed, k_embed
|
|
|
|
def default(val, d):
|
|
return val if val is not None else d
|
|
|
|
def topkgating(logits: torch.Tensor, topk: int):
|
|
logits = logits.float()
|
|
gates = F.softmax(logits, dim=1)
|
|
|
|
num_experts = int(gates.shape[1])
|
|
|
|
_, expert_index = torch.topk(gates, topk)
|
|
expert_mask = F.one_hot(expert_index, num_experts)
|
|
|
|
expert_index_flat = expert_index.flatten()
|
|
tokens_per_expert = torch.bincount(expert_index_flat, minlength=num_experts)
|
|
expert_capacity = torch.max(tokens_per_expert).item()
|
|
|
|
gates_s = torch.clamp(
|
|
torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps
|
|
)
|
|
router_probs = gates / gates_s
|
|
|
|
expert_index = torch.transpose(expert_index, 0, 1)
|
|
expert_index = expert_index.reshape(-1)
|
|
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
|
|
|
|
token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1
|
|
token_priority = token_priority.reshape((topk, -1, num_experts))
|
|
token_priority = torch.transpose(token_priority, 0, 1)
|
|
|
|
token_priority = torch.max(token_priority, dim=1)[0]
|
|
|
|
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
|
|
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
|
|
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
|
|
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity)
|
|
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
|
|
|
|
combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
|
|
|
|
return combine_weights, dispatch_mask
|
|
|
|
class HunyuanRMSNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size, device=device, dtype=dtype))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
class UNetDown(nn.Module):
|
|
def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
|
|
dropout=0.0, device=None, dtype=None, operations=None):
|
|
factory_kwargs = {'dtype': dtype, 'device': device}
|
|
super().__init__()
|
|
|
|
self.patch_size = patch_size
|
|
assert self.patch_size in [1, 2, 4, 8]
|
|
|
|
self.model = nn.ModuleList(
|
|
[operations.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=hidden_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
**factory_kwargs
|
|
)]
|
|
)
|
|
|
|
if self.patch_size == 1:
|
|
self.model.append(ResBlock(
|
|
channels=hidden_channels,
|
|
emb_channels=emb_channels,
|
|
out_channels=out_channels,
|
|
use_scale_shift_norm = True,
|
|
dropout=dropout,
|
|
**factory_kwargs,
|
|
operations = operations
|
|
))
|
|
else:
|
|
for i in range(self.patch_size // 2):
|
|
self.model.append(ResBlock(
|
|
channels=hidden_channels,
|
|
emb_channels=emb_channels,
|
|
use_scale_shift_norm = True,
|
|
out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
|
|
dropout=dropout,
|
|
down=True,
|
|
**factory_kwargs,
|
|
operations = operations
|
|
))
|
|
|
|
def forward(self, x, t):
|
|
assert x.shape[2] % self.patch_size == 0 and x.shape[3] % self.patch_size == 0
|
|
for module in self.model:
|
|
if isinstance(module, ResBlock):
|
|
x = module(x, t)
|
|
else:
|
|
x = module(x)
|
|
_, _, token_h, token_w = x.shape
|
|
x = rearrange(x, 'b c h w -> b (h w) c')
|
|
return x, token_h, token_w
|
|
|
|
|
|
class UNetUp(nn.Module):
|
|
|
|
def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
|
|
dropout=0.0, device=None, dtype=None, operations = None, out_norm=False):
|
|
operations = operations or nn
|
|
factory_kwargs = {'dtype': dtype, 'device': device}
|
|
super().__init__()
|
|
|
|
self.patch_size = patch_size
|
|
assert self.patch_size in [1, 2, 4, 8]
|
|
|
|
self.model = nn.ModuleList()
|
|
|
|
if self.patch_size == 1:
|
|
self.model.append(ResBlock(
|
|
channels=in_channels,
|
|
emb_channels=emb_channels,
|
|
out_channels=hidden_channels,
|
|
use_scale_shift_norm = True,
|
|
dropout=dropout,
|
|
**factory_kwargs,
|
|
operations = operations
|
|
))
|
|
else:
|
|
for i in range(self.patch_size // 2):
|
|
self.model.append(ResBlock(
|
|
channels=in_channels if i == 0 else hidden_channels,
|
|
emb_channels=emb_channels,
|
|
out_channels=hidden_channels,
|
|
use_scale_shift_norm = True,
|
|
dropout=dropout,
|
|
up=True,
|
|
**factory_kwargs,
|
|
operations = operations
|
|
))
|
|
|
|
if out_norm:
|
|
self.model.append(nn.Sequential(
|
|
operations.GroupNorm(32, hidden_channels, **factory_kwargs),
|
|
nn.SiLU(),
|
|
operations.Conv2d(
|
|
in_channels=hidden_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
**factory_kwargs
|
|
),
|
|
))
|
|
else:
|
|
self.model.append(operations.Conv2d(
|
|
in_channels=hidden_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
**factory_kwargs
|
|
))
|
|
|
|
# batch_size, seq_len, model_dim
|
|
def forward(self, x, t, token_h, token_w):
|
|
x = rearrange(x, 'b (h w) c -> b c h w', h=token_h, w=token_w)
|
|
for module in self.model:
|
|
if isinstance(module, ResBlock):
|
|
x = module(x, t)
|
|
else:
|
|
x = module(x)
|
|
return x
|
|
|
|
class HunyuanTopKGate(nn.Module):
|
|
def __init__(self, config, layer_idx: Optional[int] = None, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.moe_topk = 8
|
|
self.min_capacity = 8
|
|
num_experts = 64
|
|
self.wg = operations.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32, device=device)
|
|
|
|
def forward(self, hidden_states):
|
|
bsz, seq_len, hidden_size = hidden_states.shape
|
|
hidden_states = hidden_states.reshape(-1, hidden_size)
|
|
if self.wg.weight.dtype == torch.float32:
|
|
hidden_states = hidden_states.float()
|
|
logits = self.wg(hidden_states)
|
|
gate_output = topkgating(logits, self.moe_topk)
|
|
|
|
return gate_output
|
|
|
|
class HunyuanMLP(nn.Module):
|
|
def __init__(self, config, layer_idx=None, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.hidden_size = config["hidden_size"]
|
|
|
|
self.intermediate_size = 3072
|
|
|
|
self.act_fn = torch.nn.functional.silu
|
|
self.intermediate_size *= 2 # SwiGLU
|
|
self.gate_and_up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device, dtype=dtype)
|
|
self.down_proj = operations.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device, dtype=dtype)
|
|
def forward(self, x):
|
|
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
|
if x.ndim == 2:
|
|
x = x.unsqueeze(0)
|
|
gate_and_up_proj = self.gate_and_up_proj(x)
|
|
x1, x2 = gate_and_up_proj.chunk(2, dim=2)
|
|
down_proj = self.down_proj(x1 * self.act_fn(x2))
|
|
return down_proj
|
|
|
|
class MoELRUCache(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gpu_cache = OrderedDict()
|
|
self.cpu_cache = OrderedDict()
|
|
self.offload_stream = torch.cuda.Stream()
|
|
self.load_stream = torch.cuda.Stream()
|
|
|
|
self.last_offload_event = None
|
|
self._loop = asyncio.new_event_loop()
|
|
self._gpu_sem = asyncio.Semaphore(2)
|
|
self.operations = None
|
|
self.dtype = None
|
|
threading.Thread(target=self._loop.run_forever, daemon=True).start()
|
|
|
|
async def _async_offload_to_cpu(self, layer_idx):
|
|
# async offload from gpu (removed)
|
|
|
|
async with self._gpu_sem:
|
|
num_experts = 64
|
|
moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i])
|
|
for i in range(num_experts)
|
|
if (layer_idx * num_experts + i) in self.gpu_cache]
|
|
event = torch.cuda.Event()
|
|
|
|
with torch.cuda.stream(self.offload_stream):
|
|
for index, moe in moe_group:
|
|
moe_cpu = HunyuanMLP(moe.config, device="cpu", dtype=self.dtype, operations=self.operations)
|
|
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
|
|
if p_gpu.device.type == "meta":
|
|
continue
|
|
with torch.no_grad():
|
|
p_cpu.data = torch.empty_like(p_gpu, device="cpu", dtype = self.dtype, pin_memory=True)
|
|
p_cpu.copy_(p_gpu, non_blocking=True)
|
|
|
|
self.cpu_cache[index] = moe_cpu
|
|
|
|
self.offload_stream.record_event(event)
|
|
|
|
self.last_offload_event = event
|
|
|
|
def finalize_offload_layer():
|
|
event.synchronize()
|
|
for index, moe in moe_group:
|
|
moe.to("meta")
|
|
self.gpu_cache.pop(index, None)
|
|
del moe
|
|
torch.cuda.empty_cache()
|
|
|
|
self.cache._loop.call_soon_threadsafe(finalize_offload_layer)
|
|
|
|
async def _async_load_to_gpu(self, index, moe):
|
|
|
|
# if enough memory load, otherwise wait for offload
|
|
while True:
|
|
free_bytes, _ = torch.cuda.mem_get_info()
|
|
if free_bytes > 2 * MOE_LAYER_SIZE:
|
|
break
|
|
|
|
torch.cuda.empty_cache()
|
|
await asyncio.sleep(0.01)
|
|
|
|
# async loading from cpu -> gpu
|
|
with torch.no_grad():
|
|
with torch.cuda.stream(self.load_stream):
|
|
moe_gpu = HunyuanMLP(moe.config, device="meta", dtype=self.dtype, operations=self.operations)
|
|
for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()):
|
|
p_gpu.data.copy_(p_cpu, non_blocking=True)
|
|
|
|
def finalize_load():
|
|
self.gpu_cache[index] = moe_gpu
|
|
self.cpu_cache.pop(index, None)
|
|
|
|
self.cache._loop.call_soon_threadsafe(finalize_load)
|
|
|
|
def add_cpu(self, moe, index):
|
|
moe_cpu = moe.to("cpu")
|
|
|
|
for _, p in moe_cpu.named_parameters():
|
|
if not p.is_pinned():
|
|
if p.device.type == "cpu":
|
|
p.data = p.data.pin_memory()
|
|
else:
|
|
return
|
|
|
|
self.cpu_cache[index] = moe_cpu
|
|
self.cpu_cache.move_to_end(index)
|
|
|
|
def parse_layer_expert(key):
|
|
parts = key.split(".")
|
|
layer = int(parts[2])
|
|
expert = int(parts[5])
|
|
return layer, expert
|
|
|
|
class LazyMoELoader(nn.Module):
|
|
def __init__(self, cache, config, max_workers = 16, max_concurrent_loads = 32):
|
|
super().__init__()
|
|
self.cache = cache
|
|
self.config = config
|
|
self._loop = cache._loop
|
|
self.expert_key_index = self.index_safetensors()
|
|
self._checkpoint = self.get_checkpoint()
|
|
self._file = safe_open(self._checkpoint, framework="pt", device="cpu", mmap=True)
|
|
self.expert_pool = self.build_meta_experts()
|
|
|
|
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
|
self._semaphore = asyncio.Semaphore(max_concurrent_loads)
|
|
|
|
def build_meta_experts(self):
|
|
pool = {}
|
|
for layer, experts in self.expert_key_index.items():
|
|
pool[layer] = {}
|
|
for expert in experts:
|
|
pool[layer][expert] = HunyuanMLP(
|
|
self.config,
|
|
layer_idx=layer,
|
|
device="meta",
|
|
dtype = self.dtype,
|
|
operations = self.operations
|
|
)
|
|
return pool
|
|
|
|
def get_checkpoint(self):
|
|
comfyui_dir = Path.home() / "ComfyUI"
|
|
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
|
|
checkpoint = checkpoint.resolve()
|
|
if not os.path.exists(checkpoint):
|
|
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
|
|
return checkpoint
|
|
|
|
def index_safetensors(self):
|
|
checkpoint = self.get_checkpoint()
|
|
index = {}
|
|
with safe_open(checkpoint, framework="pt", device="cpu") as f:
|
|
for k in f.keys():
|
|
if "experts." in k:
|
|
layer, expert = parse_layer_expert(k)
|
|
index.setdefault(layer, {}).setdefault(expert, []).append(k)
|
|
return index
|
|
|
|
def lazy_init(self, layer_idx, expert_idx):
|
|
keys = self.expert_key_index[layer_idx][expert_idx]
|
|
model = self.expert_pool[layer_idx][expert_idx]
|
|
|
|
def strip_expert_prefix(k):
|
|
return k.split(f"experts.{expert_idx}.", 1)[1]
|
|
|
|
sd = { strip_expert_prefix(k): self._file.get_tensor(k) for k in keys }
|
|
|
|
for name, tensor in sd.items():
|
|
getattr(model, name).data = tensor
|
|
return model
|
|
|
|
async def lazy_load_from_disk(self, layer_idx, expert_idx):
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
|
|
|
|
def _schedule_disk_load(self, layer_idx, expert_idx):
|
|
|
|
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
|
|
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
|
|
def _on_disk_loaded(fut):
|
|
moe_cpu = fut.result()
|
|
def _add_cpu_in_main_thread():
|
|
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
|
|
|
|
asyncio.run_coroutine_threadsafe(
|
|
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
|
|
self.cache._loop
|
|
)
|
|
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
|
|
|
|
future.add_done_callback(_on_disk_loaded)
|
|
return future
|
|
|
|
def enough_vram(required_bytes):
|
|
free, total = torch.cuda.mem_get_info()
|
|
return free > required_bytes
|
|
|
|
class HunyuanMoE(nn.Module):
|
|
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.moe_topk = 8
|
|
self.num_experts = 64
|
|
self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
|
|
self.gate = HunyuanTopKGate(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
|
|
if INIT_MOE:
|
|
self.experts = nn.ModuleList(
|
|
[HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) for _ in range(self.num_experts)]
|
|
)
|
|
else:
|
|
self.experts = []
|
|
self.moe_lru = moe_lru
|
|
|
|
def forward(self, hidden_states):
|
|
# do the forward statement over the already loaded experts to give time for the pending experts
|
|
# makes the gpu not sit idle
|
|
if not INIT_MOE:
|
|
torch.cuda.set_device(0)
|
|
else:
|
|
torch.cuda.set_device(hidden_states.device.index)
|
|
bsz, seq_len, hidden_size = hidden_states.shape
|
|
|
|
hidden_states_mlp = self.shared_mlp(hidden_states)
|
|
|
|
reshaped_input = hidden_states.reshape(-1, hidden_size)
|
|
|
|
with torch.cuda.nvtx.range("MoE"):
|
|
combine_weights, dispatch_mask = self.gate(hidden_states)
|
|
|
|
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(reshaped_input), reshaped_input)
|
|
device = hidden_states.device
|
|
dtype = reshaped_input.dtype
|
|
|
|
used_mask = (dispatch_mask.sum(dim=(0, 2)) > 0)
|
|
used_indices = used_mask.nonzero(as_tuple=False).squeeze(1).tolist()
|
|
|
|
combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype)
|
|
|
|
if len(used_indices) == 0:
|
|
pass
|
|
else:
|
|
tokens_padded = dispatched_input[used_indices]
|
|
|
|
def compute_expert_outputs(experts_list, tokens_padded, device):
|
|
l1, l2 = [], []
|
|
for m in experts_list:
|
|
l1.append(m.gate_and_up_proj)
|
|
l2.append(m.down_proj)
|
|
|
|
W1 = torch.stack([m.weight.to(device) for m in l1], dim=0)
|
|
W1_T = W1.transpose(1, 2)
|
|
x = torch.bmm(tokens_padded.to(device), W1_T)
|
|
x1, x2 = x.chunk(2, dim=2)
|
|
gated = x1 * F.silu(x2)
|
|
|
|
W2 = torch.stack([m.weight.to(device) for m in l2], dim=0)
|
|
W2_T = W2.transpose(1, 2)
|
|
out_padded = torch.bmm(gated, W2_T)
|
|
return out_padded
|
|
|
|
out_parts = {}
|
|
ready_indices, pending_indices, pending_futures = [], [], {}
|
|
|
|
for i in used_indices:
|
|
expert = self.experts[i]
|
|
if isinstance(expert, concurrent.futures.Future) or isinstance(expert, asyncio.Future):
|
|
if expert.done():
|
|
self.experts[i] = expert.result()
|
|
ready_indices.append(i)
|
|
else:
|
|
pending_indices.append(i)
|
|
pending_futures[i] = expert
|
|
else:
|
|
ready_indices.append(i)
|
|
ready_pos = [used_indices.index(i) for i in ready_indices]
|
|
pending_pos = [used_indices.index(i) for i in pending_indices]
|
|
|
|
if ready_indices:
|
|
ready_experts = [self.experts[i] if not (isinstance(self.experts[i], concurrent.futures.Future) or isinstance(self.experts[i], asyncio.Future))
|
|
else self.experts[i].result()
|
|
for i in ready_indices]
|
|
tokens_for_ready = tokens_padded[ready_pos]
|
|
out_ready = compute_expert_outputs(ready_experts, tokens_for_ready, device)
|
|
for idx_pos, expert_idx in enumerate(ready_indices):
|
|
out_parts[expert_idx] = out_ready[idx_pos:idx_pos+1]
|
|
|
|
for i in pending_indices:
|
|
expert = self.experts[i]
|
|
if isinstance(expert, asyncio.Future):
|
|
loaded_expert = expert.result()
|
|
self.experts[i] = loaded_expert
|
|
|
|
if pending_indices:
|
|
pending_experts = [self.experts[i] for i in pending_indices]
|
|
tokens_for_pending = tokens_padded[pending_pos]
|
|
out_pending = compute_expert_outputs(pending_experts, tokens_for_pending, device)
|
|
for idx_pos, expert_idx in enumerate(pending_indices):
|
|
out_parts[expert_idx] = out_pending[idx_pos:idx_pos+1]
|
|
|
|
out_list_ordered = [out_parts[i] for i in used_indices]
|
|
out_padded_all = torch.cat(out_list_ordered, dim=0)
|
|
|
|
combine_weights = combine_weights.to(out_padded_all.dtype)
|
|
combined_output = torch.einsum("suc,uco->so", combine_weights, out_padded_all)
|
|
|
|
del out_padded_all, out_list_ordered, out_parts
|
|
|
|
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
|
|
|
output = hidden_states_mlp + combined_output
|
|
|
|
return output
|
|
|
|
class HunyuanImage3Attention(nn.Module):
|
|
def __init__(self, config, layer_idx: int, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.attention_type = 'self'
|
|
|
|
self.hidden_size = config["hidden_size"]
|
|
self.num_heads = config["num_attention_heads"]
|
|
self.head_dim = config["attention_head_dim"]
|
|
self.num_key_value_heads = 8
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.max_position_embeddings = config["max_position_embeddings"]
|
|
self.rope_theta = 10000.0
|
|
self.is_causal = True
|
|
self.hidden_size_q = self.head_dim * self.num_heads
|
|
self.hidden_size_kv = self.head_dim * self.num_key_value_heads
|
|
|
|
# define layers
|
|
self.qkv_proj = operations.Linear(
|
|
self.hidden_size,
|
|
self.hidden_size_q + 2 * self.hidden_size_kv,
|
|
bias=False, device=device, dtype=dtype
|
|
)
|
|
self.o_proj = operations.Linear(self.hidden_size_q, self.hidden_size, bias=False, device=device, dtype=dtype)
|
|
|
|
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
|
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
past_key_value,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
|
**kwargs,
|
|
):
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
qkv_states = self.qkv_proj(hidden_states)
|
|
qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2,
|
|
self.head_dim)
|
|
query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3)
|
|
|
|
query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
cos, sin = custom_pos_emb
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
query_states = self.query_layernorm(query_states)
|
|
key_states = self.key_layernorm(key_states)
|
|
|
|
query_states = query_states.to(value_states.dtype)
|
|
key_states = key_states.to(value_states.dtype)
|
|
|
|
if past_key_value is not None:
|
|
cache_kwargs = {"cache_position": position_ids}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
query_states = query_states.to(key_states.dtype)
|
|
|
|
key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups)
|
|
value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups)
|
|
|
|
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
query_states = query_states.contiguous()
|
|
key_states = key_states.contiguous()
|
|
value_states = value_states.contiguous()
|
|
|
|
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, mask = attention_mask, skip_reshape=True)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, past_key_value
|
|
|
|
class HunyuanImage3DecoderLayer(nn.Module):
|
|
def __init__(self, config, layer_idx: int, moe_lru=None, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.hidden_size = config["hidden_size"]
|
|
self.layer_idx = layer_idx
|
|
self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations)
|
|
|
|
self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru, device=device, dtype=dtype, operations=operations)
|
|
|
|
self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
|
self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps'], device=device, dtype=dtype)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
use_cache: Optional[bool] = False,
|
|
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
|
**kwargs,
|
|
) -> Tuple[torch.FloatTensor | Any]:
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, past_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
use_cache=use_cache,
|
|
custom_pos_emb=custom_pos_emb,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states, past_key_value)
|
|
|
|
return outputs
|
|
|
|
class HunyuanImage3Model(nn.Module):
|
|
def __init__(self, config, moe_lru=None, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.padding_idx = 128009
|
|
self.vocab_size = 133120
|
|
self.config = config
|
|
self.wte = operations.Embedding(133120, config["hidden_size"], self.padding_idx, device=device, dtype=dtype)
|
|
self.layers = nn.ModuleList(
|
|
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, device=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])]
|
|
)
|
|
|
|
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
|
|
|
self.shared_tensor = None
|
|
self.moe_lru = moe_lru
|
|
self.additional_layers_set = False
|
|
self.moe_loader = LazyMoELoader(self.moe_lru, self.config)
|
|
self.moe_loader.operations = operations
|
|
self.moe_loader.dtype = dtype
|
|
|
|
def forward(
|
|
self,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache = True,
|
|
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
|
mode: str = "gen_image",
|
|
first_step: Optional[bool] = None,
|
|
gen_timestep_scatter_index: Optional[torch.Tensor] = None,
|
|
):
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
next_decoder_cache = None
|
|
next_layers = 0
|
|
additional_layers = torch.cuda.mem_get_info()[0] // (MOE_LAYER_SIZE * 2)
|
|
sparse_interval = max(1, len(self.layers) // additional_layers)
|
|
if len(self.layers[0].mlp.experts) == 0:
|
|
self.layers[0].mlp.experts = [self.moe_loader._schedule_disk_load(0, i) for i in range(64)]
|
|
|
|
for layer_idx, decoder_layer in enumerate(self.layers):
|
|
|
|
# maybe the second layer loading should depend on how much gpu memory is there
|
|
next_layer = layer_idx + 1 if isinstance(self.layers[layer_idx + 1].mlp.experts, list) else layer_idx + 2
|
|
second_next_layer = next_layer + 1 if isinstance(self.layers[layer_idx + 2].mlp.experts, list) else next_layer + 2
|
|
|
|
if next_layer < len(self.layers) and len(self.layers[next_layer].mlp.experts) == 0: # not loaded
|
|
self.layers[next_layer].mlp.experts = [self.moe_loader._schedule_disk_load(next_layer, i) for i in range(64)]
|
|
|
|
if second_next_layer < len(self.layers) and len(self.layers[second_next_layer].mlp.experts) == 0: # not loaded
|
|
self.layers[second_next_layer].mlp.experts = [self.moe_loader._schedule_disk_load(second_next_layer, i) for i in range(64)]
|
|
|
|
if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval:
|
|
self.layers[next_layers].mlp.experts = [self.moe_loader._schedule_disk_load(next_layers, i) for i in range(64)]
|
|
next_layers += 1
|
|
|
|
with torch.no_grad():
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
use_cache=use_cache,
|
|
custom_pos_emb=custom_pos_emb,
|
|
mode=mode,
|
|
first_step=first_step,
|
|
gen_timestep_scatter_index=gen_timestep_scatter_index,
|
|
)
|
|
|
|
if layer_idx >= 0:
|
|
if self.additional_layers_set and layer_idx <= self.additional_layers_set:
|
|
pass
|
|
else:
|
|
asyncio.run_coroutine_threadsafe(
|
|
self.moe_lru._async_offload_to_cpu(layer_idx),
|
|
self.moe_lru._loop
|
|
)
|
|
del self.layers[layer_idx].mlp.experts
|
|
self.layers[layer_idx].mlp.experts = []
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache = layer_outputs[1]
|
|
|
|
next_cache = None
|
|
if use_cache:
|
|
next_cache = next_decoder_cache
|
|
self.additional_layers_set = True
|
|
return tuple(v for v in [hidden_states, next_cache] if v is not None)
|
|
|
|
|
|
class HunyuanImage3ForCausalMM(nn.Module):
|
|
def __init__(self, operations = None, dtype = None, device = None, **kwargs):
|
|
super().__init__()
|
|
config = kwargs
|
|
self.config = config
|
|
factory_kwargs = {"device": device, "dtype": dtype, "operations": operations}
|
|
|
|
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
|
|
self.patch_embed = UNetDown(
|
|
patch_size=1,
|
|
emb_channels=config["hidden_size"],
|
|
in_channels=32,
|
|
hidden_channels=1024,
|
|
out_channels=config["hidden_size"],
|
|
**factory_kwargs
|
|
)
|
|
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
|
|
|
|
self.final_layer = UNetUp(
|
|
patch_size=1,
|
|
emb_channels=config["hidden_size"],
|
|
in_channels=config["hidden_size"],
|
|
hidden_channels=1024,
|
|
out_channels=32,
|
|
out_norm=True,
|
|
**factory_kwargs
|
|
)
|
|
self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
|
|
|
|
self.moe_lru = None
|
|
if not INIT_MOE:
|
|
self.moe_lru = MoELRUCache()
|
|
self.moe_lru.operations = operations
|
|
self.moe_lru.dtype = dtype
|
|
|
|
self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru, **factory_kwargs)
|
|
|
|
self.pad_id = 128009
|
|
self.vocab_size = 133120
|
|
|
|
self.lm_head = operations.Linear(config["hidden_size"], 133120, bias=False, device=device, dtype=dtype)
|
|
self.first_step = True
|
|
|
|
self.kv_cache = None
|
|
self.encode_tok = None
|
|
self.special_tok = None
|
|
|
|
@staticmethod
|
|
def get_pos_emb(custom_pos_emb, position_ids):
|
|
cos, sin = custom_pos_emb
|
|
cos = real_batched_index_select(cos, dim=1, idx=position_ids)
|
|
sin = real_batched_index_select(sin, dim=1, idx=position_ids)
|
|
return cos, sin
|
|
|
|
def ragged_final_layer(self, x, image_mask, timestep, token_h, token_w, first_step):
|
|
bsz, seq_len, n_embd = x.shape
|
|
if first_step:
|
|
image_output = x.masked_select(image_mask.unsqueeze(-1).bool()).reshape(bsz, -1, n_embd)
|
|
else:
|
|
image_output = x[:, 1:, :]
|
|
timestep_emb = self.time_embed_2(timestep)
|
|
pred = self.final_layer(image_output, timestep_emb, token_h, token_w)
|
|
return pred
|
|
|
|
def forward(self, x, condition, timestep, **kwargs):
|
|
|
|
joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
|
|
|
|
bsz, seq_len, n_embd = inputs_embeds.shape
|
|
cond_exists = (joint_image[:, 0, :] != -100.0).any(dim=1).any()
|
|
|
|
height, width = x.size(2) * 16, x.size(3) * 16
|
|
gen_timestep_scatter_index = 4
|
|
|
|
def fn(string, func = self.encode_tok):
|
|
return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\
|
|
.unsqueeze(0).expand(bsz, -1, -1)
|
|
|
|
hw = f"{int(height)}x{int(width)}"
|
|
ratio_idx = [i for i, reso in enumerate(COMPUTED_RESO_GROUPS) if reso == hw][0]
|
|
img_ratio = fn(f"<img_ratio_{ratio_idx}>", self.special_tok)
|
|
|
|
if cond_exists:
|
|
with torch.no_grad():
|
|
joint_image[:, 2:3, :] = img_ratio
|
|
|
|
img_slices = []
|
|
|
|
cond_timestep = torch.zeros(x.size(0))
|
|
t_emb = self.time_embed(timestep)
|
|
|
|
if self.first_step:
|
|
x, token_height, token_width = self.patch_embed(x, t_emb)
|
|
x = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = self.special_tok), img_ratio, fn("<timestep>", self.special_tok), x, fn("<eoi>")], dim = 1)
|
|
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
|
else:
|
|
x, token_height, token_width = self.patch_embed(x, t_emb)
|
|
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
|
|
inputs_embeds = torch.cat([timestep_emb, x], dim=1)
|
|
|
|
input_args = [inputs_embeds, x] if self.first_step else [inputs_embeds]
|
|
|
|
for i in range(x.size(0)):
|
|
gen_offset = seq_len + x.size(1)
|
|
if cond_exists:
|
|
vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0]
|
|
vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1
|
|
vae_start += gen_offset
|
|
vae_end += gen_offset
|
|
|
|
vit_start = vae_end + 1
|
|
vit_end = joint_image.size(1) - 1 + gen_offset
|
|
|
|
joint_slices_i = [
|
|
slice(vae_start, vae_end),
|
|
slice(vit_start, vit_end),
|
|
]
|
|
else:
|
|
joint_slices_i = []
|
|
gen_slices_i = [slice(seq_len, gen_offset)]
|
|
img_slices.append(gen_slices_i + joint_slices_i)
|
|
|
|
img_s = img_slices[0]
|
|
rope_img = [(img_s[0], (token_height, token_width))]
|
|
rope_image_info = [rope_img if len(joint_slices_i) == 0 else rope_img + [(img_s[1], (384 // 16, 384 // 16)), (img_s[2], (256 // 16, 256 // 16))]]
|
|
|
|
#/////////////
|
|
# cond_vae_images
|
|
|
|
# cond_timestep_scatter_index
|
|
if cond_exists:
|
|
with torch.no_grad():
|
|
joint_image[:, 3:4, :] = self.timestep_emb(cond_timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
|
|
|
inputs_embeds = torch.cat([*input_args, joint_image], dim = 1)
|
|
else:
|
|
inputs_embeds = torch.cat([*input_args], dim = 1)
|
|
|
|
attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
|
|
for i in range(bsz):
|
|
for _, image_slice in enumerate(img_slices[i]):
|
|
attention_mask[i, image_slice, image_slice] = True
|
|
attention_mask = attention_mask.unsqueeze(1)
|
|
|
|
# pos embed
|
|
position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
|
cos, sin = build_batch_2d_rope(
|
|
image_infos=rope_image_info,
|
|
seq_len=inputs_embeds.shape[1],
|
|
n_elem=128, # head dim
|
|
base=10000.0,
|
|
)
|
|
custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device))
|
|
custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids)
|
|
|
|
if self.kv_cache is None:
|
|
# TODO: should change when higgsv2 gets merged
|
|
self.kv_cache = HunyuanStaticCache(
|
|
config=self.config,
|
|
batch_size=x.size(0) * 2,
|
|
max_cache_len = inputs_embeds.shape[1],
|
|
dtype=x.dtype,
|
|
)
|
|
|
|
outputs = self.model(
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=self.kv_cache,
|
|
inputs_embeds=inputs_embeds,
|
|
custom_pos_emb=custom_pos_emb,
|
|
first_step=self.first_step,
|
|
gen_timestep_scatter_index=gen_timestep_scatter_index,
|
|
)
|
|
hidden_states = outputs[0]
|
|
|
|
# safety no-op
|
|
past_key_value = outputs[1]
|
|
if past_key_value is not None:
|
|
self.kv_cache = past_key_value
|
|
|
|
hidden_states = hidden_states.to(inputs_embeds.device)
|
|
img_mask = torch.zeros(hidden_states.size(1))
|
|
img_mask[seq_len + x.size(1)+4:] = 1; img_mask[-1] = 0
|
|
|
|
diffusion_prediction = self.ragged_final_layer(
|
|
hidden_states, img_mask, timestep, int(token_height), int(token_width), self.first_step)
|
|
|
|
if self.first_step:
|
|
self.first_step = False
|
|
|
|
return diffusion_prediction
|