From 82e70aa3c2c753bc58ec4d8a9908b69f1b5b8a58 Mon Sep 17 00:00:00 2001 From: ifilipis <40601736+ifilipis@users.noreply.github.com> Date: Mon, 19 Jan 2026 15:18:06 +0200 Subject: [PATCH] Fix disk weight movement and pinned inflight tracking --- comfy/disk_weights.py | 361 +++++++++++++++++++++++++++++++------- comfy/model_management.py | 69 +++----- 2 files changed, 313 insertions(+), 117 deletions(-) diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py index f3b3ec31f..5f75f6b09 100644 --- a/comfy/disk_weights.py +++ b/comfy/disk_weights.py @@ -96,6 +96,7 @@ class CacheEntry: name: str size_bytes: int is_buffer: bool + device_type: str class DiskWeightCache: @@ -112,16 +113,25 @@ class DiskWeightCache: return (id(module), name) def record(self, module: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool): - if tensor.device.type != "cpu": + if tensor.device.type == "meta": return size_bytes = tensor.numel() * tensor.element_size() key = self._entry_key(module, name) if key in self._entries: entry = self._entries.pop(key) - self.current_bytes -= entry.size_bytes + if entry.device_type == "cpu": + self.current_bytes -= entry.size_bytes module_ref = weakref.ref(module, self._drop_module_entries) - self._entries[key] = CacheEntry(module_ref=module_ref, name=name, size_bytes=size_bytes, is_buffer=is_buffer) - self.current_bytes += size_bytes + device_type = tensor.device.type + self._entries[key] = CacheEntry( + module_ref=module_ref, + name=name, + size_bytes=size_bytes, + is_buffer=is_buffer, + device_type=device_type, + ) + if device_type == "cpu": + self.current_bytes += size_bytes self._evict_if_needed() def touch(self, module: torch.nn.Module, name: str): @@ -133,9 +143,10 @@ class DiskWeightCache: def evict_bytes(self, bytes_to_free: int): freed = 0 while self._entries and freed < bytes_to_free: - _, entry = self._entries.popitem(last=False) + entry = self.pop_lru(torch.device("cpu")) + if entry is None: + break freed += entry.size_bytes - self.current_bytes -= entry.size_bytes module = entry.module_ref() if module is not None: _evict_module_weight(module, entry.name, entry.is_buffer) @@ -148,8 +159,26 @@ class DiskWeightCache: to_remove.append(key) for key in to_remove: entry = self._entries.pop(key) + if entry.device_type == "cpu": + self.current_bytes -= entry.size_bytes + + def remove_entry(self, module: torch.nn.Module, name: str): + key = self._entry_key(module, name) + entry = self._entries.pop(key, None) + if entry is None: + return + if entry.device_type == "cpu": self.current_bytes -= entry.size_bytes + def pop_lru(self, device: torch.device) -> Optional[CacheEntry]: + for key, entry in self._entries.items(): + if entry.device_type == device.type: + self._entries.pop(key) + if entry.device_type == "cpu": + self.current_bytes -= entry.size_bytes + return entry + return None + def _drop_module_entries(self, module_ref: weakref.ReferenceType): to_remove = [] for key, entry in self._entries.items(): @@ -157,12 +186,14 @@ class DiskWeightCache: to_remove.append(key) for key in to_remove: entry = self._entries.pop(key) - self.current_bytes -= entry.size_bytes + if entry.device_type == "cpu": + self.current_bytes -= entry.size_bytes def _evict_if_needed(self): while self._entries and self.current_bytes > self.max_bytes: - _, entry = self._entries.popitem(last=False) - self.current_bytes -= entry.size_bytes + entry = self.pop_lru(torch.device("cpu")) + if entry is None: + break module = entry.module_ref() if module is not None: _evict_module_weight(module, entry.name, entry.is_buffer) @@ -282,7 +313,7 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = " meta = state_dict.meta(key) ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False) REGISTRY.register(submodule, name, ref) - if param.device.type == "cpu": + if param.device.type != "meta": CACHE.record(submodule, name, param, is_buffer=False) for name, buf in submodule.named_buffers(recurse=False): key = f"{module_prefix}{name}" if module_prefix else name @@ -290,7 +321,7 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = " meta = state_dict.meta(key) ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True) REGISTRY.register(submodule, name, ref) - if buf.device.type == "cpu": + if buf.device.type != "meta": CACHE.record(submodule, name, buf, is_buffer=True) @@ -354,6 +385,23 @@ def _meta_tensor(meta, dtype_override: Optional[torch.dtype] = None) -> torch.Te return torch.empty(shape, dtype=dtype, device="meta") +def _attach_disk_identity(tensor: torch.Tensor, module: torch.nn.Module, name: str, is_buffer: bool): + tensor._disk_weights_module_ref = weakref.ref(module) + tensor._disk_weights_name = name + tensor._disk_weights_is_buffer = is_buffer + + +def materialize_meta_tensor(tensor: torch.Tensor, target_device: torch.device, dtype_override: Optional[torch.dtype]): + module_ref = getattr(tensor, "_disk_weights_module_ref", None) + name = getattr(tensor, "_disk_weights_name", None) + if module_ref is None or name is None: + raise RuntimeError("Meta tensor missing disk weight identity") + module = module_ref() + if module is None: + raise RuntimeError("Disk weight module reference expired") + return load_module_tensor(module, name, target_device, dtype_override=dtype_override, temporary=False) + + def _state_dict_meta(state_dict: MutableMapping, key: str): if hasattr(state_dict, "meta"): return state_dict.meta(key) @@ -466,7 +514,6 @@ def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_byte ) safetensors_stream._reap_pinned_inflight() from . import model_management - model_management._reap_pinned_inflight() model_management.free_memory(required_bytes + headroom_bytes, device) free_after = _device_free_memory(device) freed = max(0, free_after - free_before) @@ -650,6 +697,7 @@ def register_lazy_modules(model: torch.nn.Module, state_dict): def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): safetensors_stream._reap_pinned_inflight() + from . import model_management lazy_state = LAZY_MODULE_STATE.get(module) if lazy_state is not None: CACHE.remove_module(module) @@ -657,6 +705,19 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): if refs: state = _get_materialization_state(module) for ref_name, disk_ref in refs.items(): + if ref_name in module._parameters: + current = module._parameters[ref_name] + elif ref_name in module._buffers: + current = module._buffers[ref_name] + else: + current = None + if ( + current is not None + and current.device.type == "cpu" + and current.data_ptr() in model_management.PINNED_MEMORY + ): + model_management.wait_for_pinned_tensor(current) + model_management.unpin_memory(current) shape = getattr(disk_ref.meta, "shape", None) dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None) if shape is None or dtype is None: @@ -664,8 +725,11 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): meta_tensor = torch.empty(shape, dtype=dtype, device="meta") if disk_ref.is_buffer: module._buffers[ref_name] = meta_tensor + _attach_disk_identity(meta_tensor, module, ref_name, True) else: - module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad) + param = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad) + module._parameters[ref_name] = param + _attach_disk_identity(param, module, ref_name, False) nbytes = _meta_nbytes(disk_ref.meta) if nbytes is not None: state.loaded_keys.discard(ref_name) @@ -679,7 +743,19 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): ref = REGISTRY.get(module) if not ref or name not in ref: return + CACHE.remove_entry(module, name) disk_ref = ref[name] + if is_buffer: + current = module._buffers.get(name) + else: + current = module._parameters.get(name) + if ( + current is not None + and current.device.type == "cpu" + and current.data_ptr() in model_management.PINNED_MEMORY + ): + model_management.wait_for_pinned_tensor(current) + model_management.unpin_memory(current) shape = getattr(disk_ref.meta, "shape", None) dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None) if shape is None or dtype is None: @@ -687,8 +763,11 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): meta_tensor = torch.empty(shape, dtype=dtype, device="meta") if is_buffer: module._buffers[name] = meta_tensor + _attach_disk_identity(meta_tensor, module, name, True) else: - module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad) + param = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad) + module._parameters[name] = param + _attach_disk_identity(param, module, name, False) state = _get_materialization_state(module) nbytes = _meta_nbytes(disk_ref.meta) if nbytes is not None: @@ -777,6 +856,9 @@ def ensure_module_materialized( _set_future_dtype(module, name, dtype_override) _rebuild_materialization_state(module, refs, state) free_mem_start = _device_free_memory(target_device) + from . import model_management + non_blocking = model_management.device_supports_non_blocking(target_device) + offload_stream = model_management.get_offload_stream(target_device) if non_blocking else None for name in sorted(refs.keys()): disk_ref = refs[name] if name in module._parameters: @@ -793,8 +875,7 @@ def ensure_module_materialized( if current.device.type != "meta" and current.device == target_device and ( target_dtype is None or current.dtype == target_dtype ): - if current.device.type == "cpu": - CACHE.touch(module, name) + CACHE.touch(module, name) continue meta_nbytes = _meta_nbytes(disk_ref.meta) if meta_nbytes is None: @@ -803,7 +884,6 @@ def ensure_module_materialized( if target_device.type == "cpu": _ensure_free_memory(target_device, required_bytes, RAM_HEADROOM_BYTES) else: - from . import model_management _ensure_free_memory(target_device, required_bytes, model_management.extra_reserved_memory()) target_for_load = target_device if current.device.type == "meta": @@ -813,16 +893,37 @@ def ensure_module_materialized( PIN_IF_CPU, dtype_override=target_dtype, ) + if tensor.device != target_for_load or (target_dtype is not None and tensor.dtype != target_dtype): + tensor = model_management.cast_to( + tensor, + device=target_for_load, + dtype=target_dtype, + non_blocking=non_blocking, + stream=offload_stream, + ) + if non_blocking and offload_stream is not None: + model_management.sync_stream(target_for_load, offload_stream) else: - if target_dtype is not None and current.dtype != target_dtype: - tensor = current.to(device=target_for_load, dtype=target_dtype) - else: - tensor = current.to(device=target_for_load) + if ( + current.device.type == "cpu" + and current.data_ptr() in model_management.PINNED_MEMORY + ): + model_management.wait_for_pinned_tensor(current) + model_management.unpin_memory(current) + tensor = model_management.cast_to( + current, + device=target_for_load, + dtype=target_dtype if target_dtype is not None else current.dtype, + non_blocking=non_blocking, + stream=offload_stream, + ) + if non_blocking and offload_stream is not None: + model_management.sync_stream(target_for_load, offload_stream) if is_buffer: module._buffers[name] = tensor else: module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) - if tensor.device.type == "cpu": + if tensor.device.type != "meta": CACHE.record(module, name, tensor, is_buffer=is_buffer) _rebuild_materialization_state(module, refs, state) _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized") @@ -834,10 +935,13 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}): input_dtype = _find_tensor_dtype(args, kwargs) manual_cast_dtype = getattr(module, "manual_cast_dtype", None) dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype) - if getattr(module, "comfy_cast_weights", False): + input_device = _find_tensor_device(args, kwargs) or torch.device("cpu") + if getattr(module, "comfy_patched_weights", False): + target_device = input_device + elif getattr(module, "comfy_cast_weights", False): target_device = torch.device("cpu") else: - target_device = _find_tensor_device(args, kwargs) or torch.device("cpu") + target_device = input_device ensure_module_materialized( module, target_device, @@ -859,11 +963,73 @@ def evict_ram_cache(bytes_to_free: int): if bytes_to_free <= 0: return 0 safetensors_stream._reap_pinned_inflight() - from . import model_management - model_management._reap_pinned_inflight() return CACHE.evict_bytes(bytes_to_free) +def _move_cache_entry_to_cpu(entry: CacheEntry): + module = entry.module_ref() + if module is None: + return + if entry.is_buffer: + current = module._buffers.get(entry.name) + else: + current = module._parameters.get(entry.name) + if current is None or current.device.type == "meta": + return + from . import model_management + non_blocking = model_management.device_supports_non_blocking(torch.device("cpu")) + offload_stream = model_management.get_offload_stream(torch.device("cpu")) if non_blocking else None + tensor = model_management.cast_to( + current, + device=torch.device("cpu"), + dtype=current.dtype, + non_blocking=non_blocking, + stream=offload_stream, + ) + if non_blocking and offload_stream is not None: + model_management.sync_stream(current.device, offload_stream) + if entry.is_buffer: + module._buffers[entry.name] = tensor + else: + module._parameters[entry.name] = torch.nn.Parameter(tensor, requires_grad=current.requires_grad) + CACHE.record(module, entry.name, tensor, is_buffer=entry.is_buffer) + + +def _evict_cpu_entry_to_meta(entry: CacheEntry): + module = entry.module_ref() + if module is None: + return + _evict_module_weight(module, entry.name, entry.is_buffer) + CACHE.remove_entry(module, entry.name) + + +def evict_for_budget(target_device: torch.device, required_bytes: int): + if not disk_weights_enabled() or required_bytes <= 0: + return + from . import model_management + free = model_management.get_free_memory(target_device) + if free >= required_bytes: + return + cpu_device = torch.device("cpu") + if target_device.type != "cpu": + while free < required_bytes: + entry = CACHE.pop_lru(target_device) + if entry is None: + break + free_cpu = model_management.get_free_memory(cpu_device) + if free_cpu < RAM_HEADROOM_BYTES: + CACHE.evict_bytes(RAM_HEADROOM_BYTES - free_cpu) + _move_cache_entry_to_cpu(entry) + free = model_management.get_free_memory(target_device) + else: + while free < required_bytes: + entry = CACHE.pop_lru(cpu_device) + if entry is None: + break + _evict_cpu_entry_to_meta(entry) + free = model_management.get_free_memory(target_device) + + def materialize_module_tree(module: torch.nn.Module, target_device: torch.device): if not disk_weights_enabled(): return @@ -901,8 +1067,34 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]: return None -def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None): - ensure_module_materialized(module, device_to, dtype_override=dtype_override) +def move_module_tensors( + module: torch.nn.Module, + device_to: torch.device, + dtype_override: Optional[torch.dtype] = None, + non_blocking: bool = False, +): + from . import model_management + offload_stream = None + if non_blocking and model_management.device_supports_non_blocking(device_to): + offload_stream = model_management.get_offload_stream(device_to) + + def apply_fn(tensor): + if tensor is None or tensor.device.type == "meta": + return tensor + target_dtype = dtype_override or tensor.dtype + if tensor.device == device_to and tensor.dtype == target_dtype: + return tensor + return model_management.cast_to( + tensor, + device=device_to, + dtype=target_dtype, + non_blocking=non_blocking, + stream=offload_stream, + ) + + module._apply(apply_fn) + if non_blocking and offload_stream is not None: + model_management.sync_stream(device_to, offload_stream) return module @@ -946,41 +1138,34 @@ def module_to( target_device = device or arg_device if target_device is None: target_device = _find_existing_device(module) or torch.device("cpu") + dtype_override = dtype or arg_dtype if target_device.type == "meta": + cpu_device = torch.device("cpu") for submodule in module.modules(): offload_module_weights(submodule) + move_module_tensors( + submodule, + cpu_device, + dtype_override=dtype_override, + non_blocking=non_blocking, + ) + return module + if not allow_materialize: + move_module_tensors( + module, + target_device, + dtype_override=dtype_override, + non_blocking=non_blocking, + ) return module - dtype_override = dtype or arg_dtype - to_kwargs = {} - if non_blocking: - to_kwargs["non_blocking"] = non_blocking - if memory_format is not None: - to_kwargs["memory_format"] = memory_format for submodule in module.modules(): ensure_module_materialized(submodule, target_device, dtype_override=dtype_override) - refs = REGISTRY.get(submodule) or {} - for name, param in submodule.named_parameters(recurse=False): - if name in refs: - continue - if param is None or param.device.type == "meta": - continue - if param.device != target_device or (dtype_override is not None and param.dtype != dtype_override): - if dtype_override is not None: - tensor = param.to(device=target_device, dtype=dtype_override, **to_kwargs) - else: - tensor = param.to(device=target_device, **to_kwargs) - submodule._parameters[name] = torch.nn.Parameter(tensor, requires_grad=param.requires_grad) - for name, buf in submodule.named_buffers(recurse=False): - if name in refs: - continue - if buf is None or buf.device.type == "meta": - continue - if buf.device != target_device or (dtype_override is not None and buf.dtype != dtype_override): - if dtype_override is not None: - tensor = buf.to(device=target_device, dtype=dtype_override, **to_kwargs) - else: - tensor = buf.to(device=target_device, **to_kwargs) - submodule._buffers[name] = tensor + move_module_tensors( + module, + target_device, + dtype_override=dtype_override, + non_blocking=non_blocking, + ) return module base_kwargs = dict(kwargs) if device is not None and arg_device is None: @@ -1024,11 +1209,24 @@ def load_module_tensor( from . import model_management headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory() _ensure_free_memory(device, _tensor_nbytes(current), headroom) - if target_dtype is not None and current.dtype != target_dtype: - tensor = current.to(device=device, dtype=target_dtype) - else: - tensor = current.to(device=device) + non_blocking = model_management.device_supports_non_blocking(device) + offload_stream = model_management.get_offload_stream(device) if non_blocking else None + tensor = model_management.cast_to( + current, + device=device, + dtype=target_dtype if target_dtype is not None else current.dtype, + non_blocking=non_blocking, + stream=offload_stream, + ) + if non_blocking and offload_stream is not None: + model_management.sync_stream(device, offload_stream) if not temporary: + if ( + current.device.type == "cpu" + and current.data_ptr() in model_management.PINNED_MEMORY + ): + model_management.wait_for_pinned_tensor(current) + model_management.unpin_memory(current) if is_buffer: module._buffers[name] = tensor else: @@ -1044,15 +1242,26 @@ def load_module_tensor( from . import model_management headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory() _ensure_free_memory(device, required_bytes, headroom) - + non_blocking = model_management.device_supports_non_blocking(device) + offload_stream = model_management.get_offload_stream(device) if non_blocking else None tensor = disk_ref.load(device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype) + if tensor.device != device or (target_dtype is not None and tensor.dtype != target_dtype): + tensor = model_management.cast_to( + tensor, + device=device, + dtype=target_dtype if target_dtype is not None else tensor.dtype, + non_blocking=non_blocking, + stream=offload_stream, + ) + if non_blocking and offload_stream is not None: + model_management.sync_stream(device, offload_stream) if temporary: return tensor if is_buffer: module._buffers[name] = tensor else: module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) - if tensor.device.type == "cpu" and record_cache: + if tensor.device.type != "meta" and record_cache: CACHE.record(module, name, tensor, is_buffer=is_buffer) state = _get_materialization_state(module) _rebuild_materialization_state(module, refs, state) @@ -1068,8 +1277,11 @@ def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_ attr = parts[-1] if is_buffer: module._buffers[attr] = tensor + return tensor else: - module._parameters[attr] = torch.nn.Parameter(tensor, requires_grad=requires_grad) + param = torch.nn.Parameter(tensor, requires_grad=requires_grad) + module._parameters[attr] = param + return param def _materialize_module_from_state_dict( @@ -1142,14 +1354,25 @@ def _materialize_module_from_state_dict( module.factory_kwargs["device"] = factory_device if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs))) + for name, disk_ref in refs.items(): + if name in module._parameters: + tensor = module._parameters[name] + is_buffer = False + elif name in module._buffers: + tensor = module._buffers[name] + is_buffer = True + else: + continue + if tensor is not None and tensor.device.type == "meta": + _attach_disk_identity(tensor, module, name, is_buffer) _rebuild_materialization_state(module, refs, state) lazy_state.loaded = True _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed") for name, param in module.named_parameters(recurse=False): - if param.device.type == "cpu": + if param.device.type != "meta": CACHE.record(module, name, param, is_buffer=False) for name, buf in module.named_buffers(recurse=False): - if buf is not None and buf.device.type == "cpu": + if buf is not None and buf.device.type != "meta": CACHE.record(module, name, buf, is_buffer=True) @@ -1178,14 +1401,16 @@ def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = Fals continue meta = state_dict.meta(name) meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta") - _replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad) + stored = _replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad) + _attach_disk_identity(stored, model, name, False) for name, buf in model.named_buffers(recurse=True): if buf is None or name not in state_keys: continue meta = state_dict.meta(name) meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta") - _replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False) + stored = _replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False) + _attach_disk_identity(stored, model, name, True) register_module_weights(model, state_dict) register_lazy_modules(model, state_dict) diff --git a/comfy/model_management.py b/comfy/model_management.py index a2f5a62ef..27c723f1b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -609,7 +609,7 @@ def free_memory(memory_required, device, keep_loaded=[]): free_before, device, ) - _reap_pinned_inflight() + comfy.disk_weights.evict_for_budget(device, memory_required) unloaded_model = [] can_unload = [] unloaded_models = [] @@ -1159,6 +1159,10 @@ def sync_stream(device, stream): current_stream(device).wait_stream(stream) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): + if comfy.disk_weights.disk_weights_enabled() and weight.device.type == "meta": + target_device = device if device is not None else torch.device("cpu") + target_dtype = dtype if dtype is not None else weight.dtype + weight = comfy.disk_weights.materialize_meta_tensor(weight, target_device, target_dtype) if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: @@ -1171,7 +1175,6 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy) - if stream is not None: wf_context = stream if hasattr(wf_context, "as_context"): @@ -1182,15 +1185,13 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str else: r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) - if ( - non_blocking - and is_device_cuda(device) - and is_device_cpu(weight.device) - and weight.is_pinned() - ): - record_stream = stream if stream is not None else current_stream(device) + if non_blocking and (is_device_cuda(device) or is_device_cuda(weight.device)): + record_stream = stream if stream is not None else current_stream(device if is_device_cuda(device) else weight.device) if record_stream is not None: - _track_pinned_inflight(record_stream, weight) + if is_device_cpu(weight.device) and weight.is_pinned(): + _record_pinned_event(weight.data_ptr(), record_stream) + if is_device_cpu(r.device) and r.is_pinned(): + _record_pinned_event(r.data_ptr(), record_stream) return r def cast_to_device(tensor, device, dtype, copy=False): @@ -1201,8 +1202,7 @@ def cast_to_device(tensor, device, dtype, copy=False): PINNED_MEMORY = {} TOTAL_PINNED_MEMORY = 0 MAX_PINNED_MEMORY = -1 -PINNED_INFLIGHT = collections.deque() -DEFERRED_UNPIN = collections.deque() +PINNED_IN_FLIGHT = {} if not args.disable_pinned_memory: if is_nvidia() or is_amd(): if WINDOWS: @@ -1270,17 +1270,17 @@ def pin_memory(tensor): return False -def _track_pinned_inflight(stream, tensor): +def _record_pinned_event(ptr, stream): + events = PINNED_IN_FLIGHT.setdefault(ptr, []) event = torch.cuda.Event() event.record(stream) - PINNED_INFLIGHT.append((event, tensor)) + events.append(event) -def _tensor_inflight(tensor): +def wait_for_pinned_tensor(tensor): ptr = tensor.data_ptr() - for _, inflight_tensor in PINNED_INFLIGHT: - if inflight_tensor.data_ptr() == ptr: - return True - return False + events = PINNED_IN_FLIGHT.pop(ptr, []) + for event in events: + event.synchronize() def _unpin_memory_now(tensor): global TOTAL_PINNED_MEMORY @@ -1313,32 +1313,6 @@ def _unpin_memory_now(tensor): return False -def _retry_deferred_unpins(): - if not DEFERRED_UNPIN: - return - remaining = collections.deque() - while DEFERRED_UNPIN: - tensor = DEFERRED_UNPIN.popleft() - if _tensor_inflight(tensor): - remaining.append(tensor) - continue - if not _unpin_memory_now(tensor): - remaining.append(tensor) - DEFERRED_UNPIN.extend(remaining) - -def _reap_pinned_inflight(): - if not PINNED_INFLIGHT: - _retry_deferred_unpins() - return - remaining = collections.deque() - while PINNED_INFLIGHT: - event, tensor = PINNED_INFLIGHT.popleft() - if event.query(): - continue - remaining.append((event, tensor)) - PINNED_INFLIGHT.extend(remaining) - _retry_deferred_unpins() - def unpin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: @@ -1347,10 +1321,7 @@ def unpin_memory(tensor): if not is_device_cpu(tensor.device): return False - _reap_pinned_inflight() - if _tensor_inflight(tensor): - DEFERRED_UNPIN.append(tensor) - return False + wait_for_pinned_tensor(tensor) return _unpin_memory_now(tensor) def sage_attention_enabled():