Move Chroma Radiance to its own directory in ldm

Minor code cleanups and tooltip improvements
This commit is contained in:
blepping 2025-09-04 23:27:44 -06:00
parent a46078afe7
commit f1f5b7d9b5
5 changed files with 76 additions and 35 deletions

View File

@ -16,7 +16,15 @@ class NerfEmbedder(nn.Module):
patch size, and enriches it with positional information before projecting patch size, and enriches it with positional information before projecting
it to a new hidden size. it to a new hidden size.
""" """
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None): def __init__(
self,
in_channels: int,
hidden_size_input: int,
max_freqs: int,
dtype=None,
device=None,
operations=None,
):
""" """
Initializes the NerfEmbedder. Initializes the NerfEmbedder.
@ -38,7 +46,7 @@ class NerfEmbedder(nn.Module):
) )
@lru_cache(maxsize=4) @lru_cache(maxsize=4)
def fetch_pos(self, patch_size: int, device, dtype) -> torch.Tensor: def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
""" """
Generates and caches 2D DCT-like positional embeddings for a given patch size. Generates and caches 2D DCT-like positional embeddings for a given patch size.
@ -179,14 +187,14 @@ class NerfFinalLayer(nn.Module):
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
# So we temporarily move the channel dimension to the end for the norm operation. # So we temporarily move the channel dimension to the end for the norm operation.
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1) return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
class NerfFinalLayerConv(nn.Module): class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.conv = operations.Conv2d( self.conv = operations.Conv2d(
@ -198,7 +206,7 @@ class NerfFinalLayerConv(nn.Module):
device=device, device=device,
) )
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
# So we temporarily move the channel dimension to the end for the norm operation. # So we temporarily move the channel dimension to the end for the norm operation.
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1)) return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))

View File

@ -12,29 +12,28 @@ import comfy.ldm.common_dit
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
from .layers import ( from comfy.ldm.chroma.model import Chroma, ChromaParams
from comfy.ldm.chroma.layers import (
DoubleStreamBlock, DoubleStreamBlock,
SingleStreamBlock, SingleStreamBlock,
Approximator, Approximator,
) )
from .layers_dct import ( from .layers import (
NerfEmbedder, NerfEmbedder,
NerfGLUBlock, NerfGLUBlock,
NerfFinalLayer, NerfFinalLayer,
NerfFinalLayerConv, NerfFinalLayerConv,
) )
from . import model as chroma_model
@dataclass @dataclass
class ChromaRadianceParams(chroma_model.ChromaParams): class ChromaRadianceParams(ChromaParams):
patch_size: int patch_size: int
nerf_hidden_size: int nerf_hidden_size: int
nerf_mlp_ratio: int nerf_mlp_ratio: int
nerf_depth: int nerf_depth: int
nerf_max_freqs: int nerf_max_freqs: int
# nerf_tile_size of 0 means unlimited. # Setting nerf_tile_size to 0 disables tiling.
nerf_tile_size: int nerf_tile_size: int
# Currently one of linear (legacy) or conv. # Currently one of linear (legacy) or conv.
nerf_final_head_type: str nerf_final_head_type: str
@ -42,12 +41,14 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
nerf_embedder_dtype: Optional[torch.dtype] nerf_embedder_dtype: Optional[torch.dtype]
class ChromaRadiance(chroma_model.Chroma): class ChromaRadiance(Chroma):
""" """
Transformer model for flow matching on sequences. Transformer model for flow matching on sequences.
""" """
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
if operations is None:
raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
nn.Module.__init__(self) nn.Module.__init__(self)
self.dtype = dtype self.dtype = dtype
params = ChromaRadianceParams(**kwargs) params = ChromaRadianceParams(**kwargs)
@ -188,7 +189,9 @@ class ChromaRadiance(chroma_model.Chroma):
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size) nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
if params.nerf_tile_size > 0: if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
# the tile size.
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params) img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
else: else:
# Reshape for per-patch processing # Reshape for per-patch processing
@ -219,8 +222,8 @@ class ChromaRadiance(chroma_model.Chroma):
self, self,
nerf_hidden: Tensor, nerf_hidden: Tensor,
nerf_pixels: Tensor, nerf_pixels: Tensor,
B: int, batch: int,
C: int, channels: int,
num_patches: int, num_patches: int,
patch_size: int, patch_size: int,
params: ChromaRadianceParams, params: ChromaRadianceParams,
@ -246,9 +249,9 @@ class ChromaRadiance(chroma_model.Chroma):
# Reshape the tile for per-patch processing # Reshape the tile for per-patch processing
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D] # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, params.hidden_size) nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C] # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2) nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
# get DCT-encoded pixel embeddings [pixel-dct] # get DCT-encoded pixel embeddings [pixel-dct]
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype) img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype)
@ -284,7 +287,16 @@ class ChromaRadiance(chroma_model.Chroma):
params_dict |= overrides params_dict |= overrides
return params.__class__(**params_dict) return params.__class__(**params_dict)
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): def _forward(
self,
x: Tensor,
timestep: Tensor,
context: Tensor,
guidance: Optional[Tensor],
control: Optional[dict]=None,
transformer_options: dict={},
**kwargs: dict,
) -> Tensor:
bs, c, h, w = x.shape bs, c, h, w = x.shape
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
@ -303,5 +315,15 @@ class ChromaRadiance(chroma_model.Chroma):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
img_out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) img_out = self.forward_orig(
img,
img_ids,
context,
txt_ids,
timestep,
guidance,
control,
transformer_options,
attn_mask=kwargs.get("attention_mask", None),
)
return self.forward_nerf(img, img_out, params) return self.forward_nerf(img, img_out, params)

View File

@ -42,7 +42,7 @@ import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model import comfy.ldm.hidream.model
import comfy.ldm.chroma.model import comfy.ldm.chroma.model
import comfy.ldm.chroma.model_dct import comfy.ldm.chroma_radiance.model
import comfy.ldm.ace.model import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2 import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model import comfy.ldm.qwen_image.model
@ -1334,7 +1334,7 @@ class Chroma(Flux):
class ChromaRadiance(Chroma): class ChromaRadiance(Chroma):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None): def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model_dct.ChromaRadiance) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
class ACEStep(BaseModel): class ACEStep(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):

View File

@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = len(guidance_keys) > 0 dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight) if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {} dit_config = {}
dit_config["image_model"] = "flux" dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
@ -204,7 +204,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072 dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120 dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5 dit_config["n_layers"] = 5
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Radiance if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
dit_config["image_model"] = "chroma_radiance" dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3 dit_config["in_channels"] = 3
dit_config["out_channels"] = 3 dit_config["out_channels"] = 3

View File

@ -29,13 +29,8 @@ class EmptyChromaRadianceLatentImage(io.ComfyNode):
class ChromaRadianceStubVAE: class ChromaRadianceStubVAE:
@classmethod @staticmethod
def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: def vae_encode_crop_pixels(pixels: torch.Tensor) -> torch.Tensor:
device = comfy.model_management.intermediate_device()
if pixels.ndim == 3:
pixels = pixels.unsqueeze(0)
elif pixels.ndim != 4:
raise ValueError("Unexpected input image shape")
dims = pixels.shape[1:-1] dims = pixels.shape[1:-1]
for d in range(len(dims)): for d in range(len(dims)):
d_adj = (dims[d] // 16) * 16 d_adj = (dims[d] // 16) * 16
@ -43,6 +38,17 @@ class ChromaRadianceStubVAE:
continue continue
d_offset = (dims[d] % 16) // 2 d_offset = (dims[d] % 16) // 2
pixels = pixels.narrow(d + 1, d_offset, d_adj) pixels = pixels.narrow(d + 1, d_offset, d_adj)
return pixels
@classmethod
def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
device = comfy.model_management.intermediate_device()
if pixels.ndim == 3:
pixels = pixels.unsqueeze(0)
elif pixels.ndim != 4:
raise ValueError("Unexpected input image shape")
# Ensure the image has spatial dimensions that are multiples of 16.
pixels = cls.vae_encode_crop_pixels(pixels)
h, w, c = pixels.shape[1:] h, w, c = pixels.shape[1:]
if h < 16 or w < 16: if h < 16 or w < 16:
raise ValueError("Chroma Radiance image inputs must have height/width of at least 16 pixels.") raise ValueError("Chroma Radiance image inputs must have height/width of at least 16 pixels.")
@ -51,6 +57,7 @@ class ChromaRadianceStubVAE:
pixels = pixels.expand(-1, -1, -1, 3) pixels = pixels.expand(-1, -1, -1, 3)
elif c != 3: elif c != 3:
raise ValueError("Unexpected number of channels in input image") raise ValueError("Unexpected number of channels in input image")
# Rescale to -1..1 and move the channel dimension to position 1.
latent = pixels.to(device=device, dtype=torch.float32, copy=True) latent = pixels.to(device=device, dtype=torch.float32, copy=True)
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous() latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
latent -= 0.5 latent -= 0.5
@ -60,6 +67,7 @@ class ChromaRadianceStubVAE:
@classmethod @classmethod
def decode(cls, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: def decode(cls, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
device = comfy.model_management.intermediate_device() device = comfy.model_management.intermediate_device()
# Rescale to 0..1 and move the channel dimension to the end.
img = samples.to(device=device, dtype=torch.float32, copy=True) img = samples.to(device=device, dtype=torch.float32, copy=True)
img = img.clamp_(-1, 1).movedim(1, -1).contiguous() img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
img += 1.0 img += 1.0
@ -71,6 +79,7 @@ class ChromaRadianceStubVAE:
@classmethod @classmethod
def spacial_compression_decode(cls) -> int: def spacial_compression_decode(cls) -> int:
# This just exists so the tiled VAE nodes don't crash.
return 1 return 1
spacial_compression_encode = spacial_compression_decode spacial_compression_encode = spacial_compression_decode
@ -115,7 +124,7 @@ class ChromaRadianceStubVAENode(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="ChromaRadianceStubVAE", node_id="ChromaRadianceStubVAE",
category="vae/chroma_radiance", category="vae/chroma_radiance",
description="For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.", description="For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Chroma Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.",
outputs=[io.Vae.Output()], outputs=[io.Vae.Output()],
) )
@ -129,37 +138,39 @@ class ChromaRadianceOptions(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="ChromaRadianceOptions", node_id="ChromaRadianceOptions",
category="model_patches/chroma_radiance", category="model_patches/chroma_radiance",
description="Allows setting some advanced options for the Chroma Radiance model.", description="Allows setting advanced options for the Chroma Radiance model.",
inputs=[ inputs=[
io.Model.Input(id="model"), io.Model.Input(id="model"),
io.Boolean.Input( io.Boolean.Input(
id="preserve_wrapper", id="preserve_wrapper",
default=True, default=True,
tooltip="When enabled preserves an existing model wrapper if it exists. Generally should be left enabled.", tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.",
), ),
io.Float.Input( io.Float.Input(
id="start_sigma", id="start_sigma",
default=1.0, default=1.0,
min=0.0, min=0.0,
max=1.0, max=1.0,
tooltip="First sigma that these options will be in effect.",
), ),
io.Float.Input( io.Float.Input(
id="end_sigma", id="end_sigma",
default=0.0, default=0.0,
min=0.0, min=0.0,
max=1.0, max=1.0,
tooltip="Last sigma that these options will be in effect.",
), ),
io.Int.Input( io.Int.Input(
id="nerf_tile_size", id="nerf_tile_size",
default=-1, default=-1,
min=-1, min=-1,
tooltip="Allows overriding the default NeRF tile size. -1 means use the default. 0 means use non-tiling mode (may require a lot of VRAM).", tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
), ),
io.Combo.Input( io.Combo.Input(
id="nerf_embedder_dtype", id="nerf_embedder_dtype",
default="default", default="default",
options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"], options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"],
tooltip="Allows overriding the dtype the NeRF embedder uses.", tooltip="Allows overriding the dtype the NeRF embedder uses. The default is float32.",
), ),
], ],
outputs=[io.Model.Output()], outputs=[io.Model.Output()],