mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 08:40:50 +08:00
Add transformer_options for attention
This commit is contained in:
parent
8fd49be716
commit
0920cdcf63
@ -108,13 +108,13 @@ class SelfAttention(nn.Module):
|
||||
v = self.to_value(x)
|
||||
|
||||
shape = q.shape[:-1]
|
||||
q = q.reshape(*shape, self.num_heads, -1)
|
||||
k = k.reshape(*shape, self.num_heads, -1)
|
||||
v = v.reshape(*shape, self.num_heads, -1)
|
||||
q = q.view(*shape, self.num_heads, -1)
|
||||
k = k.view(*shape, self.num_heads, -1)
|
||||
v = v.view(*shape, self.num_heads, -1)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, freqs):
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x)
|
||||
|
||||
q = apply_rope1(self.query_norm(q), freqs)
|
||||
@ -124,9 +124,10 @@ class SelfAttention(nn.Module):
|
||||
q.flatten(-2, -1),
|
||||
k.flatten(-2, -1),
|
||||
v.flatten(-2, -1),
|
||||
heads=self.num_heads)
|
||||
out = self.out_layer(out)
|
||||
return out
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
@ -145,18 +146,18 @@ class CrossAttention(nn.Module):
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def get_qkv(self, x, cond):
|
||||
query = self.to_query(x)
|
||||
key = self.to_key(cond)
|
||||
value = self.to_value(cond)
|
||||
q = self.to_query(x)
|
||||
k = self.to_key(cond)
|
||||
v = self.to_value(cond)
|
||||
|
||||
shape, cond_shape = query.shape[:-1], key.shape[:-1]
|
||||
query = query.reshape(*shape, self.num_heads, -1)
|
||||
key = key.reshape(*cond_shape, self.num_heads, -1)
|
||||
value = value.reshape(*cond_shape, self.num_heads, -1)
|
||||
shape, cond_shape = q.shape[:-1], k.shape[:-1]
|
||||
q = q.view(*shape, self.num_heads, -1)
|
||||
k = k.view(*cond_shape, self.num_heads, -1)
|
||||
v = v.view(*cond_shape, self.num_heads, -1)
|
||||
|
||||
return query, key, value
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, cond):
|
||||
def forward(self, x, cond, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x, cond)
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
@ -165,10 +166,11 @@ class CrossAttention(nn.Module):
|
||||
q.flatten(-2, -1),
|
||||
k.flatten(-2, -1),
|
||||
v.flatten(-2, -1),
|
||||
heads=self.num_heads)
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
|
||||
out = self.out_layer(out)
|
||||
return out
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
@ -220,11 +222,11 @@ class TransformerEncoderBlock(nn.Module):
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, time_embed, freqs):
|
||||
def forward(self, x, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = self.self_attention(out, freqs)
|
||||
out = self.self_attention(out, freqs, transformer_options=transformer_options)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
@ -249,19 +251,17 @@ class TransformerDecoderBlock(nn.Module):
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, freqs):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(
|
||||
self.visual_modulation(time_embed), 3, dim=-1
|
||||
)
|
||||
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||
# self attention
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.self_attention(visual_out, freqs)
|
||||
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# cross attention
|
||||
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.cross_attention(visual_out, text_embed)
|
||||
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# feed forward
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
@ -374,13 +374,13 @@ class Kandinsky5(nn.Module):
|
||||
time_embed = self.time_embeddings(timestep) + self.pooled_text_embeddings(y)
|
||||
|
||||
for block in self.text_transformer_blocks:
|
||||
context = block(context, time_embed, freqs_text)
|
||||
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
|
||||
for block in self.visual_transformer_blocks:
|
||||
visual_embed = block(visual_embed.flatten(1, -2), context, time_embed, freqs=freqs)
|
||||
visual_embed = block(visual_embed.flatten(1, -2), context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
return self.out_layer(visual_embed, time_embed)
|
||||
|
||||
@ -6,12 +6,11 @@ from .llama import Qwen25_7BVLI
|
||||
class Kandinsky5Tokenizer(QwenImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
# yes the typo "promt" was also in the original template...
|
||||
# yes the typos "promt", "scren" were also in the original template... todo: check if it matters
|
||||
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
self.llama_template_image2video = "<|im_start|>system\nYou are a promt engineer. Your task is to create a highly detailed and effective video description based on a provided input image.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe main characters actions.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
@ -34,12 +33,8 @@ class Kandinsky5TEModel(QwenImageTEModel):
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
#tok_pairs = token_weight_pairs["qwen25_7b"][0]
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
template_end = -1
|
||||
|
||||
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=template_end)
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
|
||||
|
||||
return cond, l_pooled, extra
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user