quant ops: Dequantize weight in-place (#10935)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

In flux2 these weights are huge (200MB). As plain_tensor is a throw-away
deep copy, do this multiplication in-place to save VRAM.
This commit is contained in:
rattus 2025-11-28 02:06:30 +10:00 committed by GitHub
parent f17251bec6
commit 3f382a4f98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -425,7 +425,8 @@ class TensorCoreFP8Layout(QuantizedLayout):
@staticmethod @staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs): def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
return plain_tensor * scale plain_tensor.mul_(scale)
return plain_tensor
@classmethod @classmethod
def get_plain_tensors(cls, qtensor): def get_plain_tensors(cls, qtensor):