ops: implement lora requanting for non QuantizedTensor fp8 (#12668)

Allow non QuantizedTensor layer to set want_requant to get the post lora
calculation stochastic cast down to the original input dtype.

This is then used by the legacy fp8 Linear implementation to set the
compute_dtype to the preferred lora dtype but then want_requant it back
down to fp8.

This fixes the issue with --fast fp8_matrix_mult is combined with
--fast dynamic_vram which doing a lora on an fp8_ non QT model.
This commit is contained in:
rattus 2026-02-27 16:05:51 -08:00 committed by GitHub
parent 25ec3d96a3
commit e721e24136
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -167,17 +167,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = to_dequant(x, dtype) x = to_dequant(x, dtype)
if not resident and lowvram_fn is not None: if not resident and lowvram_fn is not None:
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x) x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and if (want_requant and len(fns) == 0 or update_weight):
(want_requant and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key) seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) if isinstance(orig, QuantizedTensor):
if want_requant and len(fns) == 0: y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
#The layer actually wants our freshly saved QT else:
x = y y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
elif update_weight: if want_requant and len(fns) == 0:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key)) x = y
if update_weight: if update_weight:
orig.copy_(y) orig.copy_(y)
for f in fns: for f in fns:
@ -617,7 +615,8 @@ def fp8_linear(self, input):
if input.ndim != 2: if input.ndim != 2:
return None return None
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device)
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32) scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)