From e41e0060b96885f44bc87c62e83b185bf7165e22 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Fri, 19 Dec 2025 15:46:13 +0100 Subject: [PATCH] Radiance Refactoring --- comfy/ldm/chroma_radiance/model.py | 189 +++++++++++++++++++++++++++-- 1 file changed, 182 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 70d173889..4339f457e 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -7,14 +7,15 @@ from typing import Optional import torch from torch import Tensor, nn -from einops import repeat +from einops import rearrange, repeat import comfy.ldm.common_dit +import comfy.patcher_extension -from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock +from comfy.ldm.flux.layers import EmbedND, timestep_embedding, DoubleStreamBlock, SingleStreamBlock -from comfy.ldm.chroma.model import Chroma, ChromaParams from comfy.ldm.chroma.layers import ( Approximator, + ChromaModulationOut, ) from .layers import ( NerfEmbedder, @@ -25,7 +26,26 @@ from .layers import ( @dataclass -class ChromaRadianceParams(ChromaParams): +class ChromaRadianceParams: + # Fields from ChromaParams (now independent) + in_channels: int + out_channels: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list + theta: int + qkv_bias: bool + in_dim: int + out_dim: int + hidden_dim: int + n_layers: int + txt_ids_dims: list + vec_in_dim: int + # ChromaRadiance-specific fields patch_size: int nerf_hidden_size: int nerf_mlp_ratio: int @@ -39,7 +59,7 @@ class ChromaRadianceParams(ChromaParams): nerf_embedder_dtype: Optional[torch.dtype] use_x0: bool -class ChromaRadiance(Chroma): +class ChromaRadiance(nn.Module): """ Transformer model for flow matching on sequences. """ @@ -47,7 +67,7 @@ class ChromaRadiance(Chroma): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): if operations is None: raise RuntimeError("Attempt to create ChromaRadiance object without setting operations") - nn.Module.__init__(self) + super().__init__() self.dtype = dtype params = ChromaRadianceParams(**kwargs) self.params = params @@ -176,6 +196,155 @@ class ChromaRadiance(Chroma): # flatten into a sequence for the transformer. return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] + def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0): + # This function slices up the modulations tensor which has the following layout: + # single : num_single_blocks * 3 elements + # double_img : num_double_blocks * 6 elements + # double_txt : num_double_blocks * 6 elements + # final : 2 elements + if block_type == "final": + return (tensor[:, -2:-1, :], tensor[:, -1:, :]) + single_block_count = self.params.depth_single_blocks + double_block_count = self.params.depth + offset = 3 * idx + if block_type == "single": + return ChromaModulationOut.from_offset(tensor, offset) + # Double block modulations are 6 elements so we double 3 * idx. + offset *= 2 + if block_type in {"double_img", "double_txt"}: + # Advance past the single block modulations. + offset += 3 * single_block_count + if block_type == "double_txt": + # Advance past the double block img modulations. + offset += 6 * double_block_count + return ( + ChromaModulationOut.from_offset(tensor, offset), + ChromaModulationOut.from_offset(tensor, offset + 3), + ) + raise ValueError("Bad block_type") + + def forward_orig( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + guidance: Tensor = None, + control = None, + transformer_options={}, + attn_mask: Tensor = None, + ) -> Tensor: + patches_replace = transformer_options.get("patches_replace", {}) + + # running on sequences img + img = self.img_in(img) + + # distilled vector guidance + mod_index_length = 344 + distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype) + # guidance = guidance * + distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype) + + # get all modulation index + modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype) + # and we need to broadcast timestep and guidance along too + timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype) + + mod_vectors = self.distilled_guidance_layer(input_vec) + + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i + if i not in self.skip_mmdit: + double_mod = ( + self.get_modulations(mod_vectors, "double_img", idx=i), + self.get_modulations(mod_vectors, "double_txt", idx=i), + ) + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": double_mod, + "pe": pe, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=double_mod, + pe=pe, + attn_mask=attn_mask, + transformer_options=transformer_options) + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + img += add + + img = torch.cat((txt, img), 1) + + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" + for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i + if i not in self.skip_dit: + single_mod = self.get_modulations(mod_vectors, "single", idx=i) + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": single_mod, + "pe": pe, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) + + if control is not None: # Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img[:, txt.shape[1] :, ...] += add + + img = img[:, txt.shape[1] :, ...] + return img + def forward_nerf( self, img_orig: Tensor, @@ -285,6 +454,13 @@ class ChromaRadiance(Chroma): eps = 0.0 return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps) + def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, guidance, control, transformer_options, **kwargs) + def _forward( self, x: Tensor, @@ -332,4 +508,3 @@ class ChromaRadiance(Chroma): if hasattr(self, "__x0__"): out = self._apply_x0_residual(out, img, timestep) return out -