MPDynamic: Add support for model defined dtype

If the model defines a dtype that is different to what is in the state
dict, respect that at load time. This is done as part of the casting
process.
This commit is contained in:
Rattus 2026-01-22 00:03:01 +10:00
parent 65b9729912
commit 2d96b2fdf1
5 changed files with 65 additions and 6 deletions

View File

@ -1,5 +1,34 @@
import math
import torch
from typing import NamedTuple
from comfy.quant_ops import QuantizedTensor
class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype
def element_size(self):
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
return info.bits // 8
def numel(self):
return math.prod(self.shape)
def tensors_to_geometries(tensors, dtype=None):
geometries = []
for t in tensors:
if t is None or isinstance(t, QuantizedTensor):
geometries.append(t)
continue
tdtype = t.dtype
if hasattr(t, "_model_dtype"):
tdtype = t._model_dtype
if dtype is not None:
tdtype = dtype
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
return geometries
def vram_aligned_size(tensor):
if isinstance(tensor, list):
return sum([vram_aligned_size(t) for t in tensor])

View File

@ -1190,12 +1190,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
assert r is None
assert stream is None
r = torch.empty_like(weight, dtype=dtype, device=device)
r = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like([weight], raw_tensor)[0]
v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0]
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
#always take a deep copy even if _v is good, as we have no reasonable point to unpin

View File

@ -1504,6 +1504,8 @@ class ModelPatcherDynamic(ModelPatcher):
weight_function = []
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return 0
if key in self.patches:
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
@ -1513,7 +1515,12 @@ class ModelPatcherDynamic(ModelPatcher):
if key in self.weight_wrapper_patches:
weight_function.extend(self.weight_wrapper_patches[key])
setattr(m, param_key + "_function", weight_function)
return comfy.memory_management.vram_aligned_size(weight)
geometry = weight
if not isinstance(weight, QuantizedTensor):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@ -1535,9 +1542,13 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key)
weight.seed_key = key
set_dirty(weight, dirty)
weight_size = weight.numel() * weight.element_size()
geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")

View File

@ -81,6 +81,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
@ -88,6 +89,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if not resident:
cast_dest = None
xfer_source = [ s.weight, s.bias ]
@ -95,6 +97,16 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if pin is not None:
xfer_source = [ pin ]
resident = True #If pinned data exists, it always has LowVram already applied
else:
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
if data is None:
continue
if data.dtype != geometry.dtype:
cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
@ -111,6 +123,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
if cast_dest is not None:
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
if post_cast is not None:
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
pin = None
if signature is not None:
#If we are able to increase our load level (e.g. user reduces resolution or batch number)
@ -122,7 +141,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
comfy.pinned_memory.pin_memory(s)
pin = comfy.pinned_memory.get_pin(s)
params = comfy.memory_management.interpret_gathered_like([s.weight, s.bias], xfer_dest)
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0]
bias = params[1]

View File

@ -11,7 +11,7 @@ def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return
#FIXME: This is a RAM cache trigger event
params = [ module.weight, module.bias ]
params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ])
size = comfy.memory_management.vram_aligned_size(params)
pin = torch.empty((size,), dtype=torch.uint8)
if comfy.model_management.pin_memory(pin):