From 2d96b2fdf1f6aef792e3e69e3686ff9b824189a7 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 22 Jan 2026 00:03:01 +1000 Subject: [PATCH] MPDynamic: Add support for model defined dtype If the model defines a dtype that is different to what is in the state dict, respect that at load time. This is done as part of the casting process. --- comfy/memory_management.py | 29 +++++++++++++++++++++++++++++ comfy/model_management.py | 4 ++-- comfy/model_patcher.py | 15 +++++++++++++-- comfy/ops.py | 21 ++++++++++++++++++++- comfy/pinned_memory.py | 2 +- 5 files changed, 65 insertions(+), 6 deletions(-) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 3765de0a1..858bd4cc7 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,5 +1,34 @@ +import math +import torch +from typing import NamedTuple + from comfy.quant_ops import QuantizedTensor +class TensorGeometry(NamedTuple): + shape: any + dtype: torch.dtype + + def element_size(self): + info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype) + return info.bits // 8 + + def numel(self): + return math.prod(self.shape) + +def tensors_to_geometries(tensors, dtype=None): + geometries = [] + for t in tensors: + if t is None or isinstance(t, QuantizedTensor): + geometries.append(t) + continue + tdtype = t.dtype + if hasattr(t, "_model_dtype"): + tdtype = t._model_dtype + if dtype is not None: + tdtype = dtype + geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype)) + return geometries + def vram_aligned_size(tensor): if isinstance(tensor, list): return sum([vram_aligned_size(t) for t in tensor]) diff --git a/comfy/model_management.py b/comfy/model_management.py index cdb9542c0..527197447 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1190,12 +1190,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str assert r is None assert stream is None - r = torch.empty_like(weight, dtype=dtype, device=device) + r = torch.empty_like(weight, dtype=weight._model_dtype, device=device) signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) if signature is not None: raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like([weight], raw_tensor)[0] + v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0] if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): #always take a deep copy even if _v is good, as we have no reasonable point to unpin diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6b25436f2..1ef5b6661 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1504,6 +1504,8 @@ class ModelPatcherDynamic(ModelPatcher): weight_function = [] weight, _, _ = get_key_weight(self.model, key) + if weight is None: + return 0 if key in self.patches: setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) num_patches += 1 @@ -1513,7 +1515,12 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.weight_wrapper_patches: weight_function.extend(self.weight_wrapper_patches[key]) setattr(m, param_key + "_function", weight_function) - return comfy.memory_management.vram_aligned_size(weight) + geometry = weight + if not isinstance(weight, QuantizedTensor): + model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype) + weight._model_dtype = model_dtype + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + return comfy.memory_management.vram_aligned_size(geometry) if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True @@ -1535,9 +1542,13 @@ class ModelPatcherDynamic(ModelPatcher): weight, _, _ = get_key_weight(self.model, key) weight.seed_key = key set_dirty(weight, dirty) - weight_size = weight.numel() * weight.element_size() + geometry = weight + model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype) + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + weight_size = geometry.numel() * geometry.element_size() if vbar is not None and not hasattr(weight, "_v"): weight._v = vbar.alloc(weight_size) + weight._model_dtype = model_dtype allocated_size += weight_size logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") diff --git a/comfy/ops.py b/comfy/ops.py index 5bdb54cc6..9710b2de2 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -81,6 +81,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): offload_stream = None xfer_dest = None + cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) signature = comfy_aimdo.model_vbar.vbar_fault(s._v) if signature is not None: @@ -88,6 +89,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) if not resident: + cast_dest = None xfer_source = [ s.weight, s.bias ] @@ -95,6 +97,16 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu if pin is not None: xfer_source = [ pin ] resident = True #If pinned data exists, it always has LowVram already applied + else: + for data, geometry in zip([ s.weight, s.bias ], cast_geometry): + if data is None: + continue + if data.dtype != geometry.dtype: + cast_dest = xfer_dest + if cast_dest is None: + cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) + xfer_dest = None + break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) offload_stream = comfy.model_management.get_offload_stream(device) @@ -111,6 +123,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) comfy.model_management.sync_stream(device, offload_stream) + if cast_dest is not None: + for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest), + comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + if post_cast is not None: + post_cast.copy_(pre_cast) + xfer_dest = cast_dest + pin = None if signature is not None: #If we are able to increase our load level (e.g. user reduces resolution or batch number) @@ -122,7 +141,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu comfy.pinned_memory.pin_memory(s) pin = comfy.pinned_memory.get_pin(s) - params = comfy.memory_management.interpret_gathered_like([s.weight, s.bias], xfer_dest) + params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) weight = params[0] bias = params[1] diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 650e27a10..0650e4d1a 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -11,7 +11,7 @@ def pin_memory(module): if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: return #FIXME: This is a RAM cache trigger event - params = [ module.weight, module.bias ] + params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ]) size = comfy.memory_management.vram_aligned_size(params) pin = torch.empty((size,), dtype=torch.uint8) if comfy.model_management.pin_memory(pin):