From e7073b5eec9e840b7b27468b98a929b7af031bc4 Mon Sep 17 00:00:00 2001 From: blepping Date: Tue, 2 Sep 2025 04:59:23 -0600 Subject: [PATCH] Add ChromaRadianceOptions node and backend support. Cleanups/refactoring to reduce code duplication with Chroma. --- comfy/ldm/chroma/model.py | 10 +- comfy/ldm/chroma/model_dct.py | 227 ++++++++++---------------- comfy_extras/nodes_chroma_radiance.py | 81 +++++++++ 3 files changed, 171 insertions(+), 147 deletions(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 4f709f87d..ad1c523fe 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -151,8 +151,6 @@ class Chroma(nn.Module): attn_mask: Tensor = None, ) -> Tensor: patches_replace = transformer_options.get("patches_replace", {}) - if img.ndim != 3 or txt.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) @@ -254,8 +252,9 @@ class Chroma(nn.Module): img[:, txt.shape[1] :, ...] += add img = img[:, txt.shape[1] :, ...] - final_mod = self.get_modulations(mod_vectors, "final") - img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) + if hasattr(self, "final_layer"): + final_mod = self.get_modulations(mod_vectors, "final") + img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): @@ -271,6 +270,9 @@ class Chroma(nn.Module): img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) + if img.ndim != 3 or context.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + h_len = ((h + (self.patch_size // 2)) // self.patch_size) w_len = ((w + (self.patch_size // 2)) // self.patch_size) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) diff --git a/comfy/ldm/chroma/model_dct.py b/comfy/ldm/chroma/model_dct.py index dfcd6d1de..a8c6a461f 100644 --- a/comfy/ldm/chroma/model_dct.py +++ b/comfy/ldm/chroma/model_dct.py @@ -36,8 +36,11 @@ class ChromaRadianceParams(chroma_model.ChromaParams): nerf_mlp_ratio: int nerf_depth: int nerf_max_freqs: int + # nerf_tile_size of 0 means unlimited. nerf_tile_size: int + # Currently one of linear (legacy) or conv. nerf_final_head_type: str + # None means use the same dtype as the model. nerf_embedder_dtype: Optional[torch.dtype] @@ -168,6 +171,53 @@ class ChromaRadiance(chroma_model.Chroma): # Impossible to get here as we raise an error on unexpected types on initialization. raise NotImplementedError + def img_in(self, img: Tensor) -> Tensor: + img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P] + # flatten into a sequence for the transformer. + return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] + + def forward_nerf( + self, + img_orig: Tensor, + img_out: Tensor, + params: ChromaRadianceParams, + ) -> Tensor: + B, C, H, W = img_orig.shape + num_patches = img_out.shape[1] + patch_size = params.patch_size + + # Store the raw pixel values of each patch for the NeRF head later. + # unfold creates patches: [B, C * P * P, NumPatches] + nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size) + nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] + + if params.nerf_tile_size > 0: + img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params) + else: + # Reshape for per-patch processing + nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size) + nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) + + # Get DCT-encoded pixel embeddings [pixel-dct] + img_dct = self.nerf_image_embedder(nerf_pixels) + + # Pass through the dynamic MLP blocks (the NeRF) + for block in self.nerf_blocks: + img_dct = block(img_dct, nerf_hidden) + + # Reassemble the patches into the final image. + img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P] + # Reshape to combine with batch dimension for fold + img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P] + img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches] + img_dct = nn.functional.fold( + img_dct, + output_size=(H, W), + kernel_size=patch_size, + stride=patch_size, + ) + return self._nerf_final_layer(img_dct) + def forward_tiled_nerf( self, nerf_hidden: Tensor, @@ -175,17 +225,18 @@ class ChromaRadiance(chroma_model.Chroma): B: int, C: int, num_patches: int, - tile_size: int = 16 + patch_size: int, + params: ChromaRadianceParams, ) -> Tensor: """ Processes the NeRF head in tiles to save memory. nerf_hidden has shape [B, L, D] nerf_pixels has shape [B, L, C * P * P] """ + tile_size = params.nerf_tile_size output_tiles = [] # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1. for i in range(0, num_patches, tile_size): - # end = min(i + tile_size, num_patches) # Slice the current tile from the input tensors @@ -197,9 +248,9 @@ class ChromaRadiance(chroma_model.Chroma): # Reshape the tile for per-patch processing # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D] - nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, self.params.hidden_size) + nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, params.hidden_size) # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C] - nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, self.params.patch_size**2).transpose(1, 2) + nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2) # get DCT-encoded pixel embeddings [pixel-dct] img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile) @@ -213,150 +264,39 @@ class ChromaRadiance(chroma_model.Chroma): # Concatenate the processed tiles along the patch dimension return torch.cat(output_tiles, dim=0) - def forward_orig( - self, - img: Tensor, - img_ids: Tensor, - txt: Tensor, - txt_ids: Tensor, - timesteps: Tensor, - guidance: Tensor = None, - control = None, - transformer_options={}, - attn_mask: Tensor = None, - ) -> Tensor: - patches_replace = transformer_options.get("patches_replace", {}) - if img.ndim != 4: - raise ValueError("Input img tensor must be in [B, C, H, W] format.") - if txt.ndim != 3: - raise ValueError("Input txt tensors must have 3 dimensions.") - B, C, H, W = img.shape - - # gemini gogogo idk how to unfold and pack the patch properly :P - # Store the raw pixel values of each patch for the NeRF head later. - # unfold creates patches: [B, C * P * P, NumPatches] - nerf_pixels = nn.functional.unfold(img, kernel_size=self.params.patch_size, stride=self.params.patch_size) - nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] - - # partchify ops - img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P] - num_patches = img.shape[2] * img.shape[3] - # flatten into a sequence for the transformer. - img = img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] - - # distilled vector guidance - mod_index_length = 344 - distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype) - # guidance = guidance * - distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype) - - # get all modulation index - modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype) - # we need to broadcast the modulation index here so each batch has all of the index - modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype) - # and we need to broadcast timestep and guidance along too - timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype) - # then and only then we could concatenate it together - input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype) - - mod_vectors = self.distilled_guidance_layer(input_vec) - - txt = self.txt_in(txt) - - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) - - blocks_replace = patches_replace.get("dit", {}) - for i, block in enumerate(self.double_blocks): - if i not in self.skip_mmdit: - double_mod = ( - self.get_modulations(mod_vectors, "double_img", idx=i), - self.get_modulations(mod_vectors, "double_txt", idx=i), - ) - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"], out["txt"] = block(img=args["img"], - txt=args["txt"], - vec=args["vec"], - pe=args["pe"], - attn_mask=args.get("attn_mask")) - return out - - out = blocks_replace[("double_block", i)]({"img": img, - "txt": txt, - "vec": double_mod, - "pe": pe, - "attn_mask": attn_mask}, - {"original_block": block_wrap}) - txt = out["txt"] - img = out["img"] - else: - img, txt = block(img=img, - txt=txt, - vec=double_mod, - pe=pe, - attn_mask=attn_mask) - - if control is not None: # Controlnet - control_i = control.get("input") - if i < len(control_i): - add = control_i[i] - if add is not None: - img += add - - img = torch.cat((txt, img), 1) - - for i, block in enumerate(self.single_blocks): - if i not in self.skip_dit: - single_mod = self.get_modulations(mod_vectors, "single", idx=i) - if ("single_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], - vec=args["vec"], - pe=args["pe"], - attn_mask=args.get("attn_mask")) - return out - - out = blocks_replace[("single_block", i)]({"img": img, - "vec": single_mod, - "pe": pe, - "attn_mask": attn_mask}, - {"original_block": block_wrap}) - img = out["img"] - else: - img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask) - - if control is not None: # Controlnet - control_o = control.get("output") - if i < len(control_o): - add = control_o[i] - if add is not None: - img[:, txt.shape[1] :, ...] += add - - img = img[:, txt.shape[1] :, ...] - - img_dct = self.forward_tiled_nerf(img, nerf_pixels, B, C, num_patches, tile_size=self.params.nerf_tile_size) - - # gemini gogogo idk how to fold this properly :P - # Reassemble the patches into the final image. - img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P] - # Reshape to combine with batch dimension for fold - img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P] - img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches] - img_dct = nn.functional.fold( - img_dct, - output_size=(H, W), - kernel_size=self.params.patch_size, - stride=self.params.patch_size + def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams: + params = self.params + if not overrides: + return params + params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__} + nullable_keys = frozenset(("nerf_embedder_dtype",)) + bad_keys = tuple(k for k in overrides if k not in params_dict) + if bad_keys: + e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" + raise ValueError(e) + bad_keys = tuple( + k + for k, v in overrides.items() + if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) ) - return self._nerf_final_layer(img_dct) + if bad_keys: + e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" + raise ValueError(e) + # At this point it's all valid keys and values so we can merge with the existing params. + params_dict |= overrides + return params.__class__(**params_dict) def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + if img.ndim != 4: + raise ValueError("Input img tensor must be in [B, C, H, W] format.") + if context.ndim != 3: + raise ValueError("Input txt tensors must have 3 dimensions.") + + params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {})) + h_len = ((h + (self.patch_size // 2)) // self.patch_size) w_len = ((w + (self.patch_size // 2)) // self.patch_size) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) @@ -365,4 +305,5 @@ class ChromaRadiance(chroma_model.Chroma): img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - return self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) + img_out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) + return self.forward_nerf(img, img_out, params) diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py index 0200ab8e6..f5976cbcd 100644 --- a/comfy_extras/nodes_chroma_radiance.py +++ b/comfy_extras/nodes_chroma_radiance.py @@ -1,4 +1,5 @@ from typing_extensions import override +from typing import Callable import torch @@ -122,6 +123,85 @@ class ChromaRadianceStubVAENode(io.ComfyNode): def execute(cls) -> io.NodeOutput: return io.NodeOutput(ChromaRadianceStubVAE()) +class ChromaRadianceOptions(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ChromaRadianceOptions", + category="model_patches/chroma_radiance", + description="Allows setting some advanced options for the Chroma Radiance model.", + inputs=[ + io.Model.Input(id="model"), + io.Boolean.Input( + id="preserve_wrapper", + default=True, + tooltip="When enabled preserves an existing model wrapper if it exists. Generally should be left enabled.", + ), + io.Float.Input( + id="start_sigma", + default=1.0, + min=0.0, + max=1.0, + ), + io.Float.Input( + id="end_sigma", + default=0.0, + min=0.0, + max=1.0, + ), + io.Int.Input( + id="nerf_tile_size", + default=-1, + min=-1, + tooltip="Allows overriding the default NeRF tile size. -1 means use the default. 0 means use non-tiling mode (may require a lot of VRAM).", + ), + io.Combo.Input( + id="nerf_embedder_dtype", + default="default", + options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"], + tooltip="Allows overriding the dtype the NeRF embedder uses.", + ), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def execute( + cls, + *, + model: io.Model.Type, + preserve_wrapper: bool, + start_sigma: float, + end_sigma: float, + nerf_tile_size: int, + nerf_embedder_dtype: str, + ) -> io.NodeOutput: + radiance_options = {} + if nerf_tile_size >= 0: + radiance_options["nerf_tile_size"] = nerf_tile_size + if nerf_embedder_dtype != "default": + radiance_options["nerf_embedder_dtype"] = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, "float64": torch.float64}.get(nerf_embedder_dtype) + + if not radiance_options: + return io.NodeOutput(model) + + old_wrapper = model.model_options.get("model_function_wrapper") + + def model_function_wrapper(apply_model: Callable, args: dict) -> torch.Tensor: + c = args["c"].copy() + sigma = args["timestep"].max().detach().cpu().item() + if end_sigma <= sigma <= start_sigma: + transformer_options = c.get("transformer_options", {}).copy() + transformer_options["chroma_radiance_options"] = radiance_options.copy() + c["transformer_options"] = transformer_options + if not (preserve_wrapper and old_wrapper): + return apply_model(args["input"], args["timestep"], **c) + return old_wrapper(apply_model, args | {"c": c}) + + model = model.clone() + model.set_model_unet_function_wrapper(model_function_wrapper) + return io.NodeOutput(model) + class ChromaRadianceExtension(ComfyExtension): @override @@ -131,6 +211,7 @@ class ChromaRadianceExtension(ComfyExtension): ChromaRadianceLatentToImage, ChromaRadianceImageToLatent, ChromaRadianceStubVAENode, + ChromaRadianceOptions, ]