diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 55d25d897..7987c8cf2 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -16,7 +16,7 @@ from comfy.ldm.lightricks.model import ( from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import comfy.ldm.common_dit -import comfy.ops +import comfy.model_prefetch class CompressedTimestep: """Store video timestep embeddings in compressed form using per-frame indexing.""" @@ -908,11 +908,11 @@ class LTXAVModel(LTXVModel): """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) - prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks)) + prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device) # Process transformer blocks for i, block in enumerate(self.transformer_blocks): - comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, block) + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): @@ -985,7 +985,7 @@ class LTXAVModel(LTXVModel): a_prompt_timestep=a_prompt_timestep, ) - comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, None) + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None) return [vx, ax] diff --git a/comfy/lora.py b/comfy/lora.py index 63ee85323..5be361a4a 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -17,6 +17,7 @@ """ from __future__ import annotations +import comfy.memory_management import comfy.utils import comfy.model_management import comfy.model_base @@ -467,3 +468,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori weight = old_weight return weight + +def prefetch_prepared_value(value, allocate_buffer, stream): + if isinstance(value, torch.Tensor): + dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + return comfy.memory_management.interpret_gathered_like([value], dest)[0] + elif isinstance(value, weight_adapter.WeightAdapterBase): + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + elif isinstance(value, tuple): + return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + elif isinstance(value, list): + return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + + return value diff --git a/comfy/model_management.py b/comfy/model_management.py index 73c18c139..edd73ebfc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,7 +31,6 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops -import comfy_aimdo.model_vbar import comfy_aimdo.vram_buffer class VRAMState(Enum): @@ -1185,7 +1184,6 @@ STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) STREAM_AIMDO_CAST_BUFFERS = {} LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) -PREFETCH_QUEUES = [] DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 @@ -1228,31 +1226,9 @@ def get_aimdo_cast_buffer(offload_stream, device): STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer return cast_buffer - - -def cleanup_prefetched_modules(comfy_modules): - for s in comfy_modules: - prefetch = getattr(s, "_prefetch", None) - if prefetch is None: - continue - if prefetch["signature"] is not None: - comfy_aimdo.model_vbar.vbar_unpin(s._v) - delattr(s, "_prefetch") - - -def cleanup_prefetch_queue(queue): - for entry in queue: - if entry is None or not isinstance(entry, tuple): - continue - _, prefetch_state = entry - comfy_modules = prefetch_state[1] - if comfy_modules is not None: - cleanup_prefetched_modules(comfy_modules) - def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT - global PREFETCH_QUEUES LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) @@ -1261,10 +1237,6 @@ def reset_cast_buffers(): offload_stream.synchronize() synchronize() - for queue in PREFETCH_QUEUES: - cleanup_prefetch_queue(queue) - PREFETCH_QUEUES = [] - STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() soft_empty_cache() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e259aed63..7d2d6883f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -121,9 +121,20 @@ class LowVramPatch: self.patches = patches self.convert_func = convert_func # TODO: remove self.set_func = set_func + self.prepared_patches = None + + def prepare(self, allocate_buffer, stream): + self.prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + for patch in self.patches[self.key] + ] + + def clear_prepared(self): + self.prepared_patches = None def __call__(self, weight): - return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) + patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key] + return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype) LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py new file mode 100644 index 000000000..982c599b4 --- /dev/null +++ b/comfy/model_prefetch.py @@ -0,0 +1,68 @@ +import comfy_aimdo.model_vbar +import comfy.model_management +import comfy.ops + +PREFETCH_QUEUES = [] + + +def cleanup_prefetched_modules(comfy_modules): + for s in comfy_modules: + prefetch = getattr(s, "_prefetch", None) + if prefetch is None: + continue + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() + if prefetch["signature"] is not None: + comfy_aimdo.model_vbar.vbar_unpin(s._v) + delattr(s, "_prefetch") +def cleanup_prefetch_queues(): + global PREFETCH_QUEUES + + for queue in PREFETCH_QUEUES: + for entry in queue: + if entry is None: + continue + _, prefetch_state = entry + comfy_modules = prefetch_state[1] + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + PREFETCH_QUEUES = [] + + +def prefetch_queue_pop(queue, device, module): + if queue is None: + return + + consumed = queue.pop(0) + if consumed is not None: + offload_stream, prefetch_state = consumed + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + _, comfy_modules = prefetch_state + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + + prefetch = queue[0] + if prefetch is not None: + comfy_modules = [] + for s in prefetch.modules(): + if hasattr(s, "_v"): + comfy_modules.append(s) + + offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) + if offload_stream is None: + queue[0] = None + return + comfy.model_management.sync_stream(device, offload_stream) + queue[0] = (offload_stream, (prefetch, comfy_modules)) + +def make_prefetch_queue(queue, device): + if (comfy.model_management.NUM_STREAMS == 0 + or comfy.model_management.is_device_cpu(device) + or not comfy.model_management.device_supports_non_blocking(device)): + return None + + queue = [None] + queue + [None] + PREFETCH_QUEUES.append(queue) + return queue diff --git a/comfy/ops.py b/comfy/ops.py index 34b70a1bf..6693c11c5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -92,6 +92,38 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin cast_buffer = None cast_buffer_offset = 0 + def ensure_offload_stream(module, required_size, check_largest): + nonlocal offload_stream + nonlocal cast_buffer + + if offload_stream is None: + offload_stream = comfy.model_management.get_offload_stream(device) + if offload_stream is None or not check_largest or len(comfy_modules) != 1: + return + + current_size = 0 if cast_buffer is None else cast_buffer.size() + if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = None + if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: + comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size) + + def get_cast_buffer(buffer_size): + nonlocal offload_stream + nonlocal cast_buffer + nonlocal cast_buffer_offset + + if buffer_size == 0: + return None + + if offload_stream is None: + return torch.empty((buffer_size,), dtype=torch.uint8, device=device) + + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device) + cast_buffer_offset += buffer_size + return buffer + for s in comfy_modules: signature = comfy_aimdo.model_vbar.vbar_fault(s._v) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) @@ -126,23 +158,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) - if offload_stream is None: - offload_stream = comfy.model_management.get_offload_stream(device) - if xfer_dest is None and offload_stream is not None and cast_buffer is None: - cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) - if len(comfy_modules) == 1: - if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: - offload_stream = comfy.model_management.get_offload_stream(device) - cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) - if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: - comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size) + ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True) if xfer_dest is None: - if cast_buffer is not None: - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size, cast_buffer_offset), device) - cast_buffer_offset += dest_size - else: - xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) - offload_stream = None + xfer_dest = get_cast_buffer(dest_size) if signature is None and pin is None: comfy.pinned_memory.pin_memory(s) @@ -155,6 +173,13 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin xfer_source = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + ensure_offload_stream(s, cast_buffer_offset, False) + lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + prefetch["xfer_dest"] = xfer_dest prefetch["cast_dest"] = cast_dest prefetch["cast_geometry"] = cast_geometry @@ -164,47 +189,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin return offload_stream -def cast_prefetch_all(module, device): - prefetch_state = (module, None) - if (comfy.model_management.is_device_cpu(device) - or not comfy.model_management.device_supports_non_blocking(device)): - return (None, prefetch_state) - - comfy_modules = [] - for s in module.modules(): - if hasattr(s, "_v"): - comfy_modules.append(s) - - offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True) - comfy.model_management.sync_stream(device, offload_stream) - return (offload_stream, (module, comfy_modules)) - - -def uncast_prefetch_all(prefetch_state): - _, comfy_modules = prefetch_state - if comfy_modules is not None: - comfy.model_management.cleanup_prefetched_modules(comfy_modules) - - -def prefetch_queue_pop(queue, device, module): - consumed = queue.pop(0) - if consumed is not None: - offload_stream, prefetch_state = consumed - if offload_stream is not None: - offload_stream.wait_stream(comfy.model_management.current_stream(device)) - uncast_prefetch_all(prefetch_state) - - prefetch = queue[0] - if prefetch is not None: - queue[0] = cast_prefetch_all(prefetch, device) - - -def make_prefetch_queue(queue): - queue = [None] + queue + [None] - comfy.model_management.PREFETCH_QUEUES.append(queue) - return queue - - def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): del non_blocking @@ -237,6 +221,9 @@ def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requ lowvram_fn = getattr(s, param_key + "_lowvram_function", None) fns = getattr(s, param_key + "_function", []) + if x is None: + return None + orig = x def to_dequant(tensor, dtype): @@ -266,8 +253,7 @@ def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requ update_weight = prefetch["signature"] is not None weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight) - bias = None - if s.bias is not None: + if bias is not None: bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight) return weight, bias @@ -321,6 +307,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if not prefetched: if getattr(s, "_prefetch")["signature"] is not None: offload_device = device + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() delattr(s, "_prefetch") return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable) diff --git a/execution.py b/execution.py index e15eb4bda..de4b14a2b 100644 --- a/execution.py +++ b/execution.py @@ -15,6 +15,7 @@ import torch from comfy.cli_args import args import comfy.memory_management import comfy.model_management +import comfy.model_prefetch import comfy_aimdo.model_vbar from latent_preview import set_preview_method @@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if args.verbose == "DEBUG": comfy_aimdo.control.analyze() comfy.model_management.reset_cast_buffers() + comfy.model_prefetch.cleanup_prefetch_queues() comfy_aimdo.model_vbar.vbars_reset_watermark_limits() if has_pending_tasks: