From 80b7c9455bf7afba7a9e95a1eb76b172408ab56c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 13 Sep 2025 15:03:34 -0700 Subject: [PATCH] Changes to the previous radiance commit. (#9851) --- comfy/ldm/chroma_radiance/model.py | 7 +-- comfy/pixel_space_convert.py | 16 +++++++ comfy/sd.py | 69 +++++------------------------- comfy/supported_models.py | 2 +- nodes.py | 7 +-- 5 files changed, 35 insertions(+), 66 deletions(-) create mode 100644 comfy/pixel_space_convert.py diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index f7eb7a22e..47aa11b04 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -306,8 +306,9 @@ class ChromaRadiance(Chroma): params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {})) - h_len = ((h + (self.patch_size // 2)) // self.patch_size) - w_len = ((w + (self.patch_size // 2)) // self.patch_size) + h_len = (img.shape[-2] // self.patch_size) + w_len = (img.shape[-1] // self.patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) @@ -325,4 +326,4 @@ class ChromaRadiance(Chroma): transformer_options, attn_mask=kwargs.get("attention_mask", None), ) - return self.forward_nerf(img, img_out, params) + return self.forward_nerf(img, img_out, params)[:, :, :h, :w] diff --git a/comfy/pixel_space_convert.py b/comfy/pixel_space_convert.py new file mode 100644 index 000000000..049bbcfb4 --- /dev/null +++ b/comfy/pixel_space_convert.py @@ -0,0 +1,16 @@ +import torch + + +# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 +# to LATENT B, C, H, W and values on the scale of -1..1. +class PixelspaceConversionVAE(torch.nn.Module): + def __init__(self): + super().__init__() + self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0)) + + def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + return pixels + + def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + return samples + diff --git a/comfy/sd.py b/comfy/sd.py index cb92802e9..2df340739 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.hunyuan_video.vae +import comfy.pixel_space_convert import yaml import math import os @@ -516,6 +517,15 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.disable_offload = True self.extra_1d_channel = 16 + elif "pixel_space_vae" in sd: + self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE() + self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.downscale_ratio = 1 + self.upscale_ratio = 1 + self.latent_channels = 3 + self.latent_dim = 2 + self.output_channels = 3 else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -785,65 +795,6 @@ class VAE: except: return None -# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 -# to LATENT B, C, H, W and values on the scale of -1..1. -class PixelspaceConversionVAE: - def __init__(self, size_increment: int=16): - self.intermediate_device = comfy.model_management.intermediate_device() - self.size_increment = size_increment - - def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor: - if self.size_increment == 1: - return pixels - dims = pixels.shape[1:-1] - for d in range(len(dims)): - d_adj = (dims[d] // self.size_increment) * self.size_increment - if d_adj == d: - continue - d_offset = (dims[d] % self.size_increment) // 2 - pixels = pixels.narrow(d + 1, d_offset, d_adj) - return pixels - - def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: - if pixels.ndim == 3: - pixels = pixels.unsqueeze(0) - elif pixels.ndim != 4: - raise ValueError("Unexpected input image shape") - # Ensure the image has spatial dimensions that are multiples of 16. - pixels = self.vae_encode_crop_pixels(pixels) - h, w, c = pixels.shape[1:] - if h < self.size_increment or w < self.size_increment: - raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).") - pixels= pixels[..., :3] - if c == 1: - pixels = pixels.expand(-1, -1, -1, 3) - elif c != 3: - raise ValueError("Unexpected number of channels in input image") - # Rescale to -1..1 and move the channel dimension to position 1. - latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True) - latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous() - latent -= 0.5 - latent *= 2 - return latent.clamp_(-1, 1) - - def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: - # Rescale to 0..1 and move the channel dimension to the end. - img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True) - img = img.clamp_(-1, 1).movedim(1, -1).contiguous() - img += 1.0 - img *= 0.5 - return img.clamp_(0, 1) - - encode_tiled = encode - decode_tiled = decode - - @classmethod - def spacial_compression_decode(cls) -> int: - # This just exists so the tiled VAE nodes don't crash. - return 1 - - spacial_compression_encode = spacial_compression_decode - temporal_compression_decode = spacial_compression_decode class StyleModel: def __init__(self, model, device="cpu"): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index be36b5dfe..557902d11 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1213,7 +1213,7 @@ class ChromaRadiance(Chroma): latent_format = comfy.latent_formats.ChromaRadiance # Pixel-space model, no spatial compression for model input. - memory_usage_factor = 0.0325 + memory_usage_factor = 0.038 def get_model(self, state_dict, prefix="", device=None): return model_base.ChromaRadiance(self, device=device) diff --git a/nodes.py b/nodes.py index 76b8cbac8..5a5fdcb8e 100644 --- a/nodes.py +++ b/nodes.py @@ -730,7 +730,7 @@ class VAELoader: vaes.append("taesd3") if f1_taesd_dec and f1_taesd_enc: vaes.append("taef1") - vaes.append("chroma_radiance") + vaes.append("pixel_space") return vaes @staticmethod @@ -773,8 +773,9 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - if vae_name == "chroma_radiance": - return (comfy.sd.PixelspaceConversionVAE(),) + if vae_name == "pixel_space": + sd = {} + sd["pixel_space_vae"] = torch.tensor(1.0) elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: sd = self.load_taesd(vae_name) else: