Compare commits

...

6 Commits

Author SHA1 Message Date
rattus
ff1323928e
Merge 2d96b2fdf1 into 451af70154 2026-01-21 14:05:51 +00:00
Rattus
2d96b2fdf1 MPDynamic: Add support for model defined dtype
If the model defines a dtype that is different to what is in the state
dict, respect that at load time. This is done as part of the casting
process.
2026-01-22 00:05:25 +10:00
Rattus
65b9729912 ops: fix __init__ return 2026-01-22 00:02:11 +10:00
Rattus
4979c075c9 archive the model defined dtypes
Scan created models and save off the dtypes as defined by the model
creation process. This is needed for assign=True, which will override
the dtypes.
2026-01-22 00:00:33 +10:00
Rattus
6e641d88ed mp: big bump on the VBAR sizes
Now that the model defined dtype is decoupled from the state_dict
dtypes we need to be able to handle worst case scenario casts between
the SD and VBAR.
2026-01-21 23:57:52 +10:00
Alexander Piskun
451af70154
fix(api-nodes-Vidu): allow passing up to 7 subjects in Vidu Reference node (#12002)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-21 04:03:45 -08:00
8 changed files with 86 additions and 10 deletions

View File

@ -1,5 +1,34 @@
import math
import torch
from typing import NamedTuple
from comfy.quant_ops import QuantizedTensor
class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype
def element_size(self):
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
return info.bits // 8
def numel(self):
return math.prod(self.shape)
def tensors_to_geometries(tensors, dtype=None):
geometries = []
for t in tensors:
if t is None or isinstance(t, QuantizedTensor):
geometries.append(t)
continue
tdtype = t.dtype
if hasattr(t, "_model_dtype"):
tdtype = t._model_dtype
if dtype is not None:
tdtype = dtype
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
return geometries
def vram_aligned_size(tensor):
if isinstance(tensor, list):
return sum([vram_aligned_size(t) for t in tensor])

View File

@ -148,6 +148,8 @@ class BaseModel(torch.nn.Module):
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
comfy.model_management.archive_model_dtypes(self.diffusion_model)
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0

View File

@ -774,6 +774,11 @@ def cleanup_models_gc():
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
def archive_model_dtypes(model):
for name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
def cleanup_models():
to_delete = []
@ -1185,12 +1190,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
assert r is None
assert stream is None
r = torch.empty_like(weight, dtype=dtype, device=device)
r = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like([weight], raw_tensor)[0]
v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0]
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
#always take a deep copy even if _v is good, as we have no reasonable point to unpin

View File

@ -1423,7 +1423,10 @@ class ModelPatcherDynamic(ModelPatcher):
return None
vbar = self.model.dynamic_vbars.get(self.load_device, None)
if create and vbar is None:
vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 1.2, self.load_device.index)
# x10. We dont know what model defined type casts we have in the vbar, but virtual address
# space is pretty free. This will cover someone casting an entire model from FP4 to FP32
# with some left over.
vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index)
self.model.dynamic_vbars[self.load_device] = vbar
return vbar
@ -1501,6 +1504,8 @@ class ModelPatcherDynamic(ModelPatcher):
weight_function = []
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return 0
if key in self.patches:
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
@ -1510,7 +1515,12 @@ class ModelPatcherDynamic(ModelPatcher):
if key in self.weight_wrapper_patches:
weight_function.extend(self.weight_wrapper_patches[key])
setattr(m, param_key + "_function", weight_function)
return comfy.memory_management.vram_aligned_size(weight)
geometry = weight
if not isinstance(weight, QuantizedTensor):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@ -1532,9 +1542,13 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key)
weight.seed_key = key
set_dirty(weight, dirty)
weight_size = weight.numel() * weight.element_size()
geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")

View File

@ -81,6 +81,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
@ -88,6 +89,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if not resident:
cast_dest = None
xfer_source = [ s.weight, s.bias ]
@ -95,6 +97,16 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if pin is not None:
xfer_source = [ pin ]
resident = True #If pinned data exists, it always has LowVram already applied
else:
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
if data is None:
continue
if data.dtype != geometry.dtype:
cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
@ -111,6 +123,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
if cast_dest is not None:
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
if post_cast is not None:
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
pin = None
if signature is not None:
#If we are able to increase our load level (e.g. user reduces resolution or batch number)
@ -122,7 +141,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
comfy.pinned_memory.pin_memory(s)
pin = comfy.pinned_memory.get_pin(s)
params = comfy.memory_management.interpret_gathered_like([s.weight, s.bias], xfer_dest)
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0]
bias = params[1]
@ -283,7 +302,8 @@ class disable_weight_init:
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
return super().__init__(in_features, out_features, bias, device, dtype)
super().__init__(in_features, out_features, bias, device, dtype)
return
# Issue is with `torch.empty` still reserving the full memory for the layer.
# Windows doesn't over-commit memory so without this, We are momentarily commit
@ -296,6 +316,8 @@ class disable_weight_init:
self.weight = None
self.bias = None
self.comfy_need_lazy_init_bias=bias
self.weight_comfy_model_dtype = dtype
self.bias_comfy_model_dtype = dtype
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):

View File

@ -11,7 +11,7 @@ def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return
#FIXME: This is a RAM cache trigger event
params = [ module.weight, module.bias ]
params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ])
size = comfy.memory_management.vram_aligned_size(params)
pin = torch.empty((size,), dtype=torch.uint8)
if comfy.model_management.pin_memory(pin):

View File

@ -127,6 +127,8 @@ class CLIP:
self.cond_stage_model.to(offload_device)
logging.warning("Had to shift TE back.")
model_management.archive_model_dtypes(self.cond_stage_model)
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
@ -675,6 +677,8 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
model_management.archive_model_dtypes(self.first_stage_model)
if device is None:
device = model_management.vae_device()
self.device = device

View File

@ -703,7 +703,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
"subjects",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("reference_images"),
names=["subject1", "subject2", "subject3"],
names=["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"],
min=1,
),
tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). "
@ -738,7 +738,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
control_after_generate=True,
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]),
IO.Combo.Input("resolution", options=["720p"]),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Combo.Input(
"movement_amplitude",
options=["auto", "small", "medium", "large"],