mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-23 10:03:36 +08:00
Merge branch 'master' into feat/add-helios
This commit is contained in:
commit
f9d26fc23f
@ -149,6 +149,9 @@ class Attention(nn.Module):
|
|||||||
seq_img = hidden_states.shape[1]
|
seq_img = hidden_states.shape[1]
|
||||||
seq_txt = encoder_hidden_states.shape[1]
|
seq_txt = encoder_hidden_states.shape[1]
|
||||||
|
|
||||||
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
extra_options = transformer_options.copy()
|
||||||
|
|
||||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
@ -167,15 +170,22 @@ class Attention(nn.Module):
|
|||||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||||
|
|
||||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
|
||||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
|
||||||
|
|
||||||
if encoder_hidden_states_mask is not None:
|
if encoder_hidden_states_mask is not None:
|
||||||
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
||||||
else:
|
else:
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
|
extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
|
||||||
|
if "attn1_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_patch"]
|
||||||
|
for p in patch:
|
||||||
|
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
|
||||||
|
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
|
||||||
|
|
||||||
|
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||||
|
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||||
attn_mask, transformer_options=transformer_options,
|
attn_mask, transformer_options=transformer_options,
|
||||||
skip_reshape=True)
|
skip_reshape=True)
|
||||||
@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
timestep_zero_index = None
|
timestep_zero_index = None
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
|
ref_num_tokens = []
|
||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
index = 0
|
index = 0
|
||||||
@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
ref_num_tokens.append(kontext.shape[1])
|
||||||
if timestep_zero:
|
if timestep_zero:
|
||||||
if index > 0:
|
if index > 0:
|
||||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||||
timestep_zero_index = num_embeds
|
timestep_zero_index = num_embeds
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
|
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
||||||
|
|
||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
||||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
|
||||||
del ids, txt_ids, img_ids
|
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
|
if "post_input" in patches:
|
||||||
|
for p in patches["post_input"]:
|
||||||
|
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||||
|
hidden_states = out["img"]
|
||||||
|
encoder_hidden_states = out["txt"]
|
||||||
|
img_ids = out["img_ids"]
|
||||||
|
txt_ids = out["txt_ids"]
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
|||||||
@ -599,6 +599,27 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
def model_patches_call_function(self, function_name="cleanup", arguments={}):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches" in to:
|
||||||
|
patches = to["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], function_name):
|
||||||
|
getattr(patch_list[i], function_name)(**arguments)
|
||||||
|
if "patches_replace" in to:
|
||||||
|
patches = to["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], function_name):
|
||||||
|
getattr(patch_list[k], function_name)(**arguments)
|
||||||
|
if "model_function_wrapper" in self.model_options:
|
||||||
|
wrap_func = self.model_options["model_function_wrapper"]
|
||||||
|
if hasattr(wrap_func, function_name):
|
||||||
|
getattr(wrap_func, function_name)(**arguments)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
if hasattr(self.model, "get_dtype"):
|
if hasattr(self.model, "get_dtype"):
|
||||||
return self.model.get_dtype()
|
return self.model.get_dtype()
|
||||||
@ -1062,6 +1083,7 @@ class ModelPatcher:
|
|||||||
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
|
self.model_patches_call_function(function_name="cleanup")
|
||||||
self.clean_hooks()
|
self.clean_hooks()
|
||||||
if hasattr(self.model, "current_patcher"):
|
if hasattr(self.model, "current_patcher"):
|
||||||
self.model.current_patcher = None
|
self.model.current_patcher = None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user