mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
Merge branch 'comfyanonymous:master' into feature/preview-latent
This commit is contained in:
commit
9f643d845c
@ -281,19 +281,23 @@ def load_model_gpu(model):
|
|||||||
vram_set_state = VRAMState.LOW_VRAM
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
|
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
patch_model_to = None
|
||||||
if vram_set_state == VRAMState.DISABLED:
|
if vram_set_state == VRAMState.DISABLED:
|
||||||
pass
|
pass
|
||||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
real_model.to(torch_dev)
|
patch_model_to = torch_dev
|
||||||
|
|
||||||
try:
|
try:
|
||||||
real_model = model.patch_model()
|
real_model = model.patch_model(device_to=patch_model_to)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model.unpatch_model()
|
model.unpatch_model()
|
||||||
unload_model()
|
unload_model()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
if patch_model_to is not None:
|
||||||
|
real_model.to(torch_dev)
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||||
|
|||||||
@ -248,7 +248,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
|
if 'model_function_wrapper' in model_options:
|
||||||
|
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||||
|
else:
|
||||||
|
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
|
||||||
model_management.throw_exception_if_processing_interrupted()
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|||||||
@ -338,7 +338,7 @@ class ModelPatcher:
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_model(self):
|
def patch_model(self, device_to=None):
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
for key in self.patches:
|
for key in self.patches:
|
||||||
if key not in model_sd:
|
if key not in model_sd:
|
||||||
@ -350,7 +350,10 @@ class ModelPatcher:
|
|||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.to(self.offload_device)
|
self.backup[key] = weight.to(self.offload_device)
|
||||||
|
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
if device_to is not None:
|
||||||
|
temp_weight = weight.float().to(device_to, copy=True)
|
||||||
|
else:
|
||||||
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
set_attr(self.model, key, out_weight)
|
set_attr(self.model, key, out_weight)
|
||||||
del temp_weight
|
del temp_weight
|
||||||
|
|||||||
@ -37,7 +37,7 @@ def get_gpu_names():
|
|||||||
return set()
|
return set()
|
||||||
|
|
||||||
def cuda_malloc_supported():
|
def cuda_malloc_supported():
|
||||||
blacklist = {"GeForce GTX 960M", "GeForce GTX 950M", "GeForce 945M", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745"}
|
blacklist = {"GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745"}
|
||||||
try:
|
try:
|
||||||
names = get_gpu_names()
|
names = get_gpu_names()
|
||||||
except:
|
except:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user