mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
Recast disk weights with dtype overrides
This commit is contained in:
parent
1ec01dd023
commit
5c60954448
@ -665,8 +665,17 @@ def ensure_module_materialized(
|
||||
if current is None:
|
||||
continue
|
||||
if current.device.type != "meta" and current.device == target_device:
|
||||
if current.device.type == "cpu":
|
||||
CACHE.touch(module, name)
|
||||
if dtype_override is not None and current.dtype != dtype_override:
|
||||
tensor = current.to(device=target_device, dtype=dtype_override)
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||
if tensor.device.type == "cpu":
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
else:
|
||||
if current.device.type == "cpu":
|
||||
CACHE.touch(module, name)
|
||||
continue
|
||||
meta_nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if meta_nbytes is None:
|
||||
@ -704,7 +713,7 @@ def ensure_module_materialized(
|
||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
|
||||
|
||||
|
||||
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
|
||||
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
|
||||
if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE:
|
||||
return
|
||||
input_dtype = _find_tensor_dtype(args, kwargs)
|
||||
@ -740,11 +749,15 @@ def evict_ram_cache(bytes_to_free: int):
|
||||
return CACHE.evict_bytes(bytes_to_free)
|
||||
|
||||
|
||||
def materialize_module_tree(module: torch.nn.Module, target_device: torch.device):
|
||||
def materialize_module_tree(
|
||||
module: torch.nn.Module,
|
||||
target_device: torch.device,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
):
|
||||
if not disk_weights_enabled():
|
||||
return
|
||||
for submodule in module.modules():
|
||||
ensure_module_materialized(submodule, target_device)
|
||||
ensure_module_materialized(submodule, target_device, dtype_override=dtype_override)
|
||||
|
||||
|
||||
def _extract_to_device(args, kwargs) -> Optional[torch.device]:
|
||||
@ -758,6 +771,15 @@ def _extract_to_device(args, kwargs) -> Optional[torch.device]:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_to_dtype(args, kwargs) -> Optional[torch.dtype]:
|
||||
if "dtype" in kwargs and kwargs["dtype"] is not None:
|
||||
return kwargs["dtype"]
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.dtype):
|
||||
return arg
|
||||
return None
|
||||
|
||||
|
||||
def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
||||
for param in module.parameters(recurse=True):
|
||||
if param is not None and param.device.type != "meta":
|
||||
@ -771,9 +793,12 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
||||
def module_to(module: torch.nn.Module, *args, **kwargs):
|
||||
if disk_weights_enabled():
|
||||
target_device = _extract_to_device(args, kwargs)
|
||||
dtype_override = _extract_to_dtype(args, kwargs)
|
||||
if target_device is None:
|
||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||
materialize_module_tree(module, target_device)
|
||||
if dtype_override is None:
|
||||
dtype_override = getattr(module, "manual_cast_dtype", None)
|
||||
materialize_module_tree(module, target_device, dtype_override=dtype_override)
|
||||
return module.to(*args, **kwargs)
|
||||
|
||||
|
||||
@ -976,18 +1001,21 @@ def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = Fals
|
||||
if error_msgs:
|
||||
raise RuntimeError("Error(s) in loading state_dict:\n\t{}".format("\n\t".join(error_msgs)))
|
||||
|
||||
dtype_override = getattr(model, "manual_cast_dtype", None)
|
||||
for name, param in model.named_parameters(recurse=True):
|
||||
if name not in state_keys:
|
||||
continue
|
||||
meta = state_dict.meta(name)
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
|
||||
meta_dtype = dtype_override or meta.dtype
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta_dtype, device="meta")
|
||||
_replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
|
||||
|
||||
for name, buf in model.named_buffers(recurse=True):
|
||||
if buf is None or name not in state_keys:
|
||||
continue
|
||||
meta = state_dict.meta(name)
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
|
||||
meta_dtype = dtype_override or meta.dtype
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta_dtype, device="meta")
|
||||
_replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
|
||||
|
||||
register_module_weights(model, state_dict)
|
||||
|
||||
@ -180,3 +180,31 @@ def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch):
|
||||
assert len(calls) == 2
|
||||
finally:
|
||||
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
||||
|
||||
|
||||
def test_lazy_disk_weights_respects_dtype_override(tmp_path):
|
||||
if importlib.util.find_spec("fastsafetensors") is None:
|
||||
pytest.skip("fastsafetensors not installed")
|
||||
import comfy.utils
|
||||
import comfy.disk_weights
|
||||
|
||||
prev_cache = comfy.disk_weights.CACHE.max_bytes
|
||||
prev_gds = comfy.disk_weights.ALLOW_GDS
|
||||
prev_pin = comfy.disk_weights.PIN_IF_CPU
|
||||
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
|
||||
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
|
||||
|
||||
try:
|
||||
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.bfloat16), "bias": torch.zeros((4,), dtype=torch.bfloat16)})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
model = torch.nn.Linear(4, 4, bias=True)
|
||||
comfy.utils.load_state_dict(model, sd, strict=True)
|
||||
assert model.weight.device.type == "meta"
|
||||
|
||||
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"))
|
||||
assert model.weight.dtype == torch.bfloat16
|
||||
|
||||
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"), dtype_override=torch.float16)
|
||||
assert model.weight.dtype == torch.float16
|
||||
finally:
|
||||
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user