import os import math import time import torch import psutil import asyncio import threading import torch.nn as nn from pathlib import Path import concurrent.futures from einops import rearrange import torch.nn.functional as F from collections import OrderedDict from safetensors import safe_open from transformers.cache_utils import StaticCache from typing import Optional, Tuple, Any, List, Dict from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock INIT_MOE = torch.cuda.device_count() != 1 if not INIT_MOE: MOE_LAYER_SIZE = (1024**3) * 2.65 # approx torch.cuda.set_device(0) props = torch.cuda.get_device_properties(0) LAYERS_IN_CPU = math.floor((int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')) - psutil.Process(os.getpid()).memory_info().rss - (2*1024**3)) * 0.50) / MOE_LAYER_SIZE) class HunyuanStaticCache(StaticCache): def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: cache_position = cache_kwargs.get("cache_position") if hasattr(self, "key_cache") and hasattr(self, "value_cache"): if self.key_cache[layer_idx].device != key_states.device: self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] key_states = key_states.to(k_out.dtype) value_states = value_states.to(v_out.dtype) else: if self.layers[layer_idx].keys is None: self.layers[layer_idx].lazy_initialization(key_states) k_out = self.layers[layer_idx].keys v_out = self.layers[layer_idx].values if cache_position is None: k_out.copy_(key_states) v_out.copy_(value_states) else: if cache_position.dim() == 1: k_out.index_copy_(2, cache_position, key_states) v_out.index_copy_(2, cache_position, value_states) else: assert cache_position.dim() == 2, f"multiple batch dims not yet {cache_position.shape=}" batch_size, _ = cache_position.shape for i in range(batch_size): unbatched_dim = 1 k_out[i].index_copy_(unbatched_dim, cache_position[i], key_states[i]) v_out[i].index_copy_(unbatched_dim, cache_position[i], value_states[i]) return k_out, v_out def real_batched_index_select(t, dim, idx): return torch.stack([torch.index_select(t[i], dim - 1, idx[i]) for i in range(len(t))]) def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, act_layer=nn.GELU, frequency_embedding_size=256, max_period=10000, out_size=None, dtype=None, device=None ): factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() self.frequency_embedding_size = frequency_embedding_size self.max_period = max_period if out_size is None: out_size = hidden_size self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), act_layer(), nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), ) def forward(self, t): t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) t_emb = self.mlp(t_freq) return t_emb def _to_tuple(x, dim=2): if isinstance(x, int): return (x,) * dim return x def get_meshgrid_nd(start, *args, dim=2): start = _to_tuple(start, dim=dim) stop = _to_tuple(args[0], dim=dim) num = [stop[i] - start[i] for i in range(dim)] num_int = [int(x) for x in num] num = num_int axis_grid = [] for i in range(dim): a, b, n = start[i], stop[i], num[i] g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] axis_grid.append(g) grid = torch.meshgrid(*axis_grid, indexing="ij") grid = torch.stack(grid, dim=0) return grid def build_2d_rope( seq_len: int, n_elem: int, image_infos: Optional[List[Tuple[slice, Tuple[int, int]]]] = None, device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, return_all_pos: bool = False, ): assert n_elem % 4 == 0, f"n_elem must be divisible by 4, but got {n_elem}." # theta if base_rescale_factor != 1.0: base *= base_rescale_factor ** (n_elem / (n_elem - 2)) theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) theta = theta.reshape(1, n_elem // 4, 2) # [1, half_d, 2] # position indices if image_infos is None: image_infos = [] image_infos_list = [image_infos] sample_seq_lens = [seq_len] # Prepare position indices for each sample x_sections = [] y_sections = [] for sample_id, sample_image_infos in enumerate(image_infos_list): last_pos = 0 for sec_slice, (h, w) in sample_image_infos: L = sec_slice.start # start from 0, so image_slice.start is just L # previous text if last_pos < L: y_sections.append(torch.arange(last_pos, L)) x_sections.append(torch.arange(last_pos, L)) elif h is None: # Interleave data has overlapped positions for tokens. y_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) x_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) continue else: # Interleave data has overlapped positions for noised image and the successive clean image, # leading to last_pos (= last text end L + noise w * h) > L (last text end L). pass # current image beta_y = L + (w * h - h) / 2 beta_x = L + (w * h - w) / 2 grid = get_meshgrid_nd((beta_y, beta_x), (beta_y + h, beta_x + w)) # [2, h, w] grid = grid.reshape(2, -1) # (y, x) y_sections.append(grid[0]) x_sections.append(grid[1]) # step last_pos = L + w * h # final text y_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) x_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) x_pos = torch.cat(x_sections).long() y_pos = torch.cat(y_sections).long() # If there are overlap positions, we need to remove them. x_pos = x_pos[:seq_len] y_pos = y_pos[:seq_len] all_pos = torch.stack((y_pos, x_pos), dim=1).unsqueeze(1).to(device) # [seq_len, 1, 2] # calc rope idx_theta = (all_pos * theta).reshape(all_pos.shape[0], n_elem // 2).repeat(1, 2) cos = torch.cos(idx_theta) sin = torch.sin(idx_theta) if return_all_pos: return cos, sin, all_pos return cos, sin def build_batch_2d_rope( seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None, device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, return_all_pos: bool = False, ): cos_list, sin_list, all_pos_list = [], [], [] if image_infos is None: image_infos = [None] for i, image_info in enumerate(image_infos): res = build_2d_rope( seq_len, n_elem, image_infos=image_info, device=device, base=base, base_rescale_factor=base_rescale_factor, return_all_pos=return_all_pos, ) if return_all_pos: cos, sin, all_pos = res else: cos, sin = res all_pos = None cos_list.append(cos) sin_list.append(sin) all_pos_list.append(all_pos) stacked_cos = torch.stack(cos_list, dim=0) stacked_sin = torch.stack(sin_list, dim=0) if return_all_pos: return stacked_cos, stacked_sin, all_pos_list return stacked_cos, stacked_sin def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): if position_ids is not None: cos = cos[position_ids] sin = sin[position_ids] cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def default(val, d): return val if val is not None else d def conv_nd(dims, *args, **kwargs): if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) def normalization(channels, **kwargs): return nn.GroupNorm(32, channels, **kwargs) def topkgating(logits: torch.Tensor, topk: int): logits = logits.float() gates = F.softmax(logits, dim=1) num_experts = int(gates.shape[1]) _, expert_index = torch.topk(gates, topk) expert_mask = F.one_hot(expert_index, num_experts) expert_index_flat = expert_index.flatten() tokens_per_expert = torch.bincount(expert_index_flat, minlength=num_experts) expert_capacity = torch.max(tokens_per_expert).item() gates_s = torch.clamp( torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps ) router_probs = gates / gates_s expert_index = torch.transpose(expert_index, 0, 1) expert_index = expert_index.reshape(-1) expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1 token_priority = token_priority.reshape((topk, -1, num_experts)) token_priority = torch.transpose(token_priority, 0, 1) token_priority = torch.max(token_priority, dim=1)[0] valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity) dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) return combine_weights, dispatch_mask class HunyuanRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class UNetDown(nn.Module): def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, dropout=0.0, device=None, dtype=None): factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() self.patch_size = patch_size assert self.patch_size in [1, 2, 4, 8] self.model = nn.ModuleList( [conv_nd( 2, in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, padding=1, **factory_kwargs )] ) if self.patch_size == 1: self.model.append(ResBlock( channels=hidden_channels, emb_channels=emb_channels, out_channels=out_channels, use_scale_shift_norm = True, dropout=dropout, **factory_kwargs )) else: for i in range(self.patch_size // 2): self.model.append(ResBlock( channels=hidden_channels, emb_channels=emb_channels, use_scale_shift_norm = True, out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels, dropout=dropout, down=True, **factory_kwargs )) def forward(self, x, t): assert x.shape[2] % self.patch_size == 0 and x.shape[3] % self.patch_size == 0 for module in self.model: if isinstance(module, ResBlock): x = module(x, t) else: x = module(x) _, _, token_h, token_w = x.shape x = rearrange(x, 'b c h w -> b (h w) c') return x, token_h, token_w class UNetUp(nn.Module): def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, dropout=0.0, device=None, dtype=None, operations = None, out_norm=False): operations = operations or nn factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() self.patch_size = patch_size assert self.patch_size in [1, 2, 4, 8] self.model = nn.ModuleList() if self.patch_size == 1: self.model.append(ResBlock( channels=in_channels, emb_channels=emb_channels, out_channels=hidden_channels, use_scale_shift_norm = True, dropout=dropout, **factory_kwargs )) else: for i in range(self.patch_size // 2): self.model.append(ResBlock( channels=in_channels if i == 0 else hidden_channels, emb_channels=emb_channels, out_channels=hidden_channels, use_scale_shift_norm = True, dropout=dropout, up=True, **factory_kwargs )) if out_norm: self.model.append(nn.Sequential( normalization(hidden_channels, **factory_kwargs), nn.SiLU(), nn.Conv2d( in_channels=hidden_channels, out_channels=out_channels, kernel_size=3, padding=1, **factory_kwargs ), )) else: self.model.append(nn.Conv2d( in_channels=hidden_channels, out_channels=out_channels, kernel_size=3, padding=1, **factory_kwargs )) # batch_size, seq_len, model_dim def forward(self, x, t, token_h, token_w): x = rearrange(x, 'b (h w) c -> b c h w', h=token_h, w=token_w) for module in self.model: if isinstance(module, ResBlock): x = module(x, t) else: x = module(x) return x class HunyuanTopKGate(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx self.moe_topk = 8 self.min_capacity = 8 num_experts = 64 self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32) def forward(self, hidden_states): bsz, seq_len, hidden_size = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_size) if self.wg.weight.dtype == torch.float32: hidden_states = hidden_states.float() logits = self.wg(hidden_states) gate_output = topkgating(logits, self.moe_topk) return gate_output class HunyuanMLP(nn.Module): def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config["hidden_size"] self.intermediate_size = 3072 self.act_fn = torch.nn.functional.silu self.intermediate_size *= 2 # SwiGLU self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device) self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device) def forward(self, x): self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) if x.ndim == 2: x = x.unsqueeze(0) gate_and_up_proj = self.gate_and_up_proj(x) x1, x2 = gate_and_up_proj.chunk(2, dim=2) down_proj = self.down_proj(x1 * self.act_fn(x2)) return down_proj class MoELRUCache(nn.Module): def __init__(self): super().__init__() self.gpu_cache = OrderedDict() self.cpu_cache = OrderedDict() self.offload_stream = torch.cuda.Stream() self.load_stream = torch.cuda.Stream() self.last_offload_event = None self._loop = asyncio.new_event_loop() threading.Thread(target=self._loop.run_forever, daemon=True).start() async def _async_offload_to_cpu(self, layer_idx): # async offload from gpu (removed) num_experts = 64 moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i]) for i in range(num_experts) if (layer_idx * num_experts + i) in self.gpu_cache] event = torch.cuda.Event() with torch.cuda.stream(self.offload_stream): for index, moe in moe_group: moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True) for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()): if p_gpu.device.type == "meta": continue with torch.no_grad(): p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True) p_cpu.copy_(p_gpu, non_blocking=True) self.cpu_cache[index] = moe_cpu self.offload_stream.record_event(event) self.last_offload_event = event def finalize_offload_layer(): event.synchronize() for index, moe in moe_group: moe.to("meta") self.gpu_cache.pop(index, None) del moe torch.cuda.empty_cache() threading.Thread(target=finalize_offload_layer, daemon=True).start() async def _async_load_to_gpu(self, index, moe): # if enough memory load, otherwise wait for offload while True: free_bytes, _ = torch.cuda.mem_get_info() if free_bytes > 2 * MOE_LAYER_SIZE: break self.last_offload_event.synchronize() torch.cuda.empty_cache() await asyncio.sleep(0.01) # async loading from cpu -> gpu with torch.cuda.stream(self.load_stream): moe_gpu = HunyuanMLP(moe.config).to("cuda", non_blocking=True) for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()): with torch.no_grad(): p_gpu.data = torch.empty_like(p_cpu, device="cuda") p_gpu.copy_(p_cpu, non_blocking=True) def finalize_load(): self.gpu_cache[index] = moe_gpu self.cpu_cache.pop(index, None) threading.Thread(target=finalize_load, daemon=True).start() def add_cpu(self, moe, index): moe_cpu = moe.to("cpu") for _, p in moe_cpu.named_parameters(): if not p.is_pinned(): if p.device.type == "cpu": p.data = p.data.pin_memory() else: return self.cpu_cache[index] = moe_cpu self.cpu_cache.move_to_end(index) class LazyMoELoader(nn.Module): def __init__(self, cache, config): super().__init__() self.cache = cache self.config = config self._loop = cache._loop def get_checkpoint(self): comfyui_dir = Path.home() / "ComfyUI" checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors" checkpoint = checkpoint.resolve() if not os.path.exists(checkpoint): raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}") return checkpoint def lazy_init(self, layer_idx, expert_idx): checkpoint = self.get_checkpoint() prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}." additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight" sd = {} with safe_open(checkpoint, framework="pt", device="cpu") as f: for k in f.keys(): if k.startswith(prefix) or k.startswith(additional_prefix): new_k = k.split(f"experts.{expert_idx}.", 1)[1] sd[new_k] = f.get_tensor(k) model = HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True, device="meta") model.to_empty(device = "cpu") model.load_state_dict(sd) return model async def lazy_load_from_disk(self, layer_idx, expert_idx): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx) def _schedule_disk_load(self, layer_idx, expert_idx): coro = self.lazy_load_from_disk(layer_idx, expert_idx) future = asyncio.run_coroutine_threadsafe(coro, self._loop) def _on_disk_loaded(fut): moe_cpu = fut.result() def _add_cpu_in_main_thread(): self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx) asyncio.run_coroutine_threadsafe( self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu), self.cache._loop ) threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start() future.add_done_callback(_on_disk_loaded) return future def enough_vram(required_bytes): free, total = torch.cuda.mem_get_info() return free > required_bytes class HunyuanMoE(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None): super().__init__() self.config = config self.layer_idx = layer_idx self.moe_topk = 8 self.num_experts = 64 self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True) self.gate = HunyuanTopKGate(config, layer_idx=layer_idx) if INIT_MOE: self.experts = nn.ModuleList( [HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)] ) else: self.experts = [] self.moe_lru = moe_lru def forward(self, hidden_states): if not INIT_MOE: torch.cuda.set_device(0) else: torch.cuda.set_device(hidden_states.device.index) bsz, seq_len, hidden_size = hidden_states.shape hidden_states_mlp = self.shared_mlp(hidden_states) reshaped_input = hidden_states.reshape(-1, hidden_size) with torch.cuda.nvtx.range("MoE"): combine_weights, dispatch_mask = self.gate(hidden_states) dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(reshaped_input), reshaped_input) device = hidden_states.device dtype = reshaped_input.dtype used_mask = (dispatch_mask.sum(dim=(0, 2)) > 0) used_indices = used_mask.nonzero(as_tuple=False).squeeze(1).tolist() combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype) if len(used_indices) == 0: pass else: tokens_padded = dispatched_input[used_indices] l1_layers, l2_layers = [], [] for i in used_indices: expert = self.experts[i] if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): expert = expert.result() expert = expert.to(device) l1_layers.append(expert.gate_and_up_proj) l2_layers.append(expert.down_proj) l1, l2 = [], [] for i in used_indices: expert = self.experts[i] if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): expert = expert.result() expert = expert.to(device) l1.append(expert.gate_and_up_proj) l2.append(expert.down_proj) compute_device = hidden_states.device l1 = [m.to(compute_device) for m in l1] l2 = [m.to(compute_device) for m in l2] W1 = torch.stack([m.weight for m in l1], dim=0) W2 = torch.stack([m.weight for m in l2], dim=0) W1_T = W1.transpose(1, 2) W2_T = W2.transpose(1, 2) while not enough_vram(3*(1024 ** 3)): event = self.moe_lru.last_offload_event if event is not None and not event.query(): time.sleep(0.001) x = torch.bmm(tokens_padded, W1_T) x1, x2 = x.chunk(2, dim=2) gated = x1 * F.silu(x2) out_padded = torch.bmm(gated, W2_T) combine_weights_used = combine_weights[:, used_indices, :] combined_output = torch.einsum("suc,ucm->sm", combine_weights_used.type_as(out_padded), out_padded ) del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded combined_output = combined_output.reshape(bsz, seq_len, hidden_size) output = hidden_states_mlp + combined_output return output class HunyuanImage3Attention(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.attention_type = 'self' self.hidden_size = config["hidden_size"] self.num_heads = config["num_attention_heads"] self.head_dim = config["attention_head_dim"] self.num_key_value_heads = 8 self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config["max_position_embeddings"] self.rope_theta = 10000.0 self.is_causal = True self.hidden_size_q = self.head_dim * self.num_heads self.hidden_size_kv = self.head_dim * self.num_key_value_heads # define layers self.qkv_proj = nn.Linear( self.hidden_size, self.hidden_size_q + 2 * self.hidden_size_kv, bias=False ) self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=False) self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]) self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, past_key_value, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, **kwargs, ): bsz, q_len, _ = hidden_states.size() qkv_states = self.qkv_proj(hidden_states) qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim) query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3) query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = custom_pos_emb query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) query_states = query_states.to(value_states.dtype) key_states = key_states.to(value_states.dtype) if past_key_value is not None: cache_kwargs = {"cache_position": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) query_states = query_states.to(key_states.dtype) key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups) value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups) if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, mask = attention_mask, skip_reshape=True) attn_output = self.o_proj(attn_output) return attn_output, past_key_value class HunyuanImage3DecoderLayer(nn.Module): def __init__(self, config, layer_idx: int, moe_lru=None): super().__init__() self.hidden_size = config["hidden_size"] self.layer_idx = layer_idx self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx) self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru) self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps']) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor | Any]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, past_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, custom_pos_emb=custom_pos_emb, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states, past_key_value) return outputs class HunyuanImage3Model(nn.Module): def __init__(self, config, moe_lru=None): super().__init__() self.padding_idx = 128009 self.vocab_size = 133120 self.config = config self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) self.layers = nn.ModuleList( [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])] ) self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) self.shared_tensor = None self.moe_lru = moe_lru self.self.additional_layers_set = False def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache = True, custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, mode: str = "gen_image", first_step: Optional[bool] = None, gen_timestep_scatter_index: Optional[torch.Tensor] = None, ): if inputs_embeds is None: inputs_embeds = self.wte(input_ids) hidden_states = inputs_embeds next_decoder_cache = None next_layers = 0 additional_layers = torch.cuda.mem_get_info()[0] // (MOE_LAYER_SIZE * 2) sparse_interval = max(1, len(self.layers) // additional_layers) if len(self.layers[0].mlp.experts) == 0: experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)] for layer_idx, decoder_layer in enumerate(self.layers): if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)] if not self.additional_layers_set: if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval: experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] next_layers += 1 with torch.no_grad(): layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, custom_pos_emb=custom_pos_emb, mode=mode, first_step=first_step, gen_timestep_scatter_index=gen_timestep_scatter_index, ) if layer_idx >= 0: if self.additional_layers_set and layer_idx <= self.additional_layers_set: pass else: torch.cuda.synchronize() asyncio.run_coroutine_threadsafe( self.moe_lru._async_offload_to_cpu(layer_idx), self.moe_lru._loop ) self.layers[layer_idx].mlp.experts = [] hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[1] next_cache = None if use_cache: next_cache = next_decoder_cache self.additional_layers_set = True return tuple(v for v in [hidden_states, next_cache] if v is not None) class HunyuanImage3ForCausalMM(nn.Module): def __init__(self, config): super().__init__() self.config = config self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"]) self.patch_embed = UNetDown( patch_size=1, emb_channels=config["hidden_size"], in_channels=32, hidden_channels=1024, out_channels=config["hidden_size"], ) self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"]) self.final_layer = UNetUp( patch_size=1, emb_channels=config["hidden_size"], in_channels=config["hidden_size"], hidden_channels=1024, out_channels=32, out_norm=True, ) self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"]) self.moe_lru = None if not INIT_MOE: self.moe_lru = MoELRUCache() self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru) self.pad_id = 128009 self.vocab_size = 133120 self.lm_head = nn.Linear(config["hidden_size"], 133120, bias=False) self.first_step = True self.kv_cache = None self.token_dims = () @staticmethod def get_pos_emb(custom_pos_emb, position_ids): cos, sin = custom_pos_emb cos = real_batched_index_select(cos, dim=1, idx=position_ids) sin = real_batched_index_select(sin, dim=1, idx=position_ids) return cos, sin def ragged_final_layer(self, x, image_mask, timestep, token_h, token_w, first_step): bsz, seq_len, n_embd = x.shape if first_step: image_output = x.masked_select(image_mask.unsqueeze(-1).bool()).reshape(bsz, -1, n_embd) else: image_output = x[:, 1:, :] timestep_emb = self.time_embed_2(timestep) pred = self.final_layer(image_output, timestep_emb, token_h, token_w) return pred def forward(self, x, condition, timestep, **kwargs): joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() gen_timestep_scatter_index = 4 with torch.no_grad(): joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio if self.first_step: token_height, token_width = x[:, -2:, 0].tolist()[0] self.token_dims = (int(token_height), int(token_width)) x = x[:, :-2, :] else: token_height, token_width = self.token_dims img_slices = [] for i in range(x.size(0)): vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0] vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1 vit_start = vae_end + 1 vit_end = joint_image.size(1) - 1 joint_slices_i = [ slice(vae_start, vae_end), slice(vit_start, vit_end), ] gen_slices_i = [slice(3 + vit_end, x[i].size(0) - 1 + vit_end)] img_slices.append(joint_slices_i + gen_slices_i) img_s = img_slices[0] rope_image_info = [[(img_s[0], (384 // 16, 384 // 16)), (img_s[1], (256 // 16, 256 // 16)), (img_s[2], (token_height, token_width))]] cond_timestep = torch.zeros(inputs_embeds.size(0)) t_emb = self.time_embed(cond_timestep) bsz, seq_len, n_embd = inputs_embeds.shape if self.first_step: x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) else: t_emb = self.time_embed(timestep) x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb) timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) x = torch.cat([timestep_emb, x], dim=1) #///////////// # cond_vae_images # cond_timestep_scatter_index with torch.no_grad(): joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) inputs_embeds = torch.cat([inputs_embeds, joint_image, x], dim = 1) attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) for i in range(bsz): for _, image_slice in enumerate(img_slices[i]): attention_mask[i, image_slice, image_slice] = True attention_mask = attention_mask.unsqueeze(1) # pos embed position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) cos, sin = build_batch_2d_rope( image_infos=rope_image_info, seq_len=inputs_embeds.shape[1], n_elem=self.config["hidden_size"] // self.config["num_attention_heads"], base=10000.0, ) custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device)) custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) if self.kv_cache is None: # TODO: should change when higgsv2 gets merged self.kv_cache = HunyuanStaticCache( config=self.config, batch_size=x.size(0) * 2, max_cache_len = inputs_embeds.shape[1], dtype=x.dtype, ) outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=self.kv_cache, inputs_embeds=inputs_embeds, custom_pos_emb=custom_pos_emb, first_step=self.first_step, gen_timestep_scatter_index=gen_timestep_scatter_index, ) hidden_states = outputs[0] # safety no-op past_key_value = outputs[1] if past_key_value is not None: self.kv_cache = past_key_value hidden_states = hidden_states.to(inputs_embeds.device) img_mask = torch.zeros(hidden_states.size(1)) img_mask[-x.size(1)+4:] = 1; img_mask[-1] = 0 diffusion_prediction = self.ragged_final_layer( hidden_states, img_mask, timestep, int(token_height), int(token_width), self.first_step) if self.first_step: self.first_step = False return diffusion_prediction