mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 12:32:31 +08:00
Allow setting NeRF embedder dtype for Radiance
Bump Radiance nerf tile size to 32 Support EasyCache/LazyCache on Radiance (maybe)
This commit is contained in:
parent
93d9933aaa
commit
ce1c679f22
@ -15,7 +15,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
patch size, and enriches it with positional information before projecting
|
patch size, and enriches it with positional information before projecting
|
||||||
it to a new hidden size.
|
it to a new hidden size.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None):
|
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None, *, embedder_dtype=None):
|
||||||
"""
|
"""
|
||||||
Initializes the NerfEmbedder.
|
Initializes the NerfEmbedder.
|
||||||
|
|
||||||
@ -29,6 +29,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_freqs = max_freqs
|
self.max_freqs = max_freqs
|
||||||
self.hidden_size_input = hidden_size_input
|
self.hidden_size_input = hidden_size_input
|
||||||
|
self.embedder_dtype = embedder_dtype
|
||||||
|
|
||||||
# A linear layer to project the concatenated input features and
|
# A linear layer to project the concatenated input features and
|
||||||
# positional encodings to the final output dimension.
|
# positional encodings to the final output dimension.
|
||||||
@ -37,7 +38,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache(maxsize=4)
|
@lru_cache(maxsize=4)
|
||||||
def fetch_pos(self, patch_size, device, dtype):
|
def fetch_pos(self, patch_size: int, device, dtype) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Generates and caches 2D DCT-like positional embeddings for a given patch size.
|
Generates and caches 2D DCT-like positional embeddings for a given patch size.
|
||||||
|
|
||||||
@ -91,7 +92,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
|
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for the embedder.
|
Forward pass for the embedder.
|
||||||
|
|
||||||
@ -107,26 +108,34 @@ class NerfEmbedder(nn.Module):
|
|||||||
# Infer the patch side length from the number of pixels (P^2).
|
# Infer the patch side length from the number of pixels (P^2).
|
||||||
patch_size = int(P2 ** 0.5)
|
patch_size = int(P2 ** 0.5)
|
||||||
|
|
||||||
|
# Possibly run the operation with a different dtype.
|
||||||
|
orig_dtype = inputs.dtype
|
||||||
|
if self.embedder_dtype is not None and self.embedder_dtype != orig_dtype:
|
||||||
|
embedder = self.embedder.to(dtype=self.embedder_dtype)
|
||||||
|
else:
|
||||||
|
embedder = self.embedder
|
||||||
|
|
||||||
# Fetch the pre-computed or cached positional embeddings.
|
# Fetch the pre-computed or cached positional embeddings.
|
||||||
dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype)
|
dct = self.fetch_pos(patch_size, inputs.device, self.embedder_dtype or inputs.dtype)
|
||||||
|
|
||||||
# Repeat the positional embeddings for each item in the batch.
|
# Repeat the positional embeddings for each item in the batch.
|
||||||
dct = dct.repeat(B, 1, 1)
|
dct = dct.repeat(B, 1, 1)
|
||||||
|
|
||||||
# Concatenate the original input features with the positional embeddings
|
# Concatenate the original input features with the positional embeddings
|
||||||
# along the feature dimension.
|
# along the feature dimension.
|
||||||
inputs = torch.cat([inputs, dct], dim=-1)
|
inputs = torch.cat((inputs, dct), dim=-1)
|
||||||
|
|
||||||
# Project the combined tensor to the target hidden size.
|
# Project the combined tensor to the target hidden size.
|
||||||
inputs = self.embedder(inputs)
|
inputs = embedder(inputs)
|
||||||
|
|
||||||
return inputs
|
# No-op if already the same dtype.
|
||||||
|
return inputs.to(dtype=orig_dtype)
|
||||||
|
|
||||||
class NerfGLUBlock(nn.Module):
|
class NerfGLUBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
|
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
|
||||||
"""
|
"""
|
||||||
def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# The total number of parameters for the MLP is increased to accommodate
|
# The total number of parameters for the MLP is increased to accommodate
|
||||||
# the gate, value, and output projection matrices.
|
# the gate, value, and output projection matrices.
|
||||||
@ -137,7 +146,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, s):
|
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, num_x, hidden_size_x = x.shape
|
batch_size, num_x, hidden_size_x = x.shape
|
||||||
mlp_params = self.param_generator(s)
|
mlp_params = self.param_generator(s)
|
||||||
|
|
||||||
@ -160,8 +169,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
# Apply the final output projection.
|
# Apply the final output projection.
|
||||||
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
|
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
|
||||||
|
|
||||||
x = x + res_x
|
return x + res_x
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class NerfFinalLayer(nn.Module):
|
class NerfFinalLayer(nn.Module):
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
|
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -37,6 +38,7 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
|
|||||||
nerf_max_freqs: int
|
nerf_max_freqs: int
|
||||||
nerf_tile_size: int
|
nerf_tile_size: int
|
||||||
nerf_final_head_type: str
|
nerf_final_head_type: str
|
||||||
|
nerf_embedder_dtype: Optional[torch.dtype]
|
||||||
|
|
||||||
|
|
||||||
class ChromaRadiance(chroma_model.Chroma):
|
class ChromaRadiance(chroma_model.Chroma):
|
||||||
@ -101,7 +103,12 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
|
|
||||||
self.single_blocks = nn.ModuleList(
|
self.single_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
SingleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
for _ in range(params.depth_single_blocks)
|
for _ in range(params.depth_single_blocks)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -114,6 +121,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
|
embedder_dtype=params.nerf_embedder_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.nerf_blocks = nn.ModuleList([
|
self.nerf_blocks = nn.ModuleList([
|
||||||
@ -135,7 +143,6 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
self._nerf_final_layer = self.nerf_final_layer
|
|
||||||
elif params.nerf_final_head_type == "conv":
|
elif params.nerf_final_head_type == "conv":
|
||||||
self.nerf_final_layer_conv = NerfFinalLayerConv(
|
self.nerf_final_layer_conv = NerfFinalLayerConv(
|
||||||
params.nerf_hidden_size,
|
params.nerf_hidden_size,
|
||||||
@ -144,14 +151,23 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
self._nerf_final_layer = self.nerf_final_layer_conv
|
|
||||||
else:
|
else:
|
||||||
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
|
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
|
||||||
|
raise ValueError(errstr)
|
||||||
|
|
||||||
self.skip_mmdit = []
|
self.skip_mmdit = []
|
||||||
self.skip_dit = []
|
self.skip_dit = []
|
||||||
self.lite = False
|
self.lite = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _nerf_final_layer(self) -> nn.Module:
|
||||||
|
if self.params.nerf_final_head_type == "linear":
|
||||||
|
return self.nerf_final_layer
|
||||||
|
if self.params.nerf_final_head_type == "conv":
|
||||||
|
return self.nerf_final_layer_conv
|
||||||
|
# Impossible to get here as we raise an error on unexpected types on initialization.
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward_tiled_nerf(
|
def forward_tiled_nerf(
|
||||||
self,
|
self,
|
||||||
nerf_hidden: Tensor,
|
nerf_hidden: Tensor,
|
||||||
@ -337,7 +353,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
)
|
)
|
||||||
return self._nerf_final_layer(img_dct)
|
return self._nerf_final_layer(img_dct)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
|
|
||||||
|
|||||||
@ -213,8 +213,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_mlp_ratio"] = 4
|
dit_config["nerf_mlp_ratio"] = 4
|
||||||
dit_config["nerf_depth"] = 4
|
dit_config["nerf_depth"] = 4
|
||||||
dit_config["nerf_max_freqs"] = 8
|
dit_config["nerf_max_freqs"] = 8
|
||||||
dit_config["nerf_tile_size"] = 16
|
dit_config["nerf_tile_size"] = 32
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||||
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user