mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 15:20:51 +08:00
improving performance and fixing race condition
This commit is contained in:
parent
12cc6924ac
commit
d731c58353
@ -548,7 +548,10 @@ class MoELRUCache(nn.Module):
|
||||
|
||||
for _, p in moe_cpu.named_parameters():
|
||||
if not p.is_pinned():
|
||||
p.data = p.data.pin_memory()
|
||||
if p.device.type == "cpu":
|
||||
p.data = p.data.pin_memory()
|
||||
else:
|
||||
return
|
||||
|
||||
self.cpu_cache[index] = moe_cpu
|
||||
self.cpu_cache.move_to_end(index)
|
||||
@ -643,74 +646,81 @@ class HunyuanMoE(nn.Module):
|
||||
|
||||
with torch.cuda.nvtx.range("MoE"):
|
||||
expert_weight, expert_index = self.gate(hidden_states)
|
||||
|
||||
combined_output = torch.zeros_like(reshaped_input)
|
||||
experts_list = [(i, expert) for i, expert in enumerate(self.experts)]
|
||||
|
||||
per_pos, per_tokens, per_weights = [], [], []
|
||||
for e, _ in experts_list:
|
||||
device = hidden_states.device
|
||||
dtype = reshaped_input.dtype
|
||||
|
||||
combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype)
|
||||
|
||||
per_pos = [None] * self.num_experts
|
||||
per_tokens = [None] * self.num_experts
|
||||
per_weights = [None] * self.num_experts
|
||||
for e in range(self.num_experts):
|
||||
token_mask = (expert_index == e)
|
||||
|
||||
token_ids = token_mask.nonzero(as_tuple=False)
|
||||
if token_ids.numel() == 0:
|
||||
continue
|
||||
token_positions = token_ids[:, 0]
|
||||
topk_slot = token_ids[:, 1]
|
||||
per_pos[e] = token_positions
|
||||
per_tokens[e] = reshaped_input[token_positions]
|
||||
per_weights[e] = expert_weight[token_positions, topk_slot]
|
||||
|
||||
used = [i for i, t in enumerate(per_tokens) if t is not None]
|
||||
if len(used) == 0:
|
||||
pass
|
||||
else:
|
||||
tokens_list = [per_tokens[i] for i in used]
|
||||
weights_list = [per_weights[i] for i in used]
|
||||
lengths = [t.shape[0] for t in tokens_list]
|
||||
U = len(tokens_list)
|
||||
L = max(lengths)
|
||||
H = hidden_size
|
||||
|
||||
tokens_padded = torch.zeros((U, L, H), device=device, dtype=dtype)
|
||||
weights_padded = torch.zeros((U, L), device=device, dtype=weights_list[0].dtype)
|
||||
for idx, t in enumerate(tokens_list):
|
||||
n = t.shape[0]
|
||||
tokens_padded[idx, :n] = t
|
||||
weights_padded[idx, :n] = weights_list[idx]
|
||||
|
||||
l1, l2 = [], []
|
||||
for i in used:
|
||||
expert = self.experts[i]
|
||||
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
||||
expert = expert.result()
|
||||
expert = expert.to(device)
|
||||
l1.append(expert.gate_and_up_proj)
|
||||
l2.append(expert.down_proj)
|
||||
|
||||
tokens = reshaped_input[token_positions]
|
||||
weights = expert_weight[token_positions, topk_slot]
|
||||
compute_device = hidden_states.device
|
||||
l1 = [m.to(compute_device) for m in l1]
|
||||
l2 = [m.to(compute_device) for m in l2]
|
||||
|
||||
per_pos.append(token_positions)
|
||||
per_tokens.append(tokens)
|
||||
per_weights.append(weights)
|
||||
|
||||
lengths = [t.shape[0] for t in per_tokens]
|
||||
E = len(per_tokens)
|
||||
L = max(lengths)
|
||||
tokens_padded = torch.zeros((E, L, hidden_size), device=hidden_states.device, dtype=reshaped_input.dtype)
|
||||
weights_padded = torch.zeros((E, L), device=hidden_states.device, dtype=per_weights[0].dtype)
|
||||
for i, t in enumerate(per_tokens):
|
||||
tokens_padded[i, : t.shape[0]] = t
|
||||
weights_padded[i, : t.shape[0]] = per_weights[i]
|
||||
|
||||
l1, l2 = [], []
|
||||
for _, expert in experts_list:
|
||||
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
||||
expert = expert.result()
|
||||
l1.append(expert.gate_and_up_proj)
|
||||
l2.append(expert.down_proj)
|
||||
|
||||
W1 = torch.stack([l.weight for l in l1]).to(hidden_states.device)
|
||||
W2 = torch.stack([l.weight for l in l2]).to(hidden_states.device)
|
||||
|
||||
W1_T = W1.transpose(1, 2)
|
||||
W2_T = W2.transpose(1, 2)
|
||||
|
||||
# wait for enough vram for the computations
|
||||
while not enough_vram(5*(1024 ** 3)):
|
||||
event = self.moe_lru.last_offload_event
|
||||
if event is not None and not event.query():
|
||||
time.sleep(0.001)
|
||||
|
||||
x = torch.bmm(tokens_padded, W1_T)
|
||||
x = F.silu(x)
|
||||
|
||||
x1, x2 = x.chunk(2, dim=2)
|
||||
out_padded = torch.bmm(x1 * F.silu(x2), W2_T)
|
||||
|
||||
out_padded = out_padded * weights_padded.unsqueeze(-1)
|
||||
|
||||
for i, token_positions in enumerate(per_pos):
|
||||
Ni = lengths[i]
|
||||
out_i = out_padded[i, :Ni]
|
||||
combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype))
|
||||
|
||||
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
|
||||
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
|
||||
#expert_outputs = []
|
||||
#for chunk, expert in zip(chunks, self.experts):
|
||||
# expert_outputs.append(expert(chunk))
|
||||
|
||||
#expert_output = torch.cat(expert_outputs, dim=0)
|
||||
#combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
|
||||
W1 = torch.stack([m.weight for m in l1], dim=0)
|
||||
W2 = torch.stack([m.weight for m in l2], dim=0)
|
||||
|
||||
W1_T = W1.transpose(1, 2)
|
||||
W2_T = W2.transpose(1, 2)
|
||||
|
||||
while not enough_vram(3*(1024 ** 3)):
|
||||
event = self.moe_lru.last_offload_event
|
||||
if event is not None and not event.query():
|
||||
time.sleep(0.001)
|
||||
|
||||
x = torch.bmm(tokens_padded, W1_T)
|
||||
x = F.silu(x)
|
||||
x1, x2 = x.chunk(2, dim=2)
|
||||
out_padded = torch.bmm(x1 * F.silu(x2), W2_T)
|
||||
|
||||
out_padded = out_padded * weights_padded.unsqueeze(-1)
|
||||
|
||||
for idx, orig_expert_idx in enumerate(used):
|
||||
pos = per_pos[orig_expert_idx]
|
||||
n = lengths[idx]
|
||||
out_i = out_padded[idx, :n]
|
||||
combined_output.index_add_(0, pos.to(device), out_i.to(combined_output.dtype))
|
||||
|
||||
del tokens_padded, weights_padded, W1, W2, W1_T, W2_T, x, x1, x2, out_padded
|
||||
|
||||
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
||||
|
||||
@ -863,6 +873,7 @@ class HunyuanImage3Model(nn.Module):
|
||||
|
||||
self.shared_tensor = None
|
||||
self.moe_lru = moe_lru
|
||||
self.self.additional_layers_set = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -885,7 +896,8 @@ class HunyuanImage3Model(nn.Module):
|
||||
|
||||
next_decoder_cache = None
|
||||
next_layers = 0
|
||||
sparse_interval = max(1, len(self.layers) // 3)
|
||||
additional_layers = torch.cuda.mem_get_info()[0] // (MOE_LAYER_SIZE * 2)
|
||||
sparse_interval = max(1, len(self.layers) // additional_layers)
|
||||
|
||||
if len(self.layers[0].mlp.experts) == 0:
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
@ -896,13 +908,12 @@ class HunyuanImage3Model(nn.Module):
|
||||
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
|
||||
|
||||
if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval:
|
||||
if len(self.layers[next_layers].mlp.experts) > 0: # for testing
|
||||
raise ValueError("Problem with offloading")
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
|
||||
next_layers += 1
|
||||
|
||||
if not additional_layers_set:
|
||||
if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval:
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
|
||||
next_layers += 1
|
||||
|
||||
with torch.no_grad():
|
||||
layer_outputs = decoder_layer(
|
||||
@ -918,11 +929,15 @@ class HunyuanImage3Model(nn.Module):
|
||||
)
|
||||
|
||||
if layer_idx >= 0:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.moe_lru._async_offload_to_cpu(layer_idx),
|
||||
self.moe_lru._loop
|
||||
)
|
||||
self.layers[layer_idx].mlp.experts = []
|
||||
if self.additional_layers_set and layer_idx <= self.additional_layers_set:
|
||||
pass
|
||||
else:
|
||||
torch.cuda.synchronize()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.moe_lru._async_offload_to_cpu(layer_idx),
|
||||
self.moe_lru._loop
|
||||
)
|
||||
self.layers[layer_idx].mlp.experts = []
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -932,7 +947,7 @@ class HunyuanImage3Model(nn.Module):
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache
|
||||
|
||||
self.additional_layers_set = True
|
||||
return tuple(v for v in [hidden_states, next_cache] if v is not None)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user