diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 894540879..77e642a94 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -629,3 +629,20 @@ class Hunyuan3Dv2mini(LatentFormat): class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 + +class ChromaRadiance(LatentFormat): + latent_channels = 3 + + def __init__(self): + self.latent_rgb_factors = [ + # R G B + [ 1.0, 0.0, 0.0 ], + [ 0.0, 1.0, 0.0 ], + [ 0.0, 0.0, 1.0 ] + ] + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent diff --git a/comfy/ldm/chroma/layers_dct.py b/comfy/ldm/chroma/layers_dct.py new file mode 100644 index 000000000..da7fde5e0 --- /dev/null +++ b/comfy/ldm/chroma/layers_dct.py @@ -0,0 +1,180 @@ +# Adapted from https://github.com/lodestone-rock/flow +from functools import lru_cache + +import torch +from torch import nn + +from comfy.ldm.flux.layers import RMSNorm + +class NerfEmbedder(nn.Module): + """ + An embedder module that combines input features with a 2D positional + encoding that mimics the Discrete Cosine Transform (DCT). + + This module takes an input tensor of shape (B, P^2, C), where P is the + patch size, and enriches it with positional information before projecting + it to a new hidden size. + """ + def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None): + """ + Initializes the NerfEmbedder. + + Args: + in_channels (int): The number of channels in the input tensor. + hidden_size_input (int): The desired dimension of the output embedding. + max_freqs (int): The number of frequency components to use for both + the x and y dimensions of the positional encoding. + The total number of positional features will be max_freqs^2. + """ + super().__init__() + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + + # A linear layer to project the concatenated input features and + # positional encodings to the final output dimension. + self.embedder = nn.Sequential( + operations.Linear(in_channels + max_freqs**2, hidden_size_input, device=device, dtype=dtype) + ) + + @lru_cache(maxsize=4) + def fetch_pos(self, patch_size, device, dtype): + """ + Generates and caches 2D DCT-like positional embeddings for a given patch size. + + The LRU cache is a performance optimization that avoids recomputing the + same positional grid on every forward pass. + + Args: + patch_size (int): The side length of the square input patch. + device: The torch device to create the tensors on. + dtype: The torch dtype for the tensors. + + Returns: + A tensor of shape (1, patch_size^2, max_freqs^2) containing the + positional embeddings. + """ + # Create normalized 1D coordinate grids from 0 to 1. + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + + # Create a 2D meshgrid of coordinates. + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + + # Reshape positions to be broadcastable with frequencies. + # Shape becomes (patch_size^2, 1, 1). + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + # Create a 1D tensor of frequency values from 0 to max_freqs-1. + freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) + + # Reshape frequencies to be broadcastable for creating 2D basis functions. + # freqs_x shape: (1, max_freqs, 1) + # freqs_y shape: (1, 1, max_freqs) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + + # A custom weighting coefficient, not part of standard DCT. + # This seems to down-weight the contribution of higher-frequency interactions. + coeffs = (1 + freqs_x * freqs_y) ** -1 + + # Calculate the 1D cosine basis functions for x and y coordinates. + # This is the core of the DCT formulation. + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + + # Combine the 1D basis functions to create 2D basis functions by element-wise + # multiplication, and apply the custom coefficients. Broadcasting handles the + # combination of all (pos_x, freqs_x) with all (pos_y, freqs_y). + # The result is flattened into a feature vector for each position. + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + + return dct + + def forward(self, inputs): + """ + Forward pass for the embedder. + + Args: + inputs (Tensor): The input tensor of shape (B, P^2, C). + + Returns: + Tensor: The output tensor of shape (B, P^2, hidden_size_input). + """ + # Get the batch size, number of pixels, and number of channels. + B, P2, C = inputs.shape + + # Infer the patch side length from the number of pixels (P^2). + patch_size = int(P2 ** 0.5) + + # Fetch the pre-computed or cached positional embeddings. + dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype) + + # Repeat the positional embeddings for each item in the batch. + dct = dct.repeat(B, 1, 1) + + # Concatenate the original input features with the positional embeddings + # along the feature dimension. + inputs = torch.cat([inputs, dct], dim=-1) + + # Project the combined tensor to the target hidden size. + inputs = self.embedder(inputs) + + return inputs + +class NerfGLUBlock(nn.Module): + """ + A NerfBlock using a Gated Linear Unit (GLU) like MLP. + """ + def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, device=None, dtype=None, operations=None): + super().__init__() + # The total number of parameters for the MLP is increased to accommodate + # the gate, value, and output projection matrices. + # We now need to generate parameters for 3 matrices. + total_params = 3 * hidden_size_x**2 * mlp_ratio + self.param_generator = operations.Linear(hidden_size_s, total_params, device=device, dtype=dtype) + self.norm = RMSNorm(hidden_size_x, device=device, dtype=dtype, operations=operations) + self.mlp_ratio = mlp_ratio + # nn.init.zeros_(self.param_generator.weight) + # nn.init.zeros_(self.param_generator.bias) + + + def forward(self, x, s): + batch_size, num_x, hidden_size_x = x.shape + mlp_params = self.param_generator(s) + + # Split the generated parameters into three parts for the gate, value, and output projection. + fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1) + + # Reshape the parameters into matrices for batch matrix multiplication. + fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x) + + # Normalize the generated weight matrices as in the original implementation. + fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2) + fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2) + fc2 = torch.nn.functional.normalize(fc2, dim=-2) + + res_x = x + x = self.norm(x) + + # Apply the final output projection. + x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2) + + x = x + res_x + return x + + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x): + x = self.norm(x) + x = self.linear(x) + return x diff --git a/comfy/ldm/chroma/model_dct.py b/comfy/ldm/chroma/model_dct.py new file mode 100644 index 000000000..4e669f825 --- /dev/null +++ b/comfy/ldm/chroma/model_dct.py @@ -0,0 +1,305 @@ +# Credits: +# Original Flux code can be found on: https://github.com/black-forest-labs/flux +# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow + +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import repeat +import comfy.ldm.common_dit + +from comfy.ldm.flux.layers import ( + EmbedND, + timestep_embedding, +) + +from .layers import ( + DoubleStreamBlock, + SingleStreamBlock, + Approximator, +) +from .layers_dct import NerfEmbedder, NerfGLUBlock, NerfFinalLayer + +from . import model as chroma_model + +@dataclass +class ChromaRadianceParams(chroma_model.ChromaParams): + patch_size: int + nerf_hidden_size: int + nerf_mlp_ratio: int + nerf_depth: int + nerf_max_freqs: int + + +class ChromaRadiance(chroma_model.Chroma): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): + nn.Module.__init__(self) + self.dtype = dtype + params = ChromaRadianceParams(**kwargs) + self.params = params + self.patch_size = params.patch_size + self.in_channels = params.in_channels + self.out_channels = params.out_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.in_dim = params.in_dim + self.out_dim = params.out_dim + self.hidden_dim = params.hidden_dim + self.n_layers = params.n_layers + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in_patch = operations.Conv2d( + params.in_channels, + params.hidden_size, + kernel_size=params.patch_size, + stride=params.patch_size, + bias=True, + device=device, + dtype=dtype, + ) + nn.init.zeros_(self.img_in_patch.weight) + nn.init.zeros_(self.img_in_patch.bias) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + # set as nn identity for now, will overwrite it later. + self.distilled_guidance_layer = Approximator( + in_dim=self.in_dim, + hidden_dim=self.hidden_dim, + out_dim=self.out_dim, + n_layers=self.n_layers, + dtype=dtype, device=device, operations=operations + ) + + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + dtype=dtype, device=device, operations=operations + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + for _ in range(params.depth_single_blocks) + ] + ) + + # pixel channel concat with DCT + self.nerf_image_embedder = NerfEmbedder( + in_channels=params.in_channels, + hidden_size_input=params.nerf_hidden_size, + max_freqs=params.nerf_max_freqs, + dtype=dtype, + device=device, + operations=operations, + ) + + self.nerf_blocks = nn.ModuleList([ + NerfGLUBlock( + hidden_size_s=params.hidden_size, + hidden_size_x=params.nerf_hidden_size, + mlp_ratio=params.nerf_mlp_ratio, + dtype=dtype, + device=device, + operations=operations, + ) for _ in range(params.nerf_depth) + ]) + self.nerf_final_layer = NerfFinalLayer( + params.nerf_hidden_size, + out_channels=params.in_channels, + dtype=dtype, + device=device, + operations=operations, + ) + + self.skip_mmdit = [] + self.skip_dit = [] + self.lite = False + + def forward_orig( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + guidance: Tensor = None, + control = None, + transformer_options={}, + attn_mask: Tensor = None, + ) -> Tensor: + patches_replace = transformer_options.get("patches_replace", {}) + if img.ndim != 4: + raise ValueError("Input img tensor must be in [B, C, H, W] format.") + if txt.ndim != 3: + raise ValueError("Input txt tensors must have 3 dimensions.") + B, C, H, W = img.shape + + # gemini gogogo idk how to unfold and pack the patch properly :P + # Store the raw pixel values of each patch for the NeRF head later. + # unfold creates patches: [B, C * P * P, NumPatches] + nerf_pixels = nn.functional.unfold(img, kernel_size=self.params.patch_size, stride=self.params.patch_size) + nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] + + # partchify ops + img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P] + num_patches = img.shape[2] * img.shape[3] + # flatten into a sequence for the transformer. + img = img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] + + # distilled vector guidance + mod_index_length = 344 + distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype) + # guidance = guidance * + distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype) + + # get all modulation index + modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype) + # and we need to broadcast timestep and guidance along too + timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype) + + mod_vectors = self.distilled_guidance_layer(input_vec) + + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.double_blocks): + if i not in self.skip_mmdit: + double_mod = ( + self.get_modulations(mod_vectors, "double_img", idx=i), + self.get_modulations(mod_vectors, "double_txt", idx=i), + ) + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": double_mod, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=double_mod, + pe=pe, + attn_mask=attn_mask) + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + img += add + + img = torch.cat((txt, img), 1) + + for i, block in enumerate(self.single_blocks): + if i not in self.skip_dit: + single_mod = self.get_modulations(mod_vectors, "single", idx=i) + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": single_mod, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask) + + if control is not None: # Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img[:, txt.shape[1] :, ...] += add + + img = img[:, txt.shape[1] :, ...] + # aliasing + nerf_hidden = img + # reshape for per-patch processing + nerf_hidden = nerf_hidden.reshape(B * num_patches, self.params.hidden_size) + nerf_pixels = nerf_pixels.reshape(B * num_patches, C, self.params.patch_size**2).transpose(1, 2) + + # get DCT-encoded pixel embeddings [pixel-dct] + img_dct = self.nerf_image_embedder(nerf_pixels) + + # pass through the dynamic MLP blocks (the NeRF) + for i, block in enumerate(self.nerf_blocks): + img_dct = block(img_dct, nerf_hidden) + + # final projection to get the output pixel values + img_dct = self.nerf_final_layer(img_dct) # -> [B*NumPatches, P*P, C] + + # gemini gogogo idk how to fold this properly :P + # Reassemble the patches into the final image. + img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P] + # Reshape to combine with batch dimension for fold + img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P] + img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches] + img_dct = nn.functional.fold( + img_dct, + output_size=(H, W), + kernel_size=self.params.patch_size, + stride=self.params.patch_size + ) + + return img_dct + + def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): + bs, c, h, w = x.shape + img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + h_len = ((h + (self.patch_size // 2)) // self.patch_size) + w_len = ((w + (self.patch_size // 2)) // self.patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + + return self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) + + + diff --git a/comfy/model_base.py b/comfy/model_base.py index 324d89cff..1d08c6853 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model +import comfy.ldm.chroma.model_dct import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model @@ -1320,8 +1321,8 @@ class HiDream(BaseModel): return out class Chroma(Flux): - def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) + def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma): + super().__init__(model_config, model_type, device=device, unet_model=unet_model) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1331,6 +1332,10 @@ class Chroma(Flux): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class ChromaRadiance(Chroma): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model_dct.ChromaRadiance) + class ACEStep(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index fe983cede..19d5c133e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys): #Flux dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 @@ -204,6 +204,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["out_dim"] = 3072 dit_config["hidden_dim"] = 5120 dit_config["n_layers"] = 5 + if f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys: #Radiance + dit_config["image_model"] = "chroma_radiance" + dit_config["in_channels"] = 3 + dit_config["out_channels"] = 3 + dit_config["patch_size"] = 16 + dit_config["nerf_hidden_size"] = 64 + dit_config["nerf_mlp_ratio"] = 4 + dit_config["nerf_depth"] = 4 + dit_config["nerf_max_freqs"] = 8 else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 472ea0ae9..c12ff9d9e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1205,6 +1205,16 @@ class Chroma(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) +class ChromaRadiance(Chroma): + unet_config = { + "image_model": "chroma_radiance", + } + + latent_format = comfy.latent_formats.ChromaRadiance + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ChromaRadiance(self, device=device) + class ACEStep(supported_models_base.BASE): unet_config = { "audio_model": "ace", @@ -1338,6 +1348,6 @@ class HunyuanImage21Refiner(HunyuanVideo): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py new file mode 100644 index 000000000..d5752b75f --- /dev/null +++ b/comfy_extras/nodes_chroma_radiance.py @@ -0,0 +1,62 @@ +import torch + +import comfy.model_management + +import nodes + +class EmptyChromaRadianceLatentImage: + def __init__(self): + self.device = comfy.model_management.intermediate_device() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "width": ("INT", {"default": 1024, "min": 2, "max": nodes.MAX_RESOLUTION}), + "height": ("INT", {"default": 1024, "min": 2, "max": nodes.MAX_RESOLUTION}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "go" + + CATEGORY = "latent/chroma_radiance" + + def go(self, *, width, height, batch_size=1): + latent = torch.zeros((batch_size, 3, height, width), device=self.device) + return ({"samples":latent}, ) + + +class ChromaRadianceLatentToImage: + @classmethod + def INPUT_TYPES(s): + return {"required": {"latent": ("LATENT",)}} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "go" + + CATEGORY = "latent/chroma_radiance" + + @classmethod + def go(cls, *, latent): + img = latent["samples"].movedim(1, -1).clamp(-1, 1).contiguous() + img = (img + 1.0) * 0.5 + return (img,) + +class ChromaRadianceImageToLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "go" + + CATEGORY = "latent/chroma_radiance" + + @classmethod + def go(cls, *, image): + image = (image.clone().clamp(0, 1) - 0.5) * 2 + image = image.movedim(-1, 1).contiguous() + return ({"samples": image},) + +NODE_CLASS_MAPPINGS = { + "EmptyChromaRadianceLatentImage": EmptyChromaRadianceLatentImage, + "ChromaRadianceLatentToImage": ChromaRadianceLatentToImage, + "ChromaRadianceImageToLatent": ChromaRadianceImageToLatent, +} diff --git a/nodes.py b/nodes.py index 2befb4b75..a08dd7dcf 100644 --- a/nodes.py +++ b/nodes.py @@ -2322,6 +2322,7 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", + "nodes_chroma_radiance.py", "nodes_model_patch.py", "nodes_easycache.py", "nodes_audio_encoder.py",