Lowercase T for torch.tensor

This commit is contained in:
Max Tretikov 2024-06-14 15:58:59 -06:00
parent bbfb2b4950
commit cbe69364db

View File

@ -132,7 +132,7 @@ class SDClipModel(torch.nn.Module):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.Tensor(tokens, dtype=torch.long).to(device)
tokens = torch.tensor(tokens, dtype=torch.long).to(device)
attention_mask = None
if self.enable_attention_masks: