mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
365 lines
16 KiB
Python
365 lines
16 KiB
Python
from typing import Optional, Callable
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from comfy.text_encoders.bert import (
|
|
BertIntermediate as ClapTextIntermediate,
|
|
BertOutput as ClapTextOutput,
|
|
)
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
|
|
import os
|
|
from comfy import sd1_clip
|
|
from transformers import AutoTokenizer
|
|
|
|
def apply_chunking_to_forward(
|
|
forward_fn: Callable[..., torch.Tensor],
|
|
chunk_size: int,
|
|
chunk_dim: int,
|
|
*input_tensors,
|
|
) -> torch.Tensor:
|
|
|
|
if chunk_size > 0:
|
|
# chunk into tuples and apply forward_fn to each element in the tuple
|
|
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
|
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
|
|
|
|
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
|
|
return torch.cat(output_chunks, dim=chunk_dim)
|
|
|
|
return forward_fn(*input_tensors)
|
|
|
|
# from modeling_roberta
|
|
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
|
mask = input_ids.ne(padding_idx).int()
|
|
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
|
return incremental_indices.long() + padding_idx
|
|
|
|
class ClapProjectionLayer(nn.Module):
|
|
def __init__(self, hidden_size, projection_dim, device, dtype, operations):
|
|
super().__init__()
|
|
self.linear1 = operations.Linear(hidden_size, projection_dim, device=device, dtype=dtype)
|
|
self.activation = torch.nn.ReLU()
|
|
self.linear2 = operations.Linear(projection_dim, projection_dim, device=device, dtype=dtype)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.linear1(hidden_states)
|
|
hidden_states = self.activation(hidden_states)
|
|
hidden_states = self.linear2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
# same as RobertaEmbeddings
|
|
class ClapTextEmbeddings(nn.Module):
|
|
def __init__(self, vocab_size, hidden_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps, device, dtype, operations):
|
|
super().__init__()
|
|
self.word_embeddings = operations.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id, device=device, dtype=dtype)
|
|
self.position_embeddings = operations.Embedding(max_position_embeddings, hidden_size, device=device, dtype=dtype)
|
|
self.token_type_embeddings = operations.Embedding(type_vocab_size, hidden_size, device=device, dtype=dtype)
|
|
|
|
self.LayerNorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
|
self.register_buffer(
|
|
"position_ids", torch.arange(max_position_embeddings, device=device, dtype=dtype).expand((1, -1)), persistent=True
|
|
)
|
|
self.register_buffer(
|
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=device), persistent=True
|
|
)
|
|
|
|
self.padding_idx = pad_token_id
|
|
self.position_embeddings = operations.Embedding(
|
|
max_position_embeddings, hidden_size, padding_idx=self.padding_idx, device=device, dtype=dtype
|
|
)
|
|
|
|
def forward(
|
|
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
|
):
|
|
if position_ids is None:
|
|
if input_ids is not None:
|
|
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
|
|
else:
|
|
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
|
|
|
if input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
else:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
seq_length = input_shape[1]
|
|
|
|
if token_type_ids is None:
|
|
if hasattr(self, "token_type_ids"):
|
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
|
token_type_ids = buffered_token_type_ids_expanded
|
|
else:
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
|
|
embeddings = inputs_embeds + token_type_embeddings
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
embeddings += position_embeddings
|
|
embeddings = self.LayerNorm(embeddings)
|
|
return embeddings
|
|
|
|
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
|
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
sequence_length = input_shape[1]
|
|
|
|
position_ids = torch.arange(
|
|
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
|
|
)
|
|
return position_ids.unsqueeze(0).expand(input_shape)
|
|
|
|
# same as AlignTextSelfAttention
|
|
class ClapTextSelfAttention(nn.Module):
|
|
def __init__(self, num_attention_heads, hidden_size, device, dtype, operations):
|
|
super().__init__()
|
|
|
|
self.num_attention_heads = num_attention_heads
|
|
self.attention_head_size = int(hidden_size / num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
|
self.query = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype)
|
|
self.key = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype)
|
|
self.value = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype)
|
|
|
|
self.scaling = self.attention_head_size**-0.5
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor]:
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
|
|
|
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
query_states, key_states, value_states = [t.contiguous() for t in (query_states, key_states, value_states)]
|
|
attention_mask = attention_mask.to(query_states.dtype)
|
|
attn_output = optimized_attention(query_states, key_states, value_states, self.num_attention_heads, mask = attention_mask, skip_output_reshape=True, skip_reshape=True)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
return attn_output.reshape(*input_shape, -1).contiguous()
|
|
|
|
class ClapTextSelfOutput(nn.Module):
|
|
def __init__(self, hidden_size, layer_norm_eps, device, dtype, operations):
|
|
super().__init__()
|
|
self.dense = operations.Linear(hidden_size, hidden_size, device=device, dtype=dtype)
|
|
self.LayerNorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
# same as AlignTextAttention
|
|
class ClapTextAttention(nn.Module):
|
|
def __init__(self, num_attention_heads, hidden_size, layer_norm_eps, device, dtype, operations):
|
|
super().__init__()
|
|
self.self = ClapTextSelfAttention(num_attention_heads, hidden_size, device=device, dtype=dtype, operations = operations)
|
|
self.output = ClapTextSelfOutput(hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor]:
|
|
self_outputs = self.self(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
**kwargs,
|
|
)
|
|
return self.output(self_outputs, hidden_states)
|
|
|
|
# same as AlignTextLayer
|
|
class ClapTextLayer(nn.Module):
|
|
def __init__(self, num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, device, dtype, operations):
|
|
super().__init__()
|
|
# maybe we could allow chunking dynamically
|
|
self.chunk_size_feed_forward = 0
|
|
self.seq_len_dim = 1
|
|
self.attention = ClapTextAttention(num_attention_heads, hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
|
|
self.intermediate = ClapTextIntermediate(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations)
|
|
self.output = ClapTextOutput(intermediate_size, hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor]:
|
|
self_attention_outputs = self.attention(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
**kwargs,
|
|
)
|
|
attention_output = self_attention_outputs
|
|
|
|
layer_output = apply_chunking_to_forward(
|
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
|
)
|
|
return layer_output
|
|
|
|
def feed_forward_chunk(self, attention_output):
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.output(intermediate_output, attention_output)
|
|
return layer_output
|
|
|
|
|
|
# same as AlignTextEncoder
|
|
class ClapTextEncoder(nn.Module):
|
|
def __init__(self, num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, num_hidden_layers, device, dtype, operations):
|
|
super().__init__()
|
|
self.layer = nn.ModuleList([ClapTextLayer(num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
|
|
for _ in range(num_hidden_layers)])
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
**kwargs,
|
|
):
|
|
|
|
for _, layer_module in enumerate(self.layer):
|
|
hidden_states = layer_module(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
**kwargs,
|
|
)
|
|
return hidden_states
|
|
|
|
class ClapTextPooler(nn.Module):
|
|
def __init__(self, hidden_size, device, dtype, operations):
|
|
super().__init__()
|
|
self.dense = operations.Linear(hidden_size, hidden_size, device=device, dtype=dtype)
|
|
self.activation = nn.Tanh()
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
first_token_tensor = hidden_states[:, 0]
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
class ClapTextModel(nn.Module):
|
|
|
|
def __init__(self, num_attention_heads, vocab_size, hidden_size, intermediate_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps, num_hidden_layers,
|
|
device, dtype, operations, add_pooling_layer=True):
|
|
super().__init__()
|
|
|
|
self.embeddings = ClapTextEmbeddings(vocab_size, hidden_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps,
|
|
device=device, dtype=dtype, operations=operations,)
|
|
self.encoder = ClapTextEncoder(num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, num_hidden_layers, device=device, dtype=dtype, operations=operations,)
|
|
self.pooler = ClapTextPooler(hidden_size, device=device, dtype=dtype, operations=operations,) if add_pooling_layer else None
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
embeds: Optional[torch.Tensor] = None,
|
|
):
|
|
|
|
if input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
elif embeds is not None:
|
|
input_shape = embeds.size()[:-1]
|
|
|
|
batch_size, seq_length = input_shape
|
|
device = input_ids.device if input_ids is not None else embeds.device
|
|
|
|
if token_type_ids is None:
|
|
if hasattr(self.embeddings, "token_type_ids"):
|
|
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
|
token_type_ids = buffered_token_type_ids_expanded
|
|
else:
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
|
|
|
embedding_output = self.embeddings(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
token_type_ids=token_type_ids,
|
|
inputs_embeds=embeds,
|
|
)
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
attention_mask=attention_mask,
|
|
)
|
|
sequence_output = encoder_outputs
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
|
return sequence_output, pooled_output
|
|
|
|
class ClapTextModelWithProjection(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
hidden_size: int = 768,
|
|
intermediate_size: int = 3072,
|
|
layer_norm_eps: float = 1e-12,
|
|
max_position_embeddings: int = 514,
|
|
num_attention_heads: int = 12,
|
|
num_hidden_layers: int = 12,
|
|
projection_dim: int = 512,
|
|
type_vocab_size: int = 1,
|
|
vocab_size: int = 50265,
|
|
pad_token_id: int = 1,
|
|
):
|
|
super().__init__()
|
|
self.num_layers = num_hidden_layers
|
|
self.text_model = ClapTextModel(num_attention_heads, vocab_size, hidden_size, intermediate_size, pad_token_id, max_position_embeddings,
|
|
type_vocab_size, layer_norm_eps, num_hidden_layers, device=device, dtype=dtype, operations=operations)
|
|
self.text_projection = ClapProjectionLayer(hidden_size, projection_dim, device=device, dtype=dtype, operations=operations,)
|
|
|
|
def get_input_embeddings(self):
|
|
return self.text_model.embeddings.word_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
embeds = None,
|
|
**kwargs
|
|
):
|
|
|
|
text_outputs = self.text_model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
embeds=embeds
|
|
)
|
|
|
|
pooled_output = text_outputs[1]
|
|
text_embeds = self.text_projection(pooled_output)
|
|
|
|
return text_outputs[0], torch.tensor([]), text_embeds
|
|
|
|
class ClapTextEncoderModel(sd1_clip.SDClipModel):
|
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
|
self.dtypes = set([dtype])
|
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 1}, layer_norm_hidden_state=False, model_class=ClapTextModelWithProjection, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
|
|
|
class ClapLargeTokenizer(sd1_clip.SDTokenizer):
|
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clap_tokenizer")
|
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='clap_l', tokenizer_class=AutoTokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=1, tokenizer_data=tokenizer_data)
|