mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 16:32:34 +08:00
Support generating attention masks for left padded text encoders. (#12454)
This commit is contained in:
parent
e1add563f9
commit
831351a29e
@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
def process_tokens(self, tokens, device):
|
def process_tokens(self, tokens, device):
|
||||||
end_token = self.special_tokens.get("end", None)
|
end_token = self.special_tokens.get("end", None)
|
||||||
|
pad_token = self.special_tokens.get("pad", -1)
|
||||||
if end_token is None:
|
if end_token is None:
|
||||||
cmp_token = self.special_tokens.get("pad", -1)
|
cmp_token = pad_token
|
||||||
else:
|
else:
|
||||||
cmp_token = end_token
|
cmp_token = end_token
|
||||||
|
|
||||||
@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
other_embeds = []
|
other_embeds = []
|
||||||
eos = False
|
eos = False
|
||||||
index = 0
|
index = 0
|
||||||
|
left_pad = False
|
||||||
for y in x:
|
for y in x:
|
||||||
if isinstance(y, numbers.Integral):
|
if isinstance(y, numbers.Integral):
|
||||||
if eos:
|
token = int(y)
|
||||||
|
if index == 0 and token == pad_token:
|
||||||
|
left_pad = True
|
||||||
|
|
||||||
|
if eos or (left_pad and token == pad_token):
|
||||||
attention_mask.append(0)
|
attention_mask.append(0)
|
||||||
else:
|
else:
|
||||||
attention_mask.append(1)
|
attention_mask.append(1)
|
||||||
token = int(y)
|
left_pad = False
|
||||||
|
|
||||||
tokens_temp += [token]
|
tokens_temp += [token]
|
||||||
if not eos and token == cmp_token:
|
if not eos and token == cmp_token and not left_pad:
|
||||||
if end_token is None:
|
if end_token is None:
|
||||||
attention_mask[-1] = 0
|
attention_mask[-1] = 0
|
||||||
eos = True
|
eos = True
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import comfy.utils
|
|||||||
def sample_manual_loop_no_classes(
|
def sample_manual_loop_no_classes(
|
||||||
model,
|
model,
|
||||||
ids=None,
|
ids=None,
|
||||||
paddings=[],
|
|
||||||
execution_dtype=None,
|
execution_dtype=None,
|
||||||
cfg_scale: float = 2.0,
|
cfg_scale: float = 2.0,
|
||||||
temperature: float = 0.85,
|
temperature: float = 0.85,
|
||||||
@ -36,9 +35,6 @@ def sample_manual_loop_no_classes(
|
|||||||
|
|
||||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||||
embeds_batch = embeds.shape[0]
|
embeds_batch = embeds.shape[0]
|
||||||
for i, t in enumerate(paddings):
|
|
||||||
attention_mask[i, :t] = 0
|
|
||||||
attention_mask[i, t:] = 1
|
|
||||||
|
|
||||||
output_audio_codes = []
|
output_audio_codes = []
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
@ -135,13 +131,11 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
|
|||||||
pos_pad = (len(negative) - len(positive))
|
pos_pad = (len(negative) - len(positive))
|
||||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||||
|
|
||||||
paddings = [pos_pad, neg_pad]
|
|
||||||
ids = [positive, negative]
|
ids = [positive, negative]
|
||||||
else:
|
else:
|
||||||
paddings = []
|
|
||||||
ids = [positive]
|
ids = [positive]
|
||||||
|
|
||||||
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||||
|
|
||||||
|
|
||||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user