mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-06 01:37:45 +08:00
ModelPatcherDynamic: Force load all non-comfy weights (#12739)
* model_management: Remove non-comfy dynamic _v caster * Force pre-load non-comfy weights to GPU in ModelPatcherDynamic Non-comfy weights may expect to be pre-cast to the target device without in-model casting. Previously they were allocated in the vbar with _v which required the _v fault path in cast_to. Instead, back up the original CPU weight and move it directly to GPU at load time.
This commit is contained in:
parent
dff0a4a158
commit
09bcbddfcf
@ -32,9 +32,6 @@ import comfy.memory_management
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
|
|
||||||
import comfy_aimdo.torch
|
|
||||||
import comfy_aimdo.model_vbar
|
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||||
@ -1206,43 +1203,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
|||||||
|
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||||
if hasattr(weight, "_v"):
|
|
||||||
#Unexpected usage patterns. There is no reason these don't work but they
|
|
||||||
#have no testing and no callers do this.
|
|
||||||
assert r is None
|
|
||||||
assert stream is None
|
|
||||||
|
|
||||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
|
|
||||||
|
|
||||||
if dtype is None:
|
|
||||||
dtype = weight._model_dtype
|
|
||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
|
||||||
if signature is not None:
|
|
||||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
|
||||||
v_tensor = weight._v_tensor
|
|
||||||
else:
|
|
||||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
|
||||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
|
||||||
weight._v_tensor = v_tensor
|
|
||||||
weight._v_signature = signature
|
|
||||||
#Send it over
|
|
||||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
|
||||||
return v_tensor.to(dtype=dtype)
|
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
|
||||||
#Offloaded casting could skip this, however it would make the quantizations
|
|
||||||
#inconsistent between loaded and offloaded weights. So force the double casting
|
|
||||||
#that would happen in regular flow to make offload deterministic.
|
|
||||||
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
|
||||||
cast_buffer.copy_(weight, non_blocking=non_blocking)
|
|
||||||
weight = cast_buffer
|
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
return r
|
|
||||||
|
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
if not copy:
|
if not copy:
|
||||||
if dtype is None or weight.dtype == dtype:
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
|||||||
@ -1435,10 +1435,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||||
#this is now way more dynamic and we dont support the same base model for both Dynamic
|
|
||||||
#and non-dynamic patchers.
|
|
||||||
if hasattr(self.model, "model_loaded_weight_memory"):
|
|
||||||
del self.model.model_loaded_weight_memory
|
|
||||||
if not hasattr(self.model, "dynamic_vbars"):
|
if not hasattr(self.model, "dynamic_vbars"):
|
||||||
self.model.dynamic_vbars = {}
|
self.model.dynamic_vbars = {}
|
||||||
self.non_dynamic_delegate_model = None
|
self.non_dynamic_delegate_model = None
|
||||||
@ -1461,9 +1457,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
def loaded_size(self):
|
def loaded_size(self):
|
||||||
vbar = self._vbar_get()
|
vbar = self._vbar_get()
|
||||||
if vbar is None:
|
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
|
||||||
return 0
|
|
||||||
return vbar.loaded_size()
|
|
||||||
|
|
||||||
def get_free_memory(self, device):
|
def get_free_memory(self, device):
|
||||||
#NOTE: on high condition / batch counts, estimate should have already vacated
|
#NOTE: on high condition / batch counts, estimate should have already vacated
|
||||||
@ -1504,6 +1498,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
num_patches = 0
|
num_patches = 0
|
||||||
allocated_size = 0
|
allocated_size = 0
|
||||||
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
@ -1512,10 +1507,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if vbar is not None:
|
if vbar is not None:
|
||||||
vbar.prioritize()
|
vbar.prioritize()
|
||||||
|
|
||||||
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
|
||||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
|
||||||
#with a high layer rate (e.g. autoregressive LLMs).
|
|
||||||
#prioritize the non-comfy weights (note the order reverse).
|
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
|
|
||||||
@ -1558,6 +1549,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if key in self.backup:
|
if key in self.backup:
|
||||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||||
self.patch_weight_to_device(key, device_to=device_to)
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
|
if weight is not None:
|
||||||
|
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
||||||
|
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
@ -1583,21 +1577,15 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
for param in params:
|
for param in params:
|
||||||
key = key_param_name_to_key(n, param)
|
key = key_param_name_to_key(n, param)
|
||||||
weight, _, _ = get_key_weight(self.model, key)
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
weight.seed_key = key
|
if key not in self.backup:
|
||||||
set_dirty(weight, dirty)
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
|
||||||
geometry = weight
|
comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to))
|
||||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
||||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
|
||||||
weight_size = geometry.numel() * geometry.element_size()
|
|
||||||
if vbar is not None and not hasattr(weight, "_v"):
|
|
||||||
weight._v = vbar.alloc(weight_size)
|
|
||||||
weight._model_dtype = model_dtype
|
|
||||||
allocated_size += weight_size
|
|
||||||
vbar.set_watermark_limit(allocated_size)
|
|
||||||
|
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
||||||
|
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||||
|
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||||
@ -1613,7 +1601,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
assert self.load_device != torch.device("cpu")
|
assert self.load_device != torch.device("cpu")
|
||||||
|
|
||||||
vbar = self._vbar_get()
|
vbar = self._vbar_get()
|
||||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||||
|
|
||||||
|
if freed < memory_to_free:
|
||||||
|
for key in list(self.backup.keys()):
|
||||||
|
bk = self.backup.pop(key)
|
||||||
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||||
|
freed += self.model.model_loaded_weight_memory
|
||||||
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
|
return freed
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||||
@ -1640,11 +1637,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
keys = list(self.backup.keys())
|
|
||||||
for k in keys:
|
|
||||||
bk = self.backup[k]
|
|
||||||
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
assert not force_patch_weights #See above
|
assert not force_patch_weights #See above
|
||||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user