mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-01 21:02:30 +08:00
fix: Handle un-inited meta-tensors in models (fixes a CPU TE crash) (CORE-67) (#13578)
This commit is contained in:
parent
3cbf015578
commit
b47f15f25a
@ -31,6 +31,7 @@ import comfy.float
|
|||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.ops
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
@ -856,7 +857,9 @@ class ModelPatcher:
|
|||||||
if m.comfy_patched_weights == True:
|
if m.comfy_patched_weights == True:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for param in params:
|
for param, param_value in params.items():
|
||||||
|
if hasattr(m, "comfy_cast_weights") and getattr(param_value, "is_meta", False):
|
||||||
|
comfy.ops.disable_weight_init._zero_init_parameter(m, param)
|
||||||
key = key_param_name_to_key(n, param)
|
key = key_param_name_to_key(n, param)
|
||||||
self.unpin_weight(key)
|
self.unpin_weight(key)
|
||||||
self.patch_weight_to_device(key, device_to=device_to)
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
|||||||
16
comfy/ops.py
16
comfy/ops.py
@ -79,14 +79,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
def materialize_meta_param(s, param_keys):
|
||||||
|
for param_key in param_keys:
|
||||||
|
param = getattr(s, param_key, None)
|
||||||
|
if param is not None and getattr(param, "is_meta", False):
|
||||||
|
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
||||||
|
|
||||||
|
|
||||||
|
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||||
#If you are a custom node author reading this, please move your layer to the GPU
|
#If you are a custom node author reading this, please move your layer to the GPU
|
||||||
#or declare your ModelPatcher as CPU in the first place.
|
#or declare your ModelPatcher as CPU in the first place.
|
||||||
if comfy.model_management.is_device_cpu(device):
|
if comfy.model_management.is_device_cpu(device):
|
||||||
|
materialize_meta_param(s, ["weight", "bias"])
|
||||||
weight = s.weight.to(dtype=dtype, copy=True)
|
weight = s.weight.to(dtype=dtype, copy=True)
|
||||||
if isinstance(weight, QuantizedTensor):
|
if isinstance(weight, QuantizedTensor):
|
||||||
weight = weight.dequantize()
|
weight = weight.dequantize()
|
||||||
@ -108,6 +115,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||||
|
|
||||||
if not resident:
|
if not resident:
|
||||||
|
materialize_meta_param(s, ["weight", "bias"])
|
||||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||||
cast_dest = None
|
cast_dest = None
|
||||||
|
|
||||||
@ -306,6 +314,12 @@ class CastWeightBiasOp:
|
|||||||
bias_function = []
|
bias_function = []
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
|
@staticmethod
|
||||||
|
def _zero_init_parameter(module, name):
|
||||||
|
param = getattr(module, name)
|
||||||
|
device = None if getattr(param, "is_meta", False) else param.device
|
||||||
|
setattr(module, name, torch.nn.Parameter(torch.zeros(param.shape, device=device, dtype=param.dtype), requires_grad=False))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
||||||
missing_keys, unexpected_keys, weight_shape,
|
missing_keys, unexpected_keys, weight_shape,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user