mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-17 01:30:50 +08:00
Honor dtype overrides for disk weight loads
This commit is contained in:
parent
1cc2a5733c
commit
0ad0a39f94
@ -47,8 +47,14 @@ class DiskTensorRef:
|
||||
requires_grad: bool
|
||||
is_buffer: bool
|
||||
|
||||
def load(self, device: torch.device, allow_gds: bool, pin_if_cpu: bool) -> torch.Tensor:
|
||||
dtype = getattr(self.meta, "dtype", None)
|
||||
def load(
|
||||
self,
|
||||
device: torch.device,
|
||||
allow_gds: bool,
|
||||
pin_if_cpu: bool,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
dtype = dtype_override or getattr(self.meta, "dtype", None)
|
||||
if hasattr(self.state_dict, "get_tensor"):
|
||||
return self.state_dict.get_tensor(
|
||||
self.key,
|
||||
@ -607,10 +613,33 @@ def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
|
||||
return check(kwargs)
|
||||
|
||||
|
||||
def _find_tensor_dtype(args, kwargs) -> Optional[torch.dtype]:
|
||||
def check(obj):
|
||||
if torch.is_tensor(obj):
|
||||
return obj.dtype
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
dtype = check(item)
|
||||
if dtype is not None:
|
||||
return dtype
|
||||
if isinstance(obj, dict):
|
||||
for item in obj.values():
|
||||
dtype = check(item)
|
||||
if dtype is not None:
|
||||
return dtype
|
||||
return None
|
||||
|
||||
dtype = check(args)
|
||||
if dtype is not None:
|
||||
return dtype
|
||||
return check(kwargs)
|
||||
|
||||
|
||||
def ensure_module_materialized(
|
||||
module: torch.nn.Module,
|
||||
target_device: torch.device,
|
||||
fallback_device: Optional[torch.device] = None,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
):
|
||||
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||
if lazy_state is not None:
|
||||
@ -658,9 +687,12 @@ def ensure_module_materialized(
|
||||
else:
|
||||
target_for_load = target_device
|
||||
if current.device.type == "meta":
|
||||
tensor = disk_ref.load(target_for_load, ALLOW_GDS, PIN_IF_CPU)
|
||||
tensor = disk_ref.load(target_for_load, ALLOW_GDS, PIN_IF_CPU, dtype_override=dtype_override)
|
||||
else:
|
||||
tensor = current.to(device=target_for_load)
|
||||
if dtype_override is not None and current.dtype != dtype_override:
|
||||
tensor = current.to(device=target_for_load, dtype=dtype_override)
|
||||
else:
|
||||
tensor = current.to(device=target_for_load)
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
@ -675,13 +707,21 @@ def ensure_module_materialized(
|
||||
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
|
||||
if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE:
|
||||
return
|
||||
input_dtype = _find_tensor_dtype(args, kwargs)
|
||||
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
|
||||
dtype_override = manual_cast_dtype or input_dtype
|
||||
if getattr(module, "comfy_cast_weights", False):
|
||||
target_device = torch.device("cpu")
|
||||
fallback_device = _find_tensor_device(args, kwargs)
|
||||
else:
|
||||
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
||||
fallback_device = None
|
||||
ensure_module_materialized(module, target_device, fallback_device=fallback_device)
|
||||
ensure_module_materialized(
|
||||
module,
|
||||
target_device,
|
||||
fallback_device=fallback_device,
|
||||
dtype_override=dtype_override,
|
||||
)
|
||||
|
||||
|
||||
def attach_disk_weight_hooks(model: torch.nn.Module):
|
||||
@ -745,6 +785,7 @@ def load_module_tensor(
|
||||
allow_alternate: bool = True,
|
||||
record_cache: bool = True,
|
||||
temporary: bool = False,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
refs = REGISTRY.get(module)
|
||||
if not refs or name not in refs:
|
||||
@ -760,8 +801,8 @@ def load_module_tensor(
|
||||
if current is None:
|
||||
return None
|
||||
if current.device.type != "meta":
|
||||
if current.device != device:
|
||||
tensor = current.to(device=device)
|
||||
if current.device != device or (dtype_override is not None and current.dtype != dtype_override):
|
||||
tensor = current.to(device=device, dtype=dtype_override)
|
||||
if not temporary:
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
@ -809,7 +850,7 @@ def load_module_tensor(
|
||||
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
||||
return current
|
||||
|
||||
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU)
|
||||
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=dtype_override)
|
||||
if temporary:
|
||||
return tensor
|
||||
if is_buffer:
|
||||
|
||||
16
comfy/ops.py
16
comfy/ops.py
@ -103,11 +103,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
bias_source = s.bias
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
if weight_source.device.type == "meta":
|
||||
loaded = comfy.disk_weights.load_module_tensor(s, "weight", device, temporary=True)
|
||||
loaded = comfy.disk_weights.load_module_tensor(
|
||||
s,
|
||||
"weight",
|
||||
device,
|
||||
temporary=True,
|
||||
dtype_override=dtype,
|
||||
)
|
||||
if loaded is not None:
|
||||
weight_source = loaded
|
||||
if bias_source is not None and bias_source.device.type == "meta":
|
||||
loaded_bias = comfy.disk_weights.load_module_tensor(s, "bias", device, temporary=True)
|
||||
loaded_bias = comfy.disk_weights.load_module_tensor(
|
||||
s,
|
||||
"bias",
|
||||
device,
|
||||
temporary=True,
|
||||
dtype_override=bias_dtype,
|
||||
)
|
||||
if loaded_bias is not None:
|
||||
bias_source = loaded_bias
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user