diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 6f2ba41ef..3fb87b4a3 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import ( from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import comfy.ldm.common_dit +import comfy.model_prefetch class CompressedTimestep: """Store video timestep embeddings in compressed form using per-frame indexing.""" @@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel): """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options) # Process transformer blocks for i, block in enumerate(self.transformer_blocks): + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): @@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel): a_prompt_timestep=a_prompt_timestep, ) + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None) + return [vx, ax] def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): diff --git a/comfy/lora.py b/comfy/lora.py index e4337c729..db8f16bcb 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -17,6 +17,7 @@ """ from __future__ import annotations +import comfy.memory_management import comfy.utils import comfy.model_management import comfy.model_base @@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori weight = old_weight return weight + +def prefetch_prepared_value(value, allocate_buffer, stream): + if isinstance(value, torch.Tensor): + dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + return comfy.memory_management.interpret_gathered_like([value], dest)[0] + elif isinstance(value, weight_adapter.WeightAdapterBase): + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + elif isinstance(value, tuple): + return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + elif isinstance(value, list): + return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + + return value diff --git a/comfy/model_base.py b/comfy/model_base.py index 50dab5782..b61a2aa09 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -214,6 +214,11 @@ class BaseModel(torch.nn.Module): if "latent_shapes" in extra_conds: xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) + transformer_options = transformer_options.copy() + transformer_options["prefetch_dynamic_vbars"] = ( + self.current_patcher is not None and self.current_patcher.is_dynamic() + ) + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) if len(model_output) > 1 and not torch.is_tensor(model_output): model_output, _ = utils.pack_latents(model_output) diff --git a/comfy/model_management.py b/comfy/model_management.py index f86e2a4aa..02ad66656 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.vram_buffer class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -1175,6 +1176,10 @@ stream_counters = {} STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) +STREAM_AIMDO_CAST_BUFFERS = {} +LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + +DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 def get_cast_buffer(offload_stream, device, size, ref): global LARGEST_CASTED_WEIGHT @@ -1208,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref): return cast_buffer +def get_aimdo_cast_buffer(offload_stream, device): + cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None: + cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) + STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer + + return cast_buffer def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT + global LARGEST_AIMDO_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) - for offload_stream in STREAM_CAST_BUFFERS: - offload_stream.synchronize() + LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): + if offload_stream is not None: + offload_stream.synchronize() synchronize() + STREAM_CAST_BUFFERS.clear() + STREAM_AIMDO_CAST_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e259aed63..7d2d6883f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -121,9 +121,20 @@ class LowVramPatch: self.patches = patches self.convert_func = convert_func # TODO: remove self.set_func = set_func + self.prepared_patches = None + + def prepare(self, allocate_buffer, stream): + self.prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + for patch in self.patches[self.key] + ] + + def clear_prepared(self): + self.prepared_patches = None def __call__(self, weight): - return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) + patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key] + return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype) LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py new file mode 100644 index 000000000..0ad35deb5 --- /dev/null +++ b/comfy/model_prefetch.py @@ -0,0 +1,65 @@ +import comfy_aimdo.model_vbar +import comfy.model_management +import comfy.ops + +PREFETCH_QUEUES = [] + +def cleanup_prefetched_modules(comfy_modules): + for s in comfy_modules: + prefetch = getattr(s, "_prefetch", None) + if prefetch is None: + continue + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() + if prefetch["signature"] is not None: + comfy_aimdo.model_vbar.vbar_unpin(s._v) + delattr(s, "_prefetch") + +def cleanup_prefetch_queues(): + global PREFETCH_QUEUES + + for queue in PREFETCH_QUEUES: + for entry in queue: + if entry is None or not isinstance(entry, tuple): + continue + _, prefetch_state = entry + comfy_modules = prefetch_state[1] + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + PREFETCH_QUEUES = [] + +def prefetch_queue_pop(queue, device, module): + if queue is None: + return + + consumed = queue.pop(0) + if consumed is not None: + offload_stream, prefetch_state = consumed + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + _, comfy_modules = prefetch_state + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + + prefetch = queue[0] + if prefetch is not None: + comfy_modules = [] + for s in prefetch.modules(): + if hasattr(s, "_v"): + comfy_modules.append(s) + + offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) + comfy.model_management.sync_stream(device, offload_stream) + queue[0] = (offload_stream, (prefetch, comfy_modules)) + +def make_prefetch_queue(queue, device, transformer_options): + if (not transformer_options.get("prefetch_dynamic_vbars", False) + or comfy.model_management.NUM_STREAMS == 0 + or comfy.model_management.is_device_cpu(device) + or not comfy.model_management.device_supports_non_blocking(device)): + return None + + queue = [None] + queue + [None] + PREFETCH_QUEUES.append(queue) + return queue diff --git a/comfy/ops.py b/comfy/ops.py index 050f7cda0..96db1411c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys): setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad)) -def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): - #vbar doesn't support CPU weights, but some custom nodes have weird paths - #that might switch the layer to the CPU and expect it to work. We have to take - #a clone conservatively as we are mmapped and some SFT files are packed misaligned - #If you are a custom node author reading this, please move your layer to the GPU - #or declare your ModelPatcher as CPU in the first place. - if comfy.model_management.is_device_cpu(device): - materialize_meta_param(s, ["weight", "bias"]) - weight = s.weight.to(dtype=dtype, copy=True) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - bias = None - if s.bias is not None: - bias = s.bias.to(dtype=bias_dtype, copy=True) - return weight, bias, (None, None, None) - +# FIXME: add n=1 cache hit fast path +def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking): offload_stream = None - xfer_dest = None + cast_buffer = None + cast_buffer_offset = 0 + + def ensure_offload_stream(module, required_size, check_largest): + nonlocal offload_stream + nonlocal cast_buffer + + if offload_stream is None: + offload_stream = comfy.model_management.get_offload_stream(device) + if offload_stream is None or not check_largest or len(comfy_modules) != 1: + return + + current_size = 0 if cast_buffer is None else cast_buffer.size() + if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = None + if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: + comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size) + + def get_cast_buffer(buffer_size): + nonlocal offload_stream + nonlocal cast_buffer + nonlocal cast_buffer_offset + + if buffer_size == 0: + return None + + if offload_stream is None: + return torch.empty((buffer_size,), dtype=torch.uint8, device=device) + + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device) + cast_buffer_offset += buffer_size + return buffer + + for s in comfy_modules: + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + prefetch = { + "signature": signature, + "resident": resident, + } - signature = comfy_aimdo.model_vbar.vbar_fault(s._v) - resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) - if signature is not None: if resident: - weight = s._v_weight - bias = s._v_bias - else: - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + s._prefetch = prefetch + continue - if not resident: materialize_meta_param(s, ["weight", "bias"]) + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_dest = None + needs_cast = False xfer_source = [ s.weight, s.bias ] @@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu if data is None: continue if data.dtype != geometry.dtype: + needs_cast = True cast_dest = xfer_dest - if cast_dest is None: - cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) xfer_dest = None break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) - offload_stream = comfy.model_management.get_offload_stream(device) - if xfer_dest is None and offload_stream is not None: - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) - if xfer_dest is None: - offload_stream = comfy.model_management.get_offload_stream(device) - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True) if xfer_dest is None: - xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) - offload_stream = None + xfer_dest = get_cast_buffer(dest_size) if signature is None and pin is None: comfy.pinned_memory.pin_memory(s) @@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_source = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) - comfy.model_management.sync_stream(device, offload_stream) - if cast_dest is not None: + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + ensure_offload_stream(s, cast_buffer_offset, False) + lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + + prefetch["xfer_dest"] = xfer_dest + prefetch["cast_dest"] = cast_dest + prefetch["cast_geometry"] = cast_geometry + prefetch["needs_cast"] = needs_cast + s._prefetch = prefetch + + return offload_stream + + +def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant): + + prefetch = getattr(s, "_prefetch", None) + + if prefetch["resident"]: + weight = s._v_weight + bias = s._v_bias + else: + xfer_dest = prefetch["xfer_dest"] + if prefetch["needs_cast"]: + cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device) for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), - comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)): if post_cast is not None: post_cast.copy_(pre_cast) xfer_dest = cast_dest - params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest) weight = params[0] bias = params[1] - if signature is not None: + if prefetch["signature"] is not None: s._v_weight = weight s._v_bias = bias - s._v_signature=signature + s._v_signature = prefetch["signature"] def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) fns = getattr(s, param_key + "_function", []) + if x is None: + return None + orig = x def to_dequant(tensor, dtype): @@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu x = f(x) return x - update_weight = signature is not None + update_weight = prefetch["signature"] is not None + weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight) + if bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight) - weight = post_cast(s, "weight", weight, dtype, resident, update_weight) - if s.bias is not None: - bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) - - #FIXME: weird offload return protocol - return weight, bias, (offload_stream, device if signature is not None else None, None) + return weight, bias def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False): @@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + def format_return(result, offloadable): + weight, bias, offload_stream = result + return (weight, bias, offload_stream) if offloadable else (weight, bias) + non_blocking = comfy.model_management.device_supports_non_blocking(device) if hasattr(s, "_v"): - return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) + + #vbar doesn't support CPU weights, but some custom nodes have weird paths + #that might switch the layer to the CPU and expect it to work. We have to take + #a clone conservatively as we are mmapped and some SFT files are packed misaligned + #If you are a custom node author reading this, please move your layer to the GPU + #or declare your ModelPatcher as CPU in the first place. + if comfy.model_management.is_device_cpu(device): + materialize_meta_param(s, ["weight", "bias"]) + weight = s.weight.to(dtype=dtype, copy=True) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None + return format_return((weight, bias, (None, None, None)), offloadable) + + prefetched = hasattr(s, "_prefetch") + offload_stream = None + offload_device = None + if not prefetched: + offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking) + comfy.model_management.sync_stream(device, offload_stream) + + weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant) + + if not prefetched: + if getattr(s, "_prefetch")["signature"] is not None: + offload_device = device + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() + delattr(s, "_prefetch") + return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): @@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of for f in s.weight_function: weight = f(weight) - if offloadable: - return weight, bias, (offload_stream, weight_a, bias_a) - else: - #Legacy function signature - return weight, bias + return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable) def uncast_bias_weight(s, weight, bias, offload_stream): diff --git a/execution.py b/execution.py index 5a6d3404c..654db8426 100644 --- a/execution.py +++ b/execution.py @@ -15,6 +15,7 @@ import torch from comfy.cli_args import args import comfy.memory_management import comfy.model_management +import comfy.model_prefetch import comfy_aimdo.model_vbar from latent_preview import set_preview_method @@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if args.verbose == "DEBUG": comfy_aimdo.control.analyze() comfy.model_management.reset_cast_buffers() + comfy.model_prefetch.cleanup_prefetch_queues() comfy_aimdo.model_vbar.vbars_reset_watermark_limits() if has_pending_tasks: diff --git a/nodes.py b/nodes.py index 99dc07227..710cccffe 100644 --- a/nodes.py +++ b/nodes.py @@ -1694,26 +1694,27 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" + def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) + dtype = comfy.model_management.intermediate_dtype() + device = comfy.model_management.intermediate_device() + components = InputImpl.VideoFromFile(image_path).get_components() if components.images.shape[0] > 0: - return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu")) + return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device)) + # This code is left here to handle animated webp which pyav does not support loading img = node_helpers.pillow(Image.open, image_path) output_images = [] output_masks = [] w, h = None, None - dtype = comfy.model_management.intermediate_dtype() - for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") if len(output_images) == 0: @@ -1728,25 +1729,15 @@ class LoadImage: if 'A' in i.getbands(): mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) - elif i.mode == 'P' and 'transparency' in i.info: - mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") output_images.append(image.to(dtype=dtype)) output_masks.append(mask.unsqueeze(0).to(dtype=dtype)) - if img.format == "MPO": - break # ignore all frames except the first one for MPO format + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) - if len(output_images) > 1: - output_image = torch.cat(output_images, dim=0) - output_mask = torch.cat(output_masks, dim=0) - else: - output_image = output_images[0] - output_mask = output_masks[0] - - return (output_image, output_mask) + return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype)) @classmethod def IS_CHANGED(s, image):