ops: sync the offload stream with the consumption of w&b

This sync is nessacary as pytorch will queue cuda async frees on the
same stream as created to tensor. In the case of async offload, this
will be on the offload stream.

Weights and biases can go out of scope in python which then
triggers the pytorch garbage collector to queue the free operation on
the offload stream possible before the compute stream has used the
weight. This causes a use after free on weight data leading to total
corruption of some workflows.

So sync the offload stream with the compute stream after the weight
has been used so the free has to wait for the weight to be used.

The cast_bias_weight is extended in a backwards compatible way with
the new behaviour opt-in on a defaulted parameter. This handles
custom node packs calling cast_bias_weight and defeatures
async-offload for them (as they do not handle the race).

The pattern is now:

cast_bias_weight(... , offloadable=True) #This might be offloaded
thing(weight, bias, ...)
uncast_bias_weight(...)
This commit is contained in:
Rattus 2025-10-27 18:27:41 +10:00
parent 313638d13a
commit 2d69bdfabe

View File

@ -70,8 +70,12 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@torch.compiler.disable() @torch.compiler.disable()
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None: if input is not None:
if dtype is None: if dtype is None:
dtype = input.dtype dtype = input.dtype
@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if device is None: if device is None:
device = input.device device = input.device
offload_stream = comfy.model_management.get_offload_stream(device) if offloadable:
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
if offload_stream is not None: if offload_stream is not None:
wf_context = offload_stream wf_context = offload_stream
else: else:
@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
weight = f(weight) weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream) comfy.model_management.sync_stream(device, offload_stream)
return weight, bias if offloadable:
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
else:
if bias is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp: class CastWeightBiasOp:
comfy_cast_weights = False comfy_cast_weights = False
@ -118,8 +143,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -133,8 +160,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -148,8 +177,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -172,8 +203,10 @@ class disable_weight_init:
return super()._conv_forward(input, weight, bias, *args, **kwargs) return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -187,8 +220,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -203,11 +238,14 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.weight is not None: if self.weight is not None:
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else: else:
weight = None weight = None
bias = None bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -223,11 +261,15 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.weight is not None: if self.weight is not None:
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else: else:
weight = None weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated bias = None
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -246,10 +288,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.conv_transpose2d( x = torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -268,10 +312,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.conv_transpose1d( x = torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -289,8 +335,11 @@ class disable_weight_init:
output_dtype = out_dtype output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -361,7 +410,7 @@ def fp8_linear(self, input):
input_dtype = input.dtype input_dtype = input.dtype
if len(input.shape) == 3: if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = self.scale_weight scale_weight = self.scale_weight
scale_input = self.scale_input scale_input = self.scale_input
@ -382,6 +431,8 @@ def fp8_linear(self, input):
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
uncast_bias_weight(self, w, bias, offload_stream)
if tensor_2d: if tensor_2d:
return o.reshape(input_shape[0], -1) return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0])) return o.reshape((-1, input_shape[1], self.weight.shape[0]))
@ -404,8 +455,10 @@ class fp8_ops(manual_cast):
except Exception as e: except Exception as e:
logging.info("Exception during fp8 op: {}".format(e)) logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@ -433,12 +486,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
if out is not None: if out is not None:
return out return out
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
if weight.numel() < input.numel(): #TODO: optimize if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else: else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def convert_weight(self, weight, inplace=False, **kwargs): def convert_weight(self, weight, inplace=False, **kwargs):
if inplace: if inplace:
@ -577,8 +632,10 @@ class MixedPrecisionOps(disable_weight_init):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._forward(input, weight, bias) x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs): def forward(self, input, *args, **kwargs):
run_every_op() run_every_op()