mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-01 23:37:33 +08:00
ops: limit return of requants (#12506)
This check was far too broad and the dtype is not a reliable indicator of wanting the requant (as QT returns the compute dtype as the dtype). So explictly plumb whether fp8mm wants the requant or not.
This commit is contained in:
parent
19236edfa4
commit
58dcc97dcf
17
comfy/ops.py
17
comfy/ops.py
@ -79,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
xfer_dest = None
|
xfer_dest = None
|
||||||
|
|
||||||
@ -170,10 +170,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||||
x = lowvram_fn(x)
|
x = lowvram_fn(x)
|
||||||
if (isinstance(orig, QuantizedTensor) and
|
if (isinstance(orig, QuantizedTensor) and
|
||||||
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
(want_requant and len(fns) == 0 or update_weight)):
|
||||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||||
if orig.dtype == dtype and len(fns) == 0:
|
if want_requant and len(fns) == 0:
|
||||||
#The layer actually wants our freshly saved QT
|
#The layer actually wants our freshly saved QT
|
||||||
x = y
|
x = y
|
||||||
elif update_weight:
|
elif update_weight:
|
||||||
@ -194,7 +194,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||||
# will add async-offload support to your cast and improve performance.
|
# will add async-offload support to your cast and improve performance.
|
||||||
@ -212,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
if hasattr(s, "_v"):
|
if hasattr(s, "_v"):
|
||||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||||
|
|
||||||
if offloadable and (device != s.weight.device or
|
if offloadable and (device != s.weight.device or
|
||||||
(s.bias is not None and device != s.bias.device)):
|
(s.bias is not None and device != s.bias.device)):
|
||||||
@ -850,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||||
x = self._forward(input, weight, bias)
|
x = self._forward(input, weight, bias)
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
@ -881,8 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||||
|
|
||||||
|
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||||
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
|
||||||
|
|
||||||
# Reshape output back to 3D if input was 3D
|
# Reshape output back to 3D if input was 3D
|
||||||
if reshaped_3d:
|
if reshaped_3d:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user