From de43880bdb0578a45a0ddb270df53487776bfea2 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 31 Oct 2025 18:56:20 +0200 Subject: [PATCH 01/16] Hunyuan Image 3.0 --- comfy/ldm/hunyuan_image_3/model.py | 1098 +++++++++++++++++++++++++++ comfy/model_detection.py | 9 + comfy_extras/nodes_hunyuan_image.py | 121 +++ nodes.py | 1 + 4 files changed, 1229 insertions(+) create mode 100644 comfy/ldm/hunyuan_image_3/model.py create mode 100644 comfy_extras/nodes_hunyuan_image.py diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py new file mode 100644 index 000000000..f9a4a8485 --- /dev/null +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -0,0 +1,1098 @@ +import os +import gc +import math +import torch +import psutil +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +from collections import OrderedDict +from safetensors import safe_open +from contextlib import contextmanager +from transformers.cache_utils import StaticCache +from typing import Optional, Tuple, Any, List, Dict +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock + +INIT_MOE = torch.cuda.device_count() != 1 + +if not INIT_MOE: + MOE_LAYER_SIZE = (1024**3) * 2.65 # approx + CPU_MOE_RATIO = None + + torch.cuda.set_device(0) + props = torch.cuda.get_device_properties(0) + + INIT_CUDA_MEM = (props.total_memory - torch.cuda.memory_reserved()) * 0.9 + ADDITIONAL_LAYERS_IN_GPU = math.floor(INIT_CUDA_MEM / MOE_LAYER_SIZE) + +class HunyuanStaticCache(StaticCache): + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + cache_position = cache_kwargs.get("cache_position") + if hasattr(self, "key_cache") and hasattr(self, "value_cache"): + if self.key_cache[layer_idx].device != key_states.device: + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + else: + if self.layers[layer_idx].keys is None: + self.layers[layer_idx].lazy_initialization(key_states) + k_out = self.layers[layer_idx].keys + v_out = self.layers[layer_idx].values + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + if cache_position.dim() == 1: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + + else: + assert cache_position.dim() == 2, f"multiple batch dims not yet {cache_position.shape=}" + batch_size, _ = cache_position.shape + for i in range(batch_size): + unbatched_dim = 1 + k_out[i].index_copy_(unbatched_dim, cache_position[i], key_states[i]) + v_out[i].index_copy_(unbatched_dim, cache_position[i], value_states[i]) + + return k_out, v_out + +def real_batched_index_select(t, dim, idx): + return torch.stack([torch.index_select(t[i], dim - 1, idx[i]) for i in range(len(t))]) + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + +class TimestepEmbedder(nn.Module): + def __init__(self, + hidden_size, + act_layer=nn.GELU, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None + ): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + return x + + +def get_meshgrid_nd(start, *args, dim=2): + + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + num_int = [int(x) for x in num] + num = num_int + + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") + grid = torch.stack(grid, dim=0) + + return grid + +def build_2d_rope( + seq_len: int, n_elem: int, image_infos: Optional[List[Tuple[slice, Tuple[int, int]]]] = None, + device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, + return_all_pos: bool = False, +): + + assert n_elem % 4 == 0, f"n_elem must be divisible by 4, but got {n_elem}." + + # theta + if base_rescale_factor != 1.0: + base *= base_rescale_factor ** (n_elem / (n_elem - 2)) + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + theta = theta.reshape(1, n_elem // 4, 2) # [1, half_d, 2] + + # position indices + if image_infos is None: + image_infos = [] + + image_infos_list = [image_infos] + sample_seq_lens = [seq_len] + + # Prepare position indices for each sample + x_sections = [] + y_sections = [] + for sample_id, sample_image_infos in enumerate(image_infos_list): + last_pos = 0 + for sec_slice, (h, w) in sample_image_infos: + L = sec_slice.start # start from 0, so image_slice.start is just L + # previous text + if last_pos < L: + y_sections.append(torch.arange(last_pos, L)) + x_sections.append(torch.arange(last_pos, L)) + elif h is None: + # Interleave data has overlapped positions for tokens. + y_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) + x_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) + continue + else: + # Interleave data has overlapped positions for noised image and the successive clean image, + # leading to last_pos (= last text end L + noise w * h) > L (last text end L). + pass + # current image + beta_y = L + (w * h - h) / 2 + beta_x = L + (w * h - w) / 2 + grid = get_meshgrid_nd((beta_y, beta_x), (beta_y + h, beta_x + w)) # [2, h, w] + grid = grid.reshape(2, -1) # (y, x) + y_sections.append(grid[0]) + x_sections.append(grid[1]) + # step + last_pos = L + w * h + # final text + y_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) + x_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) + + x_pos = torch.cat(x_sections).long() + y_pos = torch.cat(y_sections).long() + # If there are overlap positions, we need to remove them. + x_pos = x_pos[:seq_len] + y_pos = y_pos[:seq_len] + all_pos = torch.stack((y_pos, x_pos), dim=1).unsqueeze(1).to(device) # [seq_len, 1, 2] + + # calc rope + idx_theta = (all_pos * theta).reshape(all_pos.shape[0], n_elem // 2).repeat(1, 2) + + cos = torch.cos(idx_theta) + sin = torch.sin(idx_theta) + + if return_all_pos: + return cos, sin, all_pos + + return cos, sin + + +def build_batch_2d_rope( + seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None, + device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, + return_all_pos: bool = False, +): + cos_list, sin_list, all_pos_list = [], [], [] + if image_infos is None: + image_infos = [None] + for i, image_info in enumerate(image_infos): + res = build_2d_rope( + seq_len, n_elem, image_infos=image_info, device=device, + base=base, base_rescale_factor=base_rescale_factor, + return_all_pos=return_all_pos, + ) + if return_all_pos: + cos, sin, all_pos = res + else: + cos, sin = res + all_pos = None + cos_list.append(cos) + sin_list.append(sin) + all_pos_list.append(all_pos) + + stacked_cos = torch.stack(cos_list, dim=0) + stacked_sin = torch.stack(sin_list, dim=0) + + if return_all_pos: + return stacked_cos, stacked_sin, all_pos_list + + return stacked_cos, stacked_sin + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + if position_ids is not None: + cos = cos[position_ids] + sin = sin[position_ids] + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def default(val, d): + return val if val is not None else d + +def conv_nd(dims, *args, **kwargs): + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + +def normalization(channels, **kwargs): + return nn.GroupNorm(32, channels, **kwargs) + +def topkgating( + logits: torch.Tensor, + topk: int, + norm_topk_prob: bool = True, +): + logits = logits.float() + gates = F.softmax(logits, dim=1) + + extra = ADDITIONAL_LAYERS_IN_GPU + + values_all, indices_all = torch.topk(gates, topk + extra, dim=1) + expert_weight = values_all[:, :topk] + expert_index = indices_all[:, :topk] + + _, cpu_expert_index = torch.topk(gates, int(CPU_MOE_RATIO * 64), dim = 1) + cpu_expert_index = cpu_expert_index[:, (8 + ADDITIONAL_LAYERS_IN_GPU):] + + if norm_topk_prob and topk > 1: + denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps) + expert_weight = expert_weight / denom + + return expert_weight, expert_index, cpu_expert_index, indices_all + +class HunyuanRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class UNetDown(nn.Module): + def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, + dropout=0.0, device=None, dtype=None): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + + self.patch_size = patch_size + assert self.patch_size in [1, 2, 4, 8] + + self.model = nn.ModuleList( + [conv_nd( + 2, + in_channels=in_channels, + out_channels=hidden_channels, + kernel_size=3, + padding=1, + **factory_kwargs + )] + ) + + if self.patch_size == 1: + self.model.append(ResBlock( + in_channels=hidden_channels, + emb_channels=emb_channels, + out_channels=out_channels, + dropout=dropout, + **factory_kwargs + )) + else: + for i in range(self.patch_size // 2): + self.model.append(ResBlock( + in_channels=hidden_channels, + emb_channels=emb_channels, + out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels, + dropout=dropout, + down=True, + **factory_kwargs + )) + + def forward(self, x, t): + assert x.shape[2] % self.patch_size == 0 and x.shape[3] % self.patch_size == 0 + for module in self.model: + if isinstance(module, ResBlock): + x = module(x, t) + else: + x = module(x) + _, _, token_h, token_w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + return x, token_h, token_w + + +class UNetUp(nn.Module): + + def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, + dropout=0.0, device=None, dtype=None, operations = None, out_norm=False): + operations = operations or nn + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + + self.patch_size = patch_size + assert self.patch_size in [1, 2, 4, 8] + + self.model = nn.ModuleList() + + if self.patch_size == 1: + self.model.append(ResBlock( + in_channels=in_channels, + emb_channels=emb_channels, + out_channels=hidden_channels, + dropout=dropout, + **factory_kwargs + )) + else: + for i in range(self.patch_size // 2): + self.model.append(ResBlock( + in_channels=in_channels if i == 0 else hidden_channels, + emb_channels=emb_channels, + out_channels=hidden_channels, + dropout=dropout, + up=True, + **factory_kwargs + )) + + if out_norm: + self.model.append(nn.Sequential( + normalization(hidden_channels, **factory_kwargs), + nn.SiLU(), + nn.Conv2d( + in_channels=hidden_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + **factory_kwargs + ), + )) + else: + self.model.append(nn.Conv2d( + in_channels=hidden_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + **factory_kwargs + )) + + # batch_size, seq_len, model_dim + def forward(self, x, t, token_h, token_w): + x = rearrange(x, 'b (h w) c -> b c h w', h=token_h, w=token_w) + for module in self.model: + if isinstance(module, ResBlock): + x = module(x, t) + else: + x = module(x) + return x + +class HunyuanTopKGate(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.moe_topk = 8 + self.min_capacity = 8 + num_experts = 64 + self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32) + + self.norm_topk_prob = True + + def forward(self, hidden_states): + bsz, seq_len, hidden_size = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_size) + if self.wg.weight.dtype == torch.float32: + hidden_states = hidden_states.float() + logits = self.wg(hidden_states) + gate_output = topkgating(logits, self.moe_topk, norm_topk_prob=self.norm_topk_prob,) + + return gate_output + +class HunyuanMLP(nn.Module): + def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config["hidden_size"] + + self.intermediate_size = 3072 + + self.act_fn = torch.nn.functional.silu + self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) + if x.ndim == 2: + x = x.unsqueeze(0) + self.intermediate_size *= 2 # SwiGLU + self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False) + def forward(self, x): + gate_and_up_proj = self.gate_and_up_proj(x) + x1, x2 = gate_and_up_proj.chunk(2, dim=2) + down_proj = self.down_proj(x1 * self.act_fn(x2)) + return down_proj + +class MoELRUCache(nn.Module): + def __init__(self, cpu_mem: int = 50, safety_buffer_bytes = 3*(1024**3), max_gpu_eviction_attempts = 8): + super().__init__() + global CPU_MOE_RATIO + + _, total = torch.cuda.mem_get_info() + max_gpu_mem_gb = max((total - 2 * safety_buffer_bytes) / (1024**3), 1) + + self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3) + self.MAX_CPU_MEM = int(cpu_mem * 1024**3) + self.gpu_cache = OrderedDict() + self.cpu_cache = OrderedDict() + + self.gpu_mem_usage = 0 + self.cpu_mem_usage = 0 + # 50% for system and headroom + try: + self.MAX_CPU_MEM = int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')) + - psutil.Process(os.getpid()).memory_info().rss + - safety_buffer_bytes) * 0.55 + except: + self.MAX_CPU_MEM = int(cpu_mem * (1024**3) * 0.5) # TODO + + ADDITIONAL_LAYERS_IN_CPU = math.floor((50 * (1024**3)) / MOE_LAYER_SIZE) + CPU_MOE_RATIO = (min(64 - ADDITIONAL_LAYERS_IN_GPU, ADDITIONAL_LAYERS_IN_CPU)) / 64 + + self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3) + self.SAFETY_BUFFER = int(safety_buffer_bytes) + self.MAX_GPU_EVICT_ATTEMPTS = max_gpu_eviction_attempts + + def _gpu_free_bytes(self): + free, total = torch.cuda.mem_get_info() + return int(free) + + def _estimate_size(self, moe): + # include parameters + buffers + size = 0 + for p in moe.parameters(): + size += p.numel() * p.element_size() + for b in moe.buffers(): + size += b.numel() * b.element_size() + return int(size) + + def _evict_until_free(self, required_bytes, max_attempts=16): + attempts = 0 + while self._gpu_free_bytes() < required_bytes and attempts < max_attempts: + evicted = self._evict_from_gpu() + if not evicted: + break + attempts += 1 + return self._gpu_free_bytes() >= required_bytes + + @contextmanager + def ensure_headroom(self, required_bytes): + + safety = getattr(self, "SAFETY_BUFFER", 0) + target_free = int(required_bytes + safety) + + if getattr(self, "_headroom", None) is not None: + try: + del self._headroom + except Exception: + pass + self._headroom = None + + ok = self._evict_until_free(target_free) + if not ok and self._gpu_free_bytes() < target_free: + # last ditch + try: + torch.cuda.empty_cache() + except Exception: + pass + + try: + yield + finally: + if getattr(self, "_headroom", None) is None: + try: + self._headroom = torch.empty((self._headroom_bytes,), dtype=torch.uint8, device="cuda:0") + except Exception: + self._headroom = None + + def add_gpu(self, moe, index, allowed_retries=3): + size = self._estimate_size(moe) + + while self.gpu_mem_usage + size > self.MAX_GPU_MEM: + if not self._evict_from_gpu(): + break + + attempts = 0 + while self._gpu_free_bytes() < size + self.SAFETY_BUFFER and attempts < self.MAX_GPU_EVICT_ATTEMPTS: + if not self._evict_from_gpu(): + break + attempts += 1 + + for _ in range(allowed_retries): + try: + moe_cuda = moe.to("cuda:0") + break + except RuntimeError as e: + if "out of memory" not in str(e).lower(): + raise + evicted = self._evict_from_gpu() + if not evicted: # can't evict + raise + else: + raise RuntimeError("Failed to move expert to GPU after evictions") + + self.gpu_cache[index] = moe_cuda + self.gpu_cache.move_to_end(index) + self.gpu_mem_usage += size + + return + + def add_cpu(self, moe, index): + size = self._estimate_size(moe) + while self.cpu_mem_usage + size > self.MAX_CPU_MEM: + if not self._evict_from_cpu(): + break + moe_cpu = moe.to("cpu") + self.cpu_cache[index] = moe_cpu + self.cpu_cache.move_to_end(index) + self.cpu_mem_usage += size + + def get_from_device(self, index): + if index in self.gpu_cache: + moe = self.gpu_cache[index] + self.gpu_cache.move_to_end(index) + return moe + if index in self.cpu_cache: + moe = self.cpu_cache.pop(index) + self.cpu_mem_usage = max(0, self.cpu_mem_usage - self._estimate_size(moe)) + try: + self.add_gpu(moe, index) + return self.gpu_cache[index] + except RuntimeError: + self.cpu_cache[index] = moe + self.cpu_cache.move_to_end(index) + self.cpu_mem_usage += self._estimate_size(moe) + raise + + return None # load from disk + + def _evict_from_gpu(self): + if not self.gpu_cache: + return False + + idx, moe = self.gpu_cache.popitem(last=False) + size = self._estimate_size(moe) + self.gpu_mem_usage = max(0, self.gpu_mem_usage - size) + + if self.cpu_mem_usage + size <= self.MAX_CPU_MEM: + try: + moe_cpu = moe.to("cpu") + except Exception: + # drop the model if cpu is full + del moe + return True + self.cpu_cache[idx] = moe_cpu + self.cpu_cache.move_to_end(idx) + self.cpu_mem_usage += size + return True + else: + del moe + return True + + def _evict_from_cpu(self): + if not self.cpu_cache: + return False + _, moe = self.cpu_cache.popitem(last=False) + size = self._estimate_size(moe) + self.cpu_mem_usage = max(0, self.cpu_mem_usage - size) + del moe + gc.collect() + return True + +class LazyMoELoader(nn.Module): + def __init__(self, device): + super().__init__() + self.device = device + + def lazy_init(self, config, layer_idx, expert_idx): + checkpoint = "./models/checkpoint/hunyuan_image_3.safetensors" + if not os.path.exists(checkpoint): + raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}") + + prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}." + additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight" + sd = {} + + with safe_open(checkpoint, framework="pt", device=self.device) as f: + for k in f.keys(): + if k.startswith(prefix) or k.startswith(additional_prefix): + new_k = k.split(f"experts.{expert_idx}.", 1)[1] + sd[new_k] = f.get_tensor(k) + + return HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd).to(self.deivce) + +class HunyuanMoE(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.moe_topk = 8 + self.num_experts = 64 + self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True) + self.gate = HunyuanTopKGate(config, layer_idx=layer_idx) + if INIT_MOE: + self.experts = nn.ModuleList( + [HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)] + ) + else: + self.experts = None + self.moe_lru = moe_lru + + def forward(self, hidden_states): + if not INIT_MOE: + torch.cuda.set_device(0) + else: + torch.cuda.set_device(hidden_states.device.index) + bsz, seq_len, hidden_size = hidden_states.shape + + hidden_states_mlp = self.shared_mlp(hidden_states) + + reshaped_input = hidden_states.reshape(-1, hidden_size) + + with torch.cuda.nvtx.range("MoE"): + expert_weight, expert_index, cpu_expert_index, indices_all = self.gate(hidden_states) + if not INIT_MOE: + if ADDITIONAL_LAYERS_IN_GPU > 0: + additional_expert_index = indices_all[:, expert_index.size(1): expert_index.size(1) + ADDITIONAL_LAYERS_IN_GPU] + + flat = additional_expert_index.reshape(-1).to("cpu") + counts = torch.bincount(flat, minlength=self.num_experts) + top_extra = torch.topk(counts, k=min(ADDITIONAL_LAYERS_IN_GPU, (counts>0).sum().item())).indices.tolist() + + for expert_id in top_extra: + if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None: + expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id) + self.moe_lru.add_gpu(expert_cpu, expert_id + self.layer_idx) + + if cpu_expert_index is not None and cpu_expert_index.numel() > 0: + for expert_id in torch.unique(cpu_expert_index).cpu().tolist(): + if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None: + expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id) + self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx) + + combined_output = torch.zeros_like(reshaped_input) + for e in range(self.num_experts): + token_mask = (expert_index == e) + if not token_mask.any(): + continue + + token_ids = token_mask.nonzero(as_tuple=False) + token_positions = token_ids[:, 0] + + topk_slot = token_ids[:, 1] + + tokens = reshaped_input[token_positions] + weights = expert_weight[token_positions, topk_slot] + + if self.experts is not None and INIT_MOE: + out = self.experts[e](tokens) + elif self.experts is None: + expert = self.moe_lru.get_from_device(e + self.layer_idx) + if expert is None: + expert = LazyMoELoader() + out = expert.lazy_init(self.config, self.layer_idx, e)(tokens) + self.moe_lru.add_gpu(expert, e + self.layer_idx) + else: + tokens = tokens.to(next(expert.parameters()).device) + out = expert(tokens.view(bsz, -1, hidden_size)) + + out = out * weights.to(out.device).unsqueeze(-1) + + combined_output.to(out.device).index_add_(0, token_positions.to(out.device), out.reshape(-1, hidden_size)) + #dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) + #chunks = dispatched_input.chunk(self.num_experts, dim=0) + #expert_outputs = [] + #for chunk, expert in zip(chunks, self.experts): + # expert_outputs.append(expert(chunk)) + + #expert_output = torch.cat(expert_outputs, dim=0) + #combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output) + + combined_output = combined_output.reshape(bsz, seq_len, hidden_size) + + output = hidden_states_mlp + combined_output + + return output + +class HunyuanImage3Attention(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_type = 'self' + + self.hidden_size = config["hidden_size"] + self.num_heads = config["num_attention_heads"] + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = 8 + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config["max_position_embeddings"] + self.rope_theta = 10000.0 + self.is_causal = True + self.hidden_size_q = self.head_dim * self.num_heads + self.hidden_size_kv = self.head_dim * self.num_key_value_heads + + # define layers + self.qkv_proj = nn.Linear( + self.hidden_size, + self.hidden_size_q + 2 * self.hidden_size_kv, + bias=False + ) + self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=False) + + self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]) + self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, + **kwargs, + ): + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.qkv_proj(hidden_states) + qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2, + self.head_dim) + query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3) + + query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = custom_pos_emb + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + query_states = query_states.to(value_states.dtype) + key_states = key_states.to(value_states.dtype) + + key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups) + value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups) + + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, mask = attention_mask, skip_reshape=True) + + attn_output = self.o_proj(attn_output) + + return attn_output + +class HunyuanImage3DecoderLayer(nn.Module): + def __init__(self, config, layer_idx: int, moe_lru=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.layer_idx = layer_idx + self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx) + + self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru) + + self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) + self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps']) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor | Any]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + custom_pos_emb=custom_pos_emb, + **kwargs, + ) + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + return outputs + +class HunyuanImage3Model(nn.Module): + def __init__(self, config, moe_lru=None): + super().__init__(config) + self.padding_idx = 128009 + self.vocab_size = 133120 + self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) + self.layers = nn.ModuleList( + [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])] + ) + + self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) + + self.shared_tensor = None + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache = True, + custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, + mode: str = "gen_image", + first_step: Optional[bool] = None, + gen_timestep_scatter_index: Optional[torch.Tensor] = None, + ): + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + next_decoder_cache = None + for layer_idx, decoder_layer in enumerate(self.layers): + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + custom_pos_emb=custom_pos_emb, + mode=mode, + first_step=first_step, + gen_timestep_scatter_index=gen_timestep_scatter_index, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[1] + + next_cache = None + if use_cache: + next_cache = next_decoder_cache + + return tuple(v for v in [hidden_states, next_cache] if v is not None) + + +class HunyuanImage3ForCausalMM(nn.Module): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"]) + self.patch_embed = UNetDown( + patch_size=16, + emb_channels=config["hidden_size"], + in_channels=32, + hidden_channels=1024, + out_channels=config["hidden_size"], + ) + self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"]) + + self.final_layer = UNetUp( + patch_size=16, + emb_channels=config["hidden_size"], + in_channels=config["hidden_size"], + hidden_channels=1024, + out_channels=32, + out_norm=True, + ) + self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"]) + + self.moe_lru = None + if not INIT_MOE: + self.moe_lru = MoELRUCache() + + self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru) + + self.pad_id = 128009 + self.vocab_size = 133120 + + self.lm_head = nn.Linear(config["hidden_size"], 133120, bias=False) + self.first_step = True + + self.kv_cache = None + + @staticmethod + def get_pos_emb(custom_pos_emb, position_ids): + cos, sin = custom_pos_emb + cos = real_batched_index_select(cos, dim=1, idx=position_ids) + sin = real_batched_index_select(sin, dim=1, idx=position_ids) + return cos, sin + + def ragged_final_layer(self, x, image_mask, timestep, token_h, token_w, first_step): + bsz, seq_len, n_embd = x.shape + if first_step: + image_output = x.masked_select(image_mask.unsqueeze(-1).bool()).reshape(bsz, -1, n_embd) + else: + image_output = x[:, 1:, :] + timestep_emb = self.time_embed_2(timestep) + pred = self.final_layer(image_output, timestep_emb, token_h, token_w) + return pred + + def forward(self, x, condition, timestep, **kwargs): + + if self.kv_cache is None: + # TODO: should change when higgsv2 gets merged + self.kv_cache = HunyuanStaticCache( + config=self.config, + batch_size=x.size(0) * 2, + max_cache_len = input_ids.shape[1], + dtype=x.dtype, + ) + + image_mask = torch.arange(1, x.size(1) - 1).to(torch.bool) + gen_timestep_scatter_index = 4 + cond, uncond = condition[:4], condition[4:] + joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] + + + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) + height, width = x.shape[2] * 16, x.shape[3] * 16 + token_height = height // (16 * 16) + token_width = width // (16 * 16) + + rope_image_info = [[(None, (token_height, token_width))] * 2] + seq_len = input_ids.shape[1] + cos, sin = build_batch_2d_rope( + image_infos=rope_image_info, + seq_len=seq_len, + n_elem=self.config["hidden_size"] // self.config["num_attention_heads"], + base=10000.0, + ) + custom_pos_emb = (sin, cos) + + custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) + inputs_embeds = self.model.wte(input_ids) + + cond_timestep = torch.zeros(inputs_embeds.size(0)) + t_emb = self.time_embed(cond_timestep) + + bsz, seq_len, n_embd = inputs_embeds.shape + + if self.first_step: + t_emb = self.time_embed(timestep) + x[:, 5:-4], token_h, token_w = self.patch_embed(x[:, 5:-4], t_emb) + x[:, gen_timestep_scatter_index] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + else: + t_emb = self.time_embed(timestep) + x[:, 5:-4], token_h, token_w = self.patch_embed(x, t_emb) + timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) + x = torch.cat([timestep_emb, x], dim=1) + + inputs_embeds = torch.cat([inputs_embeds, x], dim = 1) + + #///////////// + # cond_vae_images + + # cond_timestep_scatter_index + joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + # conditioning images (vae) + joint_image[:, 7:cond_vae_image_mask.size(0)], token_h, token_w = self.patch_embed( + joint_image[:, 7:cond_vae_image_mask.size(0)], self.time_embed(cond_timestep) + ) + + inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1) + + batch_image_slices = [ + input_ids[i] + x[i] + for i in range(bsz) + ] + attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) + for i in range(bsz): + for _, image_slice in enumerate(batch_image_slices[i]): + attention_mask[i, image_slice, image_slice] = True + attention_mask = attention_mask.unsqueeze(1) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=self.kv_cache, + inputs_embeds=inputs_embeds, + custom_pos_emb=custom_pos_emb, + first_step=self.first_step, + gen_timestep_scatter_index=gen_timestep_scatter_index, + ) + hidden_states = outputs[0] + + hidden_states = hidden_states.to(input_ids.device) + diffusion_prediction = self.ragged_final_layer( + hidden_states, image_mask, timestep, token_h, token_w, self.first_step) + + if self.first_step: + self.first_step = False + + return diffusion_prediction diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 18232ade3..850a3c0ac 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -406,6 +406,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["patch_size"] = 2 dit_config["text_emb_dim"] = 2048 return dit_config + + if "{}layers.32.mlp.gate_and_up_proj.weight".format(key_prefix) in state_dict_keys: + dit_config = {} + dit_config["image_model"] = "hunyuan_image_3" + dit_config["hidden_size"] = 4096 + dit_config["max_position_embeddings"] = 12800 + dit_config["num_attention_heads"] = 32 + dit_config['rms_norm_eps'] = 1e-05 + return dit_config if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 dit_config = {} diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py new file mode 100644 index 000000000..ada042fc5 --- /dev/null +++ b/comfy_extras/nodes_hunyuan_image.py @@ -0,0 +1,121 @@ +import torch +import comfy.model_management +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + +COMPUTED_RESO_GROUPS = ['512x2048', '512x1984', '512x1920', '512x1856', '512x1792', '512x1728', '512x1664', '512x1600', '512x1536', '576x1472', '640x1408', '704x1344', '768x1280', '832x1216', '896x1152', '960x1088', '1024x1024', '1088x960', '1152x896', '1216x832', '1280x768', '1344x704', '1408x640', '1472x576', '1536x512', '1600x512', '1664x512', '1728x512', '1792x512', '1856x512', '1920x512', '1984x512', '2048x512'] +RATIOS = [torch.tensor(int(r.split("x")[0]) / int(r.split("x")[1])) for r in COMPUTED_RESO_GROUPS] +def get_target_size(height, width): + ratio = height / width + idx = torch.argmin(torch.abs(torch.tensor(RATIOS) - ratio)) + reso = COMPUTED_RESO_GROUPS[idx] + return reso.split("x") + +class EmptyLatentHunyuanImage3(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyLatentHunyuanImage3", + display_name="EmptyLatentHunyuanImage3", + category="image/latent", + inputs = [ + io.Int.Input("height", min = 1, default = 512), + io.Int.Input("width", min = 1, default = 512), + io.Int.Input("batch_size", min = 1, max = 48_000, default = 1), + io.Clip.Input("clip") + ], + outputs=[io.Latent.Output(display_name="latent")] + ) + @classmethod + def execute(cls, height, width, batch_size, clip): + encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids + special_fn = clip.tokenizer.tokenizer.added_tokens_encoder + def fn(string, func = encode_fn): + return torch.tensor(func(string), device=comfy.model_management.intermediate_device()).unsqueeze(0) + + height, width = get_target_size(height, width) + latent = torch.randn(batch_size, 32, height // 16, width // 16, device=comfy.model_management.intermediate_device()) + latent = torch.cat([fn(""), fn("_start"), fn("", special_fn), fn(f"", special_fn), + latent, fn(""), fn("_start"), fn("_end"), fn("_end")], dim = 1) + return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, ) + +class HunyuanImage3Conditioning(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanImage3Conditioning", + display_name="HunyuanImage3Conditioning", + category="conditioning/video_models", + inputs = [ + io.Conditioning.Input("vae_encoding"), + io.Conditioning.Input("vit_encoding"), + io.Conditioning.Input("text_encoding_positive"), + io.Conditioning.Input("text_encoding_negative", optional = True), + io.Clip.Input("clip") + ], + outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")] + ) + + @classmethod + def execute(cls, vae_encoding, vit_encoding, text_encoding, clip, text_encoding_negative=None): + encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids + special_fn = clip.tokenizer.tokenizer.added_tokens_encoder + def fn(string, func = encode_fn): + return torch.tensor(func(string), device=text_encoding.device).unsqueeze(0) + + text_encoding = text_encoding[0][0] + + text_tokens = torch.cat([fn("_start"), text_encoding, fn("_end")], dim = 1) + vae_tokens = torch.cat([fn("_start"), fn("_start"), fn("_start"), vae_encoding, fn("_end"), fn("_end"), fn("")], dim = 1) + vit_tokens = torch.cat([fn("_start"), fn("_start"), vit_encoding, fn("_end"), fn("_end"), fn("_end")], dim = 1) + n, seq_len, dim = vit_tokens.shape + vit_tokens = vit_tokens.reshape(n * seq_len, dim) + # should dynamically change in model logic + joint_image = torch.cat([fn(""), fn("", special_fn), fn("", special_fn), fn("", special_fn), vae_tokens, vit_tokens, fn("")], dim = 1) + + seq_len_total = joint_image.shape[1] + mask = torch.zeros(seq_len_total, dtype=torch.bool, device=joint_image.device) + positions = {} + current = 4 + + def mark_region(name, tensor): + nonlocal current + start = current + current += tensor.shape[1] + end = current - 1 + positions[f"<{name}>_start"] = start + positions[f"<{name}>_end"] = end + mask[start:end + 1] = True + return start, end + + mark_region("vae_img", vae_tokens) + + mask_list = [] + for prefix in ["text", "vae_img", "vit_img"]: + start = positions[f"<{prefix}>_start"] + end = positions[f"<{prefix}>_end"] + + section_mask = torch.arange(start, end + 1, device=mask.device) + mask_list.append(section_mask) + + mask_list.insert(0, joint_image) + mask_list.append(text_tokens) + ragged_tensors = torch.nested.nested_tensor(mask_list, dtype=torch.long) + + if text_encoding_negative is not None: + uncond_ragged_tensors = cls.execute(vae_encoding, vit_encoding, text_encoding_negative, clip=clip, text_encoding_negative = None) + else: + uncond_ragged_tensors = torch.nested.nested_tensor([torch.zeros_like(t) for t in ragged_tensors.unbind()]) + + return ragged_tensors, uncond_ragged_tensors + +class Image3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + HunyuanImage3Conditioning, + EmptyLatentHunyuanImage3 + ] + +async def comfy_entrypoint() -> Image3Extension: + return Image3Extension() diff --git a/nodes.py b/nodes.py index 1b465b9e6..47a7df6fc 100644 --- a/nodes.py +++ b/nodes.py @@ -2282,6 +2282,7 @@ def init_builtin_extra_nodes(): "nodes_ace.py", "nodes_string.py", "nodes_camera_trajectory.py", + "nodes_hunyuan_image.py", "nodes_edit_model.py", "nodes_tcfg.py" ] From a2fff60d4c5a4bfb7560af0da73809a403519461 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 31 Oct 2025 23:53:13 +0200 Subject: [PATCH 02/16] vectorized implementation of moe/fixes for issues --- comfy/ldm/hunyuan_image_3/model.py | 74 ++++++++++++++++++++++-------- comfy/model_base.py | 5 ++ 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index f9a4a8485..949769839 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -4,6 +4,7 @@ import math import torch import psutil import torch.nn as nn +from pathlib import Path from einops import rearrange import torch.nn.functional as F from collections import OrderedDict @@ -460,13 +461,13 @@ class HunyuanMLP(nn.Module): self.intermediate_size = 3072 self.act_fn = torch.nn.functional.silu - self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) - if x.ndim == 2: - x = x.unsqueeze(0) self.intermediate_size *= 2 # SwiGLU self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False) def forward(self, x): + self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) + if x.ndim == 2: + x = x.unsqueeze(0) gate_and_up_proj = self.gate_and_up_proj(x) x1, x2 = gate_and_up_proj.chunk(2, dim=2) down_proj = self.down_proj(x1 * self.act_fn(x2)) @@ -654,7 +655,9 @@ class LazyMoELoader(nn.Module): self.device = device def lazy_init(self, config, layer_idx, expert_idx): - checkpoint = "./models/checkpoint/hunyuan_image_3.safetensors" + comfyui_dir = Path.home() / "ComfyUI" + checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors" + checkpoint = checkpoint.resolve() if not os.path.exists(checkpoint): raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}") @@ -720,34 +723,65 @@ class HunyuanMoE(nn.Module): self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx) combined_output = torch.zeros_like(reshaped_input) + experts_list = [] for e in range(self.num_experts): token_mask = (expert_index == e) if not token_mask.any(): continue + expert = self.moe_lru.get_from_device(e + self.layer_idx) + if expert is None: + expert = LazyMoELoader() + expert = expert.lazy_init(self.config, self.layer_idx, e) + self.moe_lru.add_gpu(expert, e + self.layer_idx) + experts_list.append((e, expert)) + + per_pos, per_tokens, per_weights = [], [], [] + for e, _ in experts_list: + token_mask = (expert_index == e) token_ids = token_mask.nonzero(as_tuple=False) token_positions = token_ids[:, 0] - topk_slot = token_ids[:, 1] tokens = reshaped_input[token_positions] weights = expert_weight[token_positions, topk_slot] - if self.experts is not None and INIT_MOE: - out = self.experts[e](tokens) - elif self.experts is None: - expert = self.moe_lru.get_from_device(e + self.layer_idx) - if expert is None: - expert = LazyMoELoader() - out = expert.lazy_init(self.config, self.layer_idx, e)(tokens) - self.moe_lru.add_gpu(expert, e + self.layer_idx) - else: - tokens = tokens.to(next(expert.parameters()).device) - out = expert(tokens.view(bsz, -1, hidden_size)) + per_pos.append(token_positions) + per_tokens.append(tokens) + per_weights.append(weights) - out = out * weights.to(out.device).unsqueeze(-1) + lengths = [t.shape[0] for t in per_tokens] + E = len(per_tokens) + L = max(lengths) + tokens_padded = torch.zeros((E, L, hidden_size), device=hidden_states.device, dtype=reshaped_input.dtype) + weights_padded = torch.zeros((E, L), device=hidden_states.device, dtype=per_weights[0].dtype) + for i, t in enumerate(per_tokens): + tokens_padded[i, : t.shape[0]] = t + weights_padded[i, : t.shape[0]] = per_weights[i] + + l1, l2 = [], [] + for _, expert in experts_list: + l1.append(expert.gate_and_up_proj) + l2.append(expert.down_proj) + + W1 = torch.stack([l.weight for l in l1]).to(hidden_states.device) + W2 = torch.stack([l.weight for l in l2]).to(hidden_states.device) + + W1_T = W1.transpose(1, 2) + W2_T = W2.transpose(1, 2) + + x = torch.bmm(tokens_padded, W1_T) + x = F.silu(x) + + out_padded = torch.bmm(x, W2_T) + + out_padded = out_padded * weights_padded.unsqueeze(-1) + + for i, token_positions in enumerate(per_pos): + Ni = lengths[i] + out_i = out_padded[i, :Ni] + combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i) - combined_output.to(out.device).index_add_(0, token_positions.to(out.device), out.reshape(-1, hidden_size)) #dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) #chunks = dispatched_input.chunk(self.num_experts, dim=0) #expert_outputs = [] @@ -1014,12 +1048,12 @@ class HunyuanImage3ForCausalMM(nn.Module): dtype=x.dtype, ) - image_mask = torch.arange(1, x.size(1) - 1).to(torch.bool) + image_mask = torch.ones(x.size(1)) + image_mask[:, :5] = torch.zeros(5); image_mask[:, -4:] = torch.zeros(4) gen_timestep_scatter_index = 4 cond, uncond = condition[:4], condition[4:] joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] - position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) height, width = x.shape[2] * 16, x.shape[3] * 16 token_height = height // (16 * 16) diff --git a/comfy/model_base.py b/comfy/model_base.py index 4392355ea..7a74debee 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 +import comfy.ldm.hunyuan_image_3.model import comfy.model_management import comfy.patcher_extension @@ -1196,6 +1197,10 @@ class Hunyuan3Dv2(BaseModel): if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out + +class HunyuanImage3(BaseModel): + def __init__(self, model_config, model_type=ModelType.Flow, device=None): + super().__init__(model_config, model_type, device, unet_model = comfy.ldm.hunyuan_image_3.model.HunyuanImage3ForCausalMM) class HiDream(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): From 70f216bbd0dadfe8c66c4cf6fd93ce1b708e21c1 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 31 Oct 2025 23:58:01 +0200 Subject: [PATCH 03/16] tiny bug --- comfy/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 7a74debee..c9fe4e4ef 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1199,7 +1199,7 @@ class Hunyuan3Dv2(BaseModel): return out class HunyuanImage3(BaseModel): - def __init__(self, model_config, model_type=ModelType.Flow, device=None): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device, unet_model = comfy.ldm.hunyuan_image_3.model.HunyuanImage3ForCausalMM) class HiDream(BaseModel): From 10a17dc85d6d1701d9742d81e552ab654ad98723 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 1 Nov 2025 16:40:49 +0200 Subject: [PATCH 04/16] a bunch of fixes --- comfy/ldm/hunyuan_image_3/model.py | 20 ++++++---- comfy_extras/nodes_hunyuan_image.py | 57 +++++++++-------------------- 2 files changed, 29 insertions(+), 48 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 949769839..ba2c1e90c 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -1053,6 +1053,7 @@ class HunyuanImage3ForCausalMM(nn.Module): gen_timestep_scatter_index = 4 cond, uncond = condition[:4], condition[4:] joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] + joint_image[:, 2] = x[:, 2] # updates image ratio position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) height, width = x.shape[2] * 16, x.shape[3] * 16 @@ -1079,11 +1080,11 @@ class HunyuanImage3ForCausalMM(nn.Module): if self.first_step: t_emb = self.time_embed(timestep) - x[:, 5:-4], token_h, token_w = self.patch_embed(x[:, 5:-4], t_emb) + x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb) x[:, gen_timestep_scatter_index] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) else: t_emb = self.time_embed(timestep) - x[:, 5:-4], token_h, token_w = self.patch_embed(x, t_emb) + x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb) timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) x = torch.cat([timestep_emb, x], dim=1) @@ -1095,16 +1096,19 @@ class HunyuanImage3ForCausalMM(nn.Module): # cond_timestep_scatter_index joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) # conditioning images (vae) - joint_image[:, 7:cond_vae_image_mask.size(0)], token_h, token_w = self.patch_embed( - joint_image[:, 7:cond_vae_image_mask.size(0)], self.time_embed(cond_timestep) + joint_image[:, 3:cond_vae_image_mask.size(0)+3], token_h, token_w = self.patch_embed( + joint_image[:, 3:cond_vae_image_mask.size(0)+3], self.time_embed(cond_timestep) ) inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1) - batch_image_slices = [ - input_ids[i] + x[i] - for i in range(bsz) - ] + batch_image_slices = [] + for i in range(x.size(0)): + # slice the vae and vit parts + slice the latent from x + joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)] + gen_slices_i = [slice(3, x[i].size(1) - 1)] + batch_image_slices.append(joint_slices_i + gen_slices_i) + attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) for i in range(bsz): for _, image_slice in enumerate(batch_image_slices[i]): diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py index ada042fc5..ba69c0978 100644 --- a/comfy_extras/nodes_hunyuan_image.py +++ b/comfy_extras/nodes_hunyuan_image.py @@ -35,8 +35,7 @@ class EmptyLatentHunyuanImage3(io.ComfyNode): height, width = get_target_size(height, width) latent = torch.randn(batch_size, 32, height // 16, width // 16, device=comfy.model_management.intermediate_device()) - latent = torch.cat([fn(""), fn("_start"), fn("", special_fn), fn(f"", special_fn), - latent, fn(""), fn("_start"), fn("_end"), fn("_end")], dim = 1) + latent = torch.cat([fn(""), fn("", special_fn), fn(f"", special_fn), latent, fn("")], dim = 1) return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, ) class HunyuanImage3Conditioning(io.ComfyNode): @@ -63,51 +62,29 @@ class HunyuanImage3Conditioning(io.ComfyNode): def fn(string, func = encode_fn): return torch.tensor(func(string), device=text_encoding.device).unsqueeze(0) - text_encoding = text_encoding[0][0] + text_tokens = text_encoding[0][0] + # should dynamically change in model logic + joint_image = torch.cat([fn(""), fn("", special_fn), fn("", special_fn), fn("", special_fn), vae_encoding, fn(""), vit_encoding, fn("")], dim = 1) - text_tokens = torch.cat([fn("_start"), text_encoding, fn("_end")], dim = 1) - vae_tokens = torch.cat([fn("_start"), fn("_start"), fn("_start"), vae_encoding, fn("_end"), fn("_end"), fn("")], dim = 1) - vit_tokens = torch.cat([fn("_start"), fn("_start"), vit_encoding, fn("_end"), fn("_end"), fn("_end")], dim = 1) - n, seq_len, dim = vit_tokens.shape - vit_tokens = vit_tokens.reshape(n * seq_len, dim) - # should dynamically change in model logic - joint_image = torch.cat([fn(""), fn("", special_fn), fn("", special_fn), fn("", special_fn), vae_tokens, vit_tokens, fn("")], dim = 1) + vae_mask = torch.ones(joint_image.size(1)) + vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(2) - seq_len_total = joint_image.shape[1] - mask = torch.zeros(seq_len_total, dtype=torch.bool, device=joint_image.device) - positions = {} - current = 4 - - def mark_region(name, tensor): - nonlocal current - start = current - current += tensor.shape[1] - end = current - 1 - positions[f"<{name}>_start"] = start - positions[f"<{name}>_end"] = end - mask[start:end + 1] = True - return start, end - - mark_region("vae_img", vae_tokens) - - mask_list = [] - for prefix in ["text", "vae_img", "vit_img"]: - start = positions[f"<{prefix}>_start"] - end = positions[f"<{prefix}>_end"] - - section_mask = torch.arange(start, end + 1, device=mask.device) - mask_list.append(section_mask) - - mask_list.insert(0, joint_image) - mask_list.append(text_tokens) - ragged_tensors = torch.nested.nested_tensor(mask_list, dtype=torch.long) + ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.unsqueeze(-1).to(joint_image.dtype)]) + uncond_ragged_tensors = None if text_encoding_negative is not None: - uncond_ragged_tensors = cls.execute(vae_encoding, vit_encoding, text_encoding_negative, clip=clip, text_encoding_negative = None) + uncond_ragged_tensors, _ = cls.execute(vae_encoding, vit_encoding, text_encoding_negative, clip=clip, text_encoding_negative = None) else: uncond_ragged_tensors = torch.nested.nested_tensor([torch.zeros_like(t) for t in ragged_tensors.unbind()]) - return ragged_tensors, uncond_ragged_tensors + if uncond_ragged_tensors is not None: + positive = [[ragged_tensors, {}]] + negative = [[uncond_ragged_tensors, {}]] + else: + positive = ragged_tensors + negative = uncond_ragged_tensors + + return positive, negative class Image3Extension(ComfyExtension): @override From ca119c44fb68d1bb694fdbf7a83d97c9095b8968 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 1 Nov 2025 23:06:11 +0200 Subject: [PATCH 05/16] returned kv cache for image generation --- comfy/ldm/hunyuan_image_3/model.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index ba2c1e90c..682fd8781 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -832,6 +832,7 @@ class HunyuanImage3Attention(nn.Module): def forward( self, hidden_states: torch.Tensor, + past_key_value, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, @@ -858,6 +859,11 @@ class HunyuanImage3Attention(nn.Module): query_states = query_states.to(value_states.dtype) key_states = key_states.to(value_states.dtype) + if past_key_value is not None: + cache_kwargs = {"cache_position": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + query_states = query_states.to(key_states.dtype) + key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups) value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups) @@ -870,7 +876,7 @@ class HunyuanImage3Attention(nn.Module): attn_output = self.o_proj(attn_output) - return attn_output + return attn_output, past_key_value class HunyuanImage3DecoderLayer(nn.Module): def __init__(self, config, layer_idx: int, moe_lru=None): @@ -900,7 +906,7 @@ class HunyuanImage3DecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states = self.self_attn( + hidden_states, past_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -917,7 +923,7 @@ class HunyuanImage3DecoderLayer(nn.Module): hidden_states = residual + hidden_states - outputs = (hidden_states,) + outputs = (hidden_states, past_key_value) return outputs @@ -1039,6 +1045,9 @@ class HunyuanImage3ForCausalMM(nn.Module): def forward(self, x, condition, timestep, **kwargs): + cond, uncond = condition[:4], condition[4:] + joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] + if self.kv_cache is None: # TODO: should change when higgsv2 gets merged self.kv_cache = HunyuanStaticCache( @@ -1049,10 +1058,8 @@ class HunyuanImage3ForCausalMM(nn.Module): ) image_mask = torch.ones(x.size(1)) - image_mask[:, :5] = torch.zeros(5); image_mask[:, -4:] = torch.zeros(4) + image_mask[:, :3] = torch.zeros(5); image_mask[:, -1] = torch.zeros(0) gen_timestep_scatter_index = 4 - cond, uncond = condition[:4], condition[4:] - joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] joint_image[:, 2] = x[:, 2] # updates image ratio position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) @@ -1126,6 +1133,11 @@ class HunyuanImage3ForCausalMM(nn.Module): ) hidden_states = outputs[0] + # safety no-op + past_key_value = outputs[1] + if past_key_value is not None: + self.kv_cache = past_key_value + hidden_states = hidden_states.to(input_ids.device) diffusion_prediction = self.ragged_final_layer( hidden_states, image_mask, timestep, token_h, token_w, self.first_step) From 9e9c536c8ead6e9061b4be723e8fae07c8c30c9e Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 4 Nov 2025 23:55:16 +0200 Subject: [PATCH 06/16] fixes from testing --- comfy/ldm/hunyuan_image_3/model.py | 25 +++++++++++++------------ comfy/model_detection.py | 1 + comfy_extras/nodes_hunyuan_image.py | 23 +++++++++++++++++------ 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 682fd8781..6ccfef697 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -337,7 +337,7 @@ class UNetDown(nn.Module): if self.patch_size == 1: self.model.append(ResBlock( - in_channels=hidden_channels, + channels=hidden_channels, emb_channels=emb_channels, out_channels=out_channels, dropout=dropout, @@ -346,7 +346,7 @@ class UNetDown(nn.Module): else: for i in range(self.patch_size // 2): self.model.append(ResBlock( - in_channels=hidden_channels, + channels=hidden_channels, emb_channels=emb_channels, out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels, dropout=dropout, @@ -381,7 +381,7 @@ class UNetUp(nn.Module): if self.patch_size == 1: self.model.append(ResBlock( - in_channels=in_channels, + channels=in_channels, emb_channels=emb_channels, out_channels=hidden_channels, dropout=dropout, @@ -390,7 +390,7 @@ class UNetUp(nn.Module): else: for i in range(self.patch_size // 2): self.model.append(ResBlock( - in_channels=in_channels if i == 0 else hidden_channels, + channels=in_channels if i == 0 else hidden_channels, emb_channels=emb_channels, out_channels=hidden_channels, dropout=dropout, @@ -929,7 +929,7 @@ class HunyuanImage3DecoderLayer(nn.Module): class HunyuanImage3Model(nn.Module): def __init__(self, config, moe_lru=None): - super().__init__(config) + super().__init__() self.padding_idx = 128009 self.vocab_size = 133120 self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) @@ -989,12 +989,12 @@ class HunyuanImage3Model(nn.Module): class HunyuanImage3ForCausalMM(nn.Module): def __init__(self, config): - super().__init__(config) + super().__init__() self.config = config self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"]) self.patch_embed = UNetDown( - patch_size=16, + patch_size=1, emb_channels=config["hidden_size"], in_channels=32, hidden_channels=1024, @@ -1003,7 +1003,7 @@ class HunyuanImage3ForCausalMM(nn.Module): self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"]) self.final_layer = UNetUp( - patch_size=16, + patch_size=1, emb_channels=config["hidden_size"], in_channels=config["hidden_size"], hidden_channels=1024, @@ -1045,8 +1045,7 @@ class HunyuanImage3ForCausalMM(nn.Module): def forward(self, x, condition, timestep, **kwargs): - cond, uncond = condition[:4], condition[4:] - joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] + joint_image, cond_vae_image_mask, input_ids, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() if self.kv_cache is None: # TODO: should change when higgsv2 gets merged @@ -1058,9 +1057,11 @@ class HunyuanImage3ForCausalMM(nn.Module): ) image_mask = torch.ones(x.size(1)) - image_mask[:, :3] = torch.zeros(5); image_mask[:, -1] = torch.zeros(0) + image_mask[:3] = torch.zeros(3); image_mask[-1] = torch.zeros(1) gen_timestep_scatter_index = 4 - joint_image[:, 2] = x[:, 2] # updates image ratio + + with torch.no_grad(): + joint_image[:, 2, 0] = x[:, 2, 0, 0] # updates image ratio position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) height, width = x.shape[2] * 16, x.shape[3] * 16 diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4669eb14b..816aed169 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -490,6 +490,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["max_position_embeddings"] = 12800 dit_config["num_attention_heads"] = 32 dit_config['rms_norm_eps'] = 1e-05 + dit_config["num_hidden_layers"] = 32 return dit_config if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py index ba69c0978..1e748669e 100644 --- a/comfy_extras/nodes_hunyuan_image.py +++ b/comfy_extras/nodes_hunyuan_image.py @@ -30,12 +30,18 @@ class EmptyLatentHunyuanImage3(io.ComfyNode): def execute(cls, height, width, batch_size, clip): encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids special_fn = clip.tokenizer.tokenizer.added_tokens_encoder - def fn(string, func = encode_fn): - return torch.tensor(func(string), device=comfy.model_management.intermediate_device()).unsqueeze(0) + word_embed = clip.tokenizer.wte + + hidden_size = word_embed.weight.shape[1] height, width = get_target_size(height, width) - latent = torch.randn(batch_size, 32, height // 16, width // 16, device=comfy.model_management.intermediate_device()) - latent = torch.cat([fn(""), fn("", special_fn), fn(f"", special_fn), latent, fn("")], dim = 1) + latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device()) + + def fn(string, func = encode_fn): + return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ + .view(1, hidden_size, 1, 1).expand(batch_size, hidden_size, int(height) // 16, int(width) // 16) + + latent = torch.cat([fn(""), fn("", func = special_fn), fn(f"", special_fn), fn("", special_fn), latent, fn("")], dim = 1) return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, ) class HunyuanImage3Conditioning(io.ComfyNode): @@ -59,15 +65,20 @@ class HunyuanImage3Conditioning(io.ComfyNode): def execute(cls, vae_encoding, vit_encoding, text_encoding, clip, text_encoding_negative=None): encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids special_fn = clip.tokenizer.tokenizer.added_tokens_encoder + + word_embed = clip.tokenizer.wte + batch_size, _, hidden_size = vae_encoding.shape + def fn(string, func = encode_fn): - return torch.tensor(func(string), device=text_encoding.device).unsqueeze(0) + return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ + .view(1, hidden_size, 1, 1).view(1, 1, hidden_size).expand(batch_size, -1, hidden_size) text_tokens = text_encoding[0][0] # should dynamically change in model logic joint_image = torch.cat([fn(""), fn("", special_fn), fn("", special_fn), fn("", special_fn), vae_encoding, fn(""), vit_encoding, fn("")], dim = 1) vae_mask = torch.ones(joint_image.size(1)) - vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(2) + vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:])) ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.unsqueeze(-1).to(joint_image.dtype)]) From 5056a1f4d4107e4f5f05000beebdeca7c6dc9137 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 6 Nov 2025 00:24:49 +0200 Subject: [PATCH 07/16] important fixes --- comfy/ldm/hunyuan_image_3/model.py | 45 ++++++++++++++--------------- comfy_extras/nodes_hunyuan_image.py | 17 +++++++---- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 6ccfef697..dca38b20e 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -1045,51 +1045,59 @@ class HunyuanImage3ForCausalMM(nn.Module): def forward(self, x, condition, timestep, **kwargs): - joint_image, cond_vae_image_mask, input_ids, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() + joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() if self.kv_cache is None: # TODO: should change when higgsv2 gets merged self.kv_cache = HunyuanStaticCache( config=self.config, batch_size=x.size(0) * 2, - max_cache_len = input_ids.shape[1], + max_cache_len = inputs_embeds.shape[1], dtype=x.dtype, ) - image_mask = torch.ones(x.size(1)) + image_mask = torch.ones(x.size(1), device=x.device) image_mask[:3] = torch.zeros(3); image_mask[-1] = torch.zeros(1) gen_timestep_scatter_index = 4 with torch.no_grad(): - joint_image[:, 2, 0] = x[:, 2, 0, 0] # updates image ratio + joint_image[:, 2, :] = x[:, 2, :] # updates image ratio - position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) + position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) height, width = x.shape[2] * 16, x.shape[3] * 16 token_height = height // (16 * 16) token_width = width // (16 * 16) - rope_image_info = [[(None, (token_height, token_width))] * 2] - seq_len = input_ids.shape[1] + batch_image_slices = [] + for i in range(x.size(0)): + # slice the vae and vit parts + slice the latent from x + joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)] + gen_slices_i = [slice(3, x[i].size(1) - 1)] + batch_image_slices.append(joint_slices_i + gen_slices_i) + + rope_image_info = [ + [(s, (token_height, token_width)) for s in slices_i] + for slices_i in batch_image_slices + ] + seq_len = inputs_embeds.shape[1] cos, sin = build_batch_2d_rope( image_infos=rope_image_info, seq_len=seq_len, n_elem=self.config["hidden_size"] // self.config["num_attention_heads"], base=10000.0, ) - custom_pos_emb = (sin, cos) + custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device)) custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) - inputs_embeds = self.model.wte(input_ids) cond_timestep = torch.zeros(inputs_embeds.size(0)) t_emb = self.time_embed(cond_timestep) bsz, seq_len, n_embd = inputs_embeds.shape + # FIXME: token_h and token_w for the first step if self.first_step: - t_emb = self.time_embed(timestep) - x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb) - x[:, gen_timestep_scatter_index] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) else: t_emb = self.time_embed(timestep) x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb) @@ -1103,20 +1111,9 @@ class HunyuanImage3ForCausalMM(nn.Module): # cond_timestep_scatter_index joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) - # conditioning images (vae) - joint_image[:, 3:cond_vae_image_mask.size(0)+3], token_h, token_w = self.patch_embed( - joint_image[:, 3:cond_vae_image_mask.size(0)+3], self.time_embed(cond_timestep) - ) inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1) - batch_image_slices = [] - for i in range(x.size(0)): - # slice the vae and vit parts + slice the latent from x - joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)] - gen_slices_i = [slice(3, x[i].size(1) - 1)] - batch_image_slices.append(joint_slices_i + gen_slices_i) - attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) for i in range(bsz): for _, image_slice in enumerate(batch_image_slices[i]): @@ -1139,7 +1136,7 @@ class HunyuanImage3ForCausalMM(nn.Module): if past_key_value is not None: self.kv_cache = past_key_value - hidden_states = hidden_states.to(input_ids.device) + hidden_states = hidden_states.to(inputs_embeds.device) diffusion_prediction = self.ragged_final_layer( hidden_states, image_mask, timestep, token_h, token_w, self.first_step) diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py index 1e748669e..7da2e5718 100644 --- a/comfy_extras/nodes_hunyuan_image.py +++ b/comfy_extras/nodes_hunyuan_image.py @@ -30,16 +30,20 @@ class EmptyLatentHunyuanImage3(io.ComfyNode): def execute(cls, height, width, batch_size, clip): encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids special_fn = clip.tokenizer.tokenizer.added_tokens_encoder - word_embed = clip.tokenizer.wte - hidden_size = word_embed.weight.shape[1] + # may convert clip.tokenizer -> clip. + word_embed = clip.tokenizer.wte + patch_embed = clip.tokenizer.patch_embed + t_embed = clip.tokenizer.time_embed height, width = get_target_size(height, width) latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device()) + + latent, _, _ = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size))) def fn(string, func = encode_fn): return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ - .view(1, hidden_size, 1, 1).expand(batch_size, hidden_size, int(height) // 16, int(width) // 16) + .unsqueeze(0).expand(batch_size, -1, -1) latent = torch.cat([fn(""), fn("", func = special_fn), fn(f"", special_fn), fn("", special_fn), latent, fn("")], dim = 1) return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, ) @@ -67,13 +71,16 @@ class HunyuanImage3Conditioning(io.ComfyNode): special_fn = clip.tokenizer.tokenizer.added_tokens_encoder word_embed = clip.tokenizer.wte - batch_size, _, hidden_size = vae_encoding.shape + patch_embed = clip.tokenizer.patch_embed + t_embed = clip.tokenizer.time_embed + batch_size, _, hidden_size = vit_encoding.shape def fn(string, func = encode_fn): return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ - .view(1, hidden_size, 1, 1).view(1, 1, hidden_size).expand(batch_size, -1, hidden_size) + .view(1, 1, hidden_size).expand(batch_size, -1, hidden_size) text_tokens = text_encoding[0][0] + vae_encoding, _, _ = patch_embed(vae_encoding, t_embed(torch.tensor([0]).repeat(vae_encoding.size(0)))) # should dynamically change in model logic joint_image = torch.cat([fn(""), fn("", special_fn), fn("", special_fn), fn("", special_fn), vae_encoding, fn(""), vit_encoding, fn("")], dim = 1) From 44346c42519ed217958b17486462078970085fd3 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 8 Nov 2025 19:49:02 +0200 Subject: [PATCH 08/16] removed all errors --- comfy/ldm/hunyuan_image_3/model.py | 145 ++++++++++-------- .../modules/diffusionmodules/openaimodel.py | 6 +- comfy_extras/nodes_hunyuan_image.py | 8 +- 3 files changed, 89 insertions(+), 70 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index dca38b20e..3cbca46cd 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -733,7 +733,7 @@ class HunyuanMoE(nn.Module): expert = LazyMoELoader() expert = expert.lazy_init(self.config, self.layer_idx, e) self.moe_lru.add_gpu(expert, e + self.layer_idx) - experts_list.append((e, expert)) + experts_list.append((e, expert)) per_pos, per_tokens, per_weights = [], [], [] for e, _ in experts_list: @@ -773,7 +773,8 @@ class HunyuanMoE(nn.Module): x = torch.bmm(tokens_padded, W1_T) x = F.silu(x) - out_padded = torch.bmm(x, W2_T) + x1, x2 = x.chunk(2, dim=2) + out_padded = torch.bmm(x1 * F.silu(x2), W2_T) out_padded = out_padded * weights_padded.unsqueeze(-1) @@ -1025,6 +1026,7 @@ class HunyuanImage3ForCausalMM(nn.Module): self.first_step = True self.kv_cache = None + self.token_dims = () @staticmethod def get_pos_emb(custom_pos_emb, position_ids): @@ -1047,6 +1049,76 @@ class HunyuanImage3ForCausalMM(nn.Module): joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() + gen_timestep_scatter_index = 4 + + with torch.no_grad(): + joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio + + if self.first_step: + token_height, token_width = x[:, -2:, 0].tolist()[0] + self.token_dims = (int(token_height), int(token_width)) + x = x[:, :-2, :] + else: + token_height, token_width = self.token_dims + + img_slices = [] + + for i in range(x.size(0)): + vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0] + vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1 + + vit_start = vae_end + 1 + vit_end = joint_image.size(1) - 1 + + joint_slices_i = [ + slice(vae_start, vae_end), + slice(vit_start, vit_end), + ] + gen_slices_i = [slice(3 + vit_end, x[i].size(0) - 1 + vit_end)] + img_slices.append(joint_slices_i + gen_slices_i) + + img_s = img_slices[0] + rope_image_info = [[(img_s[0], (384 // 16, 384 // 16)), (img_s[1], (256 // 16, 256 // 16)), (img_s[2], (token_height, token_width))]] + + cond_timestep = torch.zeros(inputs_embeds.size(0)) + t_emb = self.time_embed(cond_timestep) + + bsz, seq_len, n_embd = inputs_embeds.shape + + if self.first_step: + x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + else: + t_emb = self.time_embed(timestep) + x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb) + timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) + x = torch.cat([timestep_emb, x], dim=1) + + #///////////// + # cond_vae_images + + # cond_timestep_scatter_index + with torch.no_grad(): + joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + + inputs_embeds = torch.cat([inputs_embeds, joint_image, x], dim = 1) + + attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) + for i in range(bsz): + for _, image_slice in enumerate(img_slices[i]): + attention_mask[i, image_slice, image_slice] = True + attention_mask = attention_mask.unsqueeze(1) + + # pos embed + position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) + cos, sin = build_batch_2d_rope( + image_infos=rope_image_info, + seq_len=inputs_embeds.shape[1], + n_elem=self.config["hidden_size"] // self.config["num_attention_heads"], + base=10000.0, + ) + custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device)) + custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) + if self.kv_cache is None: # TODO: should change when higgsv2 gets merged self.kv_cache = HunyuanStaticCache( @@ -1056,70 +1128,6 @@ class HunyuanImage3ForCausalMM(nn.Module): dtype=x.dtype, ) - image_mask = torch.ones(x.size(1), device=x.device) - image_mask[:3] = torch.zeros(3); image_mask[-1] = torch.zeros(1) - gen_timestep_scatter_index = 4 - - with torch.no_grad(): - joint_image[:, 2, :] = x[:, 2, :] # updates image ratio - - position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) - height, width = x.shape[2] * 16, x.shape[3] * 16 - token_height = height // (16 * 16) - token_width = width // (16 * 16) - - batch_image_slices = [] - for i in range(x.size(0)): - # slice the vae and vit parts + slice the latent from x - joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)] - gen_slices_i = [slice(3, x[i].size(1) - 1)] - batch_image_slices.append(joint_slices_i + gen_slices_i) - - rope_image_info = [ - [(s, (token_height, token_width)) for s in slices_i] - for slices_i in batch_image_slices - ] - seq_len = inputs_embeds.shape[1] - cos, sin = build_batch_2d_rope( - image_infos=rope_image_info, - seq_len=seq_len, - n_elem=self.config["hidden_size"] // self.config["num_attention_heads"], - base=10000.0, - ) - custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device)) - - custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) - - cond_timestep = torch.zeros(inputs_embeds.size(0)) - t_emb = self.time_embed(cond_timestep) - - bsz, seq_len, n_embd = inputs_embeds.shape - - # FIXME: token_h and token_w for the first step - if self.first_step: - x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) - else: - t_emb = self.time_embed(timestep) - x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb) - timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) - x = torch.cat([timestep_emb, x], dim=1) - - inputs_embeds = torch.cat([inputs_embeds, x], dim = 1) - - #///////////// - # cond_vae_images - - # cond_timestep_scatter_index - joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) - - inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1) - - attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) - for i in range(bsz): - for _, image_slice in enumerate(batch_image_slices[i]): - attention_mask[i, image_slice, image_slice] = True - attention_mask = attention_mask.unsqueeze(1) - outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, @@ -1137,8 +1145,11 @@ class HunyuanImage3ForCausalMM(nn.Module): self.kv_cache = past_key_value hidden_states = hidden_states.to(inputs_embeds.device) + img_mask = torch.zeros(hidden_states.size(1)) + img_mask[-x.size(1)+4:] = 1; img_mask[-1] = 0 + diffusion_prediction = self.ragged_final_layer( - hidden_states, image_mask, timestep, token_h, token_w, self.first_step) + hidden_states, img_mask, timestep, int(token_height), int(token_width), self.first_step) if self.first_step: self.first_step = False diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c8d53cac..510cb9da7 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -268,7 +268,11 @@ class ResBlock(TimestepBlock): if emb_out is not None: if self.exchange_temb_dims: emb_out = emb_out.movedim(1, 2) - h = h + emb_out + try: + h = h + emb_out + except: + emb_out = emb_out.movedim(1, 2) + h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py index 7da2e5718..012ac8a08 100644 --- a/comfy_extras/nodes_hunyuan_image.py +++ b/comfy_extras/nodes_hunyuan_image.py @@ -39,13 +39,17 @@ class EmptyLatentHunyuanImage3(io.ComfyNode): height, width = get_target_size(height, width) latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device()) - latent, _, _ = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size))) + latent, tk_height, tk_width = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size))) + + def tk_fn(token): + return torch.tensor([token], device = latent.device, dtype = latent.dtype).unsqueeze(1).expand(batch_size, 1, latent.size(-1)) def fn(string, func = encode_fn): return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ .unsqueeze(0).expand(batch_size, -1, -1) latent = torch.cat([fn(""), fn("", func = special_fn), fn(f"", special_fn), fn("", special_fn), latent, fn("")], dim = 1) + latent = torch.cat([latent, tk_fn(tk_height), tk_fn(tk_width)], dim = 1) return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, ) class HunyuanImage3Conditioning(io.ComfyNode): @@ -87,7 +91,7 @@ class HunyuanImage3Conditioning(io.ComfyNode): vae_mask = torch.ones(joint_image.size(1)) vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:])) - ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.unsqueeze(-1).to(joint_image.dtype)]) + ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.to(joint_image.dtype)]) uncond_ragged_tensors = None if text_encoding_negative is not None: From 7b4c1e80312e0da4ea47e85d8e4f1feee1e716c7 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 14 Nov 2025 09:15:16 +0200 Subject: [PATCH 09/16] async cache revamp Added an async loading and offloading of moe layers, having consistent memory with oom errors. Used to give oom error after the third layer with 24 giga bytes gpu, now goes to the end with consistent memory with minimal latency --- comfy/ldm/hunyuan_image_3/model.py | 372 ++++++++++++----------------- 1 file changed, 158 insertions(+), 214 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 3cbca46cd..9682a270f 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -1,15 +1,17 @@ import os -import gc import math +import time import torch import psutil +import asyncio +import threading import torch.nn as nn from pathlib import Path +import concurrent.futures from einops import rearrange import torch.nn.functional as F from collections import OrderedDict from safetensors import safe_open -from contextlib import contextmanager from transformers.cache_utils import StaticCache from typing import Optional, Tuple, Any, List, Dict from comfy.ldm.modules.attention import optimized_attention @@ -19,13 +21,13 @@ INIT_MOE = torch.cuda.device_count() != 1 if not INIT_MOE: MOE_LAYER_SIZE = (1024**3) * 2.65 # approx - CPU_MOE_RATIO = None torch.cuda.set_device(0) props = torch.cuda.get_device_properties(0) - INIT_CUDA_MEM = (props.total_memory - torch.cuda.memory_reserved()) * 0.9 - ADDITIONAL_LAYERS_IN_GPU = math.floor(INIT_CUDA_MEM / MOE_LAYER_SIZE) + LAYERS_IN_CPU = math.floor((int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')) + - psutil.Process(os.getpid()).memory_info().rss + - (2*1024**3)) * 0.50) / MOE_LAYER_SIZE) class HunyuanStaticCache(StaticCache): @@ -286,20 +288,15 @@ def topkgating( logits = logits.float() gates = F.softmax(logits, dim=1) - extra = ADDITIONAL_LAYERS_IN_GPU - - values_all, indices_all = torch.topk(gates, topk + extra, dim=1) + values_all, indices_all = torch.topk(gates, topk, dim=1) expert_weight = values_all[:, :topk] expert_index = indices_all[:, :topk] - _, cpu_expert_index = torch.topk(gates, int(CPU_MOE_RATIO * 64), dim = 1) - cpu_expert_index = cpu_expert_index[:, (8 + ADDITIONAL_LAYERS_IN_GPU):] - if norm_topk_prob and topk > 1: denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps) expert_weight = expert_weight / denom - return expert_weight, expert_index, cpu_expert_index, indices_all + return expert_weight, expert_index class HunyuanRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -452,7 +449,7 @@ class HunyuanTopKGate(nn.Module): return gate_output class HunyuanMLP(nn.Module): - def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False): + def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -462,8 +459,8 @@ class HunyuanMLP(nn.Module): self.act_fn = torch.nn.functional.silu self.intermediate_size *= 2 # SwiGLU - self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False) + self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device) + self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device) def forward(self, x): self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) if x.ndim == 2: @@ -474,204 +471,143 @@ class HunyuanMLP(nn.Module): return down_proj class MoELRUCache(nn.Module): - def __init__(self, cpu_mem: int = 50, safety_buffer_bytes = 3*(1024**3), max_gpu_eviction_attempts = 8): + def __init__(self): super().__init__() - global CPU_MOE_RATIO - - _, total = torch.cuda.mem_get_info() - max_gpu_mem_gb = max((total - 2 * safety_buffer_bytes) / (1024**3), 1) - - self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3) - self.MAX_CPU_MEM = int(cpu_mem * 1024**3) self.gpu_cache = OrderedDict() self.cpu_cache = OrderedDict() + self.offload_stream = torch.cuda.Stream() + self.load_stream = torch.cuda.Stream() - self.gpu_mem_usage = 0 - self.cpu_mem_usage = 0 - # 50% for system and headroom - try: - self.MAX_CPU_MEM = int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')) - - psutil.Process(os.getpid()).memory_info().rss - - safety_buffer_bytes) * 0.55 - except: - self.MAX_CPU_MEM = int(cpu_mem * (1024**3) * 0.5) # TODO + self.last_offload_event = None + self._loop = asyncio.new_event_loop() + threading.Thread(target=self._loop.run_forever, daemon=True).start() - ADDITIONAL_LAYERS_IN_CPU = math.floor((50 * (1024**3)) / MOE_LAYER_SIZE) - CPU_MOE_RATIO = (min(64 - ADDITIONAL_LAYERS_IN_GPU, ADDITIONAL_LAYERS_IN_CPU)) / 64 + async def _async_offload_to_cpu(self, layer_idx): + # async offload from gpu (removed) - self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3) - self.SAFETY_BUFFER = int(safety_buffer_bytes) - self.MAX_GPU_EVICT_ATTEMPTS = max_gpu_eviction_attempts + num_experts = 64 + moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i]) + for i in range(num_experts) + if (layer_idx * num_experts + i) in self.gpu_cache] + event = torch.cuda.Event() - def _gpu_free_bytes(self): - free, total = torch.cuda.mem_get_info() - return int(free) - - def _estimate_size(self, moe): - # include parameters + buffers - size = 0 - for p in moe.parameters(): - size += p.numel() * p.element_size() - for b in moe.buffers(): - size += b.numel() * b.element_size() - return int(size) + with torch.cuda.stream(self.offload_stream): + for index, moe in moe_group: + moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True) + for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()): + if p_gpu.device.type == "meta": + continue + with torch.no_grad(): + p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True) + p_cpu.copy_(p_gpu, non_blocking=True) - def _evict_until_free(self, required_bytes, max_attempts=16): - attempts = 0 - while self._gpu_free_bytes() < required_bytes and attempts < max_attempts: - evicted = self._evict_from_gpu() - if not evicted: - break - attempts += 1 - return self._gpu_free_bytes() >= required_bytes + self.cpu_cache[index] = moe_cpu - @contextmanager - def ensure_headroom(self, required_bytes): + self.offload_stream.record_event(event) - safety = getattr(self, "SAFETY_BUFFER", 0) - target_free = int(required_bytes + safety) + self.last_offload_event = event - if getattr(self, "_headroom", None) is not None: - try: - del self._headroom - except Exception: - pass - self._headroom = None + def finalize_offload_layer(): + event.synchronize() + for index, moe in moe_group: + moe.to("meta") + self.gpu_cache.pop(index, None) + del moe + torch.cuda.empty_cache() - ok = self._evict_until_free(target_free) - if not ok and self._gpu_free_bytes() < target_free: - # last ditch - try: - torch.cuda.empty_cache() - except Exception: - pass + threading.Thread(target=finalize_offload_layer, daemon=True).start() - try: - yield - finally: - if getattr(self, "_headroom", None) is None: - try: - self._headroom = torch.empty((self._headroom_bytes,), dtype=torch.uint8, device="cuda:0") - except Exception: - self._headroom = None + async def _async_load_to_gpu(self, index, moe): - def add_gpu(self, moe, index, allowed_retries=3): - size = self._estimate_size(moe) - - while self.gpu_mem_usage + size > self.MAX_GPU_MEM: - if not self._evict_from_gpu(): + # if enough memory load, otherwise wait for offload + while True: + free_bytes, _ = torch.cuda.mem_get_info() + if free_bytes > 2 * MOE_LAYER_SIZE: break - attempts = 0 - while self._gpu_free_bytes() < size + self.SAFETY_BUFFER and attempts < self.MAX_GPU_EVICT_ATTEMPTS: - if not self._evict_from_gpu(): - break - attempts += 1 + self.last_offload_event.synchronize() + torch.cuda.empty_cache() + await asyncio.sleep(0.01) - for _ in range(allowed_retries): - try: - moe_cuda = moe.to("cuda:0") - break - except RuntimeError as e: - if "out of memory" not in str(e).lower(): - raise - evicted = self._evict_from_gpu() - if not evicted: # can't evict - raise - else: - raise RuntimeError("Failed to move expert to GPU after evictions") + # async loading from cpu -> gpu + with torch.cuda.stream(self.load_stream): + moe_gpu = HunyuanMLP(moe.config).to("cuda", non_blocking=True) + for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()): + with torch.no_grad(): + p_gpu.data = torch.empty_like(p_cpu, device="cuda") + p_gpu.copy_(p_cpu, non_blocking=True) - self.gpu_cache[index] = moe_cuda - self.gpu_cache.move_to_end(index) - self.gpu_mem_usage += size + def finalize_load(): + self.gpu_cache[index] = moe_gpu + self.cpu_cache.pop(index, None) - return + threading.Thread(target=finalize_load, daemon=True).start() def add_cpu(self, moe, index): - size = self._estimate_size(moe) - while self.cpu_mem_usage + size > self.MAX_CPU_MEM: - if not self._evict_from_cpu(): - break moe_cpu = moe.to("cpu") + + for _, p in moe_cpu.named_parameters(): + if not p.is_pinned(): + p.data = p.data.pin_memory() + self.cpu_cache[index] = moe_cpu self.cpu_cache.move_to_end(index) - self.cpu_mem_usage += size - - def get_from_device(self, index): - if index in self.gpu_cache: - moe = self.gpu_cache[index] - self.gpu_cache.move_to_end(index) - return moe - if index in self.cpu_cache: - moe = self.cpu_cache.pop(index) - self.cpu_mem_usage = max(0, self.cpu_mem_usage - self._estimate_size(moe)) - try: - self.add_gpu(moe, index) - return self.gpu_cache[index] - except RuntimeError: - self.cpu_cache[index] = moe - self.cpu_cache.move_to_end(index) - self.cpu_mem_usage += self._estimate_size(moe) - raise - - return None # load from disk - - def _evict_from_gpu(self): - if not self.gpu_cache: - return False - - idx, moe = self.gpu_cache.popitem(last=False) - size = self._estimate_size(moe) - self.gpu_mem_usage = max(0, self.gpu_mem_usage - size) - - if self.cpu_mem_usage + size <= self.MAX_CPU_MEM: - try: - moe_cpu = moe.to("cpu") - except Exception: - # drop the model if cpu is full - del moe - return True - self.cpu_cache[idx] = moe_cpu - self.cpu_cache.move_to_end(idx) - self.cpu_mem_usage += size - return True - else: - del moe - return True - - def _evict_from_cpu(self): - if not self.cpu_cache: - return False - _, moe = self.cpu_cache.popitem(last=False) - size = self._estimate_size(moe) - self.cpu_mem_usage = max(0, self.cpu_mem_usage - size) - del moe - gc.collect() - return True class LazyMoELoader(nn.Module): - def __init__(self, device): + def __init__(self, cache, config): super().__init__() - self.device = device + self.cache = cache + self.config = config + self._loop = cache._loop - def lazy_init(self, config, layer_idx, expert_idx): + def get_checkpoint(self): comfyui_dir = Path.home() / "ComfyUI" checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors" checkpoint = checkpoint.resolve() if not os.path.exists(checkpoint): raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}") + return checkpoint + def lazy_init(self, layer_idx, expert_idx): + checkpoint = self.get_checkpoint() prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}." additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight" sd = {} - with safe_open(checkpoint, framework="pt", device=self.device) as f: + with safe_open(checkpoint, framework="pt", device="cpu") as f: for k in f.keys(): if k.startswith(prefix) or k.startswith(additional_prefix): new_k = k.split(f"experts.{expert_idx}.", 1)[1] sd[new_k] = f.get_tensor(k) - return HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd).to(self.deivce) + return HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd) + + async def lazy_load_from_disk(self, layer_idx, expert_idx): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx) + + def _schedule_disk_load(self, layer_idx, expert_idx): + + coro = self.lazy_load_from_disk(layer_idx, expert_idx) + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + + def _on_disk_loaded(fut): + moe_cpu = fut.result() + def _add_cpu_in_main_thread(): + self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx) + + asyncio.run_coroutine_threadsafe( + self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu), + self.cache._loop + ) + threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start() + + future.add_done_callback(_on_disk_loaded) + return future + +def enough_vram(required_bytes): + free, total = torch.cuda.mem_get_info() + return free > required_bytes class HunyuanMoE(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None): @@ -687,7 +623,7 @@ class HunyuanMoE(nn.Module): [HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)] ) else: - self.experts = None + self.experts = [] self.moe_lru = moe_lru def forward(self, hidden_states): @@ -702,38 +638,10 @@ class HunyuanMoE(nn.Module): reshaped_input = hidden_states.reshape(-1, hidden_size) with torch.cuda.nvtx.range("MoE"): - expert_weight, expert_index, cpu_expert_index, indices_all = self.gate(hidden_states) - if not INIT_MOE: - if ADDITIONAL_LAYERS_IN_GPU > 0: - additional_expert_index = indices_all[:, expert_index.size(1): expert_index.size(1) + ADDITIONAL_LAYERS_IN_GPU] - - flat = additional_expert_index.reshape(-1).to("cpu") - counts = torch.bincount(flat, minlength=self.num_experts) - top_extra = torch.topk(counts, k=min(ADDITIONAL_LAYERS_IN_GPU, (counts>0).sum().item())).indices.tolist() - - for expert_id in top_extra: - if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None: - expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id) - self.moe_lru.add_gpu(expert_cpu, expert_id + self.layer_idx) - - if cpu_expert_index is not None and cpu_expert_index.numel() > 0: - for expert_id in torch.unique(cpu_expert_index).cpu().tolist(): - if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None: - expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id) - self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx) + expert_weight, expert_index = self.gate(hidden_states) combined_output = torch.zeros_like(reshaped_input) - experts_list = [] - for e in range(self.num_experts): - token_mask = (expert_index == e) - if not token_mask.any(): - continue - expert = self.moe_lru.get_from_device(e + self.layer_idx) - if expert is None: - expert = LazyMoELoader() - expert = expert.lazy_init(self.config, self.layer_idx, e) - self.moe_lru.add_gpu(expert, e + self.layer_idx) - experts_list.append((e, expert)) + experts_list = [(i, expert) for i, expert in enumerate(self.experts)] per_pos, per_tokens, per_weights = [], [], [] for e, _ in experts_list: @@ -761,6 +669,8 @@ class HunyuanMoE(nn.Module): l1, l2 = [], [] for _, expert in experts_list: + if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): + expert = expert.result() l1.append(expert.gate_and_up_proj) l2.append(expert.down_proj) @@ -769,6 +679,12 @@ class HunyuanMoE(nn.Module): W1_T = W1.transpose(1, 2) W2_T = W2.transpose(1, 2) + + # wait for enough vram for the computations + while not enough_vram(5*(1024 ** 3)): + event = self.moe_lru.last_offload_event + if event is not None and not event.query(): + time.sleep(0.001) x = torch.bmm(tokens_padded, W1_T) x = F.silu(x) @@ -781,7 +697,7 @@ class HunyuanMoE(nn.Module): for i, token_positions in enumerate(per_pos): Ni = lengths[i] out_i = out_padded[i, :Ni] - combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i) + combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype)) #dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) #chunks = dispatched_input.chunk(self.num_experts, dim=0) @@ -933,6 +849,7 @@ class HunyuanImage3Model(nn.Module): super().__init__() self.padding_idx = 128009 self.vocab_size = 133120 + self.config = config self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) self.layers = nn.ModuleList( [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])] @@ -941,6 +858,7 @@ class HunyuanImage3Model(nn.Module): self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) self.shared_tensor = None + self.moe_lru = moe_lru def forward( self, @@ -962,19 +880,45 @@ class HunyuanImage3Model(nn.Module): hidden_states = inputs_embeds next_decoder_cache = None - for layer_idx, decoder_layer in enumerate(self.layers): + next_layers = 0 + sparse_interval = max(1, len(self.layers) // 3) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - use_cache=use_cache, - custom_pos_emb=custom_pos_emb, - mode=mode, - first_step=first_step, - gen_timestep_scatter_index=gen_timestep_scatter_index, - ) + if len(self.layers[0].mlp.experts) == 0: + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)] + + for layer_idx, decoder_layer in enumerate(self.layers): + + if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)] + + if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval: + if len(self.layers[next_layers].mlp.experts) > 0: # for testing + raise ValueError("Problem with offloading") + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] + next_layers += 1 + + with torch.no_grad(): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + custom_pos_emb=custom_pos_emb, + mode=mode, + first_step=first_step, + gen_timestep_scatter_index=gen_timestep_scatter_index, + ) + + if layer_idx >= 0: + asyncio.run_coroutine_threadsafe( + self.moe_lru._async_offload_to_cpu(layer_idx), + self.moe_lru._loop + ) + self.layers[layer_idx].mlp.experts = [] hidden_states = layer_outputs[0] From 12cc6924ac2c954512232dd07b4dca0cd0f5fe8c Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 14 Nov 2025 20:10:52 +0200 Subject: [PATCH 10/16] meta init --- comfy/ldm/hunyuan_image_3/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 9682a270f..11eb29d9e 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -580,7 +580,11 @@ class LazyMoELoader(nn.Module): new_k = k.split(f"experts.{expert_idx}.", 1)[1] sd[new_k] = f.get_tensor(k) - return HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd) + model = HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True, device="meta") + model.to_empty(device = "cpu") + + model.load_state_dict(sd) + return model async def lazy_load_from_disk(self, layer_idx, expert_idx): loop = asyncio.get_event_loop() From d731c58353c69e14ecd400fd6f52ea1928b3fc0e Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 16 Nov 2025 16:19:39 +0200 Subject: [PATCH 11/16] improving performance and fixing race condition --- comfy/ldm/hunyuan_image_3/model.py | 169 ++++++++++++++++------------- 1 file changed, 92 insertions(+), 77 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 11eb29d9e..60ab43340 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -548,7 +548,10 @@ class MoELRUCache(nn.Module): for _, p in moe_cpu.named_parameters(): if not p.is_pinned(): - p.data = p.data.pin_memory() + if p.device.type == "cpu": + p.data = p.data.pin_memory() + else: + return self.cpu_cache[index] = moe_cpu self.cpu_cache.move_to_end(index) @@ -643,74 +646,81 @@ class HunyuanMoE(nn.Module): with torch.cuda.nvtx.range("MoE"): expert_weight, expert_index = self.gate(hidden_states) - - combined_output = torch.zeros_like(reshaped_input) - experts_list = [(i, expert) for i, expert in enumerate(self.experts)] - - per_pos, per_tokens, per_weights = [], [], [] - for e, _ in experts_list: + device = hidden_states.device + dtype = reshaped_input.dtype + + combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype) + + per_pos = [None] * self.num_experts + per_tokens = [None] * self.num_experts + per_weights = [None] * self.num_experts + for e in range(self.num_experts): token_mask = (expert_index == e) - token_ids = token_mask.nonzero(as_tuple=False) + if token_ids.numel() == 0: + continue token_positions = token_ids[:, 0] topk_slot = token_ids[:, 1] + per_pos[e] = token_positions + per_tokens[e] = reshaped_input[token_positions] + per_weights[e] = expert_weight[token_positions, topk_slot] + + used = [i for i, t in enumerate(per_tokens) if t is not None] + if len(used) == 0: + pass + else: + tokens_list = [per_tokens[i] for i in used] + weights_list = [per_weights[i] for i in used] + lengths = [t.shape[0] for t in tokens_list] + U = len(tokens_list) + L = max(lengths) + H = hidden_size + + tokens_padded = torch.zeros((U, L, H), device=device, dtype=dtype) + weights_padded = torch.zeros((U, L), device=device, dtype=weights_list[0].dtype) + for idx, t in enumerate(tokens_list): + n = t.shape[0] + tokens_padded[idx, :n] = t + weights_padded[idx, :n] = weights_list[idx] + + l1, l2 = [], [] + for i in used: + expert = self.experts[i] + if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): + expert = expert.result() + expert = expert.to(device) + l1.append(expert.gate_and_up_proj) + l2.append(expert.down_proj) - tokens = reshaped_input[token_positions] - weights = expert_weight[token_positions, topk_slot] + compute_device = hidden_states.device + l1 = [m.to(compute_device) for m in l1] + l2 = [m.to(compute_device) for m in l2] - per_pos.append(token_positions) - per_tokens.append(tokens) - per_weights.append(weights) - - lengths = [t.shape[0] for t in per_tokens] - E = len(per_tokens) - L = max(lengths) - tokens_padded = torch.zeros((E, L, hidden_size), device=hidden_states.device, dtype=reshaped_input.dtype) - weights_padded = torch.zeros((E, L), device=hidden_states.device, dtype=per_weights[0].dtype) - for i, t in enumerate(per_tokens): - tokens_padded[i, : t.shape[0]] = t - weights_padded[i, : t.shape[0]] = per_weights[i] - - l1, l2 = [], [] - for _, expert in experts_list: - if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): - expert = expert.result() - l1.append(expert.gate_and_up_proj) - l2.append(expert.down_proj) - - W1 = torch.stack([l.weight for l in l1]).to(hidden_states.device) - W2 = torch.stack([l.weight for l in l2]).to(hidden_states.device) - - W1_T = W1.transpose(1, 2) - W2_T = W2.transpose(1, 2) - - # wait for enough vram for the computations - while not enough_vram(5*(1024 ** 3)): - event = self.moe_lru.last_offload_event - if event is not None and not event.query(): - time.sleep(0.001) - - x = torch.bmm(tokens_padded, W1_T) - x = F.silu(x) - - x1, x2 = x.chunk(2, dim=2) - out_padded = torch.bmm(x1 * F.silu(x2), W2_T) - - out_padded = out_padded * weights_padded.unsqueeze(-1) - - for i, token_positions in enumerate(per_pos): - Ni = lengths[i] - out_i = out_padded[i, :Ni] - combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype)) - - #dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) - #chunks = dispatched_input.chunk(self.num_experts, dim=0) - #expert_outputs = [] - #for chunk, expert in zip(chunks, self.experts): - # expert_outputs.append(expert(chunk)) - - #expert_output = torch.cat(expert_outputs, dim=0) - #combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output) + W1 = torch.stack([m.weight for m in l1], dim=0) + W2 = torch.stack([m.weight for m in l2], dim=0) + + W1_T = W1.transpose(1, 2) + W2_T = W2.transpose(1, 2) + + while not enough_vram(3*(1024 ** 3)): + event = self.moe_lru.last_offload_event + if event is not None and not event.query(): + time.sleep(0.001) + + x = torch.bmm(tokens_padded, W1_T) + x = F.silu(x) + x1, x2 = x.chunk(2, dim=2) + out_padded = torch.bmm(x1 * F.silu(x2), W2_T) + + out_padded = out_padded * weights_padded.unsqueeze(-1) + + for idx, orig_expert_idx in enumerate(used): + pos = per_pos[orig_expert_idx] + n = lengths[idx] + out_i = out_padded[idx, :n] + combined_output.index_add_(0, pos.to(device), out_i.to(combined_output.dtype)) + + del tokens_padded, weights_padded, W1, W2, W1_T, W2_T, x, x1, x2, out_padded combined_output = combined_output.reshape(bsz, seq_len, hidden_size) @@ -863,6 +873,7 @@ class HunyuanImage3Model(nn.Module): self.shared_tensor = None self.moe_lru = moe_lru + self.self.additional_layers_set = False def forward( self, @@ -885,7 +896,8 @@ class HunyuanImage3Model(nn.Module): next_decoder_cache = None next_layers = 0 - sparse_interval = max(1, len(self.layers) // 3) + additional_layers = torch.cuda.mem_get_info()[0] // (MOE_LAYER_SIZE * 2) + sparse_interval = max(1, len(self.layers) // additional_layers) if len(self.layers[0].mlp.experts) == 0: experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] @@ -896,13 +908,12 @@ class HunyuanImage3Model(nn.Module): if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)] - - if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval: - if len(self.layers[next_layers].mlp.experts) > 0: # for testing - raise ValueError("Problem with offloading") - experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] - self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] - next_layers += 1 + + if not additional_layers_set: + if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval: + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] + next_layers += 1 with torch.no_grad(): layer_outputs = decoder_layer( @@ -918,11 +929,15 @@ class HunyuanImage3Model(nn.Module): ) if layer_idx >= 0: - asyncio.run_coroutine_threadsafe( - self.moe_lru._async_offload_to_cpu(layer_idx), - self.moe_lru._loop - ) - self.layers[layer_idx].mlp.experts = [] + if self.additional_layers_set and layer_idx <= self.additional_layers_set: + pass + else: + torch.cuda.synchronize() + asyncio.run_coroutine_threadsafe( + self.moe_lru._async_offload_to_cpu(layer_idx), + self.moe_lru._loop + ) + self.layers[layer_idx].mlp.experts = [] hidden_states = layer_outputs[0] @@ -932,7 +947,7 @@ class HunyuanImage3Model(nn.Module): next_cache = None if use_cache: next_cache = next_decoder_cache - + self.additional_layers_set = True return tuple(v for v in [hidden_states, next_cache] if v is not None) From 4a5509a4c5e0518242b3f7a30577e4e0e4c63687 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 16 Nov 2025 16:20:35 +0200 Subject: [PATCH 12/16] . --- comfy/ldm/hunyuan_image_3/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 60ab43340..6b292ed7f 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -909,7 +909,7 @@ class HunyuanImage3Model(nn.Module): experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)] - if not additional_layers_set: + if not self.additional_layers_set: if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval: experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] From 61b1efdaf078a75e51948033640d8bae82ff8537 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 16 Nov 2025 19:25:37 +0200 Subject: [PATCH 13/16] vectrozied correct implementation of moe forward --- comfy/ldm/hunyuan_image_3/model.py | 112 +++++++++++++++-------------- 1 file changed, 58 insertions(+), 54 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 6b292ed7f..8a01cd155 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -280,23 +280,43 @@ def conv_nd(dims, *args, **kwargs): def normalization(channels, **kwargs): return nn.GroupNorm(32, channels, **kwargs) -def topkgating( - logits: torch.Tensor, - topk: int, - norm_topk_prob: bool = True, -): +def topkgating(logits: torch.Tensor, topk: int): logits = logits.float() gates = F.softmax(logits, dim=1) - values_all, indices_all = torch.topk(gates, topk, dim=1) - expert_weight = values_all[:, :topk] - expert_index = indices_all[:, :topk] + num_experts = int(gates.shape[1]) - if norm_topk_prob and topk > 1: - denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps) - expert_weight = expert_weight / denom + _, expert_index = torch.topk(gates, topk) + expert_mask = F.one_hot(expert_index, num_experts) - return expert_weight, expert_index + expert_index_flat = expert_index.flatten() + tokens_per_expert = torch.bincount(expert_index_flat, minlength=num_experts) + expert_capacity = torch.max(tokens_per_expert).item() + + gates_s = torch.clamp( + torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps + ) + router_probs = gates / gates_s + + expert_index = torch.transpose(expert_index, 0, 1) + expert_index = expert_index.reshape(-1) + expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) + + token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1 + token_priority = token_priority.reshape((topk, -1, num_experts)) + token_priority = torch.transpose(token_priority, 0, 1) + + token_priority = torch.max(token_priority, dim=1)[0] + + valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) + token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) + valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity) + dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) + + combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) + + return combine_weights, dispatch_mask class HunyuanRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -436,15 +456,13 @@ class HunyuanTopKGate(nn.Module): num_experts = 64 self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32) - self.norm_topk_prob = True - def forward(self, hidden_states): bsz, seq_len, hidden_size = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_size) if self.wg.weight.dtype == torch.float32: hidden_states = hidden_states.float() logits = self.wg(hidden_states) - gate_output = topkgating(logits, self.moe_topk, norm_topk_prob=self.norm_topk_prob,) + gate_output = topkgating(logits, self.moe_topk) return gate_output @@ -645,46 +663,33 @@ class HunyuanMoE(nn.Module): reshaped_input = hidden_states.reshape(-1, hidden_size) with torch.cuda.nvtx.range("MoE"): - expert_weight, expert_index = self.gate(hidden_states) + combine_weights, dispatch_mask = self.gate(hidden_states) + + dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(reshaped_input), reshaped_input) device = hidden_states.device dtype = reshaped_input.dtype + + used_mask = (dispatch_mask.sum(dim=(0, 2)) > 0) + used_indices = used_mask.nonzero(as_tuple=False).squeeze(1).tolist() combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype) - per_pos = [None] * self.num_experts - per_tokens = [None] * self.num_experts - per_weights = [None] * self.num_experts - for e in range(self.num_experts): - token_mask = (expert_index == e) - token_ids = token_mask.nonzero(as_tuple=False) - if token_ids.numel() == 0: - continue - token_positions = token_ids[:, 0] - topk_slot = token_ids[:, 1] - per_pos[e] = token_positions - per_tokens[e] = reshaped_input[token_positions] - per_weights[e] = expert_weight[token_positions, topk_slot] - - used = [i for i, t in enumerate(per_tokens) if t is not None] - if len(used) == 0: + if len(used_indices) == 0: pass else: - tokens_list = [per_tokens[i] for i in used] - weights_list = [per_weights[i] for i in used] - lengths = [t.shape[0] for t in tokens_list] - U = len(tokens_list) - L = max(lengths) - H = hidden_size + tokens_padded = dispatched_input[used_indices] - tokens_padded = torch.zeros((U, L, H), device=device, dtype=dtype) - weights_padded = torch.zeros((U, L), device=device, dtype=weights_list[0].dtype) - for idx, t in enumerate(tokens_list): - n = t.shape[0] - tokens_padded[idx, :n] = t - weights_padded[idx, :n] = weights_list[idx] + l1_layers, l2_layers = [], [] + for i in used_indices: + expert = self.experts[i] + if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): + expert = expert.result() + expert = expert.to(device) + l1_layers.append(expert.gate_and_up_proj) + l2_layers.append(expert.down_proj) l1, l2 = [], [] - for i in used: + for i in used_indices: expert = self.experts[i] if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): expert = expert.result() @@ -708,19 +713,18 @@ class HunyuanMoE(nn.Module): time.sleep(0.001) x = torch.bmm(tokens_padded, W1_T) - x = F.silu(x) x1, x2 = x.chunk(2, dim=2) - out_padded = torch.bmm(x1 * F.silu(x2), W2_T) + gated = x1 * F.silu(x2) + out_padded = torch.bmm(gated, W2_T) - out_padded = out_padded * weights_padded.unsqueeze(-1) + combine_weights_used = combine_weights[:, used_indices, :] - for idx, orig_expert_idx in enumerate(used): - pos = per_pos[orig_expert_idx] - n = lengths[idx] - out_i = out_padded[idx, :n] - combined_output.index_add_(0, pos.to(device), out_i.to(combined_output.dtype)) + combined_output = torch.einsum("suc,ucm->sm", + combine_weights_used.type_as(out_padded), + out_padded + ) - del tokens_padded, weights_padded, W1, W2, W1_T, W2_T, x, x1, x2, out_padded + del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded combined_output = combined_output.reshape(bsz, seq_len, hidden_size) From 3f717609136439a1b5c55b84190ce75d18a9efed Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 17 Nov 2025 06:50:54 +0200 Subject: [PATCH 14/16] resblock fix --- comfy/ldm/hunyuan_image_3/model.py | 4 ++++ comfy/ldm/modules/diffusionmodules/openaimodel.py | 6 +----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 8a01cd155..62e09e65a 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -357,6 +357,7 @@ class UNetDown(nn.Module): channels=hidden_channels, emb_channels=emb_channels, out_channels=out_channels, + use_scale_shift_norm = True, dropout=dropout, **factory_kwargs )) @@ -365,6 +366,7 @@ class UNetDown(nn.Module): self.model.append(ResBlock( channels=hidden_channels, emb_channels=emb_channels, + use_scale_shift_norm = True, out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels, dropout=dropout, down=True, @@ -401,6 +403,7 @@ class UNetUp(nn.Module): channels=in_channels, emb_channels=emb_channels, out_channels=hidden_channels, + use_scale_shift_norm = True, dropout=dropout, **factory_kwargs )) @@ -410,6 +413,7 @@ class UNetUp(nn.Module): channels=in_channels if i == 0 else hidden_channels, emb_channels=emb_channels, out_channels=hidden_channels, + use_scale_shift_norm = True, dropout=dropout, up=True, **factory_kwargs diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 510cb9da7..4c8d53cac 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -268,11 +268,7 @@ class ResBlock(TimestepBlock): if emb_out is not None: if self.exchange_temb_dims: emb_out = emb_out.movedim(1, 2) - try: - h = h + emb_out - except: - emb_out = emb_out.movedim(1, 2) - h = h + emb_out + h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h From b84af5b947e1dd00b98a34ce86cf983455040d3e Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 17 Nov 2025 23:03:52 +0200 Subject: [PATCH 15/16] small attention fix --- comfy/ldm/hunyuan_image_3/model.py | 2 +- comfy/model_detection.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 62e09e65a..c82904fdc 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -745,7 +745,7 @@ class HunyuanImage3Attention(nn.Module): self.hidden_size = config["hidden_size"] self.num_heads = config["num_attention_heads"] - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = config["attention_head_dim"] self.num_key_value_heads = 8 self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config["max_position_embeddings"] diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 816aed169..246596167 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -491,6 +491,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_attention_heads"] = 32 dit_config['rms_norm_eps'] = 1e-05 dit_config["num_hidden_layers"] = 32 + dit_config["attention_head_dim"] = 128 return dit_config if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 From ea176bb87d34baaef2fef236b091dc726e9d7d8d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 20 Nov 2025 23:47:57 +0200 Subject: [PATCH 16/16] basic support for hunyuan model --- comfy/latent_formats.py | 5 +++++ comfy/supported_models.py | 13 ++++++++++++- comfy/text_encoders/hunyuan_image.py | 8 ++++++++ comfy_extras/nodes_hunyuan_image.py | 22 ++++++++++++---------- 4 files changed, 37 insertions(+), 11 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 77e642a94..a13c281dd 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -626,6 +626,11 @@ class Hunyuan3Dv2mini(LatentFormat): latent_dimensions = 1 scale_factor = 1.0188137142395404 +class HunyuanImage3(LatentFormat): + latent_channels = 32 + scale_factor = 0.562679178327931 + latent_dimensions = 3 + class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4064bdae1..5ae2baef0 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1332,6 +1332,17 @@ class QwenImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) +class HunyuanImage3(supported_models_base.BASE): + unet_config = { + "image_model": "hunyuan_image_3", + } + latent_format = latent_formats.HunyuanImage3 + + def get_model(self, state_dict, prefix="", device=None): + return model_base.HunyuanImage3(self, device = device) + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImage3Tokenizer, comfy.text_encoders.hunyuan_image.DummyClip) + class HunyuanImage21(HunyuanVideo): unet_config = { "image_model": "hunyuan_video", @@ -1374,6 +1385,6 @@ class HunyuanImage21Refiner(HunyuanVideo): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanImage3, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index ff04726e1..ab3512201 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -5,6 +5,14 @@ from transformers import ByT5Tokenizer import os import re +class DummyClip: + def __init__(*args, **kwargs): + pass + +class HunyuanImage3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, tokenizer_path="hunyuan_image_3", max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=..., has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data=..., tokenizer_args=...): + super().__init__(tokenizer_path, max_length, pad_with_end, embedding_directory, embedding_size, embedding_key, tokenizer_class, has_start_token, has_end_token, pad_to_max_length, min_length, pad_token, end_token, min_padding, tokenizer_data, tokenizer_args) + class ByT5SmallTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer") diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py index 012ac8a08..e02702995 100644 --- a/comfy_extras/nodes_hunyuan_image.py +++ b/comfy_extras/nodes_hunyuan_image.py @@ -22,19 +22,20 @@ class EmptyLatentHunyuanImage3(io.ComfyNode): io.Int.Input("height", min = 1, default = 512), io.Int.Input("width", min = 1, default = 512), io.Int.Input("batch_size", min = 1, max = 48_000, default = 1), - io.Clip.Input("clip") + io.Clip.Input("clip"), + io.Model.Input("model") ], outputs=[io.Latent.Output(display_name="latent")] ) @classmethod - def execute(cls, height, width, batch_size, clip): + def execute(cls, height, width, batch_size, clip, model): encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids special_fn = clip.tokenizer.tokenizer.added_tokens_encoder # may convert clip.tokenizer -> clip. - word_embed = clip.tokenizer.wte - patch_embed = clip.tokenizer.patch_embed - t_embed = clip.tokenizer.time_embed + word_embed = model.wte + patch_embed = model.patch_embed + t_embed = model.time_embed height, width = get_target_size(height, width) latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device()) @@ -63,20 +64,21 @@ class HunyuanImage3Conditioning(io.ComfyNode): io.Conditioning.Input("vae_encoding"), io.Conditioning.Input("vit_encoding"), io.Conditioning.Input("text_encoding_positive"), + io.Clip.Input("clip"), + io.Model.Input("model"), io.Conditioning.Input("text_encoding_negative", optional = True), - io.Clip.Input("clip") ], outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")] ) @classmethod - def execute(cls, vae_encoding, vit_encoding, text_encoding, clip, text_encoding_negative=None): + def execute(cls, vae_encoding, vit_encoding, text_encoding, clip, model, text_encoding_negative=None): encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids special_fn = clip.tokenizer.tokenizer.added_tokens_encoder - word_embed = clip.tokenizer.wte - patch_embed = clip.tokenizer.patch_embed - t_embed = clip.tokenizer.time_embed + word_embed = model.wte + patch_embed = model.patch_embed + t_embed = model.time_embed batch_size, _, hidden_size = vit_encoding.shape def fn(string, func = encode_fn):