Add functions for tensor input handling and casting

This commit is contained in:
ifilipis 2026-01-21 17:55:44 +00:00
parent fcbd22b514
commit c825bc526e

View File

@ -24,6 +24,62 @@ import comfy.float
import comfy.rmsnorm
import json
def _dw_find_input_tensor(args, kwargs):
"""Return first tensor-like input (supports QuantizedTensor)."""
def check(obj):
if torch.is_tensor(obj):
return obj
if isinstance(obj, QuantizedTensor):
return obj
if isinstance(obj, (list, tuple)):
for it in obj:
r = check(it)
if r is not None:
return r
if isinstance(obj, dict):
for it in obj.values():
r = check(it)
if r is not None:
return r
return None
for a in args:
r = check(a)
if r is not None:
return r
return check(kwargs)
def _dw_disk_weights_enabled() -> bool:
# Delayed import avoids eager circular imports.
from comfy import disk_weights as _dw
return _dw.disk_weights_enabled()
def _dw_requires_temporary_cast(module, args, kwargs) -> bool:
"""
When disk_weights is enabled, route ops through the comfy_cast path when
weights/bias are not directly usable (dtype/device mismatch or meta).
"""
if not _dw_disk_weights_enabled():
return False
inp = _dw_find_input_tensor(args, kwargs)
if inp is None:
return False
w = getattr(module, "weight", None)
if w is None:
return False
if isinstance(inp, QuantizedTensor):
req_dtype = inp.params.orig_dtype
req_dev = inp.device
else:
req_dtype = inp.dtype
req_dev = inp.device
if w.device.type == "meta" or w.device != req_dev or w.dtype != req_dtype:
return True
b = getattr(module, "bias", None)
if b is not None and (b.device.type == "meta" or b.device != req_dev or b.dtype != req_dtype):
return True
return False
def run_every_op():
if torch.compiler.is_compiling():
return
@ -163,7 +219,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -180,7 +236,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -197,7 +253,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -223,7 +279,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -240,7 +296,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -262,7 +318,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -286,7 +342,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -310,7 +366,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -334,7 +390,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -356,7 +412,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
if "out_dtype" in kwargs: