improving performance and fixing race condition

This commit is contained in:
Yousef Rafat 2025-11-16 16:19:39 +02:00
parent 12cc6924ac
commit d731c58353

View File

@ -548,7 +548,10 @@ class MoELRUCache(nn.Module):
for _, p in moe_cpu.named_parameters():
if not p.is_pinned():
p.data = p.data.pin_memory()
if p.device.type == "cpu":
p.data = p.data.pin_memory()
else:
return
self.cpu_cache[index] = moe_cpu
self.cpu_cache.move_to_end(index)
@ -643,74 +646,81 @@ class HunyuanMoE(nn.Module):
with torch.cuda.nvtx.range("MoE"):
expert_weight, expert_index = self.gate(hidden_states)
combined_output = torch.zeros_like(reshaped_input)
experts_list = [(i, expert) for i, expert in enumerate(self.experts)]
per_pos, per_tokens, per_weights = [], [], []
for e, _ in experts_list:
device = hidden_states.device
dtype = reshaped_input.dtype
combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype)
per_pos = [None] * self.num_experts
per_tokens = [None] * self.num_experts
per_weights = [None] * self.num_experts
for e in range(self.num_experts):
token_mask = (expert_index == e)
token_ids = token_mask.nonzero(as_tuple=False)
if token_ids.numel() == 0:
continue
token_positions = token_ids[:, 0]
topk_slot = token_ids[:, 1]
per_pos[e] = token_positions
per_tokens[e] = reshaped_input[token_positions]
per_weights[e] = expert_weight[token_positions, topk_slot]
used = [i for i, t in enumerate(per_tokens) if t is not None]
if len(used) == 0:
pass
else:
tokens_list = [per_tokens[i] for i in used]
weights_list = [per_weights[i] for i in used]
lengths = [t.shape[0] for t in tokens_list]
U = len(tokens_list)
L = max(lengths)
H = hidden_size
tokens_padded = torch.zeros((U, L, H), device=device, dtype=dtype)
weights_padded = torch.zeros((U, L), device=device, dtype=weights_list[0].dtype)
for idx, t in enumerate(tokens_list):
n = t.shape[0]
tokens_padded[idx, :n] = t
weights_padded[idx, :n] = weights_list[idx]
l1, l2 = [], []
for i in used:
expert = self.experts[i]
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
expert = expert.result()
expert = expert.to(device)
l1.append(expert.gate_and_up_proj)
l2.append(expert.down_proj)
tokens = reshaped_input[token_positions]
weights = expert_weight[token_positions, topk_slot]
compute_device = hidden_states.device
l1 = [m.to(compute_device) for m in l1]
l2 = [m.to(compute_device) for m in l2]
per_pos.append(token_positions)
per_tokens.append(tokens)
per_weights.append(weights)
lengths = [t.shape[0] for t in per_tokens]
E = len(per_tokens)
L = max(lengths)
tokens_padded = torch.zeros((E, L, hidden_size), device=hidden_states.device, dtype=reshaped_input.dtype)
weights_padded = torch.zeros((E, L), device=hidden_states.device, dtype=per_weights[0].dtype)
for i, t in enumerate(per_tokens):
tokens_padded[i, : t.shape[0]] = t
weights_padded[i, : t.shape[0]] = per_weights[i]
l1, l2 = [], []
for _, expert in experts_list:
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
expert = expert.result()
l1.append(expert.gate_and_up_proj)
l2.append(expert.down_proj)
W1 = torch.stack([l.weight for l in l1]).to(hidden_states.device)
W2 = torch.stack([l.weight for l in l2]).to(hidden_states.device)
W1_T = W1.transpose(1, 2)
W2_T = W2.transpose(1, 2)
# wait for enough vram for the computations
while not enough_vram(5*(1024 ** 3)):
event = self.moe_lru.last_offload_event
if event is not None and not event.query():
time.sleep(0.001)
x = torch.bmm(tokens_padded, W1_T)
x = F.silu(x)
x1, x2 = x.chunk(2, dim=2)
out_padded = torch.bmm(x1 * F.silu(x2), W2_T)
out_padded = out_padded * weights_padded.unsqueeze(-1)
for i, token_positions in enumerate(per_pos):
Ni = lengths[i]
out_i = out_padded[i, :Ni]
combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype))
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
#expert_outputs = []
#for chunk, expert in zip(chunks, self.experts):
# expert_outputs.append(expert(chunk))
#expert_output = torch.cat(expert_outputs, dim=0)
#combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
W1 = torch.stack([m.weight for m in l1], dim=0)
W2 = torch.stack([m.weight for m in l2], dim=0)
W1_T = W1.transpose(1, 2)
W2_T = W2.transpose(1, 2)
while not enough_vram(3*(1024 ** 3)):
event = self.moe_lru.last_offload_event
if event is not None and not event.query():
time.sleep(0.001)
x = torch.bmm(tokens_padded, W1_T)
x = F.silu(x)
x1, x2 = x.chunk(2, dim=2)
out_padded = torch.bmm(x1 * F.silu(x2), W2_T)
out_padded = out_padded * weights_padded.unsqueeze(-1)
for idx, orig_expert_idx in enumerate(used):
pos = per_pos[orig_expert_idx]
n = lengths[idx]
out_i = out_padded[idx, :n]
combined_output.index_add_(0, pos.to(device), out_i.to(combined_output.dtype))
del tokens_padded, weights_padded, W1, W2, W1_T, W2_T, x, x1, x2, out_padded
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
@ -863,6 +873,7 @@ class HunyuanImage3Model(nn.Module):
self.shared_tensor = None
self.moe_lru = moe_lru
self.self.additional_layers_set = False
def forward(
self,
@ -885,7 +896,8 @@ class HunyuanImage3Model(nn.Module):
next_decoder_cache = None
next_layers = 0
sparse_interval = max(1, len(self.layers) // 3)
additional_layers = torch.cuda.mem_get_info()[0] // (MOE_LAYER_SIZE * 2)
sparse_interval = max(1, len(self.layers) // additional_layers)
if len(self.layers[0].mlp.experts) == 0:
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
@ -896,13 +908,12 @@ class HunyuanImage3Model(nn.Module):
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval:
if len(self.layers[next_layers].mlp.experts) > 0: # for testing
raise ValueError("Problem with offloading")
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
next_layers += 1
if not additional_layers_set:
if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval:
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
next_layers += 1
with torch.no_grad():
layer_outputs = decoder_layer(
@ -918,11 +929,15 @@ class HunyuanImage3Model(nn.Module):
)
if layer_idx >= 0:
asyncio.run_coroutine_threadsafe(
self.moe_lru._async_offload_to_cpu(layer_idx),
self.moe_lru._loop
)
self.layers[layer_idx].mlp.experts = []
if self.additional_layers_set and layer_idx <= self.additional_layers_set:
pass
else:
torch.cuda.synchronize()
asyncio.run_coroutine_threadsafe(
self.moe_lru._async_offload_to_cpu(layer_idx),
self.moe_lru._loop
)
self.layers[layer_idx].mlp.experts = []
hidden_states = layer_outputs[0]
@ -932,7 +947,7 @@ class HunyuanImage3Model(nn.Module):
next_cache = None
if use_cache:
next_cache = next_decoder_cache
self.additional_layers_set = True
return tuple(v for v in [hidden_states, next_cache] if v is not None)