diff --git a/comfy/ops.py b/comfy/ops.py index 415c39e92..6aec1cdaf 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -102,7 +102,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of bias = None if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) comfy.model_management.sync_stream(device, offload_stream) @@ -110,6 +110,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight_a = weight if s.bias is not None: + bias = bias.to(dtype=bias_dtype) for f in s.bias_function: bias = f(bias)