import os import gc import math import torch import psutil import torch.nn as nn from einops import rearrange import torch.nn.functional as F from collections import OrderedDict from safetensors import safe_open from contextlib import contextmanager from transformers.cache_utils import StaticCache from typing import Optional, Tuple, Any, List, Dict from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock INIT_MOE = torch.cuda.device_count() != 1 if not INIT_MOE: MOE_LAYER_SIZE = (1024**3) * 2.65 # approx CPU_MOE_RATIO = None torch.cuda.set_device(0) props = torch.cuda.get_device_properties(0) INIT_CUDA_MEM = (props.total_memory - torch.cuda.memory_reserved()) * 0.9 ADDITIONAL_LAYERS_IN_GPU = math.floor(INIT_CUDA_MEM / MOE_LAYER_SIZE) class HunyuanStaticCache(StaticCache): def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: cache_position = cache_kwargs.get("cache_position") if hasattr(self, "key_cache") and hasattr(self, "value_cache"): if self.key_cache[layer_idx].device != key_states.device: self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] key_states = key_states.to(k_out.dtype) value_states = value_states.to(v_out.dtype) else: if self.layers[layer_idx].keys is None: self.layers[layer_idx].lazy_initialization(key_states) k_out = self.layers[layer_idx].keys v_out = self.layers[layer_idx].values if cache_position is None: k_out.copy_(key_states) v_out.copy_(value_states) else: if cache_position.dim() == 1: k_out.index_copy_(2, cache_position, key_states) v_out.index_copy_(2, cache_position, value_states) else: assert cache_position.dim() == 2, f"multiple batch dims not yet {cache_position.shape=}" batch_size, _ = cache_position.shape for i in range(batch_size): unbatched_dim = 1 k_out[i].index_copy_(unbatched_dim, cache_position[i], key_states[i]) v_out[i].index_copy_(unbatched_dim, cache_position[i], value_states[i]) return k_out, v_out def real_batched_index_select(t, dim, idx): return torch.stack([torch.index_select(t[i], dim - 1, idx[i]) for i in range(len(t))]) def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, act_layer=nn.GELU, frequency_embedding_size=256, max_period=10000, out_size=None, dtype=None, device=None ): factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() self.frequency_embedding_size = frequency_embedding_size self.max_period = max_period if out_size is None: out_size = hidden_size self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), act_layer(), nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), ) def forward(self, t): t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) t_emb = self.mlp(t_freq) return t_emb def _to_tuple(x, dim=2): if isinstance(x, int): return (x,) * dim return x def get_meshgrid_nd(start, *args, dim=2): start = _to_tuple(start, dim=dim) stop = _to_tuple(args[0], dim=dim) num = [stop[i] - start[i] for i in range(dim)] num_int = [int(x) for x in num] num = num_int axis_grid = [] for i in range(dim): a, b, n = start[i], stop[i], num[i] g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] axis_grid.append(g) grid = torch.meshgrid(*axis_grid, indexing="ij") grid = torch.stack(grid, dim=0) return grid def build_2d_rope( seq_len: int, n_elem: int, image_infos: Optional[List[Tuple[slice, Tuple[int, int]]]] = None, device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, return_all_pos: bool = False, ): assert n_elem % 4 == 0, f"n_elem must be divisible by 4, but got {n_elem}." # theta if base_rescale_factor != 1.0: base *= base_rescale_factor ** (n_elem / (n_elem - 2)) theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) theta = theta.reshape(1, n_elem // 4, 2) # [1, half_d, 2] # position indices if image_infos is None: image_infos = [] image_infos_list = [image_infos] sample_seq_lens = [seq_len] # Prepare position indices for each sample x_sections = [] y_sections = [] for sample_id, sample_image_infos in enumerate(image_infos_list): last_pos = 0 for sec_slice, (h, w) in sample_image_infos: L = sec_slice.start # start from 0, so image_slice.start is just L # previous text if last_pos < L: y_sections.append(torch.arange(last_pos, L)) x_sections.append(torch.arange(last_pos, L)) elif h is None: # Interleave data has overlapped positions for tokens. y_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) x_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) continue else: # Interleave data has overlapped positions for noised image and the successive clean image, # leading to last_pos (= last text end L + noise w * h) > L (last text end L). pass # current image beta_y = L + (w * h - h) / 2 beta_x = L + (w * h - w) / 2 grid = get_meshgrid_nd((beta_y, beta_x), (beta_y + h, beta_x + w)) # [2, h, w] grid = grid.reshape(2, -1) # (y, x) y_sections.append(grid[0]) x_sections.append(grid[1]) # step last_pos = L + w * h # final text y_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) x_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) x_pos = torch.cat(x_sections).long() y_pos = torch.cat(y_sections).long() # If there are overlap positions, we need to remove them. x_pos = x_pos[:seq_len] y_pos = y_pos[:seq_len] all_pos = torch.stack((y_pos, x_pos), dim=1).unsqueeze(1).to(device) # [seq_len, 1, 2] # calc rope idx_theta = (all_pos * theta).reshape(all_pos.shape[0], n_elem // 2).repeat(1, 2) cos = torch.cos(idx_theta) sin = torch.sin(idx_theta) if return_all_pos: return cos, sin, all_pos return cos, sin def build_batch_2d_rope( seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None, device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, return_all_pos: bool = False, ): cos_list, sin_list, all_pos_list = [], [], [] if image_infos is None: image_infos = [None] for i, image_info in enumerate(image_infos): res = build_2d_rope( seq_len, n_elem, image_infos=image_info, device=device, base=base, base_rescale_factor=base_rescale_factor, return_all_pos=return_all_pos, ) if return_all_pos: cos, sin, all_pos = res else: cos, sin = res all_pos = None cos_list.append(cos) sin_list.append(sin) all_pos_list.append(all_pos) stacked_cos = torch.stack(cos_list, dim=0) stacked_sin = torch.stack(sin_list, dim=0) if return_all_pos: return stacked_cos, stacked_sin, all_pos_list return stacked_cos, stacked_sin def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): if position_ids is not None: cos = cos[position_ids] sin = sin[position_ids] cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def default(val, d): return val if val is not None else d def conv_nd(dims, *args, **kwargs): if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) def normalization(channels, **kwargs): return nn.GroupNorm(32, channels, **kwargs) def topkgating( logits: torch.Tensor, topk: int, norm_topk_prob: bool = True, ): logits = logits.float() gates = F.softmax(logits, dim=1) extra = ADDITIONAL_LAYERS_IN_GPU values_all, indices_all = torch.topk(gates, topk + extra, dim=1) expert_weight = values_all[:, :topk] expert_index = indices_all[:, :topk] _, cpu_expert_index = torch.topk(gates, int(CPU_MOE_RATIO * 64), dim = 1) cpu_expert_index = cpu_expert_index[:, (8 + ADDITIONAL_LAYERS_IN_GPU):] if norm_topk_prob and topk > 1: denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps) expert_weight = expert_weight / denom return expert_weight, expert_index, cpu_expert_index, indices_all class HunyuanRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class UNetDown(nn.Module): def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, dropout=0.0, device=None, dtype=None): factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() self.patch_size = patch_size assert self.patch_size in [1, 2, 4, 8] self.model = nn.ModuleList( [conv_nd( 2, in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, padding=1, **factory_kwargs )] ) if self.patch_size == 1: self.model.append(ResBlock( in_channels=hidden_channels, emb_channels=emb_channels, out_channels=out_channels, dropout=dropout, **factory_kwargs )) else: for i in range(self.patch_size // 2): self.model.append(ResBlock( in_channels=hidden_channels, emb_channels=emb_channels, out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels, dropout=dropout, down=True, **factory_kwargs )) def forward(self, x, t): assert x.shape[2] % self.patch_size == 0 and x.shape[3] % self.patch_size == 0 for module in self.model: if isinstance(module, ResBlock): x = module(x, t) else: x = module(x) _, _, token_h, token_w = x.shape x = rearrange(x, 'b c h w -> b (h w) c') return x, token_h, token_w class UNetUp(nn.Module): def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, dropout=0.0, device=None, dtype=None, operations = None, out_norm=False): operations = operations or nn factory_kwargs = {'dtype': dtype, 'device': device} super().__init__() self.patch_size = patch_size assert self.patch_size in [1, 2, 4, 8] self.model = nn.ModuleList() if self.patch_size == 1: self.model.append(ResBlock( in_channels=in_channels, emb_channels=emb_channels, out_channels=hidden_channels, dropout=dropout, **factory_kwargs )) else: for i in range(self.patch_size // 2): self.model.append(ResBlock( in_channels=in_channels if i == 0 else hidden_channels, emb_channels=emb_channels, out_channels=hidden_channels, dropout=dropout, up=True, **factory_kwargs )) if out_norm: self.model.append(nn.Sequential( normalization(hidden_channels, **factory_kwargs), nn.SiLU(), nn.Conv2d( in_channels=hidden_channels, out_channels=out_channels, kernel_size=3, padding=1, **factory_kwargs ), )) else: self.model.append(nn.Conv2d( in_channels=hidden_channels, out_channels=out_channels, kernel_size=3, padding=1, **factory_kwargs )) # batch_size, seq_len, model_dim def forward(self, x, t, token_h, token_w): x = rearrange(x, 'b (h w) c -> b c h w', h=token_h, w=token_w) for module in self.model: if isinstance(module, ResBlock): x = module(x, t) else: x = module(x) return x class HunyuanTopKGate(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx self.moe_topk = 8 self.min_capacity = 8 num_experts = 64 self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32) self.norm_topk_prob = True def forward(self, hidden_states): bsz, seq_len, hidden_size = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_size) if self.wg.weight.dtype == torch.float32: hidden_states = hidden_states.float() logits = self.wg(hidden_states) gate_output = topkgating(logits, self.moe_topk, norm_topk_prob=self.norm_topk_prob,) return gate_output class HunyuanMLP(nn.Module): def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config["hidden_size"] self.intermediate_size = 3072 self.act_fn = torch.nn.functional.silu self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) if x.ndim == 2: x = x.unsqueeze(0) self.intermediate_size *= 2 # SwiGLU self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False) def forward(self, x): gate_and_up_proj = self.gate_and_up_proj(x) x1, x2 = gate_and_up_proj.chunk(2, dim=2) down_proj = self.down_proj(x1 * self.act_fn(x2)) return down_proj class MoELRUCache(nn.Module): def __init__(self, cpu_mem: int = 50, safety_buffer_bytes = 3*(1024**3), max_gpu_eviction_attempts = 8): super().__init__() global CPU_MOE_RATIO _, total = torch.cuda.mem_get_info() max_gpu_mem_gb = max((total - 2 * safety_buffer_bytes) / (1024**3), 1) self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3) self.MAX_CPU_MEM = int(cpu_mem * 1024**3) self.gpu_cache = OrderedDict() self.cpu_cache = OrderedDict() self.gpu_mem_usage = 0 self.cpu_mem_usage = 0 # 50% for system and headroom try: self.MAX_CPU_MEM = int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')) - psutil.Process(os.getpid()).memory_info().rss - safety_buffer_bytes) * 0.55 except: self.MAX_CPU_MEM = int(cpu_mem * (1024**3) * 0.5) # TODO ADDITIONAL_LAYERS_IN_CPU = math.floor((50 * (1024**3)) / MOE_LAYER_SIZE) CPU_MOE_RATIO = (min(64 - ADDITIONAL_LAYERS_IN_GPU, ADDITIONAL_LAYERS_IN_CPU)) / 64 self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3) self.SAFETY_BUFFER = int(safety_buffer_bytes) self.MAX_GPU_EVICT_ATTEMPTS = max_gpu_eviction_attempts def _gpu_free_bytes(self): free, total = torch.cuda.mem_get_info() return int(free) def _estimate_size(self, moe): # include parameters + buffers size = 0 for p in moe.parameters(): size += p.numel() * p.element_size() for b in moe.buffers(): size += b.numel() * b.element_size() return int(size) def _evict_until_free(self, required_bytes, max_attempts=16): attempts = 0 while self._gpu_free_bytes() < required_bytes and attempts < max_attempts: evicted = self._evict_from_gpu() if not evicted: break attempts += 1 return self._gpu_free_bytes() >= required_bytes @contextmanager def ensure_headroom(self, required_bytes): safety = getattr(self, "SAFETY_BUFFER", 0) target_free = int(required_bytes + safety) if getattr(self, "_headroom", None) is not None: try: del self._headroom except Exception: pass self._headroom = None ok = self._evict_until_free(target_free) if not ok and self._gpu_free_bytes() < target_free: # last ditch try: torch.cuda.empty_cache() except Exception: pass try: yield finally: if getattr(self, "_headroom", None) is None: try: self._headroom = torch.empty((self._headroom_bytes,), dtype=torch.uint8, device="cuda:0") except Exception: self._headroom = None def add_gpu(self, moe, index, allowed_retries=3): size = self._estimate_size(moe) while self.gpu_mem_usage + size > self.MAX_GPU_MEM: if not self._evict_from_gpu(): break attempts = 0 while self._gpu_free_bytes() < size + self.SAFETY_BUFFER and attempts < self.MAX_GPU_EVICT_ATTEMPTS: if not self._evict_from_gpu(): break attempts += 1 for _ in range(allowed_retries): try: moe_cuda = moe.to("cuda:0") break except RuntimeError as e: if "out of memory" not in str(e).lower(): raise evicted = self._evict_from_gpu() if not evicted: # can't evict raise else: raise RuntimeError("Failed to move expert to GPU after evictions") self.gpu_cache[index] = moe_cuda self.gpu_cache.move_to_end(index) self.gpu_mem_usage += size return def add_cpu(self, moe, index): size = self._estimate_size(moe) while self.cpu_mem_usage + size > self.MAX_CPU_MEM: if not self._evict_from_cpu(): break moe_cpu = moe.to("cpu") self.cpu_cache[index] = moe_cpu self.cpu_cache.move_to_end(index) self.cpu_mem_usage += size def get_from_device(self, index): if index in self.gpu_cache: moe = self.gpu_cache[index] self.gpu_cache.move_to_end(index) return moe if index in self.cpu_cache: moe = self.cpu_cache.pop(index) self.cpu_mem_usage = max(0, self.cpu_mem_usage - self._estimate_size(moe)) try: self.add_gpu(moe, index) return self.gpu_cache[index] except RuntimeError: self.cpu_cache[index] = moe self.cpu_cache.move_to_end(index) self.cpu_mem_usage += self._estimate_size(moe) raise return None # load from disk def _evict_from_gpu(self): if not self.gpu_cache: return False idx, moe = self.gpu_cache.popitem(last=False) size = self._estimate_size(moe) self.gpu_mem_usage = max(0, self.gpu_mem_usage - size) if self.cpu_mem_usage + size <= self.MAX_CPU_MEM: try: moe_cpu = moe.to("cpu") except Exception: # drop the model if cpu is full del moe return True self.cpu_cache[idx] = moe_cpu self.cpu_cache.move_to_end(idx) self.cpu_mem_usage += size return True else: del moe return True def _evict_from_cpu(self): if not self.cpu_cache: return False _, moe = self.cpu_cache.popitem(last=False) size = self._estimate_size(moe) self.cpu_mem_usage = max(0, self.cpu_mem_usage - size) del moe gc.collect() return True class LazyMoELoader(nn.Module): def __init__(self, device): super().__init__() self.device = device def lazy_init(self, config, layer_idx, expert_idx): checkpoint = "./models/checkpoint/hunyuan_image_3.safetensors" if not os.path.exists(checkpoint): raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}") prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}." additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight" sd = {} with safe_open(checkpoint, framework="pt", device=self.device) as f: for k in f.keys(): if k.startswith(prefix) or k.startswith(additional_prefix): new_k = k.split(f"experts.{expert_idx}.", 1)[1] sd[new_k] = f.get_tensor(k) return HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd).to(self.deivce) class HunyuanMoE(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None): super().__init__() self.config = config self.layer_idx = layer_idx self.moe_topk = 8 self.num_experts = 64 self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True) self.gate = HunyuanTopKGate(config, layer_idx=layer_idx) if INIT_MOE: self.experts = nn.ModuleList( [HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)] ) else: self.experts = None self.moe_lru = moe_lru def forward(self, hidden_states): if not INIT_MOE: torch.cuda.set_device(0) else: torch.cuda.set_device(hidden_states.device.index) bsz, seq_len, hidden_size = hidden_states.shape hidden_states_mlp = self.shared_mlp(hidden_states) reshaped_input = hidden_states.reshape(-1, hidden_size) with torch.cuda.nvtx.range("MoE"): expert_weight, expert_index, cpu_expert_index, indices_all = self.gate(hidden_states) if not INIT_MOE: if ADDITIONAL_LAYERS_IN_GPU > 0: additional_expert_index = indices_all[:, expert_index.size(1): expert_index.size(1) + ADDITIONAL_LAYERS_IN_GPU] flat = additional_expert_index.reshape(-1).to("cpu") counts = torch.bincount(flat, minlength=self.num_experts) top_extra = torch.topk(counts, k=min(ADDITIONAL_LAYERS_IN_GPU, (counts>0).sum().item())).indices.tolist() for expert_id in top_extra: if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None: expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id) self.moe_lru.add_gpu(expert_cpu, expert_id + self.layer_idx) if cpu_expert_index is not None and cpu_expert_index.numel() > 0: for expert_id in torch.unique(cpu_expert_index).cpu().tolist(): if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None: expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id) self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx) combined_output = torch.zeros_like(reshaped_input) for e in range(self.num_experts): token_mask = (expert_index == e) if not token_mask.any(): continue token_ids = token_mask.nonzero(as_tuple=False) token_positions = token_ids[:, 0] topk_slot = token_ids[:, 1] tokens = reshaped_input[token_positions] weights = expert_weight[token_positions, topk_slot] if self.experts is not None and INIT_MOE: out = self.experts[e](tokens) elif self.experts is None: expert = self.moe_lru.get_from_device(e + self.layer_idx) if expert is None: expert = LazyMoELoader() out = expert.lazy_init(self.config, self.layer_idx, e)(tokens) self.moe_lru.add_gpu(expert, e + self.layer_idx) else: tokens = tokens.to(next(expert.parameters()).device) out = expert(tokens.view(bsz, -1, hidden_size)) out = out * weights.to(out.device).unsqueeze(-1) combined_output.to(out.device).index_add_(0, token_positions.to(out.device), out.reshape(-1, hidden_size)) #dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) #chunks = dispatched_input.chunk(self.num_experts, dim=0) #expert_outputs = [] #for chunk, expert in zip(chunks, self.experts): # expert_outputs.append(expert(chunk)) #expert_output = torch.cat(expert_outputs, dim=0) #combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output) combined_output = combined_output.reshape(bsz, seq_len, hidden_size) output = hidden_states_mlp + combined_output return output class HunyuanImage3Attention(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.attention_type = 'self' self.hidden_size = config["hidden_size"] self.num_heads = config["num_attention_heads"] self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = 8 self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config["max_position_embeddings"] self.rope_theta = 10000.0 self.is_causal = True self.hidden_size_q = self.head_dim * self.num_heads self.hidden_size_kv = self.head_dim * self.num_key_value_heads # define layers self.qkv_proj = nn.Linear( self.hidden_size, self.hidden_size_q + 2 * self.hidden_size_kv, bias=False ) self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=False) self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]) self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, **kwargs, ): bsz, q_len, _ = hidden_states.size() qkv_states = self.qkv_proj(hidden_states) qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim) query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3) query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = custom_pos_emb query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) query_states = query_states.to(value_states.dtype) key_states = key_states.to(value_states.dtype) key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups) value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups) if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, mask = attention_mask, skip_reshape=True) attn_output = self.o_proj(attn_output) return attn_output class HunyuanImage3DecoderLayer(nn.Module): def __init__(self, config, layer_idx: int, moe_lru=None): super().__init__() self.hidden_size = config["hidden_size"] self.layer_idx = layer_idx self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx) self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru) self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps']) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor | Any]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, custom_pos_emb=custom_pos_emb, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) return outputs class HunyuanImage3Model(nn.Module): def __init__(self, config, moe_lru=None): super().__init__(config) self.padding_idx = 128009 self.vocab_size = 133120 self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) self.layers = nn.ModuleList( [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])] ) self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) self.shared_tensor = None def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache = True, custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, mode: str = "gen_image", first_step: Optional[bool] = None, gen_timestep_scatter_index: Optional[torch.Tensor] = None, ): if inputs_embeds is None: inputs_embeds = self.wte(input_ids) hidden_states = inputs_embeds next_decoder_cache = None for layer_idx, decoder_layer in enumerate(self.layers): layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, custom_pos_emb=custom_pos_emb, mode=mode, first_step=first_step, gen_timestep_scatter_index=gen_timestep_scatter_index, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[1] next_cache = None if use_cache: next_cache = next_decoder_cache return tuple(v for v in [hidden_states, next_cache] if v is not None) class HunyuanImage3ForCausalMM(nn.Module): def __init__(self, config): super().__init__(config) self.config = config self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"]) self.patch_embed = UNetDown( patch_size=16, emb_channels=config["hidden_size"], in_channels=32, hidden_channels=1024, out_channels=config["hidden_size"], ) self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"]) self.final_layer = UNetUp( patch_size=16, emb_channels=config["hidden_size"], in_channels=config["hidden_size"], hidden_channels=1024, out_channels=32, out_norm=True, ) self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"]) self.moe_lru = None if not INIT_MOE: self.moe_lru = MoELRUCache() self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru) self.pad_id = 128009 self.vocab_size = 133120 self.lm_head = nn.Linear(config["hidden_size"], 133120, bias=False) self.first_step = True self.kv_cache = None @staticmethod def get_pos_emb(custom_pos_emb, position_ids): cos, sin = custom_pos_emb cos = real_batched_index_select(cos, dim=1, idx=position_ids) sin = real_batched_index_select(sin, dim=1, idx=position_ids) return cos, sin def ragged_final_layer(self, x, image_mask, timestep, token_h, token_w, first_step): bsz, seq_len, n_embd = x.shape if first_step: image_output = x.masked_select(image_mask.unsqueeze(-1).bool()).reshape(bsz, -1, n_embd) else: image_output = x[:, 1:, :] timestep_emb = self.time_embed_2(timestep) pred = self.final_layer(image_output, timestep_emb, token_h, token_w) return pred def forward(self, x, condition, timestep, **kwargs): if self.kv_cache is None: # TODO: should change when higgsv2 gets merged self.kv_cache = HunyuanStaticCache( config=self.config, batch_size=x.size(0) * 2, max_cache_len = input_ids.shape[1], dtype=x.dtype, ) image_mask = torch.arange(1, x.size(1) - 1).to(torch.bool) gen_timestep_scatter_index = 4 cond, uncond = condition[:4], condition[4:] joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) height, width = x.shape[2] * 16, x.shape[3] * 16 token_height = height // (16 * 16) token_width = width // (16 * 16) rope_image_info = [[(None, (token_height, token_width))] * 2] seq_len = input_ids.shape[1] cos, sin = build_batch_2d_rope( image_infos=rope_image_info, seq_len=seq_len, n_elem=self.config["hidden_size"] // self.config["num_attention_heads"], base=10000.0, ) custom_pos_emb = (sin, cos) custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) inputs_embeds = self.model.wte(input_ids) cond_timestep = torch.zeros(inputs_embeds.size(0)) t_emb = self.time_embed(cond_timestep) bsz, seq_len, n_embd = inputs_embeds.shape if self.first_step: t_emb = self.time_embed(timestep) x[:, 5:-4], token_h, token_w = self.patch_embed(x[:, 5:-4], t_emb) x[:, gen_timestep_scatter_index] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) else: t_emb = self.time_embed(timestep) x[:, 5:-4], token_h, token_w = self.patch_embed(x, t_emb) timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) x = torch.cat([timestep_emb, x], dim=1) inputs_embeds = torch.cat([inputs_embeds, x], dim = 1) #///////////// # cond_vae_images # cond_timestep_scatter_index joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) # conditioning images (vae) joint_image[:, 7:cond_vae_image_mask.size(0)], token_h, token_w = self.patch_embed( joint_image[:, 7:cond_vae_image_mask.size(0)], self.time_embed(cond_timestep) ) inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1) batch_image_slices = [ input_ids[i] + x[i] for i in range(bsz) ] attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) for i in range(bsz): for _, image_slice in enumerate(batch_image_slices[i]): attention_mask[i, image_slice, image_slice] = True attention_mask = attention_mask.unsqueeze(1) outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=self.kv_cache, inputs_embeds=inputs_embeds, custom_pos_emb=custom_pos_emb, first_step=self.first_step, gen_timestep_scatter_index=gen_timestep_scatter_index, ) hidden_states = outputs[0] hidden_states = hidden_states.to(input_ids.device) diffusion_prediction = self.ragged_final_layer( hidden_states, image_mask, timestep, token_h, token_w, self.first_step) if self.first_step: self.first_step = False return diffusion_prediction