diff --git a/comfy/ops.py b/comfy/ops.py index c260b6645..a30d58cd7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -24,6 +24,62 @@ import comfy.float import comfy.rmsnorm import json +def _dw_find_input_tensor(args, kwargs): + """Return first tensor-like input (supports QuantizedTensor).""" + def check(obj): + if torch.is_tensor(obj): + return obj + if isinstance(obj, QuantizedTensor): + return obj + if isinstance(obj, (list, tuple)): + for it in obj: + r = check(it) + if r is not None: + return r + if isinstance(obj, dict): + for it in obj.values(): + r = check(it) + if r is not None: + return r + return None + for a in args: + r = check(a) + if r is not None: + return r + return check(kwargs) + +def _dw_disk_weights_enabled() -> bool: + # Delayed import avoids eager circular imports. + from comfy import disk_weights as _dw + return _dw.disk_weights_enabled() + +def _dw_requires_temporary_cast(module, args, kwargs) -> bool: + """ + When disk_weights is enabled, route ops through the comfy_cast path when + weights/bias are not directly usable (dtype/device mismatch or meta). + """ + if not _dw_disk_weights_enabled(): + return False + inp = _dw_find_input_tensor(args, kwargs) + if inp is None: + return False + w = getattr(module, "weight", None) + if w is None: + return False + if isinstance(inp, QuantizedTensor): + req_dtype = inp.params.orig_dtype + req_dev = inp.device + else: + req_dtype = inp.dtype + req_dev = inp.device + if w.device.type == "meta" or w.device != req_dev or w.dtype != req_dtype: + return True + b = getattr(module, "bias", None) + if b is not None and (b.device.type == "meta" or b.device != req_dev or b.dtype != req_dtype): + return True + return False + + def run_every_op(): if torch.compiler.is_compiling(): return @@ -163,7 +219,7 @@ class disable_weight_init: def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs): return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -180,7 +236,7 @@ class disable_weight_init: def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs): return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -197,7 +253,7 @@ class disable_weight_init: def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs): return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -223,7 +279,7 @@ class disable_weight_init: def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs): return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -240,7 +296,7 @@ class disable_weight_init: def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs): return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -262,7 +318,7 @@ class disable_weight_init: def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs): return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -286,7 +342,7 @@ class disable_weight_init: def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs): return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -310,7 +366,7 @@ class disable_weight_init: def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs): return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -334,7 +390,7 @@ class disable_weight_init: def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs): return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -356,7 +412,7 @@ class disable_weight_init: def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs): return self.forward_comfy_cast_weights(*args, **kwargs) else: if "out_dtype" in kwargs: