From 1bcda6df987a6c92b39d8b6d29e0b029450d67d0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 23 Oct 2025 18:21:14 -0700 Subject: [PATCH] WIP way to support multi multi dimensional latents. (#10456) --- comfy/model_base.py | 10 ++++- comfy/nested_tensor.py | 91 ++++++++++++++++++++++++++++++++++++++++++ comfy/sample.py | 27 ++++++++++--- comfy/samplers.py | 23 +++++++---- comfy/utils.py | 22 ++++++++++ 5 files changed, 158 insertions(+), 15 deletions(-) create mode 100644 comfy/nested_tensor.py diff --git a/comfy/model_base.py b/comfy/model_base.py index 8274c7dea..e877f19ac 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -197,8 +197,14 @@ class BaseModel(torch.nn.Module): extra_conds[o] = extra t = self.process_timestep(t, x=x, **extra_conds) - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() - return self.model_sampling.calculate_denoised(sigma, model_output, x) + if "latent_shapes" in extra_conds: + xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) + + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) + if len(model_output) > 1 and not torch.is_tensor(model_output): + model_output, _ = utils.pack_latents(model_output) + + return self.model_sampling.calculate_denoised(sigma, model_output.float(), x) def process_timestep(self, timestep, **kwargs): return timestep diff --git a/comfy/nested_tensor.py b/comfy/nested_tensor.py new file mode 100644 index 000000000..b700816fa --- /dev/null +++ b/comfy/nested_tensor.py @@ -0,0 +1,91 @@ +import torch + +class NestedTensor: + def __init__(self, tensors): + self.tensors = list(tensors) + self.is_nested = True + + def _copy(self): + return NestedTensor(self.tensors) + + def apply_operation(self, other, operation): + o = self._copy() + if isinstance(other, NestedTensor): + for i, t in enumerate(o.tensors): + o.tensors[i] = operation(t, other.tensors[i]) + else: + for i, t in enumerate(o.tensors): + o.tensors[i] = operation(t, other) + return o + + def __add__(self, b): + return self.apply_operation(b, lambda x, y: x + y) + + def __sub__(self, b): + return self.apply_operation(b, lambda x, y: x - y) + + def __mul__(self, b): + return self.apply_operation(b, lambda x, y: x * y) + + # def __itruediv__(self, b): + # return self.apply_operation(b, lambda x, y: x / y) + + def __truediv__(self, b): + return self.apply_operation(b, lambda x, y: x / y) + + def __getitem__(self, *args, **kwargs): + return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs)) + + def unbind(self): + return self.tensors + + def to(self, *args, **kwargs): + o = self._copy() + for i, t in enumerate(o.tensors): + o.tensors[i] = t.to(*args, **kwargs) + return o + + def new_ones(self, *args, **kwargs): + return self.tensors[0].new_ones(*args, **kwargs) + + def float(self): + return self.to(dtype=torch.float) + + def chunk(self, *args, **kwargs): + return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs)) + + def size(self): + return self.tensors[0].size() + + @property + def shape(self): + return self.tensors[0].shape + + @property + def ndim(self): + dims = 0 + for t in self.tensors: + dims = max(t.ndim, dims) + return dims + + @property + def device(self): + return self.tensors[0].device + + @property + def dtype(self): + return self.tensors[0].dtype + + @property + def layout(self): + return self.tensors[0].layout + + +def cat_nested(tensors, *args, **kwargs): + cated_tensors = [] + for i in range(len(tensors[0].tensors)): + tens = [] + for j in range(len(tensors)): + tens.append(tensors[j].tensors[i]) + cated_tensors.append(torch.cat(tens, *args, **kwargs)) + return NestedTensor(cated_tensors) diff --git a/comfy/sample.py b/comfy/sample.py index be5a7e246..b1395da84 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -4,13 +4,9 @@ import comfy.samplers import comfy.utils import numpy as np import logging +import comfy.nested_tensor -def prepare_noise(latent_image, seed, noise_inds=None): - """ - creates random noise given a latent image and a seed. - optional arg skip can be used to skip and discard x number of noise generations for a given seed - """ - generator = torch.manual_seed(seed) +def prepare_noise_inner(latent_image, generator, noise_inds=None): if noise_inds is None: return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") @@ -22,9 +18,28 @@ def prepare_noise(latent_image, seed, noise_inds=None): noises.append(noise) noises = [noises[i] for i in inverse] noises = torch.cat(noises, axis=0) + +def prepare_noise(latent_image, seed, noise_inds=None): + """ + creates random noise given a latent image and a seed. + optional arg skip can be used to skip and discard x number of noise generations for a given seed + """ + generator = torch.manual_seed(seed) + + if latent_image.is_nested: + tensors = latent_image.unbind() + noises = [] + for t in tensors: + noises.append(prepare_noise_inner(t, generator, noise_inds)) + noises = comfy.nested_tensor.NestedTensor(noises) + else: + noises = prepare_noise_inner(latent_image, generator, noise_inds) + return noises def fix_empty_latent_channels(model, latent_image): + if latent_image.is_nested: + return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) diff --git a/comfy/samplers.py b/comfy/samplers.py index e7efaf470..fa4640842 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -782,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): return KSAMPLER(sampler_function, extra_options, inpaint_options) -def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): +def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None): for k in conds: conds[k] = conds[k][:] resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device) @@ -792,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N if hasattr(model, 'extra_conds'): for k in conds: - conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes) #make sure each cond area has an opposite one with the same area for k in conds: @@ -962,11 +962,11 @@ class CFGGuider: def predict_noise(self, x, timestep, model_options={}, seed=None): return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) - def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): + def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None): if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. latent_image = self.inner_model.process_latent_in(latent_image) - self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) + self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes) extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas @@ -980,7 +980,7 @@ class CFGGuider: samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return self.inner_model.process_latent_out(samples.to(torch.float32)) - def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None): self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device @@ -994,7 +994,7 @@ class CFGGuider: try: self.model_patcher.pre_run() - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: self.model_patcher.cleanup() @@ -1007,6 +1007,12 @@ class CFGGuider: if sigmas.shape[-1] == 0: return latent_image + if latent_image.is_nested: + latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind()) + noise, _ = comfy.utils.pack_latents(noise.unbind()) + else: + latent_shapes = [latent_image.shape] + self.conds = {} for k in self.original_conds: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) @@ -1026,7 +1032,7 @@ class CFGGuider: self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) - output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options @@ -1034,6 +1040,9 @@ class CFGGuider: self.model_patcher.restore_hook_patches() del self.conds + + if len(latent_shapes) > 1: + output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes)) return output diff --git a/comfy/utils.py b/comfy/utils.py index 0fd03f165..4bd281057 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1106,3 +1106,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out): dim=1 ) return out + +def pack_latents(latents): + latent_shapes = [] + tensors = [] + for tensor in latents: + latent_shapes.append(tensor.shape) + tensors.append(tensor.reshape(tensor.shape[0], 1, -1)) + + latent = torch.cat(tensors, dim=-1) + return latent, latent_shapes + +def unpack_latents(combined_latent, latent_shapes): + if len(latent_shapes) > 1: + output_tensors = [] + for shape in latent_shapes: + cut = math.prod(shape[1:]) + tens = combined_latent[:, :, :cut] + combined_latent = combined_latent[:, :, cut:] + output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) + else: + output_tensors = combined_latent + return output_tensors