From ab7ab5be23fb9b71d1790f424e7dcf91dc1fe0cc Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 30 Oct 2025 07:17:46 +1000 Subject: [PATCH] Fix Race condition in --async-offload that can cause corruption (#10501) * mm: factor out the current stream getter Make this a reusable function. * ops: sync the offload stream with the consumption of w&b This sync is nessacary as pytorch will queue cuda async frees on the same stream as created to tensor. In the case of async offload, this will be on the offload stream. Weights and biases can go out of scope in python which then triggers the pytorch garbage collector to queue the free operation on the offload stream possible before the compute stream has used the weight. This causes a use after free on weight data leading to total corruption of some workflows. So sync the offload stream with the compute stream after the weight has been used so the free has to wait for the weight to be used. The cast_bias_weight is extended in a backwards compatible way with the new behaviour opt-in on a defaulted parameter. This handles custom node packs calling cast_bias_weight and defeatures async-offload for them (as they do not handle the race). The pattern is now: cast_bias_weight(... , offloadable=True) #This might be offloaded thing(weight, bias, ...) uncast_bias_weight(...) * controlnet: adopt new cast_bias_weight synchronization scheme This is nessacary for safe async weight offloading. * mm: sync the last stream in the queue, not the next Currently this peeks ahead to sync the next stream in the queue of streams with the compute stream. This doesnt allow a lot of parallelization, as then end result is you can only get one weight load ahead regardless of how many streams you have. Rotate the loop logic here to synchronize the end of the queue before returning the next stream. This allows weights to be loaded ahead of the compute streams position. --- comfy/controlnet.py | 17 +++--- comfy/model_management.py | 28 +++++---- comfy/ops.py | 121 ++++++++++++++++++++++++++++---------- 3 files changed, 114 insertions(+), 52 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index f08ff4b36..0b5e30f52 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -310,11 +310,13 @@ class ControlLoraOps: self.bias = None def forward(self, input): - weight, bias = comfy.ops.cast_bias_weight(self, input) + weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) + x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: - return torch.nn.functional.linear(input, weight, bias) + x = torch.nn.functional.linear(input, weight, bias) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__( @@ -350,12 +352,13 @@ class ControlLoraOps: def forward(self, input): - weight, bias = comfy.ops.cast_bias_weight(self, input) + weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) + x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) - + x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class ControlLora(ControlNet): def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options diff --git a/comfy/model_management.py b/comfy/model_management.py index 3e5b977d4..79c0dfdb4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1013,6 +1013,16 @@ if args.async_offload: NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) +def current_stream(device): + if device is None: + return None + if is_device_cuda(device): + return torch.cuda.current_stream() + elif is_device_xpu(device): + return torch.xpu.current_stream() + else: + return None + stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1021,21 +1031,17 @@ def get_offload_stream(device): if device in STREAMS: ss = STREAMS[device] - s = ss[stream_counter] + #Sync the oldest stream in the queue with the current + ss[stream_counter].wait_stream(current_stream(device)) stream_counter = (stream_counter + 1) % len(ss) - if is_device_cuda(device): - ss[stream_counter].wait_stream(torch.cuda.current_stream()) - elif is_device_xpu(device): - ss[stream_counter].wait_stream(torch.xpu.current_stream()) stream_counters[device] = stream_counter - return s + return ss[stream_counter] elif is_device_cuda(device): ss = [] for k in range(NUM_STREAMS): ss.append(torch.cuda.Stream(device=device, priority=0)) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s elif is_device_xpu(device): @@ -1044,18 +1050,14 @@ def get_offload_stream(device): ss.append(torch.xpu.Stream(device=device, priority=0)) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s return None def sync_stream(device, stream): - if stream is None: + if stream is None or current_stream(device) is None: return - if is_device_cuda(device): - torch.cuda.current_stream().wait_stream(stream) - elif is_device_xpu(device): - torch.xpu.current_stream().wait_stream(stream) + current_stream(device).wait_stream(stream) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): if device is None or weight.device == device: diff --git a/comfy/ops.py b/comfy/ops.py index 93731eedf..71ca7a2bd 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -70,8 +70,12 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) + @torch.compiler.disable() -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): + # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass + # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This + # will add async-offload support to your cast and improve performance. if input is not None: if dtype is None: dtype = input.dtype @@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if device is None: device = input.device - offload_stream = comfy.model_management.get_offload_stream(device) + if offloadable: + offload_stream = comfy.model_management.get_offload_stream(device) + else: + offload_stream = None + if offload_stream is not None: wf_context = offload_stream else: @@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): weight = f(weight) comfy.model_management.sync_stream(device, offload_stream) - return weight, bias + if offloadable: + return weight, bias, offload_stream + else: + #Legacy function signature + return weight, bias + + +def uncast_bias_weight(s, weight, bias, offload_stream): + if offload_stream is None: + return + if weight is not None: + device = weight.device + else: + if bias is None: + return + device = bias.device + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + class CastWeightBiasOp: comfy_cast_weights = False @@ -118,8 +143,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -133,8 +160,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -148,8 +177,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -172,8 +203,10 @@ class disable_weight_init: return super()._conv_forward(input, weight, bias, *args, **kwargs) def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -187,8 +220,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -203,11 +238,14 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None bias = None - return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + offload_stream = None + x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -223,11 +261,15 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None - return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated - # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + bias = None + offload_stream = None + x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated + # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -246,10 +288,12 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose2d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose2d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -268,10 +312,12 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose1d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose1d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -289,8 +335,11 @@ class disable_weight_init: output_dtype = out_dtype if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: out_dtype = None - weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) - return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True) + x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + uncast_bias_weight(self, weight, bias, offload_stream) + return x + def forward(self, *args, **kwargs): run_every_op() @@ -361,7 +410,7 @@ def fp8_linear(self, input): input_dtype = input.dtype if len(input.shape) == 3: - w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) scale_weight = self.scale_weight scale_input = self.scale_input @@ -382,6 +431,8 @@ def fp8_linear(self, input): quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + uncast_bias_weight(self, w, bias, offload_stream) + if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -404,8 +455,10 @@ class fp8_ops(manual_cast): except Exception as e: logging.info("Exception during fp8 op: {}".format(e)) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) @@ -433,12 +486,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None if out is not None: return out - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) if weight.numel() < input.numel(): #TODO: optimize - return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) + x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) else: - return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) + x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def convert_weight(self, weight, inplace=False, **kwargs): if inplace: @@ -577,8 +632,10 @@ class MixedPrecisionOps(disable_weight_init): return torch.nn.functional.linear(input, weight, bias) def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, input, *args, **kwargs): run_every_op()