ops: prioritize mem transfer

The async offload streams reason for existence is to transfer from
RAM to GPU. The post processing compute steps are a bonus on the side
stream, but if the compute stream is running a long kernel, it can
stall the side stream, as it wait to type-cast the bias before
transferring the weight. So do a pure xfer of the weight straight up,
then do everything bias, then go back to fix the weight type and do
weight patches.
This commit is contained in:
Rattus 2025-11-02 01:32:03 +10:00
parent fcdb4a5661
commit 7436674198

View File

@ -95,20 +95,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
else:
wf_context = contextlib.nullcontext()
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
weight = weight.to(dtype=dtype)
if weight_has_function:
with wf_context:
for f in s.weight_function:
weight = f(weight)