diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 77e642a94..a13c281dd 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -626,6 +626,11 @@ class Hunyuan3Dv2mini(LatentFormat): latent_dimensions = 1 scale_factor = 1.0188137142395404 +class HunyuanImage3(LatentFormat): + latent_channels = 32 + scale_factor = 0.562679178327931 + latent_dimensions = 3 + class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py new file mode 100644 index 000000000..c82904fdc --- /dev/null +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -0,0 +1,1128 @@ +import os +import math +import time +import torch +import psutil +import asyncio +import threading +import torch.nn as nn +from pathlib import Path +import concurrent.futures +from einops import rearrange +import torch.nn.functional as F +from collections import OrderedDict +from safetensors import safe_open +from transformers.cache_utils import StaticCache +from typing import Optional, Tuple, Any, List, Dict +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock + +INIT_MOE = torch.cuda.device_count() != 1 + +if not INIT_MOE: + MOE_LAYER_SIZE = (1024**3) * 2.65 # approx + + torch.cuda.set_device(0) + props = torch.cuda.get_device_properties(0) + + LAYERS_IN_CPU = math.floor((int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')) + - psutil.Process(os.getpid()).memory_info().rss + - (2*1024**3)) * 0.50) / MOE_LAYER_SIZE) + +class HunyuanStaticCache(StaticCache): + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + cache_position = cache_kwargs.get("cache_position") + if hasattr(self, "key_cache") and hasattr(self, "value_cache"): + if self.key_cache[layer_idx].device != key_states.device: + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + else: + if self.layers[layer_idx].keys is None: + self.layers[layer_idx].lazy_initialization(key_states) + k_out = self.layers[layer_idx].keys + v_out = self.layers[layer_idx].values + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + if cache_position.dim() == 1: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + + else: + assert cache_position.dim() == 2, f"multiple batch dims not yet {cache_position.shape=}" + batch_size, _ = cache_position.shape + for i in range(batch_size): + unbatched_dim = 1 + k_out[i].index_copy_(unbatched_dim, cache_position[i], key_states[i]) + v_out[i].index_copy_(unbatched_dim, cache_position[i], value_states[i]) + + return k_out, v_out + +def real_batched_index_select(t, dim, idx): + return torch.stack([torch.index_select(t[i], dim - 1, idx[i]) for i in range(len(t))]) + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + +class TimestepEmbedder(nn.Module): + def __init__(self, + hidden_size, + act_layer=nn.GELU, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None + ): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + return x + + +def get_meshgrid_nd(start, *args, dim=2): + + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + num_int = [int(x) for x in num] + num = num_int + + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") + grid = torch.stack(grid, dim=0) + + return grid + +def build_2d_rope( + seq_len: int, n_elem: int, image_infos: Optional[List[Tuple[slice, Tuple[int, int]]]] = None, + device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, + return_all_pos: bool = False, +): + + assert n_elem % 4 == 0, f"n_elem must be divisible by 4, but got {n_elem}." + + # theta + if base_rescale_factor != 1.0: + base *= base_rescale_factor ** (n_elem / (n_elem - 2)) + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + theta = theta.reshape(1, n_elem // 4, 2) # [1, half_d, 2] + + # position indices + if image_infos is None: + image_infos = [] + + image_infos_list = [image_infos] + sample_seq_lens = [seq_len] + + # Prepare position indices for each sample + x_sections = [] + y_sections = [] + for sample_id, sample_image_infos in enumerate(image_infos_list): + last_pos = 0 + for sec_slice, (h, w) in sample_image_infos: + L = sec_slice.start # start from 0, so image_slice.start is just L + # previous text + if last_pos < L: + y_sections.append(torch.arange(last_pos, L)) + x_sections.append(torch.arange(last_pos, L)) + elif h is None: + # Interleave data has overlapped positions for tokens. + y_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) + x_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) + continue + else: + # Interleave data has overlapped positions for noised image and the successive clean image, + # leading to last_pos (= last text end L + noise w * h) > L (last text end L). + pass + # current image + beta_y = L + (w * h - h) / 2 + beta_x = L + (w * h - w) / 2 + grid = get_meshgrid_nd((beta_y, beta_x), (beta_y + h, beta_x + w)) # [2, h, w] + grid = grid.reshape(2, -1) # (y, x) + y_sections.append(grid[0]) + x_sections.append(grid[1]) + # step + last_pos = L + w * h + # final text + y_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) + x_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) + + x_pos = torch.cat(x_sections).long() + y_pos = torch.cat(y_sections).long() + # If there are overlap positions, we need to remove them. + x_pos = x_pos[:seq_len] + y_pos = y_pos[:seq_len] + all_pos = torch.stack((y_pos, x_pos), dim=1).unsqueeze(1).to(device) # [seq_len, 1, 2] + + # calc rope + idx_theta = (all_pos * theta).reshape(all_pos.shape[0], n_elem // 2).repeat(1, 2) + + cos = torch.cos(idx_theta) + sin = torch.sin(idx_theta) + + if return_all_pos: + return cos, sin, all_pos + + return cos, sin + + +def build_batch_2d_rope( + seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None, + device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, + return_all_pos: bool = False, +): + cos_list, sin_list, all_pos_list = [], [], [] + if image_infos is None: + image_infos = [None] + for i, image_info in enumerate(image_infos): + res = build_2d_rope( + seq_len, n_elem, image_infos=image_info, device=device, + base=base, base_rescale_factor=base_rescale_factor, + return_all_pos=return_all_pos, + ) + if return_all_pos: + cos, sin, all_pos = res + else: + cos, sin = res + all_pos = None + cos_list.append(cos) + sin_list.append(sin) + all_pos_list.append(all_pos) + + stacked_cos = torch.stack(cos_list, dim=0) + stacked_sin = torch.stack(sin_list, dim=0) + + if return_all_pos: + return stacked_cos, stacked_sin, all_pos_list + + return stacked_cos, stacked_sin + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + if position_ids is not None: + cos = cos[position_ids] + sin = sin[position_ids] + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def default(val, d): + return val if val is not None else d + +def conv_nd(dims, *args, **kwargs): + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + +def normalization(channels, **kwargs): + return nn.GroupNorm(32, channels, **kwargs) + +def topkgating(logits: torch.Tensor, topk: int): + logits = logits.float() + gates = F.softmax(logits, dim=1) + + num_experts = int(gates.shape[1]) + + _, expert_index = torch.topk(gates, topk) + expert_mask = F.one_hot(expert_index, num_experts) + + expert_index_flat = expert_index.flatten() + tokens_per_expert = torch.bincount(expert_index_flat, minlength=num_experts) + expert_capacity = torch.max(tokens_per_expert).item() + + gates_s = torch.clamp( + torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps + ) + router_probs = gates / gates_s + + expert_index = torch.transpose(expert_index, 0, 1) + expert_index = expert_index.reshape(-1) + expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) + + token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1 + token_priority = token_priority.reshape((topk, -1, num_experts)) + token_priority = torch.transpose(token_priority, 0, 1) + + token_priority = torch.max(token_priority, dim=1)[0] + + valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) + token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) + valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity) + dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) + + combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) + + return combine_weights, dispatch_mask + +class HunyuanRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class UNetDown(nn.Module): + def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, + dropout=0.0, device=None, dtype=None): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + + self.patch_size = patch_size + assert self.patch_size in [1, 2, 4, 8] + + self.model = nn.ModuleList( + [conv_nd( + 2, + in_channels=in_channels, + out_channels=hidden_channels, + kernel_size=3, + padding=1, + **factory_kwargs + )] + ) + + if self.patch_size == 1: + self.model.append(ResBlock( + channels=hidden_channels, + emb_channels=emb_channels, + out_channels=out_channels, + use_scale_shift_norm = True, + dropout=dropout, + **factory_kwargs + )) + else: + for i in range(self.patch_size // 2): + self.model.append(ResBlock( + channels=hidden_channels, + emb_channels=emb_channels, + use_scale_shift_norm = True, + out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels, + dropout=dropout, + down=True, + **factory_kwargs + )) + + def forward(self, x, t): + assert x.shape[2] % self.patch_size == 0 and x.shape[3] % self.patch_size == 0 + for module in self.model: + if isinstance(module, ResBlock): + x = module(x, t) + else: + x = module(x) + _, _, token_h, token_w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + return x, token_h, token_w + + +class UNetUp(nn.Module): + + def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, + dropout=0.0, device=None, dtype=None, operations = None, out_norm=False): + operations = operations or nn + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + + self.patch_size = patch_size + assert self.patch_size in [1, 2, 4, 8] + + self.model = nn.ModuleList() + + if self.patch_size == 1: + self.model.append(ResBlock( + channels=in_channels, + emb_channels=emb_channels, + out_channels=hidden_channels, + use_scale_shift_norm = True, + dropout=dropout, + **factory_kwargs + )) + else: + for i in range(self.patch_size // 2): + self.model.append(ResBlock( + channels=in_channels if i == 0 else hidden_channels, + emb_channels=emb_channels, + out_channels=hidden_channels, + use_scale_shift_norm = True, + dropout=dropout, + up=True, + **factory_kwargs + )) + + if out_norm: + self.model.append(nn.Sequential( + normalization(hidden_channels, **factory_kwargs), + nn.SiLU(), + nn.Conv2d( + in_channels=hidden_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + **factory_kwargs + ), + )) + else: + self.model.append(nn.Conv2d( + in_channels=hidden_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + **factory_kwargs + )) + + # batch_size, seq_len, model_dim + def forward(self, x, t, token_h, token_w): + x = rearrange(x, 'b (h w) c -> b c h w', h=token_h, w=token_w) + for module in self.model: + if isinstance(module, ResBlock): + x = module(x, t) + else: + x = module(x) + return x + +class HunyuanTopKGate(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.moe_topk = 8 + self.min_capacity = 8 + num_experts = 64 + self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32) + + def forward(self, hidden_states): + bsz, seq_len, hidden_size = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_size) + if self.wg.weight.dtype == torch.float32: + hidden_states = hidden_states.float() + logits = self.wg(hidden_states) + gate_output = topkgating(logits, self.moe_topk) + + return gate_output + +class HunyuanMLP(nn.Module): + def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config["hidden_size"] + + self.intermediate_size = 3072 + + self.act_fn = torch.nn.functional.silu + self.intermediate_size *= 2 # SwiGLU + self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device) + self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device) + def forward(self, x): + self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) + if x.ndim == 2: + x = x.unsqueeze(0) + gate_and_up_proj = self.gate_and_up_proj(x) + x1, x2 = gate_and_up_proj.chunk(2, dim=2) + down_proj = self.down_proj(x1 * self.act_fn(x2)) + return down_proj + +class MoELRUCache(nn.Module): + def __init__(self): + super().__init__() + self.gpu_cache = OrderedDict() + self.cpu_cache = OrderedDict() + self.offload_stream = torch.cuda.Stream() + self.load_stream = torch.cuda.Stream() + + self.last_offload_event = None + self._loop = asyncio.new_event_loop() + threading.Thread(target=self._loop.run_forever, daemon=True).start() + + async def _async_offload_to_cpu(self, layer_idx): + # async offload from gpu (removed) + + num_experts = 64 + moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i]) + for i in range(num_experts) + if (layer_idx * num_experts + i) in self.gpu_cache] + event = torch.cuda.Event() + + with torch.cuda.stream(self.offload_stream): + for index, moe in moe_group: + moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True) + for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()): + if p_gpu.device.type == "meta": + continue + with torch.no_grad(): + p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True) + p_cpu.copy_(p_gpu, non_blocking=True) + + self.cpu_cache[index] = moe_cpu + + self.offload_stream.record_event(event) + + self.last_offload_event = event + + def finalize_offload_layer(): + event.synchronize() + for index, moe in moe_group: + moe.to("meta") + self.gpu_cache.pop(index, None) + del moe + torch.cuda.empty_cache() + + threading.Thread(target=finalize_offload_layer, daemon=True).start() + + async def _async_load_to_gpu(self, index, moe): + + # if enough memory load, otherwise wait for offload + while True: + free_bytes, _ = torch.cuda.mem_get_info() + if free_bytes > 2 * MOE_LAYER_SIZE: + break + + self.last_offload_event.synchronize() + torch.cuda.empty_cache() + await asyncio.sleep(0.01) + + # async loading from cpu -> gpu + with torch.cuda.stream(self.load_stream): + moe_gpu = HunyuanMLP(moe.config).to("cuda", non_blocking=True) + for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()): + with torch.no_grad(): + p_gpu.data = torch.empty_like(p_cpu, device="cuda") + p_gpu.copy_(p_cpu, non_blocking=True) + + def finalize_load(): + self.gpu_cache[index] = moe_gpu + self.cpu_cache.pop(index, None) + + threading.Thread(target=finalize_load, daemon=True).start() + + def add_cpu(self, moe, index): + moe_cpu = moe.to("cpu") + + for _, p in moe_cpu.named_parameters(): + if not p.is_pinned(): + if p.device.type == "cpu": + p.data = p.data.pin_memory() + else: + return + + self.cpu_cache[index] = moe_cpu + self.cpu_cache.move_to_end(index) + +class LazyMoELoader(nn.Module): + def __init__(self, cache, config): + super().__init__() + self.cache = cache + self.config = config + self._loop = cache._loop + + def get_checkpoint(self): + comfyui_dir = Path.home() / "ComfyUI" + checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors" + checkpoint = checkpoint.resolve() + if not os.path.exists(checkpoint): + raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}") + return checkpoint + + def lazy_init(self, layer_idx, expert_idx): + checkpoint = self.get_checkpoint() + prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}." + additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight" + sd = {} + + with safe_open(checkpoint, framework="pt", device="cpu") as f: + for k in f.keys(): + if k.startswith(prefix) or k.startswith(additional_prefix): + new_k = k.split(f"experts.{expert_idx}.", 1)[1] + sd[new_k] = f.get_tensor(k) + + model = HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True, device="meta") + model.to_empty(device = "cpu") + + model.load_state_dict(sd) + return model + + async def lazy_load_from_disk(self, layer_idx, expert_idx): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx) + + def _schedule_disk_load(self, layer_idx, expert_idx): + + coro = self.lazy_load_from_disk(layer_idx, expert_idx) + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + + def _on_disk_loaded(fut): + moe_cpu = fut.result() + def _add_cpu_in_main_thread(): + self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx) + + asyncio.run_coroutine_threadsafe( + self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu), + self.cache._loop + ) + threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start() + + future.add_done_callback(_on_disk_loaded) + return future + +def enough_vram(required_bytes): + free, total = torch.cuda.mem_get_info() + return free > required_bytes + +class HunyuanMoE(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.moe_topk = 8 + self.num_experts = 64 + self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True) + self.gate = HunyuanTopKGate(config, layer_idx=layer_idx) + if INIT_MOE: + self.experts = nn.ModuleList( + [HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)] + ) + else: + self.experts = [] + self.moe_lru = moe_lru + + def forward(self, hidden_states): + if not INIT_MOE: + torch.cuda.set_device(0) + else: + torch.cuda.set_device(hidden_states.device.index) + bsz, seq_len, hidden_size = hidden_states.shape + + hidden_states_mlp = self.shared_mlp(hidden_states) + + reshaped_input = hidden_states.reshape(-1, hidden_size) + + with torch.cuda.nvtx.range("MoE"): + combine_weights, dispatch_mask = self.gate(hidden_states) + + dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(reshaped_input), reshaped_input) + device = hidden_states.device + dtype = reshaped_input.dtype + + used_mask = (dispatch_mask.sum(dim=(0, 2)) > 0) + used_indices = used_mask.nonzero(as_tuple=False).squeeze(1).tolist() + + combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype) + + if len(used_indices) == 0: + pass + else: + tokens_padded = dispatched_input[used_indices] + + l1_layers, l2_layers = [], [] + for i in used_indices: + expert = self.experts[i] + if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): + expert = expert.result() + expert = expert.to(device) + l1_layers.append(expert.gate_and_up_proj) + l2_layers.append(expert.down_proj) + + l1, l2 = [], [] + for i in used_indices: + expert = self.experts[i] + if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): + expert = expert.result() + expert = expert.to(device) + l1.append(expert.gate_and_up_proj) + l2.append(expert.down_proj) + + compute_device = hidden_states.device + l1 = [m.to(compute_device) for m in l1] + l2 = [m.to(compute_device) for m in l2] + + W1 = torch.stack([m.weight for m in l1], dim=0) + W2 = torch.stack([m.weight for m in l2], dim=0) + + W1_T = W1.transpose(1, 2) + W2_T = W2.transpose(1, 2) + + while not enough_vram(3*(1024 ** 3)): + event = self.moe_lru.last_offload_event + if event is not None and not event.query(): + time.sleep(0.001) + + x = torch.bmm(tokens_padded, W1_T) + x1, x2 = x.chunk(2, dim=2) + gated = x1 * F.silu(x2) + out_padded = torch.bmm(gated, W2_T) + + combine_weights_used = combine_weights[:, used_indices, :] + + combined_output = torch.einsum("suc,ucm->sm", + combine_weights_used.type_as(out_padded), + out_padded + ) + + del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded + + combined_output = combined_output.reshape(bsz, seq_len, hidden_size) + + output = hidden_states_mlp + combined_output + + return output + +class HunyuanImage3Attention(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_type = 'self' + + self.hidden_size = config["hidden_size"] + self.num_heads = config["num_attention_heads"] + self.head_dim = config["attention_head_dim"] + self.num_key_value_heads = 8 + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config["max_position_embeddings"] + self.rope_theta = 10000.0 + self.is_causal = True + self.hidden_size_q = self.head_dim * self.num_heads + self.hidden_size_kv = self.head_dim * self.num_key_value_heads + + # define layers + self.qkv_proj = nn.Linear( + self.hidden_size, + self.hidden_size_q + 2 * self.hidden_size_kv, + bias=False + ) + self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=False) + + self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]) + self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, + **kwargs, + ): + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.qkv_proj(hidden_states) + qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2, + self.head_dim) + query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3) + + query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = custom_pos_emb + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + query_states = query_states.to(value_states.dtype) + key_states = key_states.to(value_states.dtype) + + if past_key_value is not None: + cache_kwargs = {"cache_position": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + query_states = query_states.to(key_states.dtype) + + key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups) + value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups) + + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, mask = attention_mask, skip_reshape=True) + + attn_output = self.o_proj(attn_output) + + return attn_output, past_key_value + +class HunyuanImage3DecoderLayer(nn.Module): + def __init__(self, config, layer_idx: int, moe_lru=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.layer_idx = layer_idx + self.self_attn = HunyuanImage3Attention(config, layer_idx=layer_idx) + + self.mlp = HunyuanMoE(config, layer_idx=layer_idx, moe_lru=moe_lru) + + self.input_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) + self.post_attention_layernorm = HunyuanRMSNorm(config["hidden_size"], eps=config['rms_norm_eps']) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor | Any]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, past_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + custom_pos_emb=custom_pos_emb, + **kwargs, + ) + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, past_key_value) + + return outputs + +class HunyuanImage3Model(nn.Module): + def __init__(self, config, moe_lru=None): + super().__init__() + self.padding_idx = 128009 + self.vocab_size = 133120 + self.config = config + self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) + self.layers = nn.ModuleList( + [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])] + ) + + self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) + + self.shared_tensor = None + self.moe_lru = moe_lru + self.self.additional_layers_set = False + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache = True, + custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, + mode: str = "gen_image", + first_step: Optional[bool] = None, + gen_timestep_scatter_index: Optional[torch.Tensor] = None, + ): + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + next_decoder_cache = None + next_layers = 0 + additional_layers = torch.cuda.mem_get_info()[0] // (MOE_LAYER_SIZE * 2) + sparse_interval = max(1, len(self.layers) // additional_layers) + + if len(self.layers[0].mlp.experts) == 0: + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)] + + for layer_idx, decoder_layer in enumerate(self.layers): + + if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)] + + if not self.additional_layers_set: + if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval: + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] + next_layers += 1 + + with torch.no_grad(): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + custom_pos_emb=custom_pos_emb, + mode=mode, + first_step=first_step, + gen_timestep_scatter_index=gen_timestep_scatter_index, + ) + + if layer_idx >= 0: + if self.additional_layers_set and layer_idx <= self.additional_layers_set: + pass + else: + torch.cuda.synchronize() + asyncio.run_coroutine_threadsafe( + self.moe_lru._async_offload_to_cpu(layer_idx), + self.moe_lru._loop + ) + self.layers[layer_idx].mlp.experts = [] + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[1] + + next_cache = None + if use_cache: + next_cache = next_decoder_cache + self.additional_layers_set = True + return tuple(v for v in [hidden_states, next_cache] if v is not None) + + +class HunyuanImage3ForCausalMM(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"]) + self.patch_embed = UNetDown( + patch_size=1, + emb_channels=config["hidden_size"], + in_channels=32, + hidden_channels=1024, + out_channels=config["hidden_size"], + ) + self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"]) + + self.final_layer = UNetUp( + patch_size=1, + emb_channels=config["hidden_size"], + in_channels=config["hidden_size"], + hidden_channels=1024, + out_channels=32, + out_norm=True, + ) + self.time_embed_2 = TimestepEmbedder(hidden_size=config["hidden_size"]) + + self.moe_lru = None + if not INIT_MOE: + self.moe_lru = MoELRUCache() + + self.model = HunyuanImage3Model(config, moe_lru=self.moe_lru) + + self.pad_id = 128009 + self.vocab_size = 133120 + + self.lm_head = nn.Linear(config["hidden_size"], 133120, bias=False) + self.first_step = True + + self.kv_cache = None + self.token_dims = () + + @staticmethod + def get_pos_emb(custom_pos_emb, position_ids): + cos, sin = custom_pos_emb + cos = real_batched_index_select(cos, dim=1, idx=position_ids) + sin = real_batched_index_select(sin, dim=1, idx=position_ids) + return cos, sin + + def ragged_final_layer(self, x, image_mask, timestep, token_h, token_w, first_step): + bsz, seq_len, n_embd = x.shape + if first_step: + image_output = x.masked_select(image_mask.unsqueeze(-1).bool()).reshape(bsz, -1, n_embd) + else: + image_output = x[:, 1:, :] + timestep_emb = self.time_embed_2(timestep) + pred = self.final_layer(image_output, timestep_emb, token_h, token_w) + return pred + + def forward(self, x, condition, timestep, **kwargs): + + joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() + + gen_timestep_scatter_index = 4 + + with torch.no_grad(): + joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio + + if self.first_step: + token_height, token_width = x[:, -2:, 0].tolist()[0] + self.token_dims = (int(token_height), int(token_width)) + x = x[:, :-2, :] + else: + token_height, token_width = self.token_dims + + img_slices = [] + + for i in range(x.size(0)): + vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0] + vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1 + + vit_start = vae_end + 1 + vit_end = joint_image.size(1) - 1 + + joint_slices_i = [ + slice(vae_start, vae_end), + slice(vit_start, vit_end), + ] + gen_slices_i = [slice(3 + vit_end, x[i].size(0) - 1 + vit_end)] + img_slices.append(joint_slices_i + gen_slices_i) + + img_s = img_slices[0] + rope_image_info = [[(img_s[0], (384 // 16, 384 // 16)), (img_s[1], (256 // 16, 256 // 16)), (img_s[2], (token_height, token_width))]] + + cond_timestep = torch.zeros(inputs_embeds.size(0)) + t_emb = self.time_embed(cond_timestep) + + bsz, seq_len, n_embd = inputs_embeds.shape + + if self.first_step: + x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + else: + t_emb = self.time_embed(timestep) + x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb) + timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) + x = torch.cat([timestep_emb, x], dim=1) + + #///////////// + # cond_vae_images + + # cond_timestep_scatter_index + with torch.no_grad(): + joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + + inputs_embeds = torch.cat([inputs_embeds, joint_image, x], dim = 1) + + attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) + for i in range(bsz): + for _, image_slice in enumerate(img_slices[i]): + attention_mask[i, image_slice, image_slice] = True + attention_mask = attention_mask.unsqueeze(1) + + # pos embed + position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) + cos, sin = build_batch_2d_rope( + image_infos=rope_image_info, + seq_len=inputs_embeds.shape[1], + n_elem=self.config["hidden_size"] // self.config["num_attention_heads"], + base=10000.0, + ) + custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device)) + custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) + + if self.kv_cache is None: + # TODO: should change when higgsv2 gets merged + self.kv_cache = HunyuanStaticCache( + config=self.config, + batch_size=x.size(0) * 2, + max_cache_len = inputs_embeds.shape[1], + dtype=x.dtype, + ) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=self.kv_cache, + inputs_embeds=inputs_embeds, + custom_pos_emb=custom_pos_emb, + first_step=self.first_step, + gen_timestep_scatter_index=gen_timestep_scatter_index, + ) + hidden_states = outputs[0] + + # safety no-op + past_key_value = outputs[1] + if past_key_value is not None: + self.kv_cache = past_key_value + + hidden_states = hidden_states.to(inputs_embeds.device) + img_mask = torch.zeros(hidden_states.size(1)) + img_mask[-x.size(1)+4:] = 1; img_mask[-1] = 0 + + diffusion_prediction = self.ragged_final_layer( + hidden_states, img_mask, timestep, int(token_height), int(token_width), self.first_step) + + if self.first_step: + self.first_step = False + + return diffusion_prediction diff --git a/comfy/model_base.py b/comfy/model_base.py index 7c788d085..56391257f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -46,6 +46,7 @@ import comfy.ldm.chroma.model import comfy.ldm.chroma_radiance.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 +import comfy.ldm.hunyuan_image_3.model import comfy.ldm.qwen_image.model import comfy.model_management @@ -1355,6 +1356,10 @@ class Hunyuan3Dv2(BaseModel): if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out + +class HunyuanImage3(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device, unet_model = comfy.ldm.hunyuan_image_3.model.HunyuanImage3ForCausalMM) class Hunyuan3Dv2_1(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3142a7fc3..246596167 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -482,6 +482,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["patch_size"] = 2 dit_config["text_emb_dim"] = 2048 return dit_config + + if "{}layers.32.mlp.gate_and_up_proj.weight".format(key_prefix) in state_dict_keys: + dit_config = {} + dit_config["image_model"] = "hunyuan_image_3" + dit_config["hidden_size"] = 4096 + dit_config["max_position_embeddings"] = 12800 + dit_config["num_attention_heads"] = 32 + dit_config['rms_norm_eps'] = 1e-05 + dit_config["num_hidden_layers"] = 32 + dit_config["attention_head_dim"] = 128 + return dit_config if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 dit_config = {} diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4064bdae1..5ae2baef0 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1332,6 +1332,17 @@ class QwenImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) +class HunyuanImage3(supported_models_base.BASE): + unet_config = { + "image_model": "hunyuan_image_3", + } + latent_format = latent_formats.HunyuanImage3 + + def get_model(self, state_dict, prefix="", device=None): + return model_base.HunyuanImage3(self, device = device) + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImage3Tokenizer, comfy.text_encoders.hunyuan_image.DummyClip) + class HunyuanImage21(HunyuanVideo): unet_config = { "image_model": "hunyuan_video", @@ -1374,6 +1385,6 @@ class HunyuanImage21Refiner(HunyuanVideo): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanImage3, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index ff04726e1..ab3512201 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -5,6 +5,14 @@ from transformers import ByT5Tokenizer import os import re +class DummyClip: + def __init__(*args, **kwargs): + pass + +class HunyuanImage3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, tokenizer_path="hunyuan_image_3", max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=..., has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data=..., tokenizer_args=...): + super().__init__(tokenizer_path, max_length, pad_with_end, embedding_directory, embedding_size, embedding_key, tokenizer_class, has_start_token, has_end_token, pad_to_max_length, min_length, pad_token, end_token, min_padding, tokenizer_data, tokenizer_args) + class ByT5SmallTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer") diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py new file mode 100644 index 000000000..e02702995 --- /dev/null +++ b/comfy_extras/nodes_hunyuan_image.py @@ -0,0 +1,122 @@ +import torch +import comfy.model_management +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + +COMPUTED_RESO_GROUPS = ['512x2048', '512x1984', '512x1920', '512x1856', '512x1792', '512x1728', '512x1664', '512x1600', '512x1536', '576x1472', '640x1408', '704x1344', '768x1280', '832x1216', '896x1152', '960x1088', '1024x1024', '1088x960', '1152x896', '1216x832', '1280x768', '1344x704', '1408x640', '1472x576', '1536x512', '1600x512', '1664x512', '1728x512', '1792x512', '1856x512', '1920x512', '1984x512', '2048x512'] +RATIOS = [torch.tensor(int(r.split("x")[0]) / int(r.split("x")[1])) for r in COMPUTED_RESO_GROUPS] +def get_target_size(height, width): + ratio = height / width + idx = torch.argmin(torch.abs(torch.tensor(RATIOS) - ratio)) + reso = COMPUTED_RESO_GROUPS[idx] + return reso.split("x") + +class EmptyLatentHunyuanImage3(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyLatentHunyuanImage3", + display_name="EmptyLatentHunyuanImage3", + category="image/latent", + inputs = [ + io.Int.Input("height", min = 1, default = 512), + io.Int.Input("width", min = 1, default = 512), + io.Int.Input("batch_size", min = 1, max = 48_000, default = 1), + io.Clip.Input("clip"), + io.Model.Input("model") + ], + outputs=[io.Latent.Output(display_name="latent")] + ) + @classmethod + def execute(cls, height, width, batch_size, clip, model): + encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids + special_fn = clip.tokenizer.tokenizer.added_tokens_encoder + + # may convert clip.tokenizer -> clip. + word_embed = model.wte + patch_embed = model.patch_embed + t_embed = model.time_embed + + height, width = get_target_size(height, width) + latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device()) + + latent, tk_height, tk_width = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size))) + + def tk_fn(token): + return torch.tensor([token], device = latent.device, dtype = latent.dtype).unsqueeze(1).expand(batch_size, 1, latent.size(-1)) + + def fn(string, func = encode_fn): + return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ + .unsqueeze(0).expand(batch_size, -1, -1) + + latent = torch.cat([fn(""), fn("", func = special_fn), fn(f"", special_fn), fn("", special_fn), latent, fn("")], dim = 1) + latent = torch.cat([latent, tk_fn(tk_height), tk_fn(tk_width)], dim = 1) + return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, ) + +class HunyuanImage3Conditioning(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanImage3Conditioning", + display_name="HunyuanImage3Conditioning", + category="conditioning/video_models", + inputs = [ + io.Conditioning.Input("vae_encoding"), + io.Conditioning.Input("vit_encoding"), + io.Conditioning.Input("text_encoding_positive"), + io.Clip.Input("clip"), + io.Model.Input("model"), + io.Conditioning.Input("text_encoding_negative", optional = True), + ], + outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")] + ) + + @classmethod + def execute(cls, vae_encoding, vit_encoding, text_encoding, clip, model, text_encoding_negative=None): + encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids + special_fn = clip.tokenizer.tokenizer.added_tokens_encoder + + word_embed = model.wte + patch_embed = model.patch_embed + t_embed = model.time_embed + batch_size, _, hidden_size = vit_encoding.shape + + def fn(string, func = encode_fn): + return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ + .view(1, 1, hidden_size).expand(batch_size, -1, hidden_size) + + text_tokens = text_encoding[0][0] + vae_encoding, _, _ = patch_embed(vae_encoding, t_embed(torch.tensor([0]).repeat(vae_encoding.size(0)))) + # should dynamically change in model logic + joint_image = torch.cat([fn(""), fn("", special_fn), fn("", special_fn), fn("", special_fn), vae_encoding, fn(""), vit_encoding, fn("")], dim = 1) + + vae_mask = torch.ones(joint_image.size(1)) + vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:])) + + ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.to(joint_image.dtype)]) + + uncond_ragged_tensors = None + if text_encoding_negative is not None: + uncond_ragged_tensors, _ = cls.execute(vae_encoding, vit_encoding, text_encoding_negative, clip=clip, text_encoding_negative = None) + else: + uncond_ragged_tensors = torch.nested.nested_tensor([torch.zeros_like(t) for t in ragged_tensors.unbind()]) + + if uncond_ragged_tensors is not None: + positive = [[ragged_tensors, {}]] + negative = [[uncond_ragged_tensors, {}]] + else: + positive = ragged_tensors + negative = uncond_ragged_tensors + + return positive, negative + +class Image3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + HunyuanImage3Conditioning, + EmptyLatentHunyuanImage3 + ] + +async def comfy_entrypoint() -> Image3Extension: + return Image3Extension() diff --git a/nodes.py b/nodes.py index 030371633..230f08221 100644 --- a/nodes.py +++ b/nodes.py @@ -2326,6 +2326,7 @@ async def init_builtin_extra_nodes(): "nodes_ace.py", "nodes_string.py", "nodes_camera_trajectory.py", + "nodes_hunyuan_image.py", "nodes_edit_model.py", "nodes_tcfg.py", "nodes_context_windows.py",