fix: Handle un-inited meta-tensors in models (fixes a CPU TE crash) (CORE-67) (#13578)

This commit is contained in:
rattus 2026-04-28 12:22:31 +10:00 committed by GitHub
parent 3cbf015578
commit b47f15f25a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 2 deletions

View File

@ -31,6 +31,7 @@ import comfy.float
import comfy.hooks
import comfy.lora
import comfy.model_management
import comfy.ops
import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
@ -856,7 +857,9 @@ class ModelPatcher:
if m.comfy_patched_weights == True:
continue
for param in params:
for param, param_value in params.items():
if hasattr(m, "comfy_cast_weights") and getattr(param_value, "is_meta", False):
comfy.ops.disable_weight_init._zero_init_parameter(m, param)
key = key_param_name_to_key(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)

View File

@ -79,14 +79,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
def materialize_meta_param(s, param_keys):
for param_key in param_keys:
param = getattr(s, param_key, None)
if param is not None and getattr(param, "is_meta", False):
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
@ -108,6 +115,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident:
materialize_meta_param(s, ["weight", "bias"])
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
@ -306,6 +314,12 @@ class CastWeightBiasOp:
bias_function = []
class disable_weight_init:
@staticmethod
def _zero_init_parameter(module, name):
param = getattr(module, name)
device = None if getattr(param, "is_meta", False) else param.device
setattr(module, name, torch.nn.Parameter(torch.zeros(param.shape, device=device, dtype=param.dtype), requires_grad=False))
@staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape,