From a3ac798d4e46a6a3bc6eb061bf91920727d2a0cf Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 22 Nov 2025 23:02:42 +0200 Subject: [PATCH] decrease peak memory in moe forward --- comfy/ldm/hunyuan_image_3/model.py | 92 ++++++++++++++---------------- 1 file changed, 42 insertions(+), 50 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 78dd3154d..59310f0ab 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -2,7 +2,6 @@ import os import math import time import torch -import psutil import asyncio import threading import torch.nn as nn @@ -13,12 +12,13 @@ import torch.nn.functional as F from collections import OrderedDict from safetensors import safe_open from transformers.cache_utils import StaticCache +from concurrent.futures import ThreadPoolExecutor from typing import Optional, Tuple, Any, List, Dict from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock INIT_MOE = torch.cuda.device_count() != 1 -MOE_LAYER_SIZE = (1024**3) * 2.65 # approx +MOE_LAYER_SIZE = (1024**3) * 5.15 # approx class HunyuanStaticCache(StaticCache): @@ -576,7 +576,7 @@ def parse_layer_expert(key): return layer, expert class LazyMoELoader(nn.Module): - def __init__(self, cache, config): + def __init__(self, cache, config, max_workers = 16, max_concurrent_loads = 32): super().__init__() self.cache = cache self.config = config @@ -586,6 +586,9 @@ class LazyMoELoader(nn.Module): self._file = safe_open(self._checkpoint, framework="pt", device="cpu", mmap=True) self.expert_pool = self.build_meta_experts() + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._semaphore = threading.Semaphore(max_concurrent_loads) + def build_meta_experts(self): pool = {} for layer, experts in self.expert_key_index.items(): @@ -629,28 +632,17 @@ class LazyMoELoader(nn.Module): getattr(model, name).data = tensor return model - async def lazy_load_from_disk(self, layer_idx, expert_idx): - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx) - - def _schedule_disk_load(self, layer_idx, expert_idx): + def _load_single_expert(self, layer_idx, expert_idx): + with self._semaphore: + return self.lazy_init(layer_idx, expert_idx) - coro = self.lazy_load_from_disk(layer_idx, expert_idx) - future = asyncio.run_coroutine_threadsafe(coro, self._loop) - - def _on_disk_loaded(fut): - moe_cpu = fut.result() - def _add_cpu_in_main_thread(): - self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx) - - asyncio.run_coroutine_threadsafe( - self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu), - self.cache._loop - ) - threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start() - - future.add_done_callback(_on_disk_loaded) - return future + def schedule_layer_load(self, layer_idx, num_experts = 64): + futures = [] + for i in range(num_experts): + coro = asyncio.get_event_loop().run_in_executor(self._executor, self._load_single_expert, layer_idx, i) + fut = asyncio.run_coroutine_threadsafe(coro, self._loop) + futures.append(fut) + return futures def enough_vram(required_bytes): free, total = torch.cuda.mem_get_info() @@ -700,16 +692,7 @@ class HunyuanMoE(nn.Module): pass else: tokens_padded = dispatched_input[used_indices] - - l1_layers, l2_layers = [], [] - for i in used_indices: - expert = self.experts[i] - if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): - expert = expert.result() - expert = expert.to(device) - l1_layers.append(expert.gate_and_up_proj) - l2_layers.append(expert.down_proj) - + l1, l2 = [], [] for i in used_indices: expert = self.experts[i] @@ -720,24 +703,32 @@ class HunyuanMoE(nn.Module): l2.append(expert.down_proj) compute_device = hidden_states.device - l1 = [m.to(compute_device) for m in l1] - l2 = [m.to(compute_device) for m in l2] + l1 = [m.to(compute_device) for m in l1] W1 = torch.stack([m.weight for m in l1], dim=0) - W2 = torch.stack([m.weight for m in l2], dim=0) - + del l1 W1_T = W1.transpose(1, 2) + + del W1 + x = torch.bmm(tokens_padded, W1_T) + del W1_T, tokens_padded + + x1, x2 = x.chunk(2, dim=2) + gated = x1 * F.silu(x2) + + l2 = [m.to(compute_device) for m in l2] + W2 = torch.stack([m.weight for m in l2], dim=0) + del l2 W2_T = W2.transpose(1, 2) + del W2 + out_padded = torch.bmm(gated, W2_T) + del W2_T while not enough_vram(3*(1024 ** 3)): event = self.moe_lru.last_offload_event if event is not None and not event.query(): time.sleep(0.001) - x = torch.bmm(tokens_padded, W1_T) - x1, x2 = x.chunk(2, dim=2) - gated = x1 * F.silu(x2) - out_padded = torch.bmm(gated, W2_T) combine_weights_used = combine_weights[:, used_indices, :] @@ -746,7 +737,7 @@ class HunyuanMoE(nn.Module): out_padded ) - del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded + del x, x1, x2, gated, out_padded combined_output = combined_output.reshape(bsz, seq_len, hidden_size) @@ -899,7 +890,8 @@ class HunyuanImage3Model(nn.Module): self.shared_tensor = None self.moe_lru = moe_lru - self.self.additional_layers_set = False + self.additional_layers_set = False + self.moe_loader = LazyMoELoader(self.moe_lru, self.config) def forward( self, @@ -922,19 +914,19 @@ class HunyuanImage3Model(nn.Module): sparse_interval = max(1, len(self.layers) // additional_layers) if len(self.layers[0].mlp.experts) == 0: - experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] - self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)] + self.layers[0].mlp.experts = self.moe_loader.schedule_layer_load(0) for layer_idx, decoder_layer in enumerate(self.layers): if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded - experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] - self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)] + self.layers[layer_idx+1].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 1) + + if layer_idx + 2 < len(self.layers) and len(self.layers[layer_idx + 2].mlp.experts) == 0: # load first and second layers + self.layers[layer_idx+2].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 2) if not self.additional_layers_set: if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval: - experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] - self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] + self.layers[next_layers].mlp.experts = self.moe_loader.schedule_layer_load(next_layers) next_layers += 1 with torch.no_grad():