mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 13:17:45 +08:00
Add ChromaRadianceOptions node and backend support.
Cleanups/refactoring to reduce code duplication with Chroma.
This commit is contained in:
parent
a3e5850c8b
commit
e7073b5eec
@ -151,8 +151,6 @@ class Chroma(nn.Module):
|
|||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
@ -254,8 +252,9 @@ class Chroma(nn.Module):
|
|||||||
img[:, txt.shape[1] :, ...] += add
|
img[:, txt.shape[1] :, ...] += add
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
final_mod = self.get_modulations(mod_vectors, "final")
|
if hasattr(self, "final_layer"):
|
||||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
final_mod = self.get_modulations(mod_vectors, "final")
|
||||||
|
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||||
@ -271,6 +270,9 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
|
||||||
|
|
||||||
|
if img.ndim != 3 or context.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
||||||
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
|||||||
@ -36,8 +36,11 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
|
|||||||
nerf_mlp_ratio: int
|
nerf_mlp_ratio: int
|
||||||
nerf_depth: int
|
nerf_depth: int
|
||||||
nerf_max_freqs: int
|
nerf_max_freqs: int
|
||||||
|
# nerf_tile_size of 0 means unlimited.
|
||||||
nerf_tile_size: int
|
nerf_tile_size: int
|
||||||
|
# Currently one of linear (legacy) or conv.
|
||||||
nerf_final_head_type: str
|
nerf_final_head_type: str
|
||||||
|
# None means use the same dtype as the model.
|
||||||
nerf_embedder_dtype: Optional[torch.dtype]
|
nerf_embedder_dtype: Optional[torch.dtype]
|
||||||
|
|
||||||
|
|
||||||
@ -168,6 +171,53 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
# Impossible to get here as we raise an error on unexpected types on initialization.
|
# Impossible to get here as we raise an error on unexpected types on initialization.
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def img_in(self, img: Tensor) -> Tensor:
|
||||||
|
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
|
||||||
|
# flatten into a sequence for the transformer.
|
||||||
|
return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
|
||||||
|
|
||||||
|
def forward_nerf(
|
||||||
|
self,
|
||||||
|
img_orig: Tensor,
|
||||||
|
img_out: Tensor,
|
||||||
|
params: ChromaRadianceParams,
|
||||||
|
) -> Tensor:
|
||||||
|
B, C, H, W = img_orig.shape
|
||||||
|
num_patches = img_out.shape[1]
|
||||||
|
patch_size = params.patch_size
|
||||||
|
|
||||||
|
# Store the raw pixel values of each patch for the NeRF head later.
|
||||||
|
# unfold creates patches: [B, C * P * P, NumPatches]
|
||||||
|
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||||
|
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||||
|
|
||||||
|
if params.nerf_tile_size > 0:
|
||||||
|
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||||
|
else:
|
||||||
|
# Reshape for per-patch processing
|
||||||
|
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||||
|
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||||
|
|
||||||
|
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||||
|
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||||
|
|
||||||
|
# Pass through the dynamic MLP blocks (the NeRF)
|
||||||
|
for block in self.nerf_blocks:
|
||||||
|
img_dct = block(img_dct, nerf_hidden)
|
||||||
|
|
||||||
|
# Reassemble the patches into the final image.
|
||||||
|
img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
|
||||||
|
# Reshape to combine with batch dimension for fold
|
||||||
|
img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
|
||||||
|
img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
|
||||||
|
img_dct = nn.functional.fold(
|
||||||
|
img_dct,
|
||||||
|
output_size=(H, W),
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
)
|
||||||
|
return self._nerf_final_layer(img_dct)
|
||||||
|
|
||||||
def forward_tiled_nerf(
|
def forward_tiled_nerf(
|
||||||
self,
|
self,
|
||||||
nerf_hidden: Tensor,
|
nerf_hidden: Tensor,
|
||||||
@ -175,17 +225,18 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
B: int,
|
B: int,
|
||||||
C: int,
|
C: int,
|
||||||
num_patches: int,
|
num_patches: int,
|
||||||
tile_size: int = 16
|
patch_size: int,
|
||||||
|
params: ChromaRadianceParams,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Processes the NeRF head in tiles to save memory.
|
Processes the NeRF head in tiles to save memory.
|
||||||
nerf_hidden has shape [B, L, D]
|
nerf_hidden has shape [B, L, D]
|
||||||
nerf_pixels has shape [B, L, C * P * P]
|
nerf_pixels has shape [B, L, C * P * P]
|
||||||
"""
|
"""
|
||||||
|
tile_size = params.nerf_tile_size
|
||||||
output_tiles = []
|
output_tiles = []
|
||||||
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||||
for i in range(0, num_patches, tile_size):
|
for i in range(0, num_patches, tile_size):
|
||||||
#
|
|
||||||
end = min(i + tile_size, num_patches)
|
end = min(i + tile_size, num_patches)
|
||||||
|
|
||||||
# Slice the current tile from the input tensors
|
# Slice the current tile from the input tensors
|
||||||
@ -197,9 +248,9 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
|
|
||||||
# Reshape the tile for per-patch processing
|
# Reshape the tile for per-patch processing
|
||||||
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
||||||
nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, self.params.hidden_size)
|
nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, params.hidden_size)
|
||||||
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
||||||
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, self.params.patch_size**2).transpose(1, 2)
|
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2)
|
||||||
|
|
||||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||||
@ -213,150 +264,39 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
# Concatenate the processed tiles along the patch dimension
|
# Concatenate the processed tiles along the patch dimension
|
||||||
return torch.cat(output_tiles, dim=0)
|
return torch.cat(output_tiles, dim=0)
|
||||||
|
|
||||||
def forward_orig(
|
def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
|
||||||
self,
|
params = self.params
|
||||||
img: Tensor,
|
if not overrides:
|
||||||
img_ids: Tensor,
|
return params
|
||||||
txt: Tensor,
|
params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__}
|
||||||
txt_ids: Tensor,
|
nullable_keys = frozenset(("nerf_embedder_dtype",))
|
||||||
timesteps: Tensor,
|
bad_keys = tuple(k for k in overrides if k not in params_dict)
|
||||||
guidance: Tensor = None,
|
if bad_keys:
|
||||||
control = None,
|
e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||||
transformer_options={},
|
raise ValueError(e)
|
||||||
attn_mask: Tensor = None,
|
bad_keys = tuple(
|
||||||
) -> Tensor:
|
k
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
for k, v in overrides.items()
|
||||||
if img.ndim != 4:
|
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
||||||
raise ValueError("Input img tensor must be in [B, C, H, W] format.")
|
|
||||||
if txt.ndim != 3:
|
|
||||||
raise ValueError("Input txt tensors must have 3 dimensions.")
|
|
||||||
B, C, H, W = img.shape
|
|
||||||
|
|
||||||
# gemini gogogo idk how to unfold and pack the patch properly :P
|
|
||||||
# Store the raw pixel values of each patch for the NeRF head later.
|
|
||||||
# unfold creates patches: [B, C * P * P, NumPatches]
|
|
||||||
nerf_pixels = nn.functional.unfold(img, kernel_size=self.params.patch_size, stride=self.params.patch_size)
|
|
||||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
|
||||||
|
|
||||||
# partchify ops
|
|
||||||
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
|
|
||||||
num_patches = img.shape[2] * img.shape[3]
|
|
||||||
# flatten into a sequence for the transformer.
|
|
||||||
img = img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
|
|
||||||
|
|
||||||
# distilled vector guidance
|
|
||||||
mod_index_length = 344
|
|
||||||
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
|
|
||||||
# guidance = guidance *
|
|
||||||
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
|
|
||||||
|
|
||||||
# get all modulation index
|
|
||||||
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
|
|
||||||
# we need to broadcast the modulation index here so each batch has all of the index
|
|
||||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
|
|
||||||
# and we need to broadcast timestep and guidance along too
|
|
||||||
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
|
|
||||||
# then and only then we could concatenate it together
|
|
||||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
|
|
||||||
|
|
||||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
|
||||||
|
|
||||||
txt = self.txt_in(txt)
|
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
||||||
pe = self.pe_embedder(ids)
|
|
||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
|
||||||
for i, block in enumerate(self.double_blocks):
|
|
||||||
if i not in self.skip_mmdit:
|
|
||||||
double_mod = (
|
|
||||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
|
||||||
self.get_modulations(mod_vectors, "double_txt", idx=i),
|
|
||||||
)
|
|
||||||
if ("double_block", i) in blocks_replace:
|
|
||||||
def block_wrap(args):
|
|
||||||
out = {}
|
|
||||||
out["img"], out["txt"] = block(img=args["img"],
|
|
||||||
txt=args["txt"],
|
|
||||||
vec=args["vec"],
|
|
||||||
pe=args["pe"],
|
|
||||||
attn_mask=args.get("attn_mask"))
|
|
||||||
return out
|
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img,
|
|
||||||
"txt": txt,
|
|
||||||
"vec": double_mod,
|
|
||||||
"pe": pe,
|
|
||||||
"attn_mask": attn_mask},
|
|
||||||
{"original_block": block_wrap})
|
|
||||||
txt = out["txt"]
|
|
||||||
img = out["img"]
|
|
||||||
else:
|
|
||||||
img, txt = block(img=img,
|
|
||||||
txt=txt,
|
|
||||||
vec=double_mod,
|
|
||||||
pe=pe,
|
|
||||||
attn_mask=attn_mask)
|
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
|
||||||
control_i = control.get("input")
|
|
||||||
if i < len(control_i):
|
|
||||||
add = control_i[i]
|
|
||||||
if add is not None:
|
|
||||||
img += add
|
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
|
||||||
|
|
||||||
for i, block in enumerate(self.single_blocks):
|
|
||||||
if i not in self.skip_dit:
|
|
||||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
|
||||||
if ("single_block", i) in blocks_replace:
|
|
||||||
def block_wrap(args):
|
|
||||||
out = {}
|
|
||||||
out["img"] = block(args["img"],
|
|
||||||
vec=args["vec"],
|
|
||||||
pe=args["pe"],
|
|
||||||
attn_mask=args.get("attn_mask"))
|
|
||||||
return out
|
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img,
|
|
||||||
"vec": single_mod,
|
|
||||||
"pe": pe,
|
|
||||||
"attn_mask": attn_mask},
|
|
||||||
{"original_block": block_wrap})
|
|
||||||
img = out["img"]
|
|
||||||
else:
|
|
||||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
|
||||||
control_o = control.get("output")
|
|
||||||
if i < len(control_o):
|
|
||||||
add = control_o[i]
|
|
||||||
if add is not None:
|
|
||||||
img[:, txt.shape[1] :, ...] += add
|
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
|
||||||
|
|
||||||
img_dct = self.forward_tiled_nerf(img, nerf_pixels, B, C, num_patches, tile_size=self.params.nerf_tile_size)
|
|
||||||
|
|
||||||
# gemini gogogo idk how to fold this properly :P
|
|
||||||
# Reassemble the patches into the final image.
|
|
||||||
img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
|
|
||||||
# Reshape to combine with batch dimension for fold
|
|
||||||
img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
|
|
||||||
img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
|
|
||||||
img_dct = nn.functional.fold(
|
|
||||||
img_dct,
|
|
||||||
output_size=(H, W),
|
|
||||||
kernel_size=self.params.patch_size,
|
|
||||||
stride=self.params.patch_size
|
|
||||||
)
|
)
|
||||||
return self._nerf_final_layer(img_dct)
|
if bad_keys:
|
||||||
|
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||||
|
raise ValueError(e)
|
||||||
|
# At this point it's all valid keys and values so we can merge with the existing params.
|
||||||
|
params_dict |= overrides
|
||||||
|
return params.__class__(**params_dict)
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
|
|
||||||
|
if img.ndim != 4:
|
||||||
|
raise ValueError("Input img tensor must be in [B, C, H, W] format.")
|
||||||
|
if context.ndim != 3:
|
||||||
|
raise ValueError("Input txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))
|
||||||
|
|
||||||
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
||||||
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
@ -365,4 +305,5 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
return self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
img_out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
|
return self.forward_nerf(img, img_out, params)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -122,6 +123,85 @@ class ChromaRadianceStubVAENode(io.ComfyNode):
|
|||||||
def execute(cls) -> io.NodeOutput:
|
def execute(cls) -> io.NodeOutput:
|
||||||
return io.NodeOutput(ChromaRadianceStubVAE())
|
return io.NodeOutput(ChromaRadianceStubVAE())
|
||||||
|
|
||||||
|
class ChromaRadianceOptions(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ChromaRadianceOptions",
|
||||||
|
category="model_patches/chroma_radiance",
|
||||||
|
description="Allows setting some advanced options for the Chroma Radiance model.",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input(id="model"),
|
||||||
|
io.Boolean.Input(
|
||||||
|
id="preserve_wrapper",
|
||||||
|
default=True,
|
||||||
|
tooltip="When enabled preserves an existing model wrapper if it exists. Generally should be left enabled.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
id="start_sigma",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
id="end_sigma",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
id="nerf_tile_size",
|
||||||
|
default=-1,
|
||||||
|
min=-1,
|
||||||
|
tooltip="Allows overriding the default NeRF tile size. -1 means use the default. 0 means use non-tiling mode (may require a lot of VRAM).",
|
||||||
|
),
|
||||||
|
io.Combo.Input(
|
||||||
|
id="nerf_embedder_dtype",
|
||||||
|
default="default",
|
||||||
|
options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"],
|
||||||
|
tooltip="Allows overriding the dtype the NeRF embedder uses.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
model: io.Model.Type,
|
||||||
|
preserve_wrapper: bool,
|
||||||
|
start_sigma: float,
|
||||||
|
end_sigma: float,
|
||||||
|
nerf_tile_size: int,
|
||||||
|
nerf_embedder_dtype: str,
|
||||||
|
) -> io.NodeOutput:
|
||||||
|
radiance_options = {}
|
||||||
|
if nerf_tile_size >= 0:
|
||||||
|
radiance_options["nerf_tile_size"] = nerf_tile_size
|
||||||
|
if nerf_embedder_dtype != "default":
|
||||||
|
radiance_options["nerf_embedder_dtype"] = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, "float64": torch.float64}.get(nerf_embedder_dtype)
|
||||||
|
|
||||||
|
if not radiance_options:
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
old_wrapper = model.model_options.get("model_function_wrapper")
|
||||||
|
|
||||||
|
def model_function_wrapper(apply_model: Callable, args: dict) -> torch.Tensor:
|
||||||
|
c = args["c"].copy()
|
||||||
|
sigma = args["timestep"].max().detach().cpu().item()
|
||||||
|
if end_sigma <= sigma <= start_sigma:
|
||||||
|
transformer_options = c.get("transformer_options", {}).copy()
|
||||||
|
transformer_options["chroma_radiance_options"] = radiance_options.copy()
|
||||||
|
c["transformer_options"] = transformer_options
|
||||||
|
if not (preserve_wrapper and old_wrapper):
|
||||||
|
return apply_model(args["input"], args["timestep"], **c)
|
||||||
|
return old_wrapper(apply_model, args | {"c": c})
|
||||||
|
|
||||||
|
model = model.clone()
|
||||||
|
model.set_model_unet_function_wrapper(model_function_wrapper)
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
class ChromaRadianceExtension(ComfyExtension):
|
class ChromaRadianceExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
@ -131,6 +211,7 @@ class ChromaRadianceExtension(ComfyExtension):
|
|||||||
ChromaRadianceLatentToImage,
|
ChromaRadianceLatentToImage,
|
||||||
ChromaRadianceImageToLatent,
|
ChromaRadianceImageToLatent,
|
||||||
ChromaRadianceStubVAENode,
|
ChromaRadianceStubVAENode,
|
||||||
|
ChromaRadianceOptions,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user