diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 6436dd304..63cc150d4 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -108,13 +108,13 @@ class SelfAttention(nn.Module): v = self.to_value(x) shape = q.shape[:-1] - q = q.reshape(*shape, self.num_heads, -1) - k = k.reshape(*shape, self.num_heads, -1) - v = v.reshape(*shape, self.num_heads, -1) + q = q.view(*shape, self.num_heads, -1) + k = k.view(*shape, self.num_heads, -1) + v = v.view(*shape, self.num_heads, -1) return q, k, v - def forward(self, x, freqs): + def forward(self, x, freqs, transformer_options={}): q, k, v = self.get_qkv(x) q = apply_rope1(self.query_norm(q), freqs) @@ -124,9 +124,10 @@ class SelfAttention(nn.Module): q.flatten(-2, -1), k.flatten(-2, -1), v.flatten(-2, -1), - heads=self.num_heads) - out = self.out_layer(out) - return out + heads=self.num_heads, + transformer_options=transformer_options + ) + return self.out_layer(out) class CrossAttention(nn.Module): @@ -145,18 +146,18 @@ class CrossAttention(nn.Module): self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) def get_qkv(self, x, cond): - query = self.to_query(x) - key = self.to_key(cond) - value = self.to_value(cond) + q = self.to_query(x) + k = self.to_key(cond) + v = self.to_value(cond) - shape, cond_shape = query.shape[:-1], key.shape[:-1] - query = query.reshape(*shape, self.num_heads, -1) - key = key.reshape(*cond_shape, self.num_heads, -1) - value = value.reshape(*cond_shape, self.num_heads, -1) + shape, cond_shape = q.shape[:-1], k.shape[:-1] + q = q.view(*shape, self.num_heads, -1) + k = k.view(*cond_shape, self.num_heads, -1) + v = v.view(*cond_shape, self.num_heads, -1) - return query, key, value + return q, k, v - def forward(self, x, cond): + def forward(self, x, cond, transformer_options={}): q, k, v = self.get_qkv(x, cond) q = self.query_norm(q) k = self.key_norm(k) @@ -165,10 +166,11 @@ class CrossAttention(nn.Module): q.flatten(-2, -1), k.flatten(-2, -1), v.flatten(-2, -1), - heads=self.num_heads) + heads=self.num_heads, + transformer_options=transformer_options + ) - out = self.out_layer(out) - return out + return self.out_layer(out) class FeedForward(nn.Module): @@ -220,11 +222,11 @@ class TransformerEncoderBlock(nn.Module): self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) - def forward(self, x, time_embed, freqs): + def forward(self, x, time_embed, freqs, transformer_options={}): self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) shift, scale, gate = get_shift_scale_gate(self_attn_params) out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) - out = self.self_attention(out, freqs) + out = self.self_attention(out, freqs, transformer_options=transformer_options) x = apply_gate_sum(x, out, gate) shift, scale, gate = get_shift_scale_gate(ff_params) @@ -249,19 +251,17 @@ class TransformerDecoderBlock(nn.Module): self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) - def forward(self, visual_embed, text_embed, time_embed, freqs): - self_attn_params, cross_attn_params, ff_params = torch.chunk( - self.visual_modulation(time_embed), 3, dim=-1 - ) + def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}): + self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1) # self attention shift, scale, gate = get_shift_scale_gate(self_attn_params) visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) - visual_out = self.self_attention(visual_out, freqs) + visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) # cross attention shift, scale, gate = get_shift_scale_gate(cross_attn_params) visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) - visual_out = self.cross_attention(visual_out, text_embed) + visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) # feed forward shift, scale, gate = get_shift_scale_gate(ff_params) @@ -374,13 +374,13 @@ class Kandinsky5(nn.Module): time_embed = self.time_embeddings(timestep) + self.pooled_text_embeddings(y) for block in self.text_transformer_blocks: - context = block(context, time_embed, freqs_text) + context = block(context, time_embed, freqs_text, transformer_options=transformer_options) visual_embed = self.visual_embeddings(x) visual_shape = visual_embed.shape[:-1] for block in self.visual_transformer_blocks: - visual_embed = block(visual_embed.flatten(1, -2), context, time_embed, freqs=freqs) + visual_embed = block(visual_embed.flatten(1, -2), context, time_embed, freqs=freqs, transformer_options=transformer_options) visual_embed = visual_embed.reshape(*visual_shape, -1) return self.out_layer(visual_embed, time_embed) diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py index 8c25e8b4e..adb35f0ea 100644 --- a/comfy/text_encoders/kandinsky5.py +++ b/comfy/text_encoders/kandinsky5.py @@ -6,12 +6,11 @@ from .llama import Qwen25_7BVLI class Kandinsky5Tokenizer(QwenImageTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - # yes the typo "promt" was also in the original template... + # yes the typos "promt", "scren" were also in the original template... todo: check if it matters self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" self.llama_template_image2video = "<|im_start|>system\nYou are a promt engineer. Your task is to create a highly detailed and effective video description based on a provided input image.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe main characters actions.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = super().tokenize_with_weights(text, return_word_ids, **kwargs) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) @@ -34,12 +33,8 @@ class Kandinsky5TEModel(QwenImageTEModel): self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) def encode_token_weights(self, token_weight_pairs): - #tok_pairs = token_weight_pairs["qwen25_7b"][0] - token_weight_pairs_l = token_weight_pairs["l"] - template_end = -1 - - cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=template_end) - l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) + cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1) + l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"]) return cond, l_pooled, extra