mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 15:20:51 +08:00
async cache revamp
Added an async loading and offloading of moe layers, having consistent memory with oom errors. Used to give oom error after the third layer with 24 giga bytes gpu, now goes to the end with consistent memory with minimal latency
This commit is contained in:
parent
44346c4251
commit
7b4c1e8031
@ -1,15 +1,17 @@
|
||||
import os
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
import torch
|
||||
import psutil
|
||||
import asyncio
|
||||
import threading
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
import concurrent.futures
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
from collections import OrderedDict
|
||||
from safetensors import safe_open
|
||||
from contextlib import contextmanager
|
||||
from transformers.cache_utils import StaticCache
|
||||
from typing import Optional, Tuple, Any, List, Dict
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
@ -19,13 +21,13 @@ INIT_MOE = torch.cuda.device_count() != 1
|
||||
|
||||
if not INIT_MOE:
|
||||
MOE_LAYER_SIZE = (1024**3) * 2.65 # approx
|
||||
CPU_MOE_RATIO = None
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
|
||||
INIT_CUDA_MEM = (props.total_memory - torch.cuda.memory_reserved()) * 0.9
|
||||
ADDITIONAL_LAYERS_IN_GPU = math.floor(INIT_CUDA_MEM / MOE_LAYER_SIZE)
|
||||
LAYERS_IN_CPU = math.floor((int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES'))
|
||||
- psutil.Process(os.getpid()).memory_info().rss
|
||||
- (2*1024**3)) * 0.50) / MOE_LAYER_SIZE)
|
||||
|
||||
class HunyuanStaticCache(StaticCache):
|
||||
|
||||
@ -286,20 +288,15 @@ def topkgating(
|
||||
logits = logits.float()
|
||||
gates = F.softmax(logits, dim=1)
|
||||
|
||||
extra = ADDITIONAL_LAYERS_IN_GPU
|
||||
|
||||
values_all, indices_all = torch.topk(gates, topk + extra, dim=1)
|
||||
values_all, indices_all = torch.topk(gates, topk, dim=1)
|
||||
expert_weight = values_all[:, :topk]
|
||||
expert_index = indices_all[:, :topk]
|
||||
|
||||
_, cpu_expert_index = torch.topk(gates, int(CPU_MOE_RATIO * 64), dim = 1)
|
||||
cpu_expert_index = cpu_expert_index[:, (8 + ADDITIONAL_LAYERS_IN_GPU):]
|
||||
|
||||
if norm_topk_prob and topk > 1:
|
||||
denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps)
|
||||
expert_weight = expert_weight / denom
|
||||
|
||||
return expert_weight, expert_index, cpu_expert_index, indices_all
|
||||
return expert_weight, expert_index
|
||||
|
||||
class HunyuanRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
@ -452,7 +449,7 @@ class HunyuanTopKGate(nn.Module):
|
||||
return gate_output
|
||||
|
||||
class HunyuanMLP(nn.Module):
|
||||
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False):
|
||||
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
@ -462,8 +459,8 @@ class HunyuanMLP(nn.Module):
|
||||
|
||||
self.act_fn = torch.nn.functional.silu
|
||||
self.intermediate_size *= 2 # SwiGLU
|
||||
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False)
|
||||
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device)
|
||||
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device)
|
||||
def forward(self, x):
|
||||
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
||||
if x.ndim == 2:
|
||||
@ -474,204 +471,143 @@ class HunyuanMLP(nn.Module):
|
||||
return down_proj
|
||||
|
||||
class MoELRUCache(nn.Module):
|
||||
def __init__(self, cpu_mem: int = 50, safety_buffer_bytes = 3*(1024**3), max_gpu_eviction_attempts = 8):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
global CPU_MOE_RATIO
|
||||
|
||||
_, total = torch.cuda.mem_get_info()
|
||||
max_gpu_mem_gb = max((total - 2 * safety_buffer_bytes) / (1024**3), 1)
|
||||
|
||||
self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3)
|
||||
self.MAX_CPU_MEM = int(cpu_mem * 1024**3)
|
||||
self.gpu_cache = OrderedDict()
|
||||
self.cpu_cache = OrderedDict()
|
||||
self.offload_stream = torch.cuda.Stream()
|
||||
self.load_stream = torch.cuda.Stream()
|
||||
|
||||
self.gpu_mem_usage = 0
|
||||
self.cpu_mem_usage = 0
|
||||
# 50% for system and headroom
|
||||
try:
|
||||
self.MAX_CPU_MEM = int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES'))
|
||||
- psutil.Process(os.getpid()).memory_info().rss
|
||||
- safety_buffer_bytes) * 0.55
|
||||
except:
|
||||
self.MAX_CPU_MEM = int(cpu_mem * (1024**3) * 0.5) # TODO
|
||||
self.last_offload_event = None
|
||||
self._loop = asyncio.new_event_loop()
|
||||
threading.Thread(target=self._loop.run_forever, daemon=True).start()
|
||||
|
||||
ADDITIONAL_LAYERS_IN_CPU = math.floor((50 * (1024**3)) / MOE_LAYER_SIZE)
|
||||
CPU_MOE_RATIO = (min(64 - ADDITIONAL_LAYERS_IN_GPU, ADDITIONAL_LAYERS_IN_CPU)) / 64
|
||||
async def _async_offload_to_cpu(self, layer_idx):
|
||||
# async offload from gpu (removed)
|
||||
|
||||
self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3)
|
||||
self.SAFETY_BUFFER = int(safety_buffer_bytes)
|
||||
self.MAX_GPU_EVICT_ATTEMPTS = max_gpu_eviction_attempts
|
||||
num_experts = 64
|
||||
moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i])
|
||||
for i in range(num_experts)
|
||||
if (layer_idx * num_experts + i) in self.gpu_cache]
|
||||
event = torch.cuda.Event()
|
||||
|
||||
def _gpu_free_bytes(self):
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
return int(free)
|
||||
|
||||
def _estimate_size(self, moe):
|
||||
# include parameters + buffers
|
||||
size = 0
|
||||
for p in moe.parameters():
|
||||
size += p.numel() * p.element_size()
|
||||
for b in moe.buffers():
|
||||
size += b.numel() * b.element_size()
|
||||
return int(size)
|
||||
with torch.cuda.stream(self.offload_stream):
|
||||
for index, moe in moe_group:
|
||||
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
|
||||
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
|
||||
if p_gpu.device.type == "meta":
|
||||
continue
|
||||
with torch.no_grad():
|
||||
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
|
||||
p_cpu.copy_(p_gpu, non_blocking=True)
|
||||
|
||||
def _evict_until_free(self, required_bytes, max_attempts=16):
|
||||
attempts = 0
|
||||
while self._gpu_free_bytes() < required_bytes and attempts < max_attempts:
|
||||
evicted = self._evict_from_gpu()
|
||||
if not evicted:
|
||||
break
|
||||
attempts += 1
|
||||
return self._gpu_free_bytes() >= required_bytes
|
||||
self.cpu_cache[index] = moe_cpu
|
||||
|
||||
@contextmanager
|
||||
def ensure_headroom(self, required_bytes):
|
||||
self.offload_stream.record_event(event)
|
||||
|
||||
safety = getattr(self, "SAFETY_BUFFER", 0)
|
||||
target_free = int(required_bytes + safety)
|
||||
self.last_offload_event = event
|
||||
|
||||
if getattr(self, "_headroom", None) is not None:
|
||||
try:
|
||||
del self._headroom
|
||||
except Exception:
|
||||
pass
|
||||
self._headroom = None
|
||||
def finalize_offload_layer():
|
||||
event.synchronize()
|
||||
for index, moe in moe_group:
|
||||
moe.to("meta")
|
||||
self.gpu_cache.pop(index, None)
|
||||
del moe
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
ok = self._evict_until_free(target_free)
|
||||
if not ok and self._gpu_free_bytes() < target_free:
|
||||
# last ditch
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
threading.Thread(target=finalize_offload_layer, daemon=True).start()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if getattr(self, "_headroom", None) is None:
|
||||
try:
|
||||
self._headroom = torch.empty((self._headroom_bytes,), dtype=torch.uint8, device="cuda:0")
|
||||
except Exception:
|
||||
self._headroom = None
|
||||
async def _async_load_to_gpu(self, index, moe):
|
||||
|
||||
def add_gpu(self, moe, index, allowed_retries=3):
|
||||
size = self._estimate_size(moe)
|
||||
|
||||
while self.gpu_mem_usage + size > self.MAX_GPU_MEM:
|
||||
if not self._evict_from_gpu():
|
||||
# if enough memory load, otherwise wait for offload
|
||||
while True:
|
||||
free_bytes, _ = torch.cuda.mem_get_info()
|
||||
if free_bytes > 2 * MOE_LAYER_SIZE:
|
||||
break
|
||||
|
||||
attempts = 0
|
||||
while self._gpu_free_bytes() < size + self.SAFETY_BUFFER and attempts < self.MAX_GPU_EVICT_ATTEMPTS:
|
||||
if not self._evict_from_gpu():
|
||||
break
|
||||
attempts += 1
|
||||
self.last_offload_event.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
for _ in range(allowed_retries):
|
||||
try:
|
||||
moe_cuda = moe.to("cuda:0")
|
||||
break
|
||||
except RuntimeError as e:
|
||||
if "out of memory" not in str(e).lower():
|
||||
raise
|
||||
evicted = self._evict_from_gpu()
|
||||
if not evicted: # can't evict
|
||||
raise
|
||||
else:
|
||||
raise RuntimeError("Failed to move expert to GPU after evictions")
|
||||
# async loading from cpu -> gpu
|
||||
with torch.cuda.stream(self.load_stream):
|
||||
moe_gpu = HunyuanMLP(moe.config).to("cuda", non_blocking=True)
|
||||
for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()):
|
||||
with torch.no_grad():
|
||||
p_gpu.data = torch.empty_like(p_cpu, device="cuda")
|
||||
p_gpu.copy_(p_cpu, non_blocking=True)
|
||||
|
||||
self.gpu_cache[index] = moe_cuda
|
||||
self.gpu_cache.move_to_end(index)
|
||||
self.gpu_mem_usage += size
|
||||
def finalize_load():
|
||||
self.gpu_cache[index] = moe_gpu
|
||||
self.cpu_cache.pop(index, None)
|
||||
|
||||
return
|
||||
threading.Thread(target=finalize_load, daemon=True).start()
|
||||
|
||||
def add_cpu(self, moe, index):
|
||||
size = self._estimate_size(moe)
|
||||
while self.cpu_mem_usage + size > self.MAX_CPU_MEM:
|
||||
if not self._evict_from_cpu():
|
||||
break
|
||||
moe_cpu = moe.to("cpu")
|
||||
|
||||
for _, p in moe_cpu.named_parameters():
|
||||
if not p.is_pinned():
|
||||
p.data = p.data.pin_memory()
|
||||
|
||||
self.cpu_cache[index] = moe_cpu
|
||||
self.cpu_cache.move_to_end(index)
|
||||
self.cpu_mem_usage += size
|
||||
|
||||
def get_from_device(self, index):
|
||||
if index in self.gpu_cache:
|
||||
moe = self.gpu_cache[index]
|
||||
self.gpu_cache.move_to_end(index)
|
||||
return moe
|
||||
if index in self.cpu_cache:
|
||||
moe = self.cpu_cache.pop(index)
|
||||
self.cpu_mem_usage = max(0, self.cpu_mem_usage - self._estimate_size(moe))
|
||||
try:
|
||||
self.add_gpu(moe, index)
|
||||
return self.gpu_cache[index]
|
||||
except RuntimeError:
|
||||
self.cpu_cache[index] = moe
|
||||
self.cpu_cache.move_to_end(index)
|
||||
self.cpu_mem_usage += self._estimate_size(moe)
|
||||
raise
|
||||
|
||||
return None # load from disk
|
||||
|
||||
def _evict_from_gpu(self):
|
||||
if not self.gpu_cache:
|
||||
return False
|
||||
|
||||
idx, moe = self.gpu_cache.popitem(last=False)
|
||||
size = self._estimate_size(moe)
|
||||
self.gpu_mem_usage = max(0, self.gpu_mem_usage - size)
|
||||
|
||||
if self.cpu_mem_usage + size <= self.MAX_CPU_MEM:
|
||||
try:
|
||||
moe_cpu = moe.to("cpu")
|
||||
except Exception:
|
||||
# drop the model if cpu is full
|
||||
del moe
|
||||
return True
|
||||
self.cpu_cache[idx] = moe_cpu
|
||||
self.cpu_cache.move_to_end(idx)
|
||||
self.cpu_mem_usage += size
|
||||
return True
|
||||
else:
|
||||
del moe
|
||||
return True
|
||||
|
||||
def _evict_from_cpu(self):
|
||||
if not self.cpu_cache:
|
||||
return False
|
||||
_, moe = self.cpu_cache.popitem(last=False)
|
||||
size = self._estimate_size(moe)
|
||||
self.cpu_mem_usage = max(0, self.cpu_mem_usage - size)
|
||||
del moe
|
||||
gc.collect()
|
||||
return True
|
||||
|
||||
class LazyMoELoader(nn.Module):
|
||||
def __init__(self, device):
|
||||
def __init__(self, cache, config):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.cache = cache
|
||||
self.config = config
|
||||
self._loop = cache._loop
|
||||
|
||||
def lazy_init(self, config, layer_idx, expert_idx):
|
||||
def get_checkpoint(self):
|
||||
comfyui_dir = Path.home() / "ComfyUI"
|
||||
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
|
||||
checkpoint = checkpoint.resolve()
|
||||
if not os.path.exists(checkpoint):
|
||||
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
|
||||
return checkpoint
|
||||
|
||||
def lazy_init(self, layer_idx, expert_idx):
|
||||
checkpoint = self.get_checkpoint()
|
||||
prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}."
|
||||
additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight"
|
||||
sd = {}
|
||||
|
||||
with safe_open(checkpoint, framework="pt", device=self.device) as f:
|
||||
with safe_open(checkpoint, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
if k.startswith(prefix) or k.startswith(additional_prefix):
|
||||
new_k = k.split(f"experts.{expert_idx}.", 1)[1]
|
||||
sd[new_k] = f.get_tensor(k)
|
||||
|
||||
return HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd).to(self.deivce)
|
||||
return HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd)
|
||||
|
||||
async def lazy_load_from_disk(self, layer_idx, expert_idx):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
|
||||
|
||||
def _schedule_disk_load(self, layer_idx, expert_idx):
|
||||
|
||||
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
|
||||
def _on_disk_loaded(fut):
|
||||
moe_cpu = fut.result()
|
||||
def _add_cpu_in_main_thread():
|
||||
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
|
||||
self.cache._loop
|
||||
)
|
||||
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
|
||||
|
||||
future.add_done_callback(_on_disk_loaded)
|
||||
return future
|
||||
|
||||
def enough_vram(required_bytes):
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
return free > required_bytes
|
||||
|
||||
class HunyuanMoE(nn.Module):
|
||||
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None):
|
||||
@ -687,7 +623,7 @@ class HunyuanMoE(nn.Module):
|
||||
[HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)]
|
||||
)
|
||||
else:
|
||||
self.experts = None
|
||||
self.experts = []
|
||||
self.moe_lru = moe_lru
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -702,38 +638,10 @@ class HunyuanMoE(nn.Module):
|
||||
reshaped_input = hidden_states.reshape(-1, hidden_size)
|
||||
|
||||
with torch.cuda.nvtx.range("MoE"):
|
||||
expert_weight, expert_index, cpu_expert_index, indices_all = self.gate(hidden_states)
|
||||
if not INIT_MOE:
|
||||
if ADDITIONAL_LAYERS_IN_GPU > 0:
|
||||
additional_expert_index = indices_all[:, expert_index.size(1): expert_index.size(1) + ADDITIONAL_LAYERS_IN_GPU]
|
||||
|
||||
flat = additional_expert_index.reshape(-1).to("cpu")
|
||||
counts = torch.bincount(flat, minlength=self.num_experts)
|
||||
top_extra = torch.topk(counts, k=min(ADDITIONAL_LAYERS_IN_GPU, (counts>0).sum().item())).indices.tolist()
|
||||
|
||||
for expert_id in top_extra:
|
||||
if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None:
|
||||
expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id)
|
||||
self.moe_lru.add_gpu(expert_cpu, expert_id + self.layer_idx)
|
||||
|
||||
if cpu_expert_index is not None and cpu_expert_index.numel() > 0:
|
||||
for expert_id in torch.unique(cpu_expert_index).cpu().tolist():
|
||||
if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None:
|
||||
expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id)
|
||||
self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx)
|
||||
expert_weight, expert_index = self.gate(hidden_states)
|
||||
|
||||
combined_output = torch.zeros_like(reshaped_input)
|
||||
experts_list = []
|
||||
for e in range(self.num_experts):
|
||||
token_mask = (expert_index == e)
|
||||
if not token_mask.any():
|
||||
continue
|
||||
expert = self.moe_lru.get_from_device(e + self.layer_idx)
|
||||
if expert is None:
|
||||
expert = LazyMoELoader()
|
||||
expert = expert.lazy_init(self.config, self.layer_idx, e)
|
||||
self.moe_lru.add_gpu(expert, e + self.layer_idx)
|
||||
experts_list.append((e, expert))
|
||||
experts_list = [(i, expert) for i, expert in enumerate(self.experts)]
|
||||
|
||||
per_pos, per_tokens, per_weights = [], [], []
|
||||
for e, _ in experts_list:
|
||||
@ -761,6 +669,8 @@ class HunyuanMoE(nn.Module):
|
||||
|
||||
l1, l2 = [], []
|
||||
for _, expert in experts_list:
|
||||
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
||||
expert = expert.result()
|
||||
l1.append(expert.gate_and_up_proj)
|
||||
l2.append(expert.down_proj)
|
||||
|
||||
@ -769,6 +679,12 @@ class HunyuanMoE(nn.Module):
|
||||
|
||||
W1_T = W1.transpose(1, 2)
|
||||
W2_T = W2.transpose(1, 2)
|
||||
|
||||
# wait for enough vram for the computations
|
||||
while not enough_vram(5*(1024 ** 3)):
|
||||
event = self.moe_lru.last_offload_event
|
||||
if event is not None and not event.query():
|
||||
time.sleep(0.001)
|
||||
|
||||
x = torch.bmm(tokens_padded, W1_T)
|
||||
x = F.silu(x)
|
||||
@ -781,7 +697,7 @@ class HunyuanMoE(nn.Module):
|
||||
for i, token_positions in enumerate(per_pos):
|
||||
Ni = lengths[i]
|
||||
out_i = out_padded[i, :Ni]
|
||||
combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i)
|
||||
combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype))
|
||||
|
||||
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
|
||||
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
|
||||
@ -933,6 +849,7 @@ class HunyuanImage3Model(nn.Module):
|
||||
super().__init__()
|
||||
self.padding_idx = 128009
|
||||
self.vocab_size = 133120
|
||||
self.config = config
|
||||
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
|
||||
@ -941,6 +858,7 @@ class HunyuanImage3Model(nn.Module):
|
||||
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
||||
|
||||
self.shared_tensor = None
|
||||
self.moe_lru = moe_lru
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -962,19 +880,45 @@ class HunyuanImage3Model(nn.Module):
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
next_decoder_cache = None
|
||||
for layer_idx, decoder_layer in enumerate(self.layers):
|
||||
next_layers = 0
|
||||
sparse_interval = max(1, len(self.layers) // 3)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
use_cache=use_cache,
|
||||
custom_pos_emb=custom_pos_emb,
|
||||
mode=mode,
|
||||
first_step=first_step,
|
||||
gen_timestep_scatter_index=gen_timestep_scatter_index,
|
||||
)
|
||||
if len(self.layers[0].mlp.experts) == 0:
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
|
||||
|
||||
for layer_idx, decoder_layer in enumerate(self.layers):
|
||||
|
||||
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
|
||||
|
||||
if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval:
|
||||
if len(self.layers[next_layers].mlp.experts) > 0: # for testing
|
||||
raise ValueError("Problem with offloading")
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
|
||||
next_layers += 1
|
||||
|
||||
with torch.no_grad():
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
use_cache=use_cache,
|
||||
custom_pos_emb=custom_pos_emb,
|
||||
mode=mode,
|
||||
first_step=first_step,
|
||||
gen_timestep_scatter_index=gen_timestep_scatter_index,
|
||||
)
|
||||
|
||||
if layer_idx >= 0:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.moe_lru._async_offload_to_cpu(layer_idx),
|
||||
self.moe_lru._loop
|
||||
)
|
||||
self.layers[layer_idx].mlp.experts = []
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user