diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py index e25c0c5bd..6ec0d9a01 100644 --- a/comfy/disk_weights.py +++ b/comfy/disk_weights.py @@ -47,8 +47,14 @@ class DiskTensorRef: requires_grad: bool is_buffer: bool - def load(self, device: torch.device, allow_gds: bool, pin_if_cpu: bool) -> torch.Tensor: - dtype = getattr(self.meta, "dtype", None) + def load( + self, + device: torch.device, + allow_gds: bool, + pin_if_cpu: bool, + dtype_override: Optional[torch.dtype] = None, + ) -> torch.Tensor: + dtype = dtype_override or getattr(self.meta, "dtype", None) if hasattr(self.state_dict, "get_tensor"): return self.state_dict.get_tensor( self.key, @@ -607,10 +613,33 @@ def _find_tensor_device(args, kwargs) -> Optional[torch.device]: return check(kwargs) +def _find_tensor_dtype(args, kwargs) -> Optional[torch.dtype]: + def check(obj): + if torch.is_tensor(obj): + return obj.dtype + if isinstance(obj, (list, tuple)): + for item in obj: + dtype = check(item) + if dtype is not None: + return dtype + if isinstance(obj, dict): + for item in obj.values(): + dtype = check(item) + if dtype is not None: + return dtype + return None + + dtype = check(args) + if dtype is not None: + return dtype + return check(kwargs) + + def ensure_module_materialized( module: torch.nn.Module, target_device: torch.device, fallback_device: Optional[torch.device] = None, + dtype_override: Optional[torch.dtype] = None, ): lazy_state = LAZY_MODULE_STATE.get(module) if lazy_state is not None: @@ -658,9 +687,12 @@ def ensure_module_materialized( else: target_for_load = target_device if current.device.type == "meta": - tensor = disk_ref.load(target_for_load, ALLOW_GDS, PIN_IF_CPU) + tensor = disk_ref.load(target_for_load, ALLOW_GDS, PIN_IF_CPU, dtype_override=dtype_override) else: - tensor = current.to(device=target_for_load) + if dtype_override is not None and current.dtype != dtype_override: + tensor = current.to(device=target_for_load, dtype=dtype_override) + else: + tensor = current.to(device=target_for_load) if is_buffer: module._buffers[name] = tensor else: @@ -675,13 +707,21 @@ def ensure_module_materialized( def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs): if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE: return + input_dtype = _find_tensor_dtype(args, kwargs) + manual_cast_dtype = getattr(module, "manual_cast_dtype", None) + dtype_override = manual_cast_dtype or input_dtype if getattr(module, "comfy_cast_weights", False): target_device = torch.device("cpu") fallback_device = _find_tensor_device(args, kwargs) else: target_device = _find_tensor_device(args, kwargs) or torch.device("cpu") fallback_device = None - ensure_module_materialized(module, target_device, fallback_device=fallback_device) + ensure_module_materialized( + module, + target_device, + fallback_device=fallback_device, + dtype_override=dtype_override, + ) def attach_disk_weight_hooks(model: torch.nn.Module): @@ -745,6 +785,7 @@ def load_module_tensor( allow_alternate: bool = True, record_cache: bool = True, temporary: bool = False, + dtype_override: Optional[torch.dtype] = None, ) -> Optional[torch.Tensor]: refs = REGISTRY.get(module) if not refs or name not in refs: @@ -760,8 +801,8 @@ def load_module_tensor( if current is None: return None if current.device.type != "meta": - if current.device != device: - tensor = current.to(device=device) + if current.device != device or (dtype_override is not None and current.dtype != dtype_override): + tensor = current.to(device=device, dtype=dtype_override) if not temporary: if is_buffer: module._buffers[name] = tensor @@ -809,7 +850,7 @@ def load_module_tensor( _log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred") return current - tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU) + tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=dtype_override) if temporary: return tensor if is_buffer: diff --git a/comfy/ops.py b/comfy/ops.py index 06f28317a..67f151381 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -103,11 +103,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of bias_source = s.bias if comfy.disk_weights.disk_weights_enabled(): if weight_source.device.type == "meta": - loaded = comfy.disk_weights.load_module_tensor(s, "weight", device, temporary=True) + loaded = comfy.disk_weights.load_module_tensor( + s, + "weight", + device, + temporary=True, + dtype_override=dtype, + ) if loaded is not None: weight_source = loaded if bias_source is not None and bias_source.device.type == "meta": - loaded_bias = comfy.disk_weights.load_module_tensor(s, "bias", device, temporary=True) + loaded_bias = comfy.disk_weights.load_module_tensor( + s, + "bias", + device, + temporary=True, + dtype_override=bias_dtype, + ) if loaded_bias is not None: bias_source = loaded_bias