diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 3cbca46cd..9682a270f 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -1,15 +1,17 @@ import os -import gc import math +import time import torch import psutil +import asyncio +import threading import torch.nn as nn from pathlib import Path +import concurrent.futures from einops import rearrange import torch.nn.functional as F from collections import OrderedDict from safetensors import safe_open -from contextlib import contextmanager from transformers.cache_utils import StaticCache from typing import Optional, Tuple, Any, List, Dict from comfy.ldm.modules.attention import optimized_attention @@ -19,13 +21,13 @@ INIT_MOE = torch.cuda.device_count() != 1 if not INIT_MOE: MOE_LAYER_SIZE = (1024**3) * 2.65 # approx - CPU_MOE_RATIO = None torch.cuda.set_device(0) props = torch.cuda.get_device_properties(0) - INIT_CUDA_MEM = (props.total_memory - torch.cuda.memory_reserved()) * 0.9 - ADDITIONAL_LAYERS_IN_GPU = math.floor(INIT_CUDA_MEM / MOE_LAYER_SIZE) + LAYERS_IN_CPU = math.floor((int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')) + - psutil.Process(os.getpid()).memory_info().rss + - (2*1024**3)) * 0.50) / MOE_LAYER_SIZE) class HunyuanStaticCache(StaticCache): @@ -286,20 +288,15 @@ def topkgating( logits = logits.float() gates = F.softmax(logits, dim=1) - extra = ADDITIONAL_LAYERS_IN_GPU - - values_all, indices_all = torch.topk(gates, topk + extra, dim=1) + values_all, indices_all = torch.topk(gates, topk, dim=1) expert_weight = values_all[:, :topk] expert_index = indices_all[:, :topk] - _, cpu_expert_index = torch.topk(gates, int(CPU_MOE_RATIO * 64), dim = 1) - cpu_expert_index = cpu_expert_index[:, (8 + ADDITIONAL_LAYERS_IN_GPU):] - if norm_topk_prob and topk > 1: denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps) expert_weight = expert_weight / denom - return expert_weight, expert_index, cpu_expert_index, indices_all + return expert_weight, expert_index class HunyuanRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -452,7 +449,7 @@ class HunyuanTopKGate(nn.Module): return gate_output class HunyuanMLP(nn.Module): - def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False): + def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -462,8 +459,8 @@ class HunyuanMLP(nn.Module): self.act_fn = torch.nn.functional.silu self.intermediate_size *= 2 # SwiGLU - self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False) + self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device) + self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device) def forward(self, x): self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) if x.ndim == 2: @@ -474,204 +471,143 @@ class HunyuanMLP(nn.Module): return down_proj class MoELRUCache(nn.Module): - def __init__(self, cpu_mem: int = 50, safety_buffer_bytes = 3*(1024**3), max_gpu_eviction_attempts = 8): + def __init__(self): super().__init__() - global CPU_MOE_RATIO - - _, total = torch.cuda.mem_get_info() - max_gpu_mem_gb = max((total - 2 * safety_buffer_bytes) / (1024**3), 1) - - self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3) - self.MAX_CPU_MEM = int(cpu_mem * 1024**3) self.gpu_cache = OrderedDict() self.cpu_cache = OrderedDict() + self.offload_stream = torch.cuda.Stream() + self.load_stream = torch.cuda.Stream() - self.gpu_mem_usage = 0 - self.cpu_mem_usage = 0 - # 50% for system and headroom - try: - self.MAX_CPU_MEM = int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')) - - psutil.Process(os.getpid()).memory_info().rss - - safety_buffer_bytes) * 0.55 - except: - self.MAX_CPU_MEM = int(cpu_mem * (1024**3) * 0.5) # TODO + self.last_offload_event = None + self._loop = asyncio.new_event_loop() + threading.Thread(target=self._loop.run_forever, daemon=True).start() - ADDITIONAL_LAYERS_IN_CPU = math.floor((50 * (1024**3)) / MOE_LAYER_SIZE) - CPU_MOE_RATIO = (min(64 - ADDITIONAL_LAYERS_IN_GPU, ADDITIONAL_LAYERS_IN_CPU)) / 64 + async def _async_offload_to_cpu(self, layer_idx): + # async offload from gpu (removed) - self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3) - self.SAFETY_BUFFER = int(safety_buffer_bytes) - self.MAX_GPU_EVICT_ATTEMPTS = max_gpu_eviction_attempts + num_experts = 64 + moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i]) + for i in range(num_experts) + if (layer_idx * num_experts + i) in self.gpu_cache] + event = torch.cuda.Event() - def _gpu_free_bytes(self): - free, total = torch.cuda.mem_get_info() - return int(free) - - def _estimate_size(self, moe): - # include parameters + buffers - size = 0 - for p in moe.parameters(): - size += p.numel() * p.element_size() - for b in moe.buffers(): - size += b.numel() * b.element_size() - return int(size) + with torch.cuda.stream(self.offload_stream): + for index, moe in moe_group: + moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True) + for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()): + if p_gpu.device.type == "meta": + continue + with torch.no_grad(): + p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True) + p_cpu.copy_(p_gpu, non_blocking=True) - def _evict_until_free(self, required_bytes, max_attempts=16): - attempts = 0 - while self._gpu_free_bytes() < required_bytes and attempts < max_attempts: - evicted = self._evict_from_gpu() - if not evicted: - break - attempts += 1 - return self._gpu_free_bytes() >= required_bytes + self.cpu_cache[index] = moe_cpu - @contextmanager - def ensure_headroom(self, required_bytes): + self.offload_stream.record_event(event) - safety = getattr(self, "SAFETY_BUFFER", 0) - target_free = int(required_bytes + safety) + self.last_offload_event = event - if getattr(self, "_headroom", None) is not None: - try: - del self._headroom - except Exception: - pass - self._headroom = None + def finalize_offload_layer(): + event.synchronize() + for index, moe in moe_group: + moe.to("meta") + self.gpu_cache.pop(index, None) + del moe + torch.cuda.empty_cache() - ok = self._evict_until_free(target_free) - if not ok and self._gpu_free_bytes() < target_free: - # last ditch - try: - torch.cuda.empty_cache() - except Exception: - pass + threading.Thread(target=finalize_offload_layer, daemon=True).start() - try: - yield - finally: - if getattr(self, "_headroom", None) is None: - try: - self._headroom = torch.empty((self._headroom_bytes,), dtype=torch.uint8, device="cuda:0") - except Exception: - self._headroom = None + async def _async_load_to_gpu(self, index, moe): - def add_gpu(self, moe, index, allowed_retries=3): - size = self._estimate_size(moe) - - while self.gpu_mem_usage + size > self.MAX_GPU_MEM: - if not self._evict_from_gpu(): + # if enough memory load, otherwise wait for offload + while True: + free_bytes, _ = torch.cuda.mem_get_info() + if free_bytes > 2 * MOE_LAYER_SIZE: break - attempts = 0 - while self._gpu_free_bytes() < size + self.SAFETY_BUFFER and attempts < self.MAX_GPU_EVICT_ATTEMPTS: - if not self._evict_from_gpu(): - break - attempts += 1 + self.last_offload_event.synchronize() + torch.cuda.empty_cache() + await asyncio.sleep(0.01) - for _ in range(allowed_retries): - try: - moe_cuda = moe.to("cuda:0") - break - except RuntimeError as e: - if "out of memory" not in str(e).lower(): - raise - evicted = self._evict_from_gpu() - if not evicted: # can't evict - raise - else: - raise RuntimeError("Failed to move expert to GPU after evictions") + # async loading from cpu -> gpu + with torch.cuda.stream(self.load_stream): + moe_gpu = HunyuanMLP(moe.config).to("cuda", non_blocking=True) + for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()): + with torch.no_grad(): + p_gpu.data = torch.empty_like(p_cpu, device="cuda") + p_gpu.copy_(p_cpu, non_blocking=True) - self.gpu_cache[index] = moe_cuda - self.gpu_cache.move_to_end(index) - self.gpu_mem_usage += size + def finalize_load(): + self.gpu_cache[index] = moe_gpu + self.cpu_cache.pop(index, None) - return + threading.Thread(target=finalize_load, daemon=True).start() def add_cpu(self, moe, index): - size = self._estimate_size(moe) - while self.cpu_mem_usage + size > self.MAX_CPU_MEM: - if not self._evict_from_cpu(): - break moe_cpu = moe.to("cpu") + + for _, p in moe_cpu.named_parameters(): + if not p.is_pinned(): + p.data = p.data.pin_memory() + self.cpu_cache[index] = moe_cpu self.cpu_cache.move_to_end(index) - self.cpu_mem_usage += size - - def get_from_device(self, index): - if index in self.gpu_cache: - moe = self.gpu_cache[index] - self.gpu_cache.move_to_end(index) - return moe - if index in self.cpu_cache: - moe = self.cpu_cache.pop(index) - self.cpu_mem_usage = max(0, self.cpu_mem_usage - self._estimate_size(moe)) - try: - self.add_gpu(moe, index) - return self.gpu_cache[index] - except RuntimeError: - self.cpu_cache[index] = moe - self.cpu_cache.move_to_end(index) - self.cpu_mem_usage += self._estimate_size(moe) - raise - - return None # load from disk - - def _evict_from_gpu(self): - if not self.gpu_cache: - return False - - idx, moe = self.gpu_cache.popitem(last=False) - size = self._estimate_size(moe) - self.gpu_mem_usage = max(0, self.gpu_mem_usage - size) - - if self.cpu_mem_usage + size <= self.MAX_CPU_MEM: - try: - moe_cpu = moe.to("cpu") - except Exception: - # drop the model if cpu is full - del moe - return True - self.cpu_cache[idx] = moe_cpu - self.cpu_cache.move_to_end(idx) - self.cpu_mem_usage += size - return True - else: - del moe - return True - - def _evict_from_cpu(self): - if not self.cpu_cache: - return False - _, moe = self.cpu_cache.popitem(last=False) - size = self._estimate_size(moe) - self.cpu_mem_usage = max(0, self.cpu_mem_usage - size) - del moe - gc.collect() - return True class LazyMoELoader(nn.Module): - def __init__(self, device): + def __init__(self, cache, config): super().__init__() - self.device = device + self.cache = cache + self.config = config + self._loop = cache._loop - def lazy_init(self, config, layer_idx, expert_idx): + def get_checkpoint(self): comfyui_dir = Path.home() / "ComfyUI" checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors" checkpoint = checkpoint.resolve() if not os.path.exists(checkpoint): raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}") + return checkpoint + def lazy_init(self, layer_idx, expert_idx): + checkpoint = self.get_checkpoint() prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}." additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight" sd = {} - with safe_open(checkpoint, framework="pt", device=self.device) as f: + with safe_open(checkpoint, framework="pt", device="cpu") as f: for k in f.keys(): if k.startswith(prefix) or k.startswith(additional_prefix): new_k = k.split(f"experts.{expert_idx}.", 1)[1] sd[new_k] = f.get_tensor(k) - return HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd).to(self.deivce) + return HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd) + + async def lazy_load_from_disk(self, layer_idx, expert_idx): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx) + + def _schedule_disk_load(self, layer_idx, expert_idx): + + coro = self.lazy_load_from_disk(layer_idx, expert_idx) + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + + def _on_disk_loaded(fut): + moe_cpu = fut.result() + def _add_cpu_in_main_thread(): + self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx) + + asyncio.run_coroutine_threadsafe( + self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu), + self.cache._loop + ) + threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start() + + future.add_done_callback(_on_disk_loaded) + return future + +def enough_vram(required_bytes): + free, total = torch.cuda.mem_get_info() + return free > required_bytes class HunyuanMoE(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None): @@ -687,7 +623,7 @@ class HunyuanMoE(nn.Module): [HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)] ) else: - self.experts = None + self.experts = [] self.moe_lru = moe_lru def forward(self, hidden_states): @@ -702,38 +638,10 @@ class HunyuanMoE(nn.Module): reshaped_input = hidden_states.reshape(-1, hidden_size) with torch.cuda.nvtx.range("MoE"): - expert_weight, expert_index, cpu_expert_index, indices_all = self.gate(hidden_states) - if not INIT_MOE: - if ADDITIONAL_LAYERS_IN_GPU > 0: - additional_expert_index = indices_all[:, expert_index.size(1): expert_index.size(1) + ADDITIONAL_LAYERS_IN_GPU] - - flat = additional_expert_index.reshape(-1).to("cpu") - counts = torch.bincount(flat, minlength=self.num_experts) - top_extra = torch.topk(counts, k=min(ADDITIONAL_LAYERS_IN_GPU, (counts>0).sum().item())).indices.tolist() - - for expert_id in top_extra: - if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None: - expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id) - self.moe_lru.add_gpu(expert_cpu, expert_id + self.layer_idx) - - if cpu_expert_index is not None and cpu_expert_index.numel() > 0: - for expert_id in torch.unique(cpu_expert_index).cpu().tolist(): - if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None: - expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id) - self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx) + expert_weight, expert_index = self.gate(hidden_states) combined_output = torch.zeros_like(reshaped_input) - experts_list = [] - for e in range(self.num_experts): - token_mask = (expert_index == e) - if not token_mask.any(): - continue - expert = self.moe_lru.get_from_device(e + self.layer_idx) - if expert is None: - expert = LazyMoELoader() - expert = expert.lazy_init(self.config, self.layer_idx, e) - self.moe_lru.add_gpu(expert, e + self.layer_idx) - experts_list.append((e, expert)) + experts_list = [(i, expert) for i, expert in enumerate(self.experts)] per_pos, per_tokens, per_weights = [], [], [] for e, _ in experts_list: @@ -761,6 +669,8 @@ class HunyuanMoE(nn.Module): l1, l2 = [], [] for _, expert in experts_list: + if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): + expert = expert.result() l1.append(expert.gate_and_up_proj) l2.append(expert.down_proj) @@ -769,6 +679,12 @@ class HunyuanMoE(nn.Module): W1_T = W1.transpose(1, 2) W2_T = W2.transpose(1, 2) + + # wait for enough vram for the computations + while not enough_vram(5*(1024 ** 3)): + event = self.moe_lru.last_offload_event + if event is not None and not event.query(): + time.sleep(0.001) x = torch.bmm(tokens_padded, W1_T) x = F.silu(x) @@ -781,7 +697,7 @@ class HunyuanMoE(nn.Module): for i, token_positions in enumerate(per_pos): Ni = lengths[i] out_i = out_padded[i, :Ni] - combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i) + combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype)) #dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) #chunks = dispatched_input.chunk(self.num_experts, dim=0) @@ -933,6 +849,7 @@ class HunyuanImage3Model(nn.Module): super().__init__() self.padding_idx = 128009 self.vocab_size = 133120 + self.config = config self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) self.layers = nn.ModuleList( [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])] @@ -941,6 +858,7 @@ class HunyuanImage3Model(nn.Module): self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) self.shared_tensor = None + self.moe_lru = moe_lru def forward( self, @@ -962,19 +880,45 @@ class HunyuanImage3Model(nn.Module): hidden_states = inputs_embeds next_decoder_cache = None - for layer_idx, decoder_layer in enumerate(self.layers): + next_layers = 0 + sparse_interval = max(1, len(self.layers) // 3) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - use_cache=use_cache, - custom_pos_emb=custom_pos_emb, - mode=mode, - first_step=first_step, - gen_timestep_scatter_index=gen_timestep_scatter_index, - ) + if len(self.layers[0].mlp.experts) == 0: + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)] + + for layer_idx, decoder_layer in enumerate(self.layers): + + if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)] + + if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval: + if len(self.layers[next_layers].mlp.experts) > 0: # for testing + raise ValueError("Problem with offloading") + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] + next_layers += 1 + + with torch.no_grad(): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + custom_pos_emb=custom_pos_emb, + mode=mode, + first_step=first_step, + gen_timestep_scatter_index=gen_timestep_scatter_index, + ) + + if layer_idx >= 0: + asyncio.run_coroutine_threadsafe( + self.moe_lru._async_offload_to_cpu(layer_idx), + self.moe_lru._loop + ) + self.layers[layer_idx].mlp.experts = [] hidden_states = layer_outputs[0]