Fix Race condition in --async-offload that can cause corruption (#10501)

* mm: factor out the current stream getter

Make this a reusable function.

* ops: sync the offload stream with the consumption of w&b

This sync is nessacary as pytorch will queue cuda async frees on the
same stream as created to tensor. In the case of async offload, this
will be on the offload stream.

Weights and biases can go out of scope in python which then
triggers the pytorch garbage collector to queue the free operation on
the offload stream possible before the compute stream has used the
weight. This causes a use after free on weight data leading to total
corruption of some workflows.

So sync the offload stream with the compute stream after the weight
has been used so the free has to wait for the weight to be used.

The cast_bias_weight is extended in a backwards compatible way with
the new behaviour opt-in on a defaulted parameter. This handles
custom node packs calling cast_bias_weight and defeatures
async-offload for them (as they do not handle the race).

The pattern is now:

cast_bias_weight(... , offloadable=True) #This might be offloaded
thing(weight, bias, ...)
uncast_bias_weight(...)

* controlnet: adopt new cast_bias_weight synchronization scheme

This is nessacary for safe async weight offloading.

* mm: sync the last stream in the queue, not the next

Currently this peeks ahead to sync the next stream in the queue of
streams with the compute stream. This doesnt allow a lot of
parallelization, as then end result is you can only get one weight load
ahead regardless of how many streams you have.

Rotate the loop logic here to synchronize the end of the queue before
returning the next stream. This allows weights to be loaded ahead of the
compute streams position.
This commit is contained in:
rattus 2025-10-30 07:17:46 +10:00 committed by GitHub
parent ec4fc2a09a
commit ab7ab5be23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 114 additions and 52 deletions

View File

@ -310,11 +310,13 @@ class ControlLoraOps:
self.bias = None self.bias = None
def forward(self, input): def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input) weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None: if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else: else:
return torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__( def __init__(
@ -350,12 +352,13 @@ class ControlLoraOps:
def forward(self, input): def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input) weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None: if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else: else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class ControlLora(ControlNet): class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options

View File

@ -1013,6 +1013,16 @@ if args.async_offload:
NUM_STREAMS = 2 NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device):
if device is None:
return None
if is_device_cuda(device):
return torch.cuda.current_stream()
elif is_device_xpu(device):
return torch.xpu.current_stream()
else:
return None
stream_counters = {} stream_counters = {}
def get_offload_stream(device): def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0) stream_counter = stream_counters.get(device, 0)
@ -1021,21 +1031,17 @@ def get_offload_stream(device):
if device in STREAMS: if device in STREAMS:
ss = STREAMS[device] ss = STREAMS[device]
s = ss[stream_counter] #Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
stream_counter = (stream_counter + 1) % len(ss) stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
return s return ss[stream_counter]
elif is_device_cuda(device): elif is_device_cuda(device):
ss = [] ss = []
for k in range(NUM_STREAMS): for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0)) ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss STREAMS[device] = ss
s = ss[stream_counter] s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
return s return s
elif is_device_xpu(device): elif is_device_xpu(device):
@ -1044,18 +1050,14 @@ def get_offload_stream(device):
ss.append(torch.xpu.Stream(device=device, priority=0)) ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss STREAMS[device] = ss
s = ss[stream_counter] s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
return s return s
return None return None
def sync_stream(device, stream): def sync_stream(device, stream):
if stream is None: if stream is None or current_stream(device) is None:
return return
if is_device_cuda(device): current_stream(device).wait_stream(stream)
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device: if device is None or weight.device == device:

View File

@ -70,8 +70,12 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@torch.compiler.disable() @torch.compiler.disable()
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None: if input is not None:
if dtype is None: if dtype is None:
dtype = input.dtype dtype = input.dtype
@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if device is None: if device is None:
device = input.device device = input.device
offload_stream = comfy.model_management.get_offload_stream(device) if offloadable:
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
if offload_stream is not None: if offload_stream is not None:
wf_context = offload_stream wf_context = offload_stream
else: else:
@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
weight = f(weight) weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream) comfy.model_management.sync_stream(device, offload_stream)
return weight, bias if offloadable:
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
else:
if bias is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp: class CastWeightBiasOp:
comfy_cast_weights = False comfy_cast_weights = False
@ -118,8 +143,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -133,8 +160,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -148,8 +177,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -172,8 +203,10 @@ class disable_weight_init:
return super()._conv_forward(input, weight, bias, *args, **kwargs) return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -187,8 +220,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -203,11 +238,14 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.weight is not None: if self.weight is not None:
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else: else:
weight = None weight = None
bias = None bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -223,11 +261,15 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.weight is not None: if self.weight is not None:
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else: else:
weight = None weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated bias = None
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -246,10 +288,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.conv_transpose2d( x = torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -268,10 +312,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.conv_transpose1d( x = torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -289,8 +335,11 @@ class disable_weight_init:
output_dtype = out_dtype output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -361,7 +410,7 @@ def fp8_linear(self, input):
input_dtype = input.dtype input_dtype = input.dtype
if len(input.shape) == 3: if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = self.scale_weight scale_weight = self.scale_weight
scale_input = self.scale_input scale_input = self.scale_input
@ -382,6 +431,8 @@ def fp8_linear(self, input):
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
uncast_bias_weight(self, w, bias, offload_stream)
if tensor_2d: if tensor_2d:
return o.reshape(input_shape[0], -1) return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0])) return o.reshape((-1, input_shape[1], self.weight.shape[0]))
@ -404,8 +455,10 @@ class fp8_ops(manual_cast):
except Exception as e: except Exception as e:
logging.info("Exception during fp8 op: {}".format(e)) logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@ -433,12 +486,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
if out is not None: if out is not None:
return out return out
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
if weight.numel() < input.numel(): #TODO: optimize if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else: else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def convert_weight(self, weight, inplace=False, **kwargs): def convert_weight(self, weight, inplace=False, **kwargs):
if inplace: if inplace:
@ -577,8 +632,10 @@ class MixedPrecisionOps(disable_weight_init):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._forward(input, weight, bias) x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs): def forward(self, input, *args, **kwargs):
run_every_op() run_every_op()