mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 16:50:17 +08:00
Fix disk-weight tiering and meta handling
This commit is contained in:
parent
95ca11fe25
commit
c3eaea0429
@ -179,7 +179,24 @@ def configure(*, allow_gds: bool, pin_if_cpu: bool, ram_headroom_bytes: int, ena
|
||||
PIN_IF_CPU = pin_if_cpu
|
||||
DISK_WEIGHTS_ENABLED = enabled
|
||||
RAM_HEADROOM_BYTES = max(0, int(ram_headroom_bytes))
|
||||
CACHE.set_limit(0 if enabled else 0)
|
||||
if enabled:
|
||||
from . import model_management
|
||||
cpu_capacity_bytes = max(0, model_management.get_total_memory(torch.device("cpu")) - RAM_HEADROOM_BYTES)
|
||||
CACHE.set_limit(cpu_capacity_bytes)
|
||||
LOGGER.debug(
|
||||
"Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d",
|
||||
enabled,
|
||||
RAM_HEADROOM_BYTES,
|
||||
cpu_capacity_bytes,
|
||||
)
|
||||
else:
|
||||
CACHE.set_limit(0)
|
||||
LOGGER.debug(
|
||||
"Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d",
|
||||
enabled,
|
||||
RAM_HEADROOM_BYTES,
|
||||
0,
|
||||
)
|
||||
if enabled:
|
||||
install_monkeypatches()
|
||||
else:
|
||||
@ -449,12 +466,14 @@ def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_byte
|
||||
)
|
||||
safetensors_stream._reap_pinned_inflight()
|
||||
from . import model_management
|
||||
model_management._reap_pinned_inflight()
|
||||
model_management.free_memory(required_bytes + headroom_bytes, device)
|
||||
free_after = _device_free_memory(device)
|
||||
freed = max(0, free_after - free_before)
|
||||
LOGGER.debug(
|
||||
"Disk weight memory freed: freed=%d free=%d device=%s",
|
||||
"Disk weight memory freed: freed=%d free_before=%d free_after=%d device=%s",
|
||||
freed,
|
||||
free_before,
|
||||
free_after,
|
||||
device,
|
||||
)
|
||||
@ -840,6 +859,8 @@ def evict_ram_cache(bytes_to_free: int):
|
||||
if bytes_to_free <= 0:
|
||||
return 0
|
||||
safetensors_stream._reap_pinned_inflight()
|
||||
from . import model_management
|
||||
model_management._reap_pinned_inflight()
|
||||
return CACHE.evict_bytes(bytes_to_free)
|
||||
|
||||
|
||||
@ -926,22 +947,41 @@ def module_to(
|
||||
if target_device is None:
|
||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||
if target_device.type == "meta":
|
||||
offload_module_weights(module)
|
||||
for submodule in module.modules():
|
||||
offload_module_weights(submodule)
|
||||
return module
|
||||
if allow_materialize:
|
||||
materialize_module_tree(module, target_device)
|
||||
base_kwargs = dict(kwargs)
|
||||
if device is not None and arg_device is None:
|
||||
base_kwargs["device"] = device
|
||||
if dtype is not None and arg_dtype is None:
|
||||
base_kwargs["dtype"] = dtype
|
||||
if non_blocking:
|
||||
base_kwargs["non_blocking"] = non_blocking
|
||||
if memory_format is not None:
|
||||
base_kwargs["memory_format"] = memory_format
|
||||
return BASE_MODULE_TO(module, *args, **base_kwargs)
|
||||
dtype_override = dtype or arg_dtype
|
||||
return move_module_tensors(module, target_device, dtype_override=dtype_override)
|
||||
to_kwargs = {}
|
||||
if non_blocking:
|
||||
to_kwargs["non_blocking"] = non_blocking
|
||||
if memory_format is not None:
|
||||
to_kwargs["memory_format"] = memory_format
|
||||
for submodule in module.modules():
|
||||
ensure_module_materialized(submodule, target_device, dtype_override=dtype_override)
|
||||
refs = REGISTRY.get(submodule) or {}
|
||||
for name, param in submodule.named_parameters(recurse=False):
|
||||
if name in refs:
|
||||
continue
|
||||
if param is None or param.device.type == "meta":
|
||||
continue
|
||||
if param.device != target_device or (dtype_override is not None and param.dtype != dtype_override):
|
||||
if dtype_override is not None:
|
||||
tensor = param.to(device=target_device, dtype=dtype_override, **to_kwargs)
|
||||
else:
|
||||
tensor = param.to(device=target_device, **to_kwargs)
|
||||
submodule._parameters[name] = torch.nn.Parameter(tensor, requires_grad=param.requires_grad)
|
||||
for name, buf in submodule.named_buffers(recurse=False):
|
||||
if name in refs:
|
||||
continue
|
||||
if buf is None or buf.device.type == "meta":
|
||||
continue
|
||||
if buf.device != target_device or (dtype_override is not None and buf.dtype != dtype_override):
|
||||
if dtype_override is not None:
|
||||
tensor = buf.to(device=target_device, dtype=dtype_override, **to_kwargs)
|
||||
else:
|
||||
tensor = buf.to(device=target_device, **to_kwargs)
|
||||
submodule._buffers[name] = tensor
|
||||
return module
|
||||
base_kwargs = dict(kwargs)
|
||||
if device is not None and arg_device is None:
|
||||
base_kwargs["device"] = device
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
import collections
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import torch
|
||||
@ -524,18 +525,14 @@ class LoadedModel:
|
||||
return True
|
||||
return False
|
||||
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True, offload_device=None):
|
||||
target_offload_device = self.model.offload_device if offload_device is None else offload_device
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
freed = self.model.partially_unload(target_offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
offload_device = None
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
offload_device = torch.device("meta")
|
||||
self.model.detach(unpatch_weights, offload_device=offload_device)
|
||||
if offload_device is not None and offload_device.type == "meta":
|
||||
logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk")
|
||||
self.model.detach(unpatch_weights, offload_device=target_offload_device)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
@ -589,27 +586,30 @@ def minimum_inference_memory():
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
free_before = get_free_memory(device)
|
||||
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||
if free_before < memory_required:
|
||||
logging.debug(
|
||||
"RAM pressure: required=%d free=%d headroom=%d",
|
||||
memory_required,
|
||||
free_before,
|
||||
headroom,
|
||||
)
|
||||
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
freed_disk = 0
|
||||
if freed_cache < memory_required:
|
||||
freed_disk = evict_ram_to_disk(memory_required - freed_cache)
|
||||
free_after = get_free_memory(device)
|
||||
freed_total = max(0, free_after - free_before)
|
||||
logging.debug(
|
||||
"RAM freed: freed=%d free=%d",
|
||||
freed_total if freed_total > 0 else freed_cache + freed_disk,
|
||||
free_after,
|
||||
)
|
||||
if is_device_cpu(device):
|
||||
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||
if free_before < memory_required:
|
||||
logging.debug(
|
||||
"Disk weights RAM pressure: required=%d free=%d headroom=%d device=%s",
|
||||
memory_required,
|
||||
free_before,
|
||||
headroom,
|
||||
device,
|
||||
)
|
||||
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
if freed_cache < memory_required:
|
||||
evict_ram_to_disk(memory_required - freed_cache, keep_loaded=keep_loaded)
|
||||
elif is_device_cuda(device) or is_device_xpu(device):
|
||||
if free_before < memory_required:
|
||||
logging.debug(
|
||||
"Disk weights VRAM pressure: required=%d free=%d device=%s",
|
||||
memory_required,
|
||||
free_before,
|
||||
device,
|
||||
)
|
||||
_reap_pinned_inflight()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -630,7 +630,10 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
offload_device = None
|
||||
if comfy.disk_weights.disk_weights_enabled() and is_device_cpu(device):
|
||||
offload_device = torch.device("meta")
|
||||
if current_loaded_models[i].model_unload(memory_to_free, offload_device=offload_device):
|
||||
unloaded_model.append(i)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
@ -643,6 +646,16 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
soft_empty_cache()
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
free_after = get_free_memory(device)
|
||||
freed_total = max(0, free_after - free_before)
|
||||
logging.debug(
|
||||
"Disk weights free_memory: device=%s free_before=%d free_after=%d freed=%d",
|
||||
device,
|
||||
free_before,
|
||||
free_after,
|
||||
freed_total,
|
||||
)
|
||||
return unloaded_models
|
||||
|
||||
|
||||
@ -826,8 +839,6 @@ def dtype_size(dtype):
|
||||
return dtype_size
|
||||
|
||||
def unet_offload_device():
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
return torch.device("meta")
|
||||
if vram_state == VRAMState.HIGH_VRAM:
|
||||
return get_torch_device()
|
||||
else:
|
||||
@ -932,8 +943,6 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
||||
return torch.float32
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
return torch.device("meta")
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
else:
|
||||
@ -994,8 +1003,6 @@ def vae_device():
|
||||
return get_torch_device()
|
||||
|
||||
def vae_offload_device():
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
return torch.device("meta")
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
else:
|
||||
@ -1175,6 +1182,15 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
else:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
if (
|
||||
non_blocking
|
||||
and is_device_cuda(device)
|
||||
and is_device_cpu(weight.device)
|
||||
and weight.is_pinned()
|
||||
):
|
||||
record_stream = stream if stream is not None else current_stream(device)
|
||||
if record_stream is not None:
|
||||
_track_pinned_inflight(record_stream, weight)
|
||||
return r
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
@ -1185,6 +1201,8 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
PINNED_MEMORY = {}
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
MAX_PINNED_MEMORY = -1
|
||||
PINNED_INFLIGHT = collections.deque()
|
||||
DEFERRED_UNPIN = collections.deque()
|
||||
if not args.disable_pinned_memory:
|
||||
if is_nvidia() or is_amd():
|
||||
if WINDOWS:
|
||||
@ -1252,7 +1270,19 @@ def pin_memory(tensor):
|
||||
|
||||
return False
|
||||
|
||||
def unpin_memory(tensor):
|
||||
def _track_pinned_inflight(stream, tensor):
|
||||
event = torch.cuda.Event()
|
||||
event.record(stream)
|
||||
PINNED_INFLIGHT.append((event, tensor))
|
||||
|
||||
def _tensor_inflight(tensor):
|
||||
ptr = tensor.data_ptr()
|
||||
for _, inflight_tensor in PINNED_INFLIGHT:
|
||||
if inflight_tensor.data_ptr() == ptr:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _unpin_memory_now(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
@ -1283,6 +1313,46 @@ def unpin_memory(tensor):
|
||||
|
||||
return False
|
||||
|
||||
def _retry_deferred_unpins():
|
||||
if not DEFERRED_UNPIN:
|
||||
return
|
||||
remaining = collections.deque()
|
||||
while DEFERRED_UNPIN:
|
||||
tensor = DEFERRED_UNPIN.popleft()
|
||||
if _tensor_inflight(tensor):
|
||||
remaining.append(tensor)
|
||||
continue
|
||||
if not _unpin_memory_now(tensor):
|
||||
remaining.append(tensor)
|
||||
DEFERRED_UNPIN.extend(remaining)
|
||||
|
||||
def _reap_pinned_inflight():
|
||||
if not PINNED_INFLIGHT:
|
||||
_retry_deferred_unpins()
|
||||
return
|
||||
remaining = collections.deque()
|
||||
while PINNED_INFLIGHT:
|
||||
event, tensor = PINNED_INFLIGHT.popleft()
|
||||
if event.query():
|
||||
continue
|
||||
remaining.append((event, tensor))
|
||||
PINNED_INFLIGHT.extend(remaining)
|
||||
_retry_deferred_unpins()
|
||||
|
||||
def unpin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
_reap_pinned_inflight()
|
||||
if _tensor_inflight(tensor):
|
||||
DEFERRED_UNPIN.append(tensor)
|
||||
return False
|
||||
return _unpin_memory_now(tensor)
|
||||
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
|
||||
@ -621,6 +621,16 @@ class ModelPatcher:
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
||||
if comfy.disk_weights.disk_weights_enabled() and weight is not None and weight.device.type == "meta":
|
||||
parts = key.split(".")
|
||||
param_name = parts[-1]
|
||||
module = self.model
|
||||
for part in parts[:-1]:
|
||||
module = getattr(module, part)
|
||||
target_device = device_to or self.offload_device or torch.device("cpu")
|
||||
comfy.disk_weights.load_module_tensor(module, param_name, device=target_device)
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
@ -650,8 +660,8 @@ class ModelPatcher:
|
||||
def unpin_weight(self, key):
|
||||
if key in self.pinned:
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
comfy.model_management.unpin_memory(weight)
|
||||
self.pinned.remove(key)
|
||||
if comfy.model_management.unpin_memory(weight):
|
||||
self.pinned.remove(key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
for key in list(self.pinned):
|
||||
@ -885,9 +895,16 @@ class ModelPatcher:
|
||||
NS = comfy.model_management.NUM_STREAMS
|
||||
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
|
||||
remaining_ram = None
|
||||
cpu_device = torch.device("cpu")
|
||||
if device_to is not None and comfy.model_management.is_device_cpu(device_to):
|
||||
remaining_ram = comfy.model_management.get_free_memory(device_to)
|
||||
|
||||
def offload_module_tree(module):
|
||||
freed = 0
|
||||
for submodule in module.modules():
|
||||
freed += comfy.disk_weights.offload_module_weights(submodule)
|
||||
return freed
|
||||
|
||||
for unload in unload_list:
|
||||
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||
break
|
||||
@ -922,15 +939,28 @@ class ModelPatcher:
|
||||
cast_weight = self.force_cast_weights
|
||||
freed_bytes = module_mem
|
||||
if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
|
||||
freed_bytes = comfy.disk_weights.offload_module_weights(m)
|
||||
freed_bytes = offload_module_tree(m)
|
||||
if freed_bytes == 0:
|
||||
freed_bytes = module_mem
|
||||
else:
|
||||
if remaining_ram is not None and remaining_ram < module_mem and comfy.disk_weights.disk_weights_enabled():
|
||||
logging.info("Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.", n, module_mem / (1024 * 1024), remaining_ram / (1024 * 1024))
|
||||
freed_bytes = comfy.disk_weights.offload_module_weights(m)
|
||||
if freed_bytes == 0:
|
||||
freed_bytes = module_mem
|
||||
if remaining_ram is not None and comfy.disk_weights.disk_weights_enabled():
|
||||
required_bytes = module_mem
|
||||
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||
comfy.model_management.free_memory(required_bytes + headroom, cpu_device, keep_loaded=[self])
|
||||
remaining_ram = comfy.model_management.get_free_memory(cpu_device)
|
||||
if remaining_ram < required_bytes:
|
||||
logging.info(
|
||||
"Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.",
|
||||
n,
|
||||
required_bytes / (1024 * 1024),
|
||||
remaining_ram / (1024 * 1024),
|
||||
)
|
||||
freed_bytes = offload_module_tree(m)
|
||||
if freed_bytes == 0:
|
||||
freed_bytes = module_mem
|
||||
else:
|
||||
comfy.disk_weights.move_module_tensors(m, device_to)
|
||||
remaining_ram = max(0, remaining_ram - required_bytes)
|
||||
else:
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
comfy.disk_weights.move_module_tensors(m, device_to)
|
||||
@ -980,16 +1010,22 @@ class ModelPatcher:
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
offload_device = self.offload_device
|
||||
if comfy.disk_weights.disk_weights_enabled() and device_to is not None:
|
||||
if comfy.model_management.is_device_cpu(device_to):
|
||||
offload_device = torch.device("meta")
|
||||
else:
|
||||
offload_device = torch.device("cpu")
|
||||
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
||||
# TODO: force_patch_weights should not unload + reload full model
|
||||
used = self.model.model_loaded_weight_memory
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.unpatch_model(offload_device, unpatch_weights=unpatch_weights)
|
||||
if unpatch_weights:
|
||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||
|
||||
self.patch_model(load_weights=False)
|
||||
if extra_memory < 0 and not unpatch_weights:
|
||||
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
self.partially_unload(offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
return 0
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||
|
||||
32
comfy/ops.py
32
comfy/ops.py
@ -535,10 +535,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
key = f"{prefix}{param_name}"
|
||||
value = state_dict.pop(key, None)
|
||||
if value is not None:
|
||||
if value.device.type != "meta":
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return value
|
||||
|
||||
@ -555,16 +554,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None and layer_conf.device.type != "meta":
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
elif layer_conf is not None:
|
||||
layer_conf = None
|
||||
|
||||
if layer_conf is None:
|
||||
if weight.device.type == "meta":
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
@ -610,13 +604,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
if weight.device.type == "meta":
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
@ -626,10 +617,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
if _v.device.type == "meta":
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False))
|
||||
else:
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user