Honor dtype overrides for disk weight loads

This commit is contained in:
ifilipis 2026-01-08 20:02:27 +02:00
parent 1cc2a5733c
commit 0ad0a39f94
2 changed files with 63 additions and 10 deletions

View File

@ -47,8 +47,14 @@ class DiskTensorRef:
requires_grad: bool
is_buffer: bool
def load(self, device: torch.device, allow_gds: bool, pin_if_cpu: bool) -> torch.Tensor:
dtype = getattr(self.meta, "dtype", None)
def load(
self,
device: torch.device,
allow_gds: bool,
pin_if_cpu: bool,
dtype_override: Optional[torch.dtype] = None,
) -> torch.Tensor:
dtype = dtype_override or getattr(self.meta, "dtype", None)
if hasattr(self.state_dict, "get_tensor"):
return self.state_dict.get_tensor(
self.key,
@ -607,10 +613,33 @@ def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
return check(kwargs)
def _find_tensor_dtype(args, kwargs) -> Optional[torch.dtype]:
def check(obj):
if torch.is_tensor(obj):
return obj.dtype
if isinstance(obj, (list, tuple)):
for item in obj:
dtype = check(item)
if dtype is not None:
return dtype
if isinstance(obj, dict):
for item in obj.values():
dtype = check(item)
if dtype is not None:
return dtype
return None
dtype = check(args)
if dtype is not None:
return dtype
return check(kwargs)
def ensure_module_materialized(
module: torch.nn.Module,
target_device: torch.device,
fallback_device: Optional[torch.device] = None,
dtype_override: Optional[torch.dtype] = None,
):
lazy_state = LAZY_MODULE_STATE.get(module)
if lazy_state is not None:
@ -658,9 +687,12 @@ def ensure_module_materialized(
else:
target_for_load = target_device
if current.device.type == "meta":
tensor = disk_ref.load(target_for_load, ALLOW_GDS, PIN_IF_CPU)
tensor = disk_ref.load(target_for_load, ALLOW_GDS, PIN_IF_CPU, dtype_override=dtype_override)
else:
tensor = current.to(device=target_for_load)
if dtype_override is not None and current.dtype != dtype_override:
tensor = current.to(device=target_for_load, dtype=dtype_override)
else:
tensor = current.to(device=target_for_load)
if is_buffer:
module._buffers[name] = tensor
else:
@ -675,13 +707,21 @@ def ensure_module_materialized(
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE:
return
input_dtype = _find_tensor_dtype(args, kwargs)
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
dtype_override = manual_cast_dtype or input_dtype
if getattr(module, "comfy_cast_weights", False):
target_device = torch.device("cpu")
fallback_device = _find_tensor_device(args, kwargs)
else:
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
fallback_device = None
ensure_module_materialized(module, target_device, fallback_device=fallback_device)
ensure_module_materialized(
module,
target_device,
fallback_device=fallback_device,
dtype_override=dtype_override,
)
def attach_disk_weight_hooks(model: torch.nn.Module):
@ -745,6 +785,7 @@ def load_module_tensor(
allow_alternate: bool = True,
record_cache: bool = True,
temporary: bool = False,
dtype_override: Optional[torch.dtype] = None,
) -> Optional[torch.Tensor]:
refs = REGISTRY.get(module)
if not refs or name not in refs:
@ -760,8 +801,8 @@ def load_module_tensor(
if current is None:
return None
if current.device.type != "meta":
if current.device != device:
tensor = current.to(device=device)
if current.device != device or (dtype_override is not None and current.dtype != dtype_override):
tensor = current.to(device=device, dtype=dtype_override)
if not temporary:
if is_buffer:
module._buffers[name] = tensor
@ -809,7 +850,7 @@ def load_module_tensor(
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
return current
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU)
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=dtype_override)
if temporary:
return tensor
if is_buffer:

View File

@ -103,11 +103,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
bias_source = s.bias
if comfy.disk_weights.disk_weights_enabled():
if weight_source.device.type == "meta":
loaded = comfy.disk_weights.load_module_tensor(s, "weight", device, temporary=True)
loaded = comfy.disk_weights.load_module_tensor(
s,
"weight",
device,
temporary=True,
dtype_override=dtype,
)
if loaded is not None:
weight_source = loaded
if bias_source is not None and bias_source.device.type == "meta":
loaded_bias = comfy.disk_weights.load_module_tensor(s, "bias", device, temporary=True)
loaded_bias = comfy.disk_weights.load_module_tensor(
s,
"bias",
device,
temporary=True,
dtype_override=bias_dtype,
)
if loaded_bias is not None:
bias_source = loaded_bias