diff --git a/comfy/model_management.py b/comfy/model_management.py index ca4318a45..145a32080 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -524,7 +524,7 @@ def ensure_pin_budget(size, evict_active=False): shortfall += PIN_PRESSURE_HYSTERESIS for loaded_model in reversed(current_loaded_models): model = loaded_model.model - if model is not None and model.is_dynamic() and (evict_active or not model.dynamic_pins[model.load_device]["active"]): + if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): shortfall -= model.partially_unload_ram(shortfall) if shortfall <= 0: break @@ -1237,9 +1237,9 @@ def reset_cast_buffers(): for loaded_model in current_loaded_models: model = loaded_model.model if model is not None and model.is_dynamic(): - model.dynamic_pins[model.load_device]["active"] = False + model.model.dynamic_pins[model.load_device]["active"] = False model.partially_unload_ram(1e30, subsets=[ "patches" ]) - model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0), []) + model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), []) STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d18f76456..4f200d6b6 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1508,8 +1508,8 @@ class ModelPatcherDynamic(ModelPatcher): self.model.dynamic_pins = {} if self.load_device not in self.model.dynamic_pins: self.model.dynamic_pins[self.load_device] = { - "weights": (comfy_aimdo.host_buffer.HostBuffer(0), []), - "patches": (comfy_aimdo.host_buffer.HostBuffer(0), []), + "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024), []), + "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), []), "failed": False, "active": False, } @@ -1602,7 +1602,9 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.patches: if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: return (True, 0) - setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + lowvram_patch = LowVramPatch(key, self.patches) + lowvram_patch._pin_state = pin_state + setattr(m, param_key + "_lowvram_function", lowvram_patch) num_patches += 1 else: setattr(m, param_key + "_lowvram_function", None) diff --git a/comfy/ops.py b/comfy/ops.py index 6b873cc66..9ffd418a6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -183,34 +183,45 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if xfer_dest is None: xfer_dest = get_cast_buffer(dest_size) - if pin is None: - if signature is None: - comfy.pinned_memory.pin_memory(s) - pin = comfy.pinned_memory.get_pin(s) - if pin is not None: - comfy.model_management.cast_to_gathered(xfer_source, pin) - xfer_source = [ pin ] - if pin is None: - pin_offset = get_stream_pin_buffer_offset(dest_size) - if pin_offset is not None: - stream_pin_queue.append((xfer_source, pin_offset, dest_size, xfer_dest)) - xfer_source = None + def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream): + if xfer_source is not None: + if getattr(xfer_source, "is_lowvram_patch", False): + xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) + else: + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream) - if xfer_source is not None: - comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + def handle_pin_miss(m, source, dest, subset="weights", size=None): + pin = None + if signature is None: + comfy.pinned_memory.pin_memory(m, subset=subset, size=size) + pin = comfy.pinned_memory.get_pin(m, subset=subset) + if pin is not None: + cast_maybe_lowvram_patch(source, pin, None) + return [ pin ] + if pin is None: + pin_offset = get_stream_pin_buffer_offset(size) + if pin_offset is not None: + stream_pin_queue.append((source, pin_offset, size, dest)) + return None + return source + + if pin is None: + xfer_source = handle_pin_miss(s, xfer_source, xfer_dest, size=dest_size) + + cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream) for param_key in ("weight", "bias"): - lowvram_fn = getattr(s, param_key + "_lowvram_function", None) - if lowvram_fn is not None: + lowvram_source = getattr(s, param_key + "_lowvram_function", None) + if lowvram_source is not None: ensure_offload_stream(s, cast_buffer_offset, False) - lowvram_size = lowvram_fn.memory_required() + lowvram_size = lowvram_source.memory_required() lowvram_dest = get_cast_buffer(lowvram_size) - lowvram_fn.prepare(lowvram_dest, None, copy=False, commit=True) - pin_offset = get_stream_pin_buffer_offset(lowvram_size) - if pin_offset is not None: - stream_pin_queue.append((lowvram_fn, pin_offset, lowvram_size, lowvram_dest)) - else: - lowvram_fn.prepare(lowvram_dest, offload_stream, copy=True, commit=True) + lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True) + + pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches") + lowvram_source = handle_pin_miss(lowvram_source, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) if pin is None else [ pin ] + + cast_maybe_lowvram_patch(lowvram_source, lowvram_dest, offload_stream) prefetch["xfer_dest"] = xfer_dest prefetch["cast_dest"] = cast_dest @@ -218,12 +229,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin prefetch["needs_cast"] = needs_cast s._prefetch = prefetch - def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream): - if getattr(xfer_source, "is_lowvram_patch", False): - xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) - else: - comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream) - if stream_pin_offset > 0: if stream_pin_hostbuf.size < stream_pin_offset: if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):