Fix loras not working on mixed fp8. (#10899)

This commit is contained in:
comfyanonymous 2025-11-25 21:07:58 -08:00 committed by GitHub
parent 0e24dbb19f
commit bdb10a583f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 37 additions and 9 deletions

View File

@ -132,7 +132,7 @@ class LowVramPatch:
def __call__(self, weight): def __call__(self, weight):
intermediate_dtype = weight.dtype intermediate_dtype = weight.dtype
if self.convert_func is not None: if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True) weight = self.convert_func(weight, inplace=False)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32 intermediate_dtype = torch.float32

View File

@ -117,6 +117,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if weight_has_function or weight.dtype != dtype: if weight_has_function or weight.dtype != dtype:
with wf_context: with wf_context:
weight = weight.to(dtype=dtype) weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function: for f in s.weight_function:
weight = f(weight) weight = f(weight)
@ -502,7 +504,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight return weight
else: else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
@ -643,6 +645,24 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
not isinstance(input, QuantizedTensor)): not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias) return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor):
return weight.dequantize()
else:
return weight
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
return weight
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return MixedPrecisionOps return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):

View File

@ -1,6 +1,7 @@
import torch import torch
import logging import logging
from typing import Tuple, Dict from typing import Tuple, Dict
import comfy.float
_LAYOUT_REGISTRY = {} _LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {} _GENERIC_UTILS = {}
@ -393,7 +394,7 @@ class TensorCoreFP8Layout(QuantizedLayout):
- orig_dtype: Original dtype before quantization (for casting back) - orig_dtype: Original dtype before quantization (for casting back)
""" """
@classmethod @classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype orig_dtype = tensor.dtype
if scale is None: if scale is None:
@ -403,17 +404,23 @@ class TensorCoreFP8Layout(QuantizedLayout):
scale = torch.tensor(scale) scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32) scale = scale.to(device=tensor.device, dtype=torch.float32)
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) if inplace_ops:
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality' tensor *= (1.0 / scale).to(tensor.dtype)
lp_amax = torch.finfo(dtype).max else:
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) tensor = tensor * (1.0 / scale).to(tensor.dtype)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
else:
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
layout_params = { layout_params = {
'scale': scale, 'scale': scale,
'orig_dtype': orig_dtype 'orig_dtype': orig_dtype
} }
return qdata, layout_params return tensor, layout_params
@staticmethod @staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs): def dequantize(qdata, scale, orig_dtype, **kwargs):

View File

@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase):
lora_diff = torch.mm( lora_diff = torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
).reshape(weight.shape) ).reshape(weight.shape)
del mat1, mat2
if dora_scale is not None: if dora_scale is not None:
weight = weight_decompose( weight = weight_decompose(
dora_scale, dora_scale,