mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 15:20:51 +08:00
fixed some mistakes/errors
This commit is contained in:
parent
b4001bbd27
commit
823870db53
@ -493,42 +493,44 @@ class MoELRUCache(nn.Module):
|
||||
|
||||
self.last_offload_event = None
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._gpu_sem = asyncio.Semaphore(1) # maybe 2
|
||||
threading.Thread(target=self._loop.run_forever, daemon=True).start()
|
||||
|
||||
async def _async_offload_to_cpu(self, layer_idx):
|
||||
# async offload from gpu (removed)
|
||||
|
||||
num_experts = 64
|
||||
moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i])
|
||||
for i in range(num_experts)
|
||||
if (layer_idx * num_experts + i) in self.gpu_cache]
|
||||
event = torch.cuda.Event()
|
||||
|
||||
with torch.cuda.stream(self.offload_stream):
|
||||
for index, moe in moe_group:
|
||||
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
|
||||
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
|
||||
if p_gpu.device.type == "meta":
|
||||
continue
|
||||
with torch.no_grad():
|
||||
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
|
||||
p_cpu.copy_(p_gpu, non_blocking=True)
|
||||
async with self._gpu_sem:
|
||||
num_experts = 64
|
||||
moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i])
|
||||
for i in range(num_experts)
|
||||
if (layer_idx * num_experts + i) in self.gpu_cache]
|
||||
event = torch.cuda.Event()
|
||||
|
||||
with torch.cuda.stream(self.offload_stream):
|
||||
for index, moe in moe_group:
|
||||
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
|
||||
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
|
||||
if p_gpu.device.type == "meta":
|
||||
continue
|
||||
with torch.no_grad():
|
||||
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
|
||||
p_cpu.copy_(p_gpu, non_blocking=True)
|
||||
|
||||
self.cpu_cache[index] = moe_cpu
|
||||
self.cpu_cache[index] = moe_cpu
|
||||
|
||||
self.offload_stream.record_event(event)
|
||||
self.offload_stream.record_event(event)
|
||||
|
||||
self.last_offload_event = event
|
||||
self.last_offload_event = event
|
||||
|
||||
def finalize_offload_layer():
|
||||
event.synchronize()
|
||||
for index, moe in moe_group:
|
||||
moe.to("meta")
|
||||
self.gpu_cache.pop(index, None)
|
||||
del moe
|
||||
torch.cuda.empty_cache()
|
||||
def finalize_offload_layer():
|
||||
event.synchronize()
|
||||
for index, moe in moe_group:
|
||||
moe.to("meta")
|
||||
self.gpu_cache.pop(index, None)
|
||||
del moe
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
threading.Thread(target=finalize_offload_layer, daemon=True).start()
|
||||
threading.Thread(target=finalize_offload_layer, daemon=True).start()
|
||||
|
||||
async def _async_load_to_gpu(self, index, moe):
|
||||
|
||||
@ -632,37 +634,29 @@ class LazyMoELoader(nn.Module):
|
||||
getattr(model, name).data = tensor
|
||||
return model
|
||||
|
||||
def _register_expert_sync(self, layer_idx, expert_idx, moe_cpu):
|
||||
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
|
||||
self.cache._loop
|
||||
)
|
||||
|
||||
async def lazy_load_from_disk(self, layer_idx, expert_idx):
|
||||
loop = asyncio.get_event_loop()
|
||||
async with self._semaphore:
|
||||
moe_cpu = await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
|
||||
self._loop.call_soon_threadsafe(self._register_expert_sync, layer_idx, expert_idx, moe_cpu)
|
||||
return moe_cpu
|
||||
|
||||
async def schedule_layer_load_progressive(self, layer_idx, num_experts = 64):
|
||||
tasks = [asyncio.create_task(self.lazy_load_from_disk(layer_idx, i)) for i in range(num_experts)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
return results
|
||||
|
||||
def schedule_layer_load(self, layer_idx, num_experts = 64):
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self.schedule_layer_load_progressive(layer_idx, num_experts),
|
||||
self._loop
|
||||
)
|
||||
return fut
|
||||
|
||||
async def schedule_layer_load_(self, layer_idx):
|
||||
tasks = [self.lazy_load_from_disk(layer_idx, i) for i in range(64)]
|
||||
experts = await asyncio.gather(*tasks)
|
||||
return experts
|
||||
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
|
||||
|
||||
def _schedule_disk_load(self, layer_idx, expert_idx):
|
||||
|
||||
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
|
||||
def _on_disk_loaded(fut):
|
||||
moe_cpu = fut.result()
|
||||
def _add_cpu_in_main_thread():
|
||||
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
|
||||
self.cache._loop
|
||||
)
|
||||
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
|
||||
|
||||
future.add_done_callback(_on_disk_loaded)
|
||||
return future
|
||||
|
||||
def enough_vram(required_bytes):
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
return free > required_bytes
|
||||
@ -748,7 +742,6 @@ class HunyuanMoE(nn.Module):
|
||||
if event is not None and not event.query():
|
||||
time.sleep(0.001)
|
||||
|
||||
|
||||
combine_weights_used = combine_weights[:, used_indices, :]
|
||||
|
||||
combined_output = torch.einsum("suc,ucm->sm",
|
||||
@ -900,6 +893,7 @@ class HunyuanImage3Model(nn.Module):
|
||||
self.padding_idx = 128009
|
||||
self.vocab_size = 133120
|
||||
self.config = config
|
||||
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
|
||||
)
|
||||
@ -932,20 +926,27 @@ class HunyuanImage3Model(nn.Module):
|
||||
sparse_interval = max(1, len(self.layers) // additional_layers)
|
||||
|
||||
if len(self.layers[0].mlp.experts) == 0:
|
||||
self.layers[0].mlp.experts = self.moe_loader.schedule_layer_load(0)
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
|
||||
|
||||
for layer_idx, decoder_layer in enumerate(self.layers):
|
||||
|
||||
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
|
||||
self.layers[layer_idx+1].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 1)
|
||||
# maybe the second layer loading should depend on how much gpu memory is there
|
||||
next_layer = layer_idx + 1 if isinstance(self.layers[layer_idx + 1].mlp.experts, list) else layer_idx + 2
|
||||
second_next_layer = next_layer + 1 if isinstance(self.layers[layer_idx + 2].mlp.experts, list) else next_layer + 2
|
||||
|
||||
if layer_idx + 2 < len(self.layers) and len(self.layers[layer_idx + 2].mlp.experts) == 0: # load first and second layers
|
||||
self.layers[layer_idx+2].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 2)
|
||||
|
||||
if not self.additional_layers_set:
|
||||
if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval:
|
||||
self.layers[next_layers].mlp.experts = self.moe_loader.schedule_layer_load(next_layers)
|
||||
next_layers += 1
|
||||
if next_layer < len(self.layers) and len(self.layers[next_layer].mlp.experts) == 0: # not loaded
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[next_layer].mlp.experts = [expert._schedule_disk_load(next_layer, i) for i, expert in enumerate(experts)]
|
||||
|
||||
if second_next_layer < len(self.layers) and len(self.layers[second_next_layer].mlp.experts) == 0: # not loaded
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[second_next_layer].mlp.experts = [expert._schedule_disk_load(second_next_layer, i) for i, expert in enumerate(experts)]
|
||||
|
||||
if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval:
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
|
||||
next_layers += 1
|
||||
|
||||
with torch.no_grad():
|
||||
layer_outputs = decoder_layer(
|
||||
@ -1020,7 +1021,8 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
self.first_step = True
|
||||
|
||||
self.kv_cache = None
|
||||
self.token_dims = ()
|
||||
self.encode_tok = None
|
||||
self.special_tok = None
|
||||
|
||||
@staticmethod
|
||||
def get_pos_emb(custom_pos_emb, position_ids):
|
||||
@ -1043,24 +1045,36 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
|
||||
joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
|
||||
|
||||
bsz, seq_len, n_embd = inputs_embeds.shape
|
||||
cond_exists = (joint_image[:, 0, :] != -100.0).any(dim=1).any()
|
||||
|
||||
height, width = x.size(2) * 16, x.size(3) * 16
|
||||
gen_timestep_scatter_index = 4
|
||||
|
||||
def fn(string, func = self.encode_tok):
|
||||
return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\
|
||||
.unsqueeze(0).expand(bsz, -1, -1)
|
||||
|
||||
if cond_exists:
|
||||
with torch.no_grad():
|
||||
joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio
|
||||
|
||||
if self.first_step:
|
||||
token_height, token_width = x[:, -2:, 0].tolist()[0]
|
||||
self.token_dims = (int(token_height), int(token_width))
|
||||
x = x[:, :-2, :]
|
||||
else:
|
||||
token_height, token_width = self.token_dims
|
||||
joint_image[:, 2:3, :] = fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok) # updates image ratio
|
||||
|
||||
img_slices = []
|
||||
bsz, seq_len, n_embd = inputs_embeds.shape
|
||||
|
||||
cond_timestep = torch.zeros(x.size(0))
|
||||
t_emb = self.time_embed(timestep)
|
||||
|
||||
if self.first_step:
|
||||
x, token_height, token_width = self.patch_embed(x, t_emb)
|
||||
x = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = self.special_tok), fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok), fn("<timestep>", self.special_tok), x, fn("<eoi>")], dim = 1)
|
||||
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
else:
|
||||
x, token_height, token_width = self.patch_embed(x, t_emb)
|
||||
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
|
||||
inputs_embeds = torch.cat([timestep_emb, x], dim=1)
|
||||
|
||||
input_args = [inputs_embeds, x] if self.first_step else [inputs_embeds]
|
||||
|
||||
for i in range(x.size(0)):
|
||||
gen_offset = seq_len + x.size(1)
|
||||
if cond_exists:
|
||||
@ -1085,27 +1099,13 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
rope_img = [(img_s[0], (token_height, token_width))]
|
||||
rope_image_info = [rope_img if len(joint_slices_i) == 0 else rope_img + [(img_s[1], (384 // 16, 384 // 16)), (img_s[2], (256 // 16, 256 // 16))]]
|
||||
|
||||
cond_timestep = torch.zeros(inputs_embeds.size(0))
|
||||
t_emb = self.time_embed(cond_timestep)
|
||||
|
||||
|
||||
if self.first_step:
|
||||
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
else:
|
||||
t_emb = self.time_embed(timestep)
|
||||
x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb)
|
||||
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
|
||||
inputs_embeds = torch.cat([timestep_emb, x], dim=1)
|
||||
|
||||
input_args = [inputs_embeds, x] if self.first_step else [inputs_embeds]
|
||||
|
||||
#/////////////
|
||||
# cond_vae_images
|
||||
|
||||
# cond_timestep_scatter_index
|
||||
if cond_exists:
|
||||
with torch.no_grad():
|
||||
joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
joint_image[:, 3:4, :] = self.timestep_emb(cond_timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
|
||||
inputs_embeds = torch.cat([*input_args, joint_image], dim = 1)
|
||||
else:
|
||||
|
||||
@ -1342,6 +1342,11 @@ class HunyuanImage3(supported_models_base.BASE):
|
||||
state_dict["text_encoders.wte"] = state_dict["model.model.wte"]
|
||||
state_dict.pop("model.model.wte", None)
|
||||
model = model_base.HunyuanImage3(self, device = device)
|
||||
|
||||
temp_tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImage3Tokenizer()
|
||||
model.encode_tok = temp_tokenizer.tokenizer.convert_tokens_to_ids
|
||||
model.special_tok = temp_tokenizer.tokenizer.added_tokens_encoder
|
||||
|
||||
return model
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImage3Tokenizer, comfy.text_encoders.hunyuan_image.HunyuanImage3)
|
||||
|
||||
@ -22,34 +22,13 @@ class EmptyLatentHunyuanImage3(io.ComfyNode):
|
||||
io.Int.Input("height", min = 1, default = 512),
|
||||
io.Int.Input("width", min = 1, default = 512),
|
||||
io.Int.Input("batch_size", min = 1, max = 48_000, default = 1),
|
||||
io.Clip.Input("clip"),
|
||||
io.Model.Input("model")
|
||||
],
|
||||
outputs=[io.Latent.Output(display_name="latent")]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, height, width, batch_size, clip, model):
|
||||
encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids
|
||||
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
|
||||
|
||||
word_embed = clip.wte
|
||||
patch_embed = model.patch_embed
|
||||
t_embed = model.time_embed
|
||||
|
||||
def execute(cls, height, width, batch_size):
|
||||
height, width = get_target_size(height, width)
|
||||
latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device())
|
||||
|
||||
latent, tk_height, tk_width = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size)))
|
||||
|
||||
def tk_fn(token):
|
||||
return torch.tensor([token], device = latent.device, dtype = latent.dtype).unsqueeze(1).expand(batch_size, 1, latent.size(-1))
|
||||
|
||||
def fn(string, func = encode_fn):
|
||||
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
|
||||
.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
latent = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = special_fn), fn(f"<img_ratio_{int(height) // int(width)}>", special_fn), fn("<timestep>", special_fn), latent, fn("<eoi>")], dim = 1)
|
||||
latent = torch.cat([latent, tk_fn(tk_height), tk_fn(tk_width)], dim = 1)
|
||||
return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, )
|
||||
|
||||
class HunyuanImage3Conditioning(io.ComfyNode):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user