mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 11:50:16 +08:00
input correction and improvements
This commit is contained in:
parent
86e9a7a669
commit
b4001bbd27
@ -587,7 +587,7 @@ class LazyMoELoader(nn.Module):
|
|||||||
self.expert_pool = self.build_meta_experts()
|
self.expert_pool = self.build_meta_experts()
|
||||||
|
|
||||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
self._semaphore = threading.Semaphore(max_concurrent_loads)
|
self._semaphore = asyncio.Semaphore(max_concurrent_loads)
|
||||||
|
|
||||||
def build_meta_experts(self):
|
def build_meta_experts(self):
|
||||||
pool = {}
|
pool = {}
|
||||||
@ -632,17 +632,36 @@ class LazyMoELoader(nn.Module):
|
|||||||
getattr(model, name).data = tensor
|
getattr(model, name).data = tensor
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load_single_expert(self, layer_idx, expert_idx):
|
def _register_expert_sync(self, layer_idx, expert_idx, moe_cpu):
|
||||||
with self._semaphore:
|
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
|
||||||
return self.lazy_init(layer_idx, expert_idx)
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
|
||||||
|
self.cache._loop
|
||||||
|
)
|
||||||
|
|
||||||
|
async def lazy_load_from_disk(self, layer_idx, expert_idx):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
async with self._semaphore:
|
||||||
|
moe_cpu = await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
|
||||||
|
self._loop.call_soon_threadsafe(self._register_expert_sync, layer_idx, expert_idx, moe_cpu)
|
||||||
|
return moe_cpu
|
||||||
|
|
||||||
|
async def schedule_layer_load_progressive(self, layer_idx, num_experts = 64):
|
||||||
|
tasks = [asyncio.create_task(self.lazy_load_from_disk(layer_idx, i)) for i in range(num_experts)]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
return results
|
||||||
|
|
||||||
def schedule_layer_load(self, layer_idx, num_experts = 64):
|
def schedule_layer_load(self, layer_idx, num_experts = 64):
|
||||||
futures = []
|
fut = asyncio.run_coroutine_threadsafe(
|
||||||
for i in range(num_experts):
|
self.schedule_layer_load_progressive(layer_idx, num_experts),
|
||||||
coro = asyncio.get_event_loop().run_in_executor(self._executor, self._load_single_expert, layer_idx, i)
|
self._loop
|
||||||
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
)
|
||||||
futures.append(fut)
|
return fut
|
||||||
return futures
|
|
||||||
|
async def schedule_layer_load_(self, layer_idx):
|
||||||
|
tasks = [self.lazy_load_from_disk(layer_idx, i) for i in range(64)]
|
||||||
|
experts = await asyncio.gather(*tasks)
|
||||||
|
return experts
|
||||||
|
|
||||||
def enough_vram(required_bytes):
|
def enough_vram(required_bytes):
|
||||||
free, total = torch.cuda.mem_get_info()
|
free, total = torch.cuda.mem_get_info()
|
||||||
@ -945,7 +964,6 @@ class HunyuanImage3Model(nn.Module):
|
|||||||
if self.additional_layers_set and layer_idx <= self.additional_layers_set:
|
if self.additional_layers_set and layer_idx <= self.additional_layers_set:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
torch.cuda.synchronize()
|
|
||||||
asyncio.run_coroutine_threadsafe(
|
asyncio.run_coroutine_threadsafe(
|
||||||
self.moe_lru._async_offload_to_cpu(layer_idx),
|
self.moe_lru._async_offload_to_cpu(layer_idx),
|
||||||
self.moe_lru._loop
|
self.moe_lru._loop
|
||||||
@ -1025,7 +1043,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
|
|
||||||
joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
|
joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
|
||||||
|
|
||||||
cond_exists = (joint_image[:, 0, :] == -100.0).any(dim=1).any()
|
cond_exists = (joint_image[:, 0, :] != -100.0).any(dim=1).any()
|
||||||
|
|
||||||
gen_timestep_scatter_index = 4
|
gen_timestep_scatter_index = 4
|
||||||
|
|
||||||
@ -1048,19 +1066,24 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
if cond_exists:
|
if cond_exists:
|
||||||
vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0]
|
vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0]
|
||||||
vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1
|
vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1
|
||||||
|
vae_start += gen_offset
|
||||||
|
vae_end += gen_offset
|
||||||
|
|
||||||
vit_start = vae_end + 1 + gen_offset
|
vit_start = vae_end + 1
|
||||||
vit_end = joint_image.size(1) - 1 + gen_offset
|
vit_end = joint_image.size(1) - 1 + gen_offset
|
||||||
|
|
||||||
joint_slices_i = [
|
joint_slices_i = [
|
||||||
slice(vae_start, vae_end),
|
slice(vae_start, vae_end),
|
||||||
slice(vit_start, vit_end),
|
slice(vit_start, vit_end),
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
joint_slices_i = []
|
||||||
gen_slices_i = [slice(seq_len, gen_offset)]
|
gen_slices_i = [slice(seq_len, gen_offset)]
|
||||||
img_slices.append(gen_slices_i + joint_slices_i)
|
img_slices.append(gen_slices_i + joint_slices_i)
|
||||||
|
|
||||||
img_s = img_slices[0]
|
img_s = img_slices[0]
|
||||||
rope_image_info = [[(img_s[0], (token_height, token_width)), (img_s[1], (384 // 16, 384 // 16)), (img_s[2], (256 // 16, 256 // 16))]]
|
rope_img = [(img_s[0], (token_height, token_width))]
|
||||||
|
rope_image_info = [rope_img if len(joint_slices_i) == 0 else rope_img + [(img_s[1], (384 // 16, 384 // 16)), (img_s[2], (256 // 16, 256 // 16))]]
|
||||||
|
|
||||||
cond_timestep = torch.zeros(inputs_embeds.size(0))
|
cond_timestep = torch.zeros(inputs_embeds.size(0))
|
||||||
t_emb = self.time_embed(cond_timestep)
|
t_emb = self.time_embed(cond_timestep)
|
||||||
@ -1072,7 +1095,9 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
t_emb = self.time_embed(timestep)
|
t_emb = self.time_embed(timestep)
|
||||||
x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb)
|
x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb)
|
||||||
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
|
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
|
||||||
x = torch.cat([timestep_emb, x], dim=1)
|
inputs_embeds = torch.cat([timestep_emb, x], dim=1)
|
||||||
|
|
||||||
|
input_args = [inputs_embeds, x] if self.first_step else [inputs_embeds]
|
||||||
|
|
||||||
#/////////////
|
#/////////////
|
||||||
# cond_vae_images
|
# cond_vae_images
|
||||||
@ -1082,9 +1107,9 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||||
|
|
||||||
inputs_embeds = torch.cat([inputs_embeds, x, joint_image], dim = 1)
|
inputs_embeds = torch.cat([*input_args, joint_image], dim = 1)
|
||||||
else:
|
else:
|
||||||
inputs_embeds = torch.cat([inputs_embeds, x, joint_image[:, 1:, :]], dim = 1) # joint_image == eos_token
|
inputs_embeds = torch.cat([*input_args, joint_image[:, 1:, :]], dim = 1) # joint_image == eos_token
|
||||||
|
|
||||||
attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
|
attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
@ -1097,7 +1122,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
cos, sin = build_batch_2d_rope(
|
cos, sin = build_batch_2d_rope(
|
||||||
image_infos=rope_image_info,
|
image_infos=rope_image_info,
|
||||||
seq_len=inputs_embeds.shape[1],
|
seq_len=inputs_embeds.shape[1],
|
||||||
n_elem=self.config["hidden_size"] // self.config["num_attention_heads"],
|
n_elem=128, # head dim
|
||||||
base=10000.0,
|
base=10000.0,
|
||||||
)
|
)
|
||||||
custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device))
|
custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device))
|
||||||
|
|||||||
@ -79,13 +79,14 @@ class HunyuanImage3Conditioning(io.ComfyNode):
|
|||||||
patch_embed = model.patch_embed
|
patch_embed = model.patch_embed
|
||||||
t_embed = model.time_embed
|
t_embed = model.time_embed
|
||||||
|
|
||||||
|
text_tokens = text_encoding[0][0]
|
||||||
|
batch_size, _, hidden_size = text_tokens.shape
|
||||||
|
|
||||||
def fn(string, func = encode_fn):
|
def fn(string, func = encode_fn):
|
||||||
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
|
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
|
||||||
.view(1, 1, hidden_size).expand(batch_size, -1, hidden_size)
|
.view(1, 1, hidden_size).expand(batch_size, -1, hidden_size)
|
||||||
|
|
||||||
text_tokens = text_encoding[0][0]
|
|
||||||
text_tokens = torch.cat([fn("<|startoftext|>"), text_tokens], dim = 1)
|
text_tokens = torch.cat([fn("<|startoftext|>"), text_tokens], dim = 1)
|
||||||
batch_size, _, hidden_size = text_tokens.shape
|
|
||||||
|
|
||||||
if vae_encoding is not None or vit_encoding is not None:
|
if vae_encoding is not None or vit_encoding is not None:
|
||||||
vae_encoding, _, _ = patch_embed(vae_encoding, t_embed(torch.tensor([0]).repeat(vae_encoding.size(0))))
|
vae_encoding, _, _ = patch_embed(vae_encoding, t_embed(torch.tensor([0]).repeat(vae_encoding.size(0))))
|
||||||
@ -93,11 +94,13 @@ class HunyuanImage3Conditioning(io.ComfyNode):
|
|||||||
joint_image = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn("<img_ratio_3>", special_fn), fn("<timestep>", special_fn), vae_encoding, fn("<joint_img_sep>"), vit_encoding, fn("<eoi>"), fn("<|endoftext|>")], dim = 1)
|
joint_image = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn("<img_ratio_3>", special_fn), fn("<timestep>", special_fn), vae_encoding, fn("<joint_img_sep>"), vit_encoding, fn("<eoi>"), fn("<|endoftext|>")], dim = 1)
|
||||||
vae_mask = torch.ones(joint_image.size(1))
|
vae_mask = torch.ones(joint_image.size(1))
|
||||||
vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:]))
|
vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:]))
|
||||||
|
vae_mask = vae_mask.unsqueeze(0).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
pad_token = torch.tensor([-100.0]).view(1, 1, 1).expand(batch_size, 1, hidden_size)
|
pad_token = torch.tensor([-100.0]).view(1, 1, 1).expand(batch_size, 1, hidden_size)
|
||||||
joint_image = torch.cat([pad_token, fn("<|endoftext|>")], dim = 1)
|
joint_image = torch.cat([pad_token, fn("<|endoftext|>")], dim = 1)
|
||||||
|
vae_mask = torch.empty_like(joint_image)
|
||||||
|
|
||||||
ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.to(joint_image.dtype)])
|
ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask, text_tokens.to(joint_image.dtype)])
|
||||||
|
|
||||||
uncond_ragged_tensors = None
|
uncond_ragged_tensors = None
|
||||||
if text_encoding_negative is not None:
|
if text_encoding_negative is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user