Move Chroma Radiance to its own directory in ldm

Minor code cleanups and tooltip improvements
This commit is contained in:
blepping 2025-09-04 23:27:44 -06:00
parent a46078afe7
commit f1f5b7d9b5
5 changed files with 76 additions and 35 deletions

View File

@ -16,7 +16,15 @@ class NerfEmbedder(nn.Module):
patch size, and enriches it with positional information before projecting
it to a new hidden size.
"""
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None):
def __init__(
self,
in_channels: int,
hidden_size_input: int,
max_freqs: int,
dtype=None,
device=None,
operations=None,
):
"""
Initializes the NerfEmbedder.
@ -38,7 +46,7 @@ class NerfEmbedder(nn.Module):
)
@lru_cache(maxsize=4)
def fetch_pos(self, patch_size: int, device, dtype) -> torch.Tensor:
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""
Generates and caches 2D DCT-like positional embeddings for a given patch size.
@ -179,14 +187,14 @@ class NerfFinalLayer(nn.Module):
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
# So we temporarily move the channel dimension to the end for the norm operation.
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.conv = operations.Conv2d(
@ -198,7 +206,7 @@ class NerfFinalLayerConv(nn.Module):
device=device,
)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
# So we temporarily move the channel dimension to the end for the norm operation.
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))

View File

@ -12,29 +12,28 @@ import comfy.ldm.common_dit
from comfy.ldm.flux.layers import EmbedND
from .layers import (
from comfy.ldm.chroma.model import Chroma, ChromaParams
from comfy.ldm.chroma.layers import (
DoubleStreamBlock,
SingleStreamBlock,
Approximator,
)
from .layers_dct import (
from .layers import (
NerfEmbedder,
NerfGLUBlock,
NerfFinalLayer,
NerfFinalLayerConv,
)
from . import model as chroma_model
@dataclass
class ChromaRadianceParams(chroma_model.ChromaParams):
class ChromaRadianceParams(ChromaParams):
patch_size: int
nerf_hidden_size: int
nerf_mlp_ratio: int
nerf_depth: int
nerf_max_freqs: int
# nerf_tile_size of 0 means unlimited.
# Setting nerf_tile_size to 0 disables tiling.
nerf_tile_size: int
# Currently one of linear (legacy) or conv.
nerf_final_head_type: str
@ -42,12 +41,14 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
nerf_embedder_dtype: Optional[torch.dtype]
class ChromaRadiance(chroma_model.Chroma):
class ChromaRadiance(Chroma):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
if operations is None:
raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
nn.Module.__init__(self)
self.dtype = dtype
params = ChromaRadianceParams(**kwargs)
@ -188,7 +189,9 @@ class ChromaRadiance(chroma_model.Chroma):
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
if params.nerf_tile_size > 0:
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
# the tile size.
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
else:
# Reshape for per-patch processing
@ -219,8 +222,8 @@ class ChromaRadiance(chroma_model.Chroma):
self,
nerf_hidden: Tensor,
nerf_pixels: Tensor,
B: int,
C: int,
batch: int,
channels: int,
num_patches: int,
patch_size: int,
params: ChromaRadianceParams,
@ -246,9 +249,9 @@ class ChromaRadiance(chroma_model.Chroma):
# Reshape the tile for per-patch processing
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, params.hidden_size)
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2)
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
# get DCT-encoded pixel embeddings [pixel-dct]
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype)
@ -284,7 +287,16 @@ class ChromaRadiance(chroma_model.Chroma):
params_dict |= overrides
return params.__class__(**params_dict)
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
def _forward(
self,
x: Tensor,
timestep: Tensor,
context: Tensor,
guidance: Optional[Tensor],
control: Optional[dict]=None,
transformer_options: dict={},
**kwargs: dict,
) -> Tensor:
bs, c, h, w = x.shape
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
@ -303,5 +315,15 @@ class ChromaRadiance(chroma_model.Chroma):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
img_out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
img_out = self.forward_orig(
img,
img_ids,
context,
txt_ids,
timestep,
guidance,
control,
transformer_options,
attn_mask=kwargs.get("attention_mask", None),
)
return self.forward_nerf(img, img_out, params)

View File

@ -42,7 +42,7 @@ import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.chroma.model_dct
import comfy.ldm.chroma_radiance.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
@ -1334,7 +1334,7 @@ class Chroma(Flux):
class ChromaRadiance(Chroma):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model_dct.ChromaRadiance)
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
class ACEStep(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):

View File

@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight)
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
@ -204,7 +204,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Radiance
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3

View File

@ -29,13 +29,8 @@ class EmptyChromaRadianceLatentImage(io.ComfyNode):
class ChromaRadianceStubVAE:
@classmethod
def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
device = comfy.model_management.intermediate_device()
if pixels.ndim == 3:
pixels = pixels.unsqueeze(0)
elif pixels.ndim != 4:
raise ValueError("Unexpected input image shape")
@staticmethod
def vae_encode_crop_pixels(pixels: torch.Tensor) -> torch.Tensor:
dims = pixels.shape[1:-1]
for d in range(len(dims)):
d_adj = (dims[d] // 16) * 16
@ -43,6 +38,17 @@ class ChromaRadianceStubVAE:
continue
d_offset = (dims[d] % 16) // 2
pixels = pixels.narrow(d + 1, d_offset, d_adj)
return pixels
@classmethod
def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
device = comfy.model_management.intermediate_device()
if pixels.ndim == 3:
pixels = pixels.unsqueeze(0)
elif pixels.ndim != 4:
raise ValueError("Unexpected input image shape")
# Ensure the image has spatial dimensions that are multiples of 16.
pixels = cls.vae_encode_crop_pixels(pixels)
h, w, c = pixels.shape[1:]
if h < 16 or w < 16:
raise ValueError("Chroma Radiance image inputs must have height/width of at least 16 pixels.")
@ -51,6 +57,7 @@ class ChromaRadianceStubVAE:
pixels = pixels.expand(-1, -1, -1, 3)
elif c != 3:
raise ValueError("Unexpected number of channels in input image")
# Rescale to -1..1 and move the channel dimension to position 1.
latent = pixels.to(device=device, dtype=torch.float32, copy=True)
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
latent -= 0.5
@ -60,6 +67,7 @@ class ChromaRadianceStubVAE:
@classmethod
def decode(cls, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
device = comfy.model_management.intermediate_device()
# Rescale to 0..1 and move the channel dimension to the end.
img = samples.to(device=device, dtype=torch.float32, copy=True)
img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
img += 1.0
@ -71,6 +79,7 @@ class ChromaRadianceStubVAE:
@classmethod
def spacial_compression_decode(cls) -> int:
# This just exists so the tiled VAE nodes don't crash.
return 1
spacial_compression_encode = spacial_compression_decode
@ -115,7 +124,7 @@ class ChromaRadianceStubVAENode(io.ComfyNode):
return io.Schema(
node_id="ChromaRadianceStubVAE",
category="vae/chroma_radiance",
description="For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.",
description="For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Chroma Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.",
outputs=[io.Vae.Output()],
)
@ -129,37 +138,39 @@ class ChromaRadianceOptions(io.ComfyNode):
return io.Schema(
node_id="ChromaRadianceOptions",
category="model_patches/chroma_radiance",
description="Allows setting some advanced options for the Chroma Radiance model.",
description="Allows setting advanced options for the Chroma Radiance model.",
inputs=[
io.Model.Input(id="model"),
io.Boolean.Input(
id="preserve_wrapper",
default=True,
tooltip="When enabled preserves an existing model wrapper if it exists. Generally should be left enabled.",
tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.",
),
io.Float.Input(
id="start_sigma",
default=1.0,
min=0.0,
max=1.0,
tooltip="First sigma that these options will be in effect.",
),
io.Float.Input(
id="end_sigma",
default=0.0,
min=0.0,
max=1.0,
tooltip="Last sigma that these options will be in effect.",
),
io.Int.Input(
id="nerf_tile_size",
default=-1,
min=-1,
tooltip="Allows overriding the default NeRF tile size. -1 means use the default. 0 means use non-tiling mode (may require a lot of VRAM).",
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
),
io.Combo.Input(
id="nerf_embedder_dtype",
default="default",
options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"],
tooltip="Allows overriding the dtype the NeRF embedder uses.",
tooltip="Allows overriding the dtype the NeRF embedder uses. The default is float32.",
),
],
outputs=[io.Model.Output()],