mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-01 09:10:16 +08:00
Add functions for tensor input handling and casting
This commit is contained in:
parent
fcbd22b514
commit
c825bc526e
76
comfy/ops.py
76
comfy/ops.py
@ -24,6 +24,62 @@ import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
|
||||
def _dw_find_input_tensor(args, kwargs):
|
||||
"""Return first tensor-like input (supports QuantizedTensor)."""
|
||||
def check(obj):
|
||||
if torch.is_tensor(obj):
|
||||
return obj
|
||||
if isinstance(obj, QuantizedTensor):
|
||||
return obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for it in obj:
|
||||
r = check(it)
|
||||
if r is not None:
|
||||
return r
|
||||
if isinstance(obj, dict):
|
||||
for it in obj.values():
|
||||
r = check(it)
|
||||
if r is not None:
|
||||
return r
|
||||
return None
|
||||
for a in args:
|
||||
r = check(a)
|
||||
if r is not None:
|
||||
return r
|
||||
return check(kwargs)
|
||||
|
||||
def _dw_disk_weights_enabled() -> bool:
|
||||
# Delayed import avoids eager circular imports.
|
||||
from comfy import disk_weights as _dw
|
||||
return _dw.disk_weights_enabled()
|
||||
|
||||
def _dw_requires_temporary_cast(module, args, kwargs) -> bool:
|
||||
"""
|
||||
When disk_weights is enabled, route ops through the comfy_cast path when
|
||||
weights/bias are not directly usable (dtype/device mismatch or meta).
|
||||
"""
|
||||
if not _dw_disk_weights_enabled():
|
||||
return False
|
||||
inp = _dw_find_input_tensor(args, kwargs)
|
||||
if inp is None:
|
||||
return False
|
||||
w = getattr(module, "weight", None)
|
||||
if w is None:
|
||||
return False
|
||||
if isinstance(inp, QuantizedTensor):
|
||||
req_dtype = inp.params.orig_dtype
|
||||
req_dev = inp.device
|
||||
else:
|
||||
req_dtype = inp.dtype
|
||||
req_dev = inp.device
|
||||
if w.device.type == "meta" or w.device != req_dev or w.dtype != req_dtype:
|
||||
return True
|
||||
b = getattr(module, "bias", None)
|
||||
if b is not None and (b.device.type == "meta" or b.device != req_dev or b.dtype != req_dtype):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
return
|
||||
@ -163,7 +219,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -180,7 +236,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -197,7 +253,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -223,7 +279,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -240,7 +296,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -262,7 +318,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -286,7 +342,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -310,7 +366,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -334,7 +390,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -356,7 +412,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
if "out_dtype" in kwargs:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user