mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
ops: Fix offloading with FP8MM performance (#11697)
This logic was checking comfy_cast_weights, and going straight to to the forward_comfy_cast_weights implementation without attempting to downscale input to fp8 in the event comfy_cast_weights is set. The main reason comfy_cast_weights would be set would be for async offload, which is not a good reason to nix FP8MM. So instead, and together the underlying exclusions for FP8MM which are: * having a weight_function (usually LowVramPatch) * force_cast_weights (compute dtype override) * the weight is not Quantized * the input is already quantized * the model or layer has MM explictily disabled. If you get past all of those exclusions, quantize the input tensor. Then hand the new input, quantized or not off to forward_comfy_cast_weights to handle it. If the weight is offloaded but input is quantized you will get an offloaded MM8.
This commit is contained in:
parent
25bc1b5b57
commit
b6c79a648a
@ -718,6 +718,7 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
cast_weight = self.force_cast_weights
|
||||
m.comfy_force_cast_weights = self.force_cast_weights
|
||||
if lowvram_weight:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.weight_function = []
|
||||
|
||||
30
comfy/ops.py
30
comfy/ops.py
@ -654,29 +654,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
run_every_op()
|
||||
|
||||
input_shape = input.shape
|
||||
tensor_3d = input.ndim == 3
|
||||
|
||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
reshaped_3d = False
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
if tensor_3d:
|
||||
input = input.reshape(-1, input_shape[2])
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
|
||||
if input.ndim != 2:
|
||||
# Fall back to comfy_cast_weights for non-2D tensors
|
||||
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
|
||||
# Fall back to non-quantized for non-2D tensors
|
||||
if input_reshaped.ndim == 2:
|
||||
reshaped_3d = input.ndim == 3
|
||||
# dtype is now implicit in the layout class
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
# dtype is now implicit in the layout class
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
|
||||
|
||||
output = self._forward(input, self.weight, self.bias)
|
||||
output = self.forward_comfy_cast_weights(input)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if tensor_3d:
|
||||
if reshaped_3d:
|
||||
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
|
||||
|
||||
return output
|
||||
|
||||
Loading…
Reference in New Issue
Block a user