mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-05 19:12:41 +08:00
returned kv cache for image generation
This commit is contained in:
parent
10a17dc85d
commit
ca119c44fb
@ -832,6 +832,7 @@ class HunyuanImage3Attention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
past_key_value,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
@ -858,6 +859,11 @@ class HunyuanImage3Attention(nn.Module):
|
|||||||
query_states = query_states.to(value_states.dtype)
|
query_states = query_states.to(value_states.dtype)
|
||||||
key_states = key_states.to(value_states.dtype)
|
key_states = key_states.to(value_states.dtype)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"cache_position": position_ids}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
query_states = query_states.to(key_states.dtype)
|
||||||
|
|
||||||
key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups)
|
key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups)
|
||||||
value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups)
|
value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups)
|
||||||
|
|
||||||
@ -870,7 +876,7 @@ class HunyuanImage3Attention(nn.Module):
|
|||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output
|
return attn_output, past_key_value
|
||||||
|
|
||||||
class HunyuanImage3DecoderLayer(nn.Module):
|
class HunyuanImage3DecoderLayer(nn.Module):
|
||||||
def __init__(self, config, layer_idx: int, moe_lru=None):
|
def __init__(self, config, layer_idx: int, moe_lru=None):
|
||||||
@ -900,7 +906,7 @@ class HunyuanImage3DecoderLayer(nn.Module):
|
|||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states = self.self_attn(
|
hidden_states, past_key_value = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -917,7 +923,7 @@ class HunyuanImage3DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states, past_key_value)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@ -1039,6 +1045,9 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, condition, timestep, **kwargs):
|
def forward(self, x, condition, timestep, **kwargs):
|
||||||
|
|
||||||
|
cond, uncond = condition[:4], condition[4:]
|
||||||
|
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
|
||||||
|
|
||||||
if self.kv_cache is None:
|
if self.kv_cache is None:
|
||||||
# TODO: should change when higgsv2 gets merged
|
# TODO: should change when higgsv2 gets merged
|
||||||
self.kv_cache = HunyuanStaticCache(
|
self.kv_cache = HunyuanStaticCache(
|
||||||
@ -1049,10 +1058,8 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
image_mask = torch.ones(x.size(1))
|
image_mask = torch.ones(x.size(1))
|
||||||
image_mask[:, :5] = torch.zeros(5); image_mask[:, -4:] = torch.zeros(4)
|
image_mask[:, :3] = torch.zeros(5); image_mask[:, -1] = torch.zeros(0)
|
||||||
gen_timestep_scatter_index = 4
|
gen_timestep_scatter_index = 4
|
||||||
cond, uncond = condition[:4], condition[4:]
|
|
||||||
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
|
|
||||||
joint_image[:, 2] = x[:, 2] # updates image ratio
|
joint_image[:, 2] = x[:, 2] # updates image ratio
|
||||||
|
|
||||||
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
||||||
@ -1126,6 +1133,11 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
)
|
)
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
# safety no-op
|
||||||
|
past_key_value = outputs[1]
|
||||||
|
if past_key_value is not None:
|
||||||
|
self.kv_cache = past_key_value
|
||||||
|
|
||||||
hidden_states = hidden_states.to(input_ids.device)
|
hidden_states = hidden_states.to(input_ids.device)
|
||||||
diffusion_prediction = self.ragged_final_layer(
|
diffusion_prediction = self.ragged_final_layer(
|
||||||
hidden_states, image_mask, timestep, token_h, token_w, self.first_step)
|
hidden_states, image_mask, timestep, token_h, token_w, self.first_step)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user