import os import math import torch import asyncio import threading import torch.nn as nn from pathlib import Path import concurrent.futures from einops import rearrange import torch.nn.functional as F from collections import OrderedDict from safetensors import safe_open from transformers.cache_utils import StaticCache from concurrent.futures import ThreadPoolExecutor from typing import Optional, Tuple, Any, List, Dict from comfy.ldm.modules.attention import optimized_attention from comfy_extras.nodes_hunyuan_image import COMPUTED_RESO_GROUPS from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock INIT_MOE = torch.cuda.device_count() != 1 MOE_LAYER_SIZE = (1024**3) * 5.15 # approx class HunyuanStaticCache(StaticCache): def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: cache_position = cache_kwargs.get("cache_position") if hasattr(self, "key_cache") and hasattr(self, "value_cache"): if self.key_cache[layer_idx].device != key_states.device: self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] key_states = key_states.to(k_out.dtype) value_states = value_states.to(v_out.dtype) else: if self.layers[layer_idx].keys is None: self.layers[layer_idx].lazy_initialization(key_states) k_out = self.layers[layer_idx].keys v_out = self.layers[layer_idx].values if cache_position is None: k_out.copy_(key_states) v_out.copy_(value_states) else: if cache_position.dim() == 1: k_out.index_copy_(2, cache_position, key_states) v_out.index_copy_(2, cache_position, value_states) else: assert cache_position.dim() == 2, f"multiple batch dims not yet {cache_position.shape=}" batch_size, _ = cache_position.shape for i in range(batch_size): unbatched_dim = 1 k_out[i].index_copy_(unbatched_dim, cache_position[i], key_states[i]) v_out[i].index_copy_(unbatched_dim, cache_position[i], value_states[i]) return k_out, v_out def real_batched_index_select(t, dim, idx): return torch.stack([torch.index_select(t[i], dim - 1, idx[i]) for i in range(len(t))]) def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, act_layer=nn.GELU, frequency_embedding_size=256, max_period=10000, out_size=None, dtype=None, device=None, operations = None ): factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() self.frequency_embedding_size = frequency_embedding_size self.max_period = max_period if out_size is None: out_size = hidden_size self.mlp = nn.Sequential( operations.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), act_layer(), operations.Linear(hidden_size, out_size, bias=True, **factory_kwargs), ) def forward(self, t): t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) t_emb = self.mlp(t_freq) return t_emb def _to_tuple(x, dim=2): if isinstance(x, int): return (x,) * dim return x def get_meshgrid_nd(start, *args, dim=2): start = _to_tuple(start, dim=dim) stop = _to_tuple(args[0], dim=dim) num = [stop[i] - start[i] for i in range(dim)] num_int = [int(x) for x in num] num = num_int axis_grid = [] for i in range(dim): a, b, n = start[i], stop[i], num[i] g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] axis_grid.append(g) grid = torch.meshgrid(*axis_grid, indexing="ij") grid = torch.stack(grid, dim=0) return grid def build_2d_rope( seq_len: int, n_elem: int, image_infos: Optional[List[Tuple[slice, Tuple[int, int]]]] = None, device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, return_all_pos: bool = False, ): assert n_elem % 4 == 0, f"n_elem must be divisible by 4, but got {n_elem}." # theta if base_rescale_factor != 1.0: base *= base_rescale_factor ** (n_elem / (n_elem - 2)) theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) theta = theta.reshape(1, n_elem // 4, 2) # [1, half_d, 2] # position indices if image_infos is None: image_infos = [] image_infos_list = [image_infos] sample_seq_lens = [seq_len] x_sections = [] y_sections = [] for sample_id, sample_image_infos in enumerate(image_infos_list): last_pos = 0 for sec_slice, (h, w) in sample_image_infos: L = sec_slice.start # start from 0, so image_slice.start is just L # previous text if last_pos < L: y_sections.append(torch.arange(last_pos, L)) x_sections.append(torch.arange(last_pos, L)) elif h is None: y_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) x_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) continue else: pass # current image beta_y = L + (w * h - h) / 2 beta_x = L + (w * h - w) / 2 grid = get_meshgrid_nd((beta_y, beta_x), (beta_y + h, beta_x + w)) # [2, h, w] grid = grid.reshape(2, -1) # (y, x) y_sections.append(grid[0]) x_sections.append(grid[1]) # step last_pos = L + w * h # final text y_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) x_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) x_pos = torch.cat(x_sections).long() y_pos = torch.cat(y_sections).long() # If there are overlap positions, we need to remove them. x_pos = x_pos[:seq_len] y_pos = y_pos[:seq_len] all_pos = torch.stack((y_pos, x_pos), dim=1).unsqueeze(1).to(device) # [seq_len, 1, 2] # calc rope idx_theta = (all_pos * theta).reshape(all_pos.shape[0], n_elem // 2).repeat(1, 2) cos = torch.cos(idx_theta) sin = torch.sin(idx_theta) if return_all_pos: return cos, sin, all_pos return cos, sin def build_batch_2d_rope( seq_len: int, n_elem: int, image_infos = None, device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, return_all_pos: bool = False, ): cos_list, sin_list, all_pos_list = [], [], [] if image_infos is None: image_infos = [None] for i, image_info in enumerate(image_infos): res = build_2d_rope( seq_len, n_elem, image_infos=image_info, device=device, base=base, base_rescale_factor=base_rescale_factor, return_all_pos=return_all_pos, ) if return_all_pos: cos, sin, all_pos = res else: cos, sin = res all_pos = None cos_list.append(cos) sin_list.append(sin) all_pos_list.append(all_pos) stacked_cos = torch.stack(cos_list, dim=0) stacked_sin = torch.stack(sin_list, dim=0) if return_all_pos: return stacked_cos, stacked_sin, all_pos_list return stacked_cos, stacked_sin def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): if position_ids is not None: cos = cos[position_ids] sin = sin[position_ids] cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def default(val, d): return val if val is not None else d def topkgating(logits: torch.Tensor, topk: int): logits = logits.float() gates = F.softmax(logits, dim=1) num_experts = int(gates.shape[1]) _, expert_index = torch.topk(gates, topk) expert_mask = F.one_hot(expert_index, num_experts) expert_index_flat = expert_index.flatten() tokens_per_expert = torch.bincount(expert_index_flat, minlength=num_experts) expert_capacity = torch.max(tokens_per_expert).item() gates_s = torch.clamp( torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps ) router_probs = gates / gates_s expert_index = torch.transpose(expert_index, 0, 1) expert_index = expert_index.reshape(-1) expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1 token_priority = token_priority.reshape((topk, -1, num_experts)) token_priority = torch.transpose(token_priority, 0, 1) token_priority = torch.max(token_priority, dim=1)[0] valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity) dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) return combine_weights, dispatch_mask class HunyuanRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size, device=device, dtype=dtype)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class UNetDown(nn.Module): def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, dropout=0.0, device=None, dtype=None, operations=None): factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() self.patch_size = patch_size assert self.patch_size in [1, 2, 4, 8] self.model = nn.ModuleList( [operations.Conv2d( in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, padding=1, **factory_kwargs )] ) if self.patch_size == 1: self.model.append(ResBlock( channels=hidden_channels, emb_channels=emb_channels, out_channels=out_channels, use_scale_shift_norm = True, dropout=dropout, **factory_kwargs, operations = operations )) else: for i in range(self.patch_size // 2): self.model.append(ResBlock( channels=hidden_channels, emb_channels=emb_channels, use_scale_shift_norm = True, out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels, dropout=dropout, down=True, **factory_kwargs, operations = operations )) def forward(self, x, t): assert x.shape[2] % self.patch_size == 0 and x.shape[3] % self.patch_size == 0 for module in self.model: if isinstance(module, ResBlock): x = module(x, t) else: x = module(x) _, _, token_h, token_w = x.shape x = rearrange(x, 'b c h w -> b (h w) c') return x, token_h, token_w class UNetUp(nn.Module): def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, dropout=0.0, device=None, dtype=None, operations = None, out_norm=False): operations = operations or nn factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() self.patch_size = patch_size assert self.patch_size in [1, 2, 4, 8] self.model = nn.ModuleList() if self.patch_size == 1: self.model.append(ResBlock( channels=in_channels, emb_channels=emb_channels, out_channels=hidden_channels, use_scale_shift_norm = True, dropout=dropout, **factory_kwargs, operations = operations )) else: for i in range(self.patch_size // 2): self.model.append(ResBlock( channels=in_channels if i == 0 else hidden_channels, emb_channels=emb_channels, out_channels=hidden_channels, use_scale_shift_norm = True, dropout=dropout, up=True, **factory_kwargs, operations = operations )) if out_norm: self.model.append(nn.Sequential( operations.GroupNorm(32, hidden_channels, **factory_kwargs), nn.SiLU(), operations.Conv2d( in_channels=hidden_channels, out_channels=out_channels, kernel_size=3, padding=1, **factory_kwargs ), )) else: self.model.append(operations.Conv2d( in_channels=hidden_channels, out_channels=out_channels, kernel_size=3, padding=1, **factory_kwargs )) # batch_size, seq_len, model_dim def forward(self, x, t, token_h, token_w): x = rearrange(x, 'b (h w) c -> b c h w', h=token_h, w=token_w) for module in self.model: if isinstance(module, ResBlock): x = module(x, t) else: x = module(x) return x class HunyuanTopKGate(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None, device=None, dtype=None, operations=None): super().__init__() self.config = config self.layer_idx = layer_idx self.moe_topk = 8 self.min_capacity = 8 num_experts = 64 self.wg = operations.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32, device=device) def forward(self, hidden_states): bsz, seq_len, hidden_size = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_size) if self.wg.weight.dtype == torch.float32: hidden_states = hidden_states.float() logits = self.wg(hidden_states) gate_output = topkgating(logits, self.moe_topk) return gate_output class HunyuanMLP(nn.Module): def __init__(self, config, layer_idx=None, device=None, dtype=None, operations=None): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config["hidden_size"] self.intermediate_size = 3072 self.act_fn = torch.nn.functional.silu self.intermediate_size *= 2 # SwiGLU self.gate_and_up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device, dtype=dtype) self.down_proj = operations.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device, dtype=dtype) def forward(self, x): self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) if x.ndim == 2: x = x.unsqueeze(0) gate_and_up_proj = self.gate_and_up_proj(x) x1, x2 = gate_and_up_proj.chunk(2, dim=2) down_proj = self.down_proj(x1 * self.act_fn(x2)) return down_proj class MoELRUCache(nn.Module): def __init__(self): super().__init__() self.gpu_cache = OrderedDict() self.cpu_cache = OrderedDict() self.offload_stream = torch.cuda.Stream() self.load_stream = torch.cuda.Stream() self.last_offload_event = None self._loop = asyncio.new_event_loop() self._gpu_sem = asyncio.Semaphore(2) self.operations = None self.dtype = None threading.Thread(target=self._loop.run_forever, daemon=True).start() async def _async_offload_to_cpu(self, layer_idx): # async offload from gpu (removed) async with self._gpu_sem: num_experts = 64 moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i]) for i in range(num_experts) if (layer_idx * num_experts + i) in self.gpu_cache] event = torch.cuda.Event() with torch.cuda.stream(self.offload_stream): for index, moe in moe_group: moe_cpu = HunyuanMLP(moe.config, device="cpu", dtype=self.dtype, operations=self.operations) for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()): if p_gpu.device.type == "meta": continue with torch.no_grad(): p_cpu.data = torch.empty_like(p_gpu, device="cpu", dtype = self.dtype, pin_memory=True) p_cpu.copy_(p_gpu, non_blocking=True) self.cpu_cache[index] = moe_cpu self.offload_stream.record_event(event) self.last_offload_event = event def finalize_offload_layer(): event.synchronize() for index, moe in moe_group: moe.to("meta") self.gpu_cache.pop(index, None) del moe torch.cuda.empty_cache() self.cache._loop.call_soon_threadsafe(finalize_offload_layer) async def _async_load_to_gpu(self, index, moe): # if enough memory load, otherwise wait for offload while True: free_bytes, _ = torch.cuda.mem_get_info() if free_bytes > 2 * MOE_LAYER_SIZE: break torch.cuda.empty_cache() await asyncio.sleep(0.01) # async loading from cpu -> gpu with torch.no_grad(): with torch.cuda.stream(self.load_stream): moe_gpu = HunyuanMLP(moe.config, device="meta", dtype=self.dtype, operations=self.operations) for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()): p_gpu.data.copy_(p_cpu, non_blocking=True) def finalize_load(): self.gpu_cache[index] = moe_gpu self.cpu_cache.pop(index, None) self.cache._loop.call_soon_threadsafe(finalize_load) def add_cpu(self, moe, index): moe_cpu = moe.to("cpu") for _, p in moe_cpu.named_parameters(): if not p.is_pinned(): if p.device.type == "cpu": p.data = p.data.pin_memory() else: return self.cpu_cache[index] = moe_cpu self.cpu_cache.move_to_end(index) def parse_layer_expert(key): parts = key.split(".") layer = int(parts[2]) expert = int(parts[5]) return layer, expert class LazyMoELoader(nn.Module): def __init__(self, cache, config, max_workers = 16, max_concurrent_loads = 32): super().__init__() self.cache = cache self.config = config self._loop = cache._loop self.expert_key_index = self.index_safetensors() self._checkpoint = self.get_checkpoint() self._file = safe_open(self._checkpoint, framework="pt", device="cpu", mmap=True) self.expert_pool = self.build_meta_experts() self._executor = ThreadPoolExecutor(max_workers=max_workers) self._semaphore = asyncio.Semaphore(max_concurrent_loads) def build_meta_experts(self): pool = {} for layer, experts in self.expert_key_index.items(): pool[layer] = {} for expert in experts: pool[layer][expert] = HunyuanMLP( self.config, layer_idx=layer, device="meta", dtype = self.dtype, operations = self.operations ) return pool def get_checkpoint(self): comfyui_dir = Path.home() / "ComfyUI" checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors" checkpoint = checkpoint.resolve() if not os.path.exists(checkpoint): raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}") return checkpoint def index_safetensors(self): checkpoint = self.get_checkpoint() index = {} with safe_open(checkpoint, framework="pt", device="cpu") as f: for k in f.keys(): if "experts." in k: layer, expert = parse_layer_expert(k) index.setdefault(layer, {}).setdefault(expert, []).append(k) return index def lazy_init(self, layer_idx, expert_idx): keys = self.expert_key_index[layer_idx][expert_idx] model = self.expert_pool[layer_idx][expert_idx] def strip_expert_prefix(k): return k.split(f"experts.{expert_idx}.", 1)[1] sd = { strip_expert_prefix(k): self._file.get_tensor(k) for k in keys } for name, tensor in sd.items(): getattr(model, name).data = tensor return model async def lazy_load_from_disk(self, layer_idx, expert_idx): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx) def _schedule_disk_load(self, layer_idx, expert_idx): coro = self.lazy_load_from_disk(layer_idx, expert_idx) future = asyncio.run_coroutine_threadsafe(coro, self._loop) def _on_disk_loaded(fut): moe_cpu = fut.result() def _add_cpu_in_main_thread(): self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx) asyncio.run_coroutine_threadsafe( self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu), self.cache._loop ) threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start() future.add_done_callback(_on_disk_loaded) return future def enough_vram(required_bytes): free, total = torch.cuda.mem_get_info() return free > required_bytes class HunyuanMoE(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None, device=None, dtype=None, operations=None): super().__init__() self.config = config self.layer_idx = layer_idx self.moe_topk = 8 self.num_experts = 64 self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) self.gate = HunyuanTopKGate(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) if INIT_MOE: self.experts = nn.ModuleList( [HunyuanMLP(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) for _ in range(self.num_experts)] ) else: self.experts = [] self.moe_lru = moe_lru def forward(self, hidden_states): # do the forward statement over the already loaded experts to give time for the pending experts # makes the gpu not sit idle if not INIT_MOE: torch.cuda.set_device(0) else: torch.cuda.set_device(hidden_states.device.index) bsz, seq_len, hidden_size = hidden_states.shape hidden_states_mlp = self.shared_mlp(hidden_states) reshaped_input = hidden_states.reshape(-1, hidden_size) with torch.cuda.nvtx.range("MoE"): combine_weights, dispatch_mask = self.gate(hidden_states) dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(reshaped_input), reshaped_input) device = hidden_states.device dtype = reshaped_input.dtype used_mask = (dispatch_mask.sum(dim=(0, 2)) > 0) used_indices = used_mask.nonzero(as_tuple=False).squeeze(1).tolist() combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype) if len(used_indices) == 0: pass else: tokens_padded = dispatched_input[used_indices] def compute_expert_outputs(experts_list, tokens_padded, device): l1, l2 = [], [] for m in experts_list: l1.append(m.gate_and_up_proj) l2.append(m.down_proj) W1 = torch.stack([m.weight.to(device) for m in l1], dim=0) W1_T = W1.transpose(1, 2) x = torch.bmm(tokens_padded.to(device), W1_T) x1, x2 = x.chunk(2, dim=2) gated = x1 * F.silu(x2) W2 = torch.stack([m.weight.to(device) for m in l2], dim=0) W2_T = W2.transpose(1, 2) out_padded = torch.bmm(gated, W2_T) return out_padded out_parts = {} ready_indices, pending_indices, pending_futures = [], [], {} for i in used_indices: expert = self.experts[i] if isinstance(expert, concurrent.futures.Future) or isinstance(expert, asyncio.Future): if expert.done(): self.experts[i] = expert.result() ready_indices.append(i) else: pending_indices.append(i) pending_futures[i] = expert else: ready_indices.append(i) ready_pos = [used_indices.index(i) for i in ready_indices] pending_pos = [used_indices.index(i) for i in pending_indices] if ready_indices: ready_experts = [self.experts[i] if not (isinstance(self.experts[i], concurrent.futures.Future) or isinstance(self.experts[i], asyncio.Future)) else self.experts[i].result() for i in ready_indices] tokens_for_ready = tokens_padded[ready_pos] out_ready = compute_expert_outputs(ready_experts, tokens_for_ready, device) for idx_pos, expert_idx in enumerate(ready_indices): out_parts[expert_idx] = out_ready[idx_pos:idx_pos+1] for i in pending_indices: expert = self.experts[i] if isinstance(expert, asyncio.Future): loaded_expert = expert.result() self.experts[i] = loaded_expert if pending_indices: pending_experts = [self.experts[i] for i in pending_indices] tokens_for_pending = tokens_padded[pending_pos] out_pending = compute_expert_outputs(pending_experts, tokens_for_pending, device) for idx_pos, expert_idx in enumerate(pending_indices): out_parts[expert_idx] = out_pending[idx_pos:idx_pos+1] out_list_ordered = [out_parts[i] for i in used_indices] out_padded_all = torch.cat(out_list_ordered, dim=0) combine_weights = combine_weights.to(out_padded_all.dtype) combined_output = torch.einsum("suc,uco->so", combine_weights, out_padded_all) del out_padded_all, out_list_ordered, out_parts combined_output = combined_output.reshape(bsz, seq_len, hidden_size) output = hidden_states_mlp + combined_output return output class HunyuanImage3Attention(nn.Module): def __init__(self, config, layer_idx: int, device=None, dtype=None, operations=None): super().__init__() self.config = config self.layer_idx = layer_idx self.attention_type = 'self' self.hidden_size = config["hidden_size"] self.num_heads = config["num_attention_heads"] self.head_dim = config["attention_head_dim"] self.num_key_value_heads = 8 self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config["max_position_embeddings"] self.rope_theta = 10000.0 self.is_causal = True self.hidden_size_q = self.head_dim * self.num_heads self.hidden_size_kv = self.head_dim * self.num_key_value_heads # define layers self.qkv_proj = operations.Linear( self.hidden_size, self.hidden_size_q + 2 * self.hidden_size_kv, bias=False, device=device, dtype=dtype ) self.o_proj = operations.Linear(self.hidden_size_q, self.hidden_size, bias=False, device=device, dtype=dtype) self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, past_key_value, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, **kwargs, ): bsz, q_len, _ = hidden_states.size() qkv_states = self.qkv_proj(hidden_states) qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim) query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3) query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = custom_pos_emb query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) query_states = query_states.to(value_states.dtype) key_states = key_states.to(value_states.dtype) if past_key_value is not None: cache_kwargs = {"cache_position": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) query_states = query_states.to(key_states.dtype) key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups) value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups) if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, mask = attention_mask, skip_reshape=True) attn_output = self.o_proj(attn_output) return attn_output, past_key_value class HunyuanImage3DecoderLayer(nn.Module): def __init__(self, config, layer_idx: int, moe_lru=None, device=None, dtype=None, operations=None): super().__init__() self.hidden_size = config["hidden_size"] self.layer_idx = layer_idx self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx, device=device, dtype=dtype, operations=operations) self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru, device=device, dtype=dtype, operations=operations) self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"], device=device, dtype=dtype) self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps'], device=device, dtype=dtype) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor | Any]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, past_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, custom_pos_emb=custom_pos_emb, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states, past_key_value) return outputs class HunyuanImage3Model(nn.Module): def __init__(self, config, moe_lru=None, device=None, dtype=None, operations=None): super().__init__() self.padding_idx = 128009 self.vocab_size = 133120 self.config = config self.wte = operations.Embedding(133120, config["hidden_size"], self.padding_idx, device=device, dtype=dtype) self.layers = nn.ModuleList( [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, device=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])] ) self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) self.shared_tensor = None self.moe_lru = moe_lru self.additional_layers_set = False self.moe_loader = LazyMoELoader(self.moe_lru, self.config) self.moe_loader.operations = operations self.moe_loader.dtype = dtype def forward( self, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache = True, custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, mode: str = "gen_image", first_step: Optional[bool] = None, gen_timestep_scatter_index: Optional[torch.Tensor] = None, ): hidden_states = inputs_embeds next_decoder_cache = None next_layers = 0 additional_layers = torch.cuda.mem_get_info()[0] // (MOE_LAYER_SIZE * 2) sparse_interval = max(1, len(self.layers) // additional_layers) if len(self.layers[0].mlp.experts) == 0: self.layers[0].mlp.experts = [self.moe_loader._schedule_disk_load(0, i) for i in range(64)] for layer_idx, decoder_layer in enumerate(self.layers): # maybe the second layer loading should depend on how much gpu memory is there next_layer = layer_idx + 1 if isinstance(self.layers[layer_idx + 1].mlp.experts, list) else layer_idx + 2 second_next_layer = next_layer + 1 if isinstance(self.layers[layer_idx + 2].mlp.experts, list) else next_layer + 2 if next_layer < len(self.layers) and len(self.layers[next_layer].mlp.experts) == 0: # not loaded self.layers[next_layer].mlp.experts = [self.moe_loader._schedule_disk_load(next_layer, i) for i in range(64)] if second_next_layer < len(self.layers) and len(self.layers[second_next_layer].mlp.experts) == 0: # not loaded self.layers[second_next_layer].mlp.experts = [self.moe_loader._schedule_disk_load(second_next_layer, i) for i in range(64)] if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval: self.layers[next_layers].mlp.experts = [self.moe_loader._schedule_disk_load(next_layers, i) for i in range(64)] next_layers += 1 with torch.no_grad(): layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, custom_pos_emb=custom_pos_emb, mode=mode, first_step=first_step, gen_timestep_scatter_index=gen_timestep_scatter_index, ) if layer_idx >= 0: if self.additional_layers_set and layer_idx <= self.additional_layers_set: pass else: asyncio.run_coroutine_threadsafe( self.moe_lru._async_offload_to_cpu(layer_idx), self.moe_lru._loop ) del self.layers[layer_idx].mlp.experts self.layers[layer_idx].mlp.experts = [] hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[1] next_cache = None if use_cache: next_cache = next_decoder_cache self.additional_layers_set = True return tuple(v for v in [hidden_states, next_cache] if v is not None) class HunyuanImage3ForCausalMM(nn.Module): def __init__(self, operations = None, dtype = None, device = None, **kwargs): super().__init__() config = kwargs self.config = config factory_kwargs = {"device": device, "dtype": dtype, "operations": operations} self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs) self.patch_embed = UNetDown( patch_size=1, emb_channels=config["hidden_size"], in_channels=32, hidden_channels=1024, out_channels=config["hidden_size"], **factory_kwargs ) self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs) self.final_layer = UNetUp( patch_size=1, emb_channels=config["hidden_size"], in_channels=config["hidden_size"], hidden_channels=1024, out_channels=32, out_norm=True, **factory_kwargs ) self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs) self.moe_lru = None if not INIT_MOE: self.moe_lru = MoELRUCache() self.moe_lru.operations = operations self.moe_lru.dtype = dtype self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru, **factory_kwargs) self.pad_id = 128009 self.vocab_size = 133120 self.lm_head = operations.Linear(config["hidden_size"], 133120, bias=False, device=device, dtype=dtype) self.first_step = True self.kv_cache = None self.encode_tok = None self.special_tok = None @staticmethod def get_pos_emb(custom_pos_emb, position_ids): cos, sin = custom_pos_emb cos = real_batched_index_select(cos, dim=1, idx=position_ids) sin = real_batched_index_select(sin, dim=1, idx=position_ids) return cos, sin def ragged_final_layer(self, x, image_mask, timestep, token_h, token_w, first_step): bsz, seq_len, n_embd = x.shape if first_step: image_output = x.masked_select(image_mask.unsqueeze(-1).bool()).reshape(bsz, -1, n_embd) else: image_output = x[:, 1:, :] timestep_emb = self.time_embed_2(timestep) pred = self.final_layer(image_output, timestep_emb, token_h, token_w) return pred def forward(self, x, condition, timestep, **kwargs): joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() bsz, seq_len, n_embd = inputs_embeds.shape cond_exists = (joint_image[:, 0, :] != -100.0).any(dim=1).any() height, width = x.size(2) * 16, x.size(3) * 16 gen_timestep_scatter_index = 4 def fn(string, func = self.encode_tok): return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\ .unsqueeze(0).expand(bsz, -1, -1) hw = f"{int(height)}x{int(width)}" ratio_idx = [i for i, reso in enumerate(COMPUTED_RESO_GROUPS) if reso == hw][0] img_ratio = fn(f"", self.special_tok) if cond_exists: with torch.no_grad(): joint_image[:, 2:3, :] = img_ratio img_slices = [] cond_timestep = torch.zeros(x.size(0)) t_emb = self.time_embed(timestep) if self.first_step: x, token_height, token_width = self.patch_embed(x, t_emb) x = torch.cat([fn(""), fn("", func = self.special_tok), img_ratio, fn("", self.special_tok), x, fn("")], dim = 1) x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) else: x, token_height, token_width = self.patch_embed(x, t_emb) timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) inputs_embeds = torch.cat([timestep_emb, x], dim=1) input_args = [inputs_embeds, x] if self.first_step else [inputs_embeds] for i in range(x.size(0)): gen_offset = seq_len + x.size(1) if cond_exists: vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0] vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1 vae_start += gen_offset vae_end += gen_offset vit_start = vae_end + 1 vit_end = joint_image.size(1) - 1 + gen_offset joint_slices_i = [ slice(vae_start, vae_end), slice(vit_start, vit_end), ] else: joint_slices_i = [] gen_slices_i = [slice(seq_len, gen_offset)] img_slices.append(gen_slices_i + joint_slices_i) img_s = img_slices[0] rope_img = [(img_s[0], (token_height, token_width))] rope_image_info = [rope_img if len(joint_slices_i) == 0 else rope_img + [(img_s[1], (384 // 16, 384 // 16)), (img_s[2], (256 // 16, 256 // 16))]] #///////////// # cond_vae_images # cond_timestep_scatter_index if cond_exists: with torch.no_grad(): joint_image[:, 3:4, :] = self.timestep_emb(cond_timestep.reshape(-1)).reshape(bsz, -1, n_embd) inputs_embeds = torch.cat([*input_args, joint_image], dim = 1) else: inputs_embeds = torch.cat([*input_args], dim = 1) attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) for i in range(bsz): for _, image_slice in enumerate(img_slices[i]): attention_mask[i, image_slice, image_slice] = True attention_mask = attention_mask.unsqueeze(1) # pos embed position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) cos, sin = build_batch_2d_rope( image_infos=rope_image_info, seq_len=inputs_embeds.shape[1], n_elem=128, # head dim base=10000.0, ) custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device)) custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) if self.kv_cache is None: # TODO: should change when higgsv2 gets merged self.kv_cache = HunyuanStaticCache( config=self.config, batch_size=x.size(0) * 2, max_cache_len = inputs_embeds.shape[1], dtype=x.dtype, ) outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=self.kv_cache, inputs_embeds=inputs_embeds, custom_pos_emb=custom_pos_emb, first_step=self.first_step, gen_timestep_scatter_index=gen_timestep_scatter_index, ) hidden_states = outputs[0] # safety no-op past_key_value = outputs[1] if past_key_value is not None: self.kv_cache = past_key_value hidden_states = hidden_states.to(inputs_embeds.device) img_mask = torch.zeros(hidden_states.size(1)) img_mask[seq_len + x.size(1)+4:] = 1; img_mask[-1] = 0 diffusion_prediction = self.ragged_final_layer( hidden_states, img_mask, timestep, int(token_height), int(token_width), self.first_step) if self.first_step: self.first_step = False return diffusion_prediction