diff --git a/comfy/ldm/chroma/layers_dct.py b/comfy/ldm/chroma_radiance/layers.py similarity index 94% rename from comfy/ldm/chroma/layers_dct.py rename to comfy/ldm/chroma_radiance/layers.py index 24c1aa7ee..7ed61d69d 100644 --- a/comfy/ldm/chroma/layers_dct.py +++ b/comfy/ldm/chroma_radiance/layers.py @@ -16,7 +16,15 @@ class NerfEmbedder(nn.Module): patch size, and enriches it with positional information before projecting it to a new hidden size. """ - def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None): + def __init__( + self, + in_channels: int, + hidden_size_input: int, + max_freqs: int, + dtype=None, + device=None, + operations=None, + ): """ Initializes the NerfEmbedder. @@ -38,7 +46,7 @@ class NerfEmbedder(nn.Module): ) @lru_cache(maxsize=4) - def fetch_pos(self, patch_size: int, device, dtype) -> torch.Tensor: + def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: """ Generates and caches 2D DCT-like positional embeddings for a given patch size. @@ -179,14 +187,14 @@ class NerfFinalLayer(nn.Module): self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. # So we temporarily move the channel dimension to the end for the norm operation. return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1) class NerfFinalLayerConv(nn.Module): - def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None): super().__init__() self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) self.conv = operations.Conv2d( @@ -198,7 +206,7 @@ class NerfFinalLayerConv(nn.Module): device=device, ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. # So we temporarily move the channel dimension to the end for the norm operation. return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1)) diff --git a/comfy/ldm/chroma/model_dct.py b/comfy/ldm/chroma_radiance/model.py similarity index 89% rename from comfy/ldm/chroma/model_dct.py rename to comfy/ldm/chroma_radiance/model.py index 1c5e0bb3a..393f612f8 100644 --- a/comfy/ldm/chroma/model_dct.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -12,29 +12,28 @@ import comfy.ldm.common_dit from comfy.ldm.flux.layers import EmbedND -from .layers import ( +from comfy.ldm.chroma.model import Chroma, ChromaParams +from comfy.ldm.chroma.layers import ( DoubleStreamBlock, SingleStreamBlock, Approximator, ) -from .layers_dct import ( +from .layers import ( NerfEmbedder, NerfGLUBlock, NerfFinalLayer, NerfFinalLayerConv, ) -from . import model as chroma_model - @dataclass -class ChromaRadianceParams(chroma_model.ChromaParams): +class ChromaRadianceParams(ChromaParams): patch_size: int nerf_hidden_size: int nerf_mlp_ratio: int nerf_depth: int nerf_max_freqs: int - # nerf_tile_size of 0 means unlimited. + # Setting nerf_tile_size to 0 disables tiling. nerf_tile_size: int # Currently one of linear (legacy) or conv. nerf_final_head_type: str @@ -42,12 +41,14 @@ class ChromaRadianceParams(chroma_model.ChromaParams): nerf_embedder_dtype: Optional[torch.dtype] -class ChromaRadiance(chroma_model.Chroma): +class ChromaRadiance(Chroma): """ Transformer model for flow matching on sequences. """ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): + if operations is None: + raise RuntimeError("Attempt to create ChromaRadiance object without setting operations") nn.Module.__init__(self) self.dtype = dtype params = ChromaRadianceParams(**kwargs) @@ -188,7 +189,9 @@ class ChromaRadiance(chroma_model.Chroma): nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size) nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] - if params.nerf_tile_size > 0: + if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size: + # Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than + # the tile size. img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params) else: # Reshape for per-patch processing @@ -219,8 +222,8 @@ class ChromaRadiance(chroma_model.Chroma): self, nerf_hidden: Tensor, nerf_pixels: Tensor, - B: int, - C: int, + batch: int, + channels: int, num_patches: int, patch_size: int, params: ChromaRadianceParams, @@ -246,9 +249,9 @@ class ChromaRadiance(chroma_model.Chroma): # Reshape the tile for per-patch processing # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D] - nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, params.hidden_size) + nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size) # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C] - nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2) + nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2) # get DCT-encoded pixel embeddings [pixel-dct] img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype) @@ -284,7 +287,16 @@ class ChromaRadiance(chroma_model.Chroma): params_dict |= overrides return params.__class__(**params_dict) - def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): + def _forward( + self, + x: Tensor, + timestep: Tensor, + context: Tensor, + guidance: Optional[Tensor], + control: Optional[dict]=None, + transformer_options: dict={}, + **kwargs: dict, + ) -> Tensor: bs, c, h, w = x.shape img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) @@ -303,5 +315,15 @@ class ChromaRadiance(chroma_model.Chroma): img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - img_out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) + img_out = self.forward_orig( + img, + img_ids, + context, + txt_ids, + timestep, + guidance, + control, + transformer_options, + attn_mask=kwargs.get("attention_mask", None), + ) return self.forward_nerf(img, img_out, params) diff --git a/comfy/model_base.py b/comfy/model_base.py index 1d08c6853..252dfcf69 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,7 +42,7 @@ import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model -import comfy.ldm.chroma.model_dct +import comfy.ldm.chroma_radiance.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model @@ -1334,7 +1334,7 @@ class Chroma(Flux): class ChromaRadiance(Chroma): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model_dct.ChromaRadiance) + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance) class ACEStep(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bfc6a188a..03d44f65e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight) + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 @@ -204,7 +204,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["out_dim"] = 3072 dit_config["hidden_dim"] = 5120 dit_config["n_layers"] = 5 - if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Radiance + if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance dit_config["image_model"] = "chroma_radiance" dit_config["in_channels"] = 3 dit_config["out_channels"] = 3 diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py index f5976cbcd..807828899 100644 --- a/comfy_extras/nodes_chroma_radiance.py +++ b/comfy_extras/nodes_chroma_radiance.py @@ -29,13 +29,8 @@ class EmptyChromaRadianceLatentImage(io.ComfyNode): class ChromaRadianceStubVAE: - @classmethod - def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: - device = comfy.model_management.intermediate_device() - if pixels.ndim == 3: - pixels = pixels.unsqueeze(0) - elif pixels.ndim != 4: - raise ValueError("Unexpected input image shape") + @staticmethod + def vae_encode_crop_pixels(pixels: torch.Tensor) -> torch.Tensor: dims = pixels.shape[1:-1] for d in range(len(dims)): d_adj = (dims[d] // 16) * 16 @@ -43,6 +38,17 @@ class ChromaRadianceStubVAE: continue d_offset = (dims[d] % 16) // 2 pixels = pixels.narrow(d + 1, d_offset, d_adj) + return pixels + + @classmethod + def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + device = comfy.model_management.intermediate_device() + if pixels.ndim == 3: + pixels = pixels.unsqueeze(0) + elif pixels.ndim != 4: + raise ValueError("Unexpected input image shape") + # Ensure the image has spatial dimensions that are multiples of 16. + pixels = cls.vae_encode_crop_pixels(pixels) h, w, c = pixels.shape[1:] if h < 16 or w < 16: raise ValueError("Chroma Radiance image inputs must have height/width of at least 16 pixels.") @@ -51,6 +57,7 @@ class ChromaRadianceStubVAE: pixels = pixels.expand(-1, -1, -1, 3) elif c != 3: raise ValueError("Unexpected number of channels in input image") + # Rescale to -1..1 and move the channel dimension to position 1. latent = pixels.to(device=device, dtype=torch.float32, copy=True) latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous() latent -= 0.5 @@ -60,6 +67,7 @@ class ChromaRadianceStubVAE: @classmethod def decode(cls, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: device = comfy.model_management.intermediate_device() + # Rescale to 0..1 and move the channel dimension to the end. img = samples.to(device=device, dtype=torch.float32, copy=True) img = img.clamp_(-1, 1).movedim(1, -1).contiguous() img += 1.0 @@ -71,6 +79,7 @@ class ChromaRadianceStubVAE: @classmethod def spacial_compression_decode(cls) -> int: + # This just exists so the tiled VAE nodes don't crash. return 1 spacial_compression_encode = spacial_compression_decode @@ -115,7 +124,7 @@ class ChromaRadianceStubVAENode(io.ComfyNode): return io.Schema( node_id="ChromaRadianceStubVAE", category="vae/chroma_radiance", - description="For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.", + description="For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Chroma Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.", outputs=[io.Vae.Output()], ) @@ -129,37 +138,39 @@ class ChromaRadianceOptions(io.ComfyNode): return io.Schema( node_id="ChromaRadianceOptions", category="model_patches/chroma_radiance", - description="Allows setting some advanced options for the Chroma Radiance model.", + description="Allows setting advanced options for the Chroma Radiance model.", inputs=[ io.Model.Input(id="model"), io.Boolean.Input( id="preserve_wrapper", default=True, - tooltip="When enabled preserves an existing model wrapper if it exists. Generally should be left enabled.", + tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.", ), io.Float.Input( id="start_sigma", default=1.0, min=0.0, max=1.0, + tooltip="First sigma that these options will be in effect.", ), io.Float.Input( id="end_sigma", default=0.0, min=0.0, max=1.0, + tooltip="Last sigma that these options will be in effect.", ), io.Int.Input( id="nerf_tile_size", default=-1, min=-1, - tooltip="Allows overriding the default NeRF tile size. -1 means use the default. 0 means use non-tiling mode (may require a lot of VRAM).", + tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", ), io.Combo.Input( id="nerf_embedder_dtype", default="default", options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"], - tooltip="Allows overriding the dtype the NeRF embedder uses.", + tooltip="Allows overriding the dtype the NeRF embedder uses. The default is float32.", ), ], outputs=[io.Model.Output()],