Prs/lora reservations (reduce massive Lora reservations especially on Flux2) (#11069)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

* mp: only count the offload cost of math once

This was previously bundling the combined weight storage and computation
cost

* ops: put all post async transfer compute on the main stream

Some models have massive weights that need either complex
dequantization or lora patching. Don't do these patchings on the offload
stream, instead do them on the main stream to syncrhonize the
potentially large vram spikes for these compute processes. This avoids
having to assume a worst case scenario of multiple offload streams
all spiking VRAM is parallel with whatever the main stream is doing.
This commit is contained in:
rattus 2025-12-03 17:28:45 +10:00 committed by GitHub
parent 861817d22d
commit 519c941165
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 19 deletions

View File

@ -704,7 +704,7 @@ class ModelPatcher:
lowvram_weight = False lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
@ -883,7 +883,7 @@ class ModelPatcher:
break break
module_offload_mem, module_mem, n, m, params = unload module_offload_mem, module_mem, n, m, params = unload
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)
lowvram_possible = hasattr(m, "comfy_cast_weights") lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:

View File

@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if s.bias is not None: if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function: comfy.model_management.sync_stream(device, offload_stream)
with wf_context:
bias_a = bias
weight_a = weight
if s.bias is not None:
for f in s.bias_function: for f in s.bias_function:
bias = f(bias) bias = f(bias)
if weight_has_function or weight.dtype != dtype: if weight_has_function or weight.dtype != dtype:
with wf_context:
weight = weight.to(dtype=dtype) weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor): if isinstance(weight, QuantizedTensor):
weight = weight.dequantize() weight = weight.dequantize()
for f in s.weight_function: for f in s.weight_function:
weight = f(weight) weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
if offloadable: if offloadable:
return weight, bias, offload_stream return weight, bias, (offload_stream, weight_a, bias_a)
else: else:
#Legacy function signature #Legacy function signature
return weight, bias return weight, bias
@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
def uncast_bias_weight(s, weight, bias, offload_stream): def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None: if offload_stream is None:
return return
if weight is not None: os, weight_a, bias_a = offload_stream
device = weight.device if os is None:
else:
if bias is None:
return return
device = bias.device if weight_a is not None:
offload_stream.wait_stream(comfy.model_management.current_stream(device)) device = weight_a.device
else:
if bias_a is None:
return
device = bias_a.device
os.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp: class CastWeightBiasOp: