ops: move cpu handler up to the caller

This commit is contained in:
Rattus 2026-04-24 19:24:13 +10:00
parent 312f08e44f
commit 132c9f3ac6

View File

@ -87,21 +87,6 @@ def materialize_meta_param(s, param_keys):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
#plan = []
#Some sort of loop here like what you did
#for module in comfy_modules: ...
@ -247,7 +232,18 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if hasattr(s, "_v"):
#that cpu guard goes here.
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
return (weight, bias, (None, None, None)) if offloadable else (weight, bias)
#check for a prefetch result here. Something like:
#if not prefetch: