Merge branch 'comfyanonymous:master' into fix/secure-combo

This commit is contained in:
Dr.Lt.Data 2023-07-16 22:58:09 +09:00 committed by GitHub
commit 1397c17781
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 36 deletions

View File

@ -233,10 +233,9 @@ def unload_model():
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
model_accelerated = False
current_loaded_model.unpatch_model()
current_loaded_model.model.to(current_loaded_model.offload_device)
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
current_loaded_model.unpatch_model()
current_loaded_model = None
if vram_state != VRAMState.HIGH_VRAM:
soft_empty_cache()
@ -258,15 +257,11 @@ def load_model_gpu(model):
if model is current_loaded_model:
return
unload_model()
try:
real_model = model.patch_model()
except Exception as e:
model.unpatch_model()
raise e
torch_dev = model.load_device
model.model_patches_to(torch_dev)
model.model_patches_to(model.model_dtype())
current_loaded_model = model
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
@ -280,21 +275,29 @@ def load_model_gpu(model):
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
current_loaded_model = model
real_model = model.model
if vram_set_state == VRAMState.DISABLED:
pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False
real_model.to(torch_dev)
else:
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
try:
real_model = model.patch_model()
except Exception as e:
model.unpatch_model()
unload_model()
raise e
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
return current_loaded_model
def load_controlnet_gpu(control_models):

View File

@ -340,7 +340,7 @@ class ModelPatcher:
weight = model_sd[key]
if key not in self.backup:
self.backup[key] = weight.clone()
self.backup[key] = weight.to(self.offload_device, copy=True)
temp_weight = weight.to(torch.float32, copy=True)
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
@ -367,15 +367,16 @@ class ModelPatcher:
else:
weight += alpha * w1.type(weight.dtype).to(weight.device)
elif len(v) == 4: #lora/locon
mat1 = v[0]
mat2 = v[1]
mat1 = v[0].float().to(weight.device)
mat2 = v[1].float().to(weight.device)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
mat3 = v[3].float().to(weight.device)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
@ -389,20 +390,24 @@ class ModelPatcher:
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
else:
w1 = w1.float().to(weight.device)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float(), w2_b.float())
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
else:
w2 = w2.float().to(weight.device)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
else: #loha
w1a = v[0]
w1b = v[1]
@ -413,21 +418,27 @@ class ModelPatcher:
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float())
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float())
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
else:
m1 = torch.mm(w1a.float(), w1b.float())
m2 = torch.mm(w2a.float(), w2b.float())
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
return weight
def unpatch_model(self):
model_sd = self.model_state_dict()
keys = list(self.backup.keys())
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
for k in keys:
model_sd[k][:] = self.backup[k]
del self.backup[k]
set_attr(self.model, k, self.backup[k])
self.backup = {}

View File

@ -4,18 +4,20 @@ import struct
import comfy.checkpoint_pickle
import safetensors.torch
def load_torch_file(ckpt, safe_load=False):
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors"):
sd = safetensors.torch.load_file(ckpt, device="cpu")
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle)
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:

View File

@ -497,7 +497,9 @@ class LoraLoader:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
del self.loaded_lora
temp = self.loaded_lora
self.loaded_lora = None
del temp
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)