mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
Merge branch 'comfyanonymous:master' into bugfix/extra_data
This commit is contained in:
commit
031a8c7275
@ -3,7 +3,7 @@ import os
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
|
from comfy.sd import load_checkpoint
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -216,11 +216,6 @@ current_gpu_controlnets = []
|
|||||||
|
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
|
|
||||||
def unet_offload_device():
|
|
||||||
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
|
||||||
return get_torch_device()
|
|
||||||
else:
|
|
||||||
return torch.device("cpu")
|
|
||||||
|
|
||||||
def unload_model():
|
def unload_model():
|
||||||
global current_loaded_model
|
global current_loaded_model
|
||||||
@ -234,8 +229,8 @@ def unload_model():
|
|||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
|
|
||||||
|
|
||||||
current_loaded_model.model.to(unet_offload_device())
|
current_loaded_model.model.to(current_loaded_model.offload_device)
|
||||||
current_loaded_model.model_patches_to(unet_offload_device())
|
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
|
||||||
current_loaded_model.unpatch_model()
|
current_loaded_model.unpatch_model()
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
|
|
||||||
@ -260,10 +255,14 @@ def load_model_gpu(model):
|
|||||||
model.unpatch_model()
|
model.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
torch_dev = get_torch_device()
|
torch_dev = model.load_device
|
||||||
model.model_patches_to(torch_dev)
|
model.model_patches_to(torch_dev)
|
||||||
|
|
||||||
vram_set_state = vram_state
|
if is_device_cpu(torch_dev):
|
||||||
|
vram_set_state = VRAMState.DISABLED
|
||||||
|
else:
|
||||||
|
vram_set_state = vram_state
|
||||||
|
|
||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||||
model_size = model.model_size()
|
model_size = model.model_size()
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
@ -277,14 +276,14 @@ def load_model_gpu(model):
|
|||||||
pass
|
pass
|
||||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
real_model.to(get_torch_device())
|
real_model.to(torch_dev)
|
||||||
else:
|
else:
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||||
elif vram_set_state == VRAMState.LOW_VRAM:
|
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||||
|
|
||||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||||
model_accelerated = True
|
model_accelerated = True
|
||||||
return current_loaded_model
|
return current_loaded_model
|
||||||
|
|
||||||
@ -327,8 +326,34 @@ def unload_if_low_vram(model):
|
|||||||
return model.cpu()
|
return model.cpu()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def unet_offload_device():
|
||||||
|
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def text_encoder_offload_device():
|
||||||
|
if args.gpu_only or vram_state == VRAMState.SHARED:
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
def text_encoder_device():
|
def text_encoder_device():
|
||||||
if args.gpu_only:
|
if args.gpu_only or vram_state == VRAMState.SHARED:
|
||||||
|
return get_torch_device()
|
||||||
|
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
||||||
|
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def vae_device():
|
||||||
|
return get_torch_device()
|
||||||
|
|
||||||
|
def vae_offload_device():
|
||||||
|
if args.gpu_only or vram_state == VRAMState.SHARED:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
@ -422,10 +447,20 @@ def mps_mode():
|
|||||||
global cpu_state
|
global cpu_state
|
||||||
return cpu_state == CPUState.MPS
|
return cpu_state == CPUState.MPS
|
||||||
|
|
||||||
def should_use_fp16():
|
def is_device_cpu(device):
|
||||||
|
if hasattr(device, 'type'):
|
||||||
|
if (device.type == 'cpu' or device.type == 'mps'):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def should_use_fp16(device=None):
|
||||||
global xpu_available
|
global xpu_available
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
|
||||||
|
if device is not None: #TODO
|
||||||
|
if is_device_cpu(device):
|
||||||
|
return False
|
||||||
|
|
||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
56
comfy/sd.py
56
comfy/sd.py
@ -308,13 +308,15 @@ def model_lora_keys(model, key_map={}):
|
|||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, size=0):
|
def __init__(self, model, load_device, offload_device, size=0):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model = model
|
||||||
self.patches = []
|
self.patches = []
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
self.model_options = {"transformer_options":{}}
|
self.model_options = {"transformer_options":{}}
|
||||||
self.model_size()
|
self.model_size()
|
||||||
|
self.load_device = load_device
|
||||||
|
self.offload_device = offload_device
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
@ -329,7 +331,7 @@ class ModelPatcher:
|
|||||||
return size
|
return size
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model, self.size)
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size)
|
||||||
n.patches = self.patches[:]
|
n.patches = self.patches[:]
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
n.model_keys = self.model_keys
|
n.model_keys = self.model_keys
|
||||||
@ -341,6 +343,9 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
|
|
||||||
|
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||||
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
|
|
||||||
def set_model_patch(self, patch, name):
|
def set_model_patch(self, patch, name):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" not in to:
|
if "patches" not in to:
|
||||||
@ -525,13 +530,17 @@ class CLIP:
|
|||||||
clip = target.clip
|
clip = target.clip
|
||||||
tokenizer = target.tokenizer
|
tokenizer = target.tokenizer
|
||||||
|
|
||||||
self.device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
params["device"] = self.device
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
self.cond_stage_model = self.cond_stage_model.to(self.device)
|
#TODO: make sure this doesn't have a quality loss before enabling.
|
||||||
|
# if model_management.should_use_fp16(load_device):
|
||||||
|
# self.cond_stage_model.half()
|
||||||
|
|
||||||
|
self.cond_stage_model = self.cond_stage_model.to()
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||||
self.patcher = ModelPatcher(self.cond_stage_model)
|
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
@ -540,7 +549,6 @@ class CLIP:
|
|||||||
n.cond_stage_model = self.cond_stage_model
|
n.cond_stage_model = self.cond_stage_model
|
||||||
n.tokenizer = self.tokenizer
|
n.tokenizer = self.tokenizer
|
||||||
n.layer_idx = self.layer_idx
|
n.layer_idx = self.layer_idx
|
||||||
n.device = self.device
|
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def load_from_state_dict(self, sd):
|
def load_from_state_dict(self, sd):
|
||||||
@ -558,18 +566,12 @@ class CLIP:
|
|||||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
def encode_from_tokens(self, tokens, return_pooled=False):
|
||||||
if self.layer_idx is not None:
|
if self.layer_idx is not None:
|
||||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||||
try:
|
|
||||||
self.patch_model()
|
|
||||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
|
||||||
self.unpatch_model()
|
|
||||||
except Exception as e:
|
|
||||||
self.unpatch_model()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
cond_out = cond
|
model_management.load_model_gpu(self.patcher)
|
||||||
|
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
if return_pooled:
|
if return_pooled:
|
||||||
return cond_out, pooled
|
return cond, pooled
|
||||||
return cond_out
|
return cond
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
tokens = self.tokenize(text)
|
tokens = self.tokenize(text)
|
||||||
@ -603,8 +605,9 @@ class VAE:
|
|||||||
self.first_stage_model.load_state_dict(sd, strict=False)
|
self.first_stage_model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.get_torch_device()
|
device = model_management.vae_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.offload_device = model_management.vae_offload_device()
|
||||||
|
|
||||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||||
@ -649,7 +652,7 @@ class VAE:
|
|||||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
@ -657,7 +660,7 @@ class VAE:
|
|||||||
model_management.unload_model()
|
model_management.unload_model()
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
return output.movedim(1,-1)
|
return output.movedim(1,-1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
@ -677,7 +680,7 @@ class VAE:
|
|||||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
|
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
@ -685,7 +688,7 @@ class VAE:
|
|||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1,1)
|
||||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
@ -1093,6 +1096,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
if fp16:
|
if fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
|
offload_device = model_management.unet_offload_device()
|
||||||
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(state_dict, "model.diffusion_model.")
|
model.load_model_weights(state_dict, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
@ -1115,7 +1120,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
w.cond_stage_model = clip.cond_stage_model
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
load_clip_weights(w, state_dict)
|
load_clip_weights(w, state_dict)
|
||||||
|
|
||||||
return (ModelPatcher(model), clip, vae)
|
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
||||||
@ -1140,8 +1145,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if output_clipvision:
|
if output_clipvision:
|
||||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||||
|
|
||||||
|
offload_device = model_management.unet_offload_device()
|
||||||
model = model_config.get_model(sd)
|
model = model_config.get_model(sd)
|
||||||
model = model.to(model_management.unet_offload_device())
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
@ -1162,7 +1168,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
print("left over keys:", left_over)
|
print("left over keys:", left_over)
|
||||||
|
|
||||||
return (ModelPatcher(model), clip, vae, clipvision)
|
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision)
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -5,24 +5,34 @@ import comfy.ops
|
|||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
import zipfile
|
import zipfile
|
||||||
|
from . import model_management
|
||||||
|
import contextlib
|
||||||
|
|
||||||
class ClipTokenWeightEncoder:
|
class ClipTokenWeightEncoder:
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
z_empty, _ = self.encode(self.empty_tokens)
|
to_encode = list(self.empty_tokens)
|
||||||
output = []
|
|
||||||
first_pooled = None
|
|
||||||
for x in token_weight_pairs:
|
for x in token_weight_pairs:
|
||||||
tokens = [list(map(lambda a: a[0], x))]
|
tokens = list(map(lambda a: a[0], x))
|
||||||
z, pooled = self.encode(tokens)
|
to_encode.append(tokens)
|
||||||
if first_pooled is None:
|
|
||||||
first_pooled = pooled
|
out, pooled = self.encode(to_encode)
|
||||||
|
z_empty = out[0:1]
|
||||||
|
if pooled.shape[0] > 1:
|
||||||
|
first_pooled = pooled[1:2]
|
||||||
|
else:
|
||||||
|
first_pooled = pooled[0:1]
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for i in range(1, out.shape[0]):
|
||||||
|
z = out[i:i+1]
|
||||||
for i in range(len(z)):
|
for i in range(len(z)):
|
||||||
for j in range(len(z[i])):
|
for j in range(len(z[i])):
|
||||||
weight = x[j][1]
|
weight = token_weight_pairs[i - 1][j][1]
|
||||||
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
||||||
output += [z]
|
output.append(z)
|
||||||
|
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
return self.encode(self.empty_tokens)
|
return z_empty, first_pooled
|
||||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
@ -46,7 +56,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
with modeling_utils.no_init_weights():
|
with modeling_utils.no_init_weights():
|
||||||
self.transformer = CLIPTextModel(config)
|
self.transformer = CLIPTextModel(config)
|
||||||
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
if freeze:
|
if freeze:
|
||||||
self.freeze()
|
self.freeze()
|
||||||
@ -95,7 +104,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
out_tokens += [tokens_temp]
|
out_tokens += [tokens_temp]
|
||||||
|
|
||||||
if len(embedding_weights) > 0:
|
if len(embedding_weights) > 0:
|
||||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=self.device)
|
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
||||||
n = token_dict_size
|
n = token_dict_size
|
||||||
for x in embedding_weights:
|
for x in embedding_weights:
|
||||||
@ -106,24 +115,32 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
backup_embeds = self.transformer.get_input_embeddings()
|
backup_embeds = self.transformer.get_input_embeddings()
|
||||||
|
device = backup_embeds.weight.device
|
||||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||||
tokens = torch.LongTensor(tokens).to(self.device)
|
tokens = torch.LongTensor(tokens).to(device)
|
||||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
|
||||||
|
|
||||||
if self.layer == "last":
|
if backup_embeds.weight.dtype != torch.float32:
|
||||||
z = outputs.last_hidden_state
|
precision_scope = torch.autocast
|
||||||
elif self.layer == "pooled":
|
|
||||||
z = outputs.pooler_output[:, None, :]
|
|
||||||
else:
|
else:
|
||||||
z = outputs.hidden_states[self.layer_idx]
|
precision_scope = contextlib.nullcontext
|
||||||
if self.layer_norm_hidden_state:
|
|
||||||
z = self.transformer.text_model.final_layer_norm(z)
|
|
||||||
|
|
||||||
pooled_output = outputs.pooler_output
|
with precision_scope(model_management.get_autocast_device(device)):
|
||||||
if self.text_projection is not None:
|
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||||
pooled_output = pooled_output @ self.text_projection
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
return z, pooled_output
|
|
||||||
|
if self.layer == "last":
|
||||||
|
z = outputs.last_hidden_state
|
||||||
|
elif self.layer == "pooled":
|
||||||
|
z = outputs.pooler_output[:, None, :]
|
||||||
|
else:
|
||||||
|
z = outputs.hidden_states[self.layer_idx]
|
||||||
|
if self.layer_norm_hidden_state:
|
||||||
|
z = self.transformer.text_model.final_layer_norm(z)
|
||||||
|
|
||||||
|
pooled_output = outputs.pooler_output
|
||||||
|
if self.text_projection is not None:
|
||||||
|
pooled_output = pooled_output @ self.text_projection
|
||||||
|
return z.float(), pooled_output.float()
|
||||||
|
|
||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user