diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index ee56f8523..e259aed63 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -31,6 +31,7 @@ import comfy.float import comfy.hooks import comfy.lora import comfy.model_management +import comfy.ops import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction @@ -856,7 +857,9 @@ class ModelPatcher: if m.comfy_patched_weights == True: continue - for param in params: + for param, param_value in params.items(): + if hasattr(m, "comfy_cast_weights") and getattr(param_value, "is_meta", False): + comfy.ops.disable_weight_init._zero_init_parameter(m, param) key = key_param_name_to_key(n, param) self.unpin_weight(key) self.patch_weight_to_device(key, device_to=device_to) diff --git a/comfy/ops.py b/comfy/ops.py index 7a9b4b84c..050f7cda0 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -79,14 +79,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): +def materialize_meta_param(s, param_keys): + for param_key in param_keys: + param = getattr(s, param_key, None) + if param is not None and getattr(param, "is_meta", False): + setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad)) + +def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): #vbar doesn't support CPU weights, but some custom nodes have weird paths #that might switch the layer to the CPU and expect it to work. We have to take #a clone conservatively as we are mmapped and some SFT files are packed misaligned #If you are a custom node author reading this, please move your layer to the GPU #or declare your ModelPatcher as CPU in the first place. if comfy.model_management.is_device_cpu(device): + materialize_meta_param(s, ["weight", "bias"]) weight = s.weight.to(dtype=dtype, copy=True) if isinstance(weight, QuantizedTensor): weight = weight.dequantize() @@ -108,6 +115,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if not resident: + materialize_meta_param(s, ["weight", "bias"]) cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_dest = None @@ -306,6 +314,12 @@ class CastWeightBiasOp: bias_function = [] class disable_weight_init: + @staticmethod + def _zero_init_parameter(module, name): + param = getattr(module, name) + device = None if getattr(param, "is_meta", False) else param.device + setattr(module, name, torch.nn.Parameter(torch.zeros(param.shape, device=device, dtype=param.dtype), requires_grad=False)) + @staticmethod def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata, missing_keys, unexpected_keys, weight_shape,