mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-12 12:47:45 +08:00
Move Chroma Radiance to its own directory in ldm
Minor code cleanups and tooltip improvements
This commit is contained in:
parent
a46078afe7
commit
f1f5b7d9b5
@ -16,7 +16,15 @@ class NerfEmbedder(nn.Module):
|
||||
patch size, and enriches it with positional information before projecting
|
||||
it to a new hidden size.
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_size_input: int,
|
||||
max_freqs: int,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
"""
|
||||
Initializes the NerfEmbedder.
|
||||
|
||||
@ -38,7 +46,7 @@ class NerfEmbedder(nn.Module):
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_pos(self, patch_size: int, device, dtype) -> torch.Tensor:
|
||||
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Generates and caches 2D DCT-like positional embeddings for a given patch size.
|
||||
|
||||
@ -179,14 +187,14 @@ class NerfFinalLayer(nn.Module):
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
|
||||
|
||||
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.conv = operations.Conv2d(
|
||||
@ -198,7 +206,7 @@ class NerfFinalLayerConv(nn.Module):
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))
|
||||
@ -12,29 +12,28 @@ import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
from .layers import (
|
||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||
from comfy.ldm.chroma.layers import (
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
)
|
||||
from .layers_dct import (
|
||||
from .layers import (
|
||||
NerfEmbedder,
|
||||
NerfGLUBlock,
|
||||
NerfFinalLayer,
|
||||
NerfFinalLayerConv,
|
||||
)
|
||||
|
||||
from . import model as chroma_model
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaRadianceParams(chroma_model.ChromaParams):
|
||||
class ChromaRadianceParams(ChromaParams):
|
||||
patch_size: int
|
||||
nerf_hidden_size: int
|
||||
nerf_mlp_ratio: int
|
||||
nerf_depth: int
|
||||
nerf_max_freqs: int
|
||||
# nerf_tile_size of 0 means unlimited.
|
||||
# Setting nerf_tile_size to 0 disables tiling.
|
||||
nerf_tile_size: int
|
||||
# Currently one of linear (legacy) or conv.
|
||||
nerf_final_head_type: str
|
||||
@ -42,12 +41,14 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
|
||||
|
||||
class ChromaRadiance(chroma_model.Chroma):
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
if operations is None:
|
||||
raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
params = ChromaRadianceParams(**kwargs)
|
||||
@ -188,7 +189,9 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||
|
||||
if params.nerf_tile_size > 0:
|
||||
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
||||
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
||||
# the tile size.
|
||||
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
else:
|
||||
# Reshape for per-patch processing
|
||||
@ -219,8 +222,8 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
self,
|
||||
nerf_hidden: Tensor,
|
||||
nerf_pixels: Tensor,
|
||||
B: int,
|
||||
C: int,
|
||||
batch: int,
|
||||
channels: int,
|
||||
num_patches: int,
|
||||
patch_size: int,
|
||||
params: ChromaRadianceParams,
|
||||
@ -246,9 +249,9 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
|
||||
# Reshape the tile for per-patch processing
|
||||
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
||||
nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, params.hidden_size)
|
||||
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
|
||||
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
||||
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2)
|
||||
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype)
|
||||
@ -284,7 +287,16 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
params_dict |= overrides
|
||||
return params.__class__(**params_dict)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
timestep: Tensor,
|
||||
context: Tensor,
|
||||
guidance: Optional[Tensor],
|
||||
control: Optional[dict]=None,
|
||||
transformer_options: dict={},
|
||||
**kwargs: dict,
|
||||
) -> Tensor:
|
||||
bs, c, h, w = x.shape
|
||||
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
@ -303,5 +315,15 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
|
||||
img_out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
img_out = self.forward_orig(
|
||||
img,
|
||||
img_ids,
|
||||
context,
|
||||
txt_ids,
|
||||
timestep,
|
||||
guidance,
|
||||
control,
|
||||
transformer_options,
|
||||
attn_mask=kwargs.get("attention_mask", None),
|
||||
)
|
||||
return self.forward_nerf(img, img_out, params)
|
||||
@ -42,7 +42,7 @@ import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma.model_dct
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
@ -1334,7 +1334,7 @@ class Chroma(Flux):
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model_dct.ChromaRadiance)
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
|
||||
|
||||
class ACEStep(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
|
||||
@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight)
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
@ -204,7 +204,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["out_dim"] = 3072
|
||||
dit_config["hidden_dim"] = 5120
|
||||
dit_config["n_layers"] = 5
|
||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Radiance
|
||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
||||
dit_config["image_model"] = "chroma_radiance"
|
||||
dit_config["in_channels"] = 3
|
||||
dit_config["out_channels"] = 3
|
||||
|
||||
@ -29,13 +29,8 @@ class EmptyChromaRadianceLatentImage(io.ComfyNode):
|
||||
|
||||
|
||||
class ChromaRadianceStubVAE:
|
||||
@classmethod
|
||||
def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||
device = comfy.model_management.intermediate_device()
|
||||
if pixels.ndim == 3:
|
||||
pixels = pixels.unsqueeze(0)
|
||||
elif pixels.ndim != 4:
|
||||
raise ValueError("Unexpected input image shape")
|
||||
@staticmethod
|
||||
def vae_encode_crop_pixels(pixels: torch.Tensor) -> torch.Tensor:
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
d_adj = (dims[d] // 16) * 16
|
||||
@ -43,6 +38,17 @@ class ChromaRadianceStubVAE:
|
||||
continue
|
||||
d_offset = (dims[d] % 16) // 2
|
||||
pixels = pixels.narrow(d + 1, d_offset, d_adj)
|
||||
return pixels
|
||||
|
||||
@classmethod
|
||||
def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||
device = comfy.model_management.intermediate_device()
|
||||
if pixels.ndim == 3:
|
||||
pixels = pixels.unsqueeze(0)
|
||||
elif pixels.ndim != 4:
|
||||
raise ValueError("Unexpected input image shape")
|
||||
# Ensure the image has spatial dimensions that are multiples of 16.
|
||||
pixels = cls.vae_encode_crop_pixels(pixels)
|
||||
h, w, c = pixels.shape[1:]
|
||||
if h < 16 or w < 16:
|
||||
raise ValueError("Chroma Radiance image inputs must have height/width of at least 16 pixels.")
|
||||
@ -51,6 +57,7 @@ class ChromaRadianceStubVAE:
|
||||
pixels = pixels.expand(-1, -1, -1, 3)
|
||||
elif c != 3:
|
||||
raise ValueError("Unexpected number of channels in input image")
|
||||
# Rescale to -1..1 and move the channel dimension to position 1.
|
||||
latent = pixels.to(device=device, dtype=torch.float32, copy=True)
|
||||
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
|
||||
latent -= 0.5
|
||||
@ -60,6 +67,7 @@ class ChromaRadianceStubVAE:
|
||||
@classmethod
|
||||
def decode(cls, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||
device = comfy.model_management.intermediate_device()
|
||||
# Rescale to 0..1 and move the channel dimension to the end.
|
||||
img = samples.to(device=device, dtype=torch.float32, copy=True)
|
||||
img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
|
||||
img += 1.0
|
||||
@ -71,6 +79,7 @@ class ChromaRadianceStubVAE:
|
||||
|
||||
@classmethod
|
||||
def spacial_compression_decode(cls) -> int:
|
||||
# This just exists so the tiled VAE nodes don't crash.
|
||||
return 1
|
||||
|
||||
spacial_compression_encode = spacial_compression_decode
|
||||
@ -115,7 +124,7 @@ class ChromaRadianceStubVAENode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ChromaRadianceStubVAE",
|
||||
category="vae/chroma_radiance",
|
||||
description="For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.",
|
||||
description="For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Chroma Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.",
|
||||
outputs=[io.Vae.Output()],
|
||||
)
|
||||
|
||||
@ -129,37 +138,39 @@ class ChromaRadianceOptions(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ChromaRadianceOptions",
|
||||
category="model_patches/chroma_radiance",
|
||||
description="Allows setting some advanced options for the Chroma Radiance model.",
|
||||
description="Allows setting advanced options for the Chroma Radiance model.",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Boolean.Input(
|
||||
id="preserve_wrapper",
|
||||
default=True,
|
||||
tooltip="When enabled preserves an existing model wrapper if it exists. Generally should be left enabled.",
|
||||
tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.",
|
||||
),
|
||||
io.Float.Input(
|
||||
id="start_sigma",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
tooltip="First sigma that these options will be in effect.",
|
||||
),
|
||||
io.Float.Input(
|
||||
id="end_sigma",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
tooltip="Last sigma that these options will be in effect.",
|
||||
),
|
||||
io.Int.Input(
|
||||
id="nerf_tile_size",
|
||||
default=-1,
|
||||
min=-1,
|
||||
tooltip="Allows overriding the default NeRF tile size. -1 means use the default. 0 means use non-tiling mode (may require a lot of VRAM).",
|
||||
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
|
||||
),
|
||||
io.Combo.Input(
|
||||
id="nerf_embedder_dtype",
|
||||
default="default",
|
||||
options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"],
|
||||
tooltip="Allows overriding the dtype the NeRF embedder uses.",
|
||||
tooltip="Allows overriding the dtype the NeRF embedder uses. The default is float32.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user