mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 08:52:34 +08:00
Allow setting NeRF embedder dtype for Radiance
Bump Radiance nerf tile size to 32 Support EasyCache/LazyCache on Radiance (maybe)
This commit is contained in:
parent
93d9933aaa
commit
ce1c679f22
@ -15,7 +15,7 @@ class NerfEmbedder(nn.Module):
|
||||
patch size, and enriches it with positional information before projecting
|
||||
it to a new hidden size.
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None):
|
||||
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None, *, embedder_dtype=None):
|
||||
"""
|
||||
Initializes the NerfEmbedder.
|
||||
|
||||
@ -29,6 +29,7 @@ class NerfEmbedder(nn.Module):
|
||||
super().__init__()
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
self.embedder_dtype = embedder_dtype
|
||||
|
||||
# A linear layer to project the concatenated input features and
|
||||
# positional encodings to the final output dimension.
|
||||
@ -37,7 +38,7 @@ class NerfEmbedder(nn.Module):
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_pos(self, patch_size, device, dtype):
|
||||
def fetch_pos(self, patch_size: int, device, dtype) -> torch.Tensor:
|
||||
"""
|
||||
Generates and caches 2D DCT-like positional embeddings for a given patch size.
|
||||
|
||||
@ -91,7 +92,7 @@ class NerfEmbedder(nn.Module):
|
||||
|
||||
return dct
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the embedder.
|
||||
|
||||
@ -107,26 +108,34 @@ class NerfEmbedder(nn.Module):
|
||||
# Infer the patch side length from the number of pixels (P^2).
|
||||
patch_size = int(P2 ** 0.5)
|
||||
|
||||
# Possibly run the operation with a different dtype.
|
||||
orig_dtype = inputs.dtype
|
||||
if self.embedder_dtype is not None and self.embedder_dtype != orig_dtype:
|
||||
embedder = self.embedder.to(dtype=self.embedder_dtype)
|
||||
else:
|
||||
embedder = self.embedder
|
||||
|
||||
# Fetch the pre-computed or cached positional embeddings.
|
||||
dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype)
|
||||
dct = self.fetch_pos(patch_size, inputs.device, self.embedder_dtype or inputs.dtype)
|
||||
|
||||
# Repeat the positional embeddings for each item in the batch.
|
||||
dct = dct.repeat(B, 1, 1)
|
||||
|
||||
# Concatenate the original input features with the positional embeddings
|
||||
# along the feature dimension.
|
||||
inputs = torch.cat([inputs, dct], dim=-1)
|
||||
inputs = torch.cat((inputs, dct), dim=-1)
|
||||
|
||||
# Project the combined tensor to the target hidden size.
|
||||
inputs = self.embedder(inputs)
|
||||
inputs = embedder(inputs)
|
||||
|
||||
return inputs
|
||||
# No-op if already the same dtype.
|
||||
return inputs.to(dtype=orig_dtype)
|
||||
|
||||
class NerfGLUBlock(nn.Module):
|
||||
"""
|
||||
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
|
||||
"""
|
||||
def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
# The total number of parameters for the MLP is increased to accommodate
|
||||
# the gate, value, and output projection matrices.
|
||||
@ -137,7 +146,7 @@ class NerfGLUBlock(nn.Module):
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
def forward(self, x, s):
|
||||
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_x, hidden_size_x = x.shape
|
||||
mlp_params = self.param_generator(s)
|
||||
|
||||
@ -160,8 +169,7 @@ class NerfGLUBlock(nn.Module):
|
||||
# Apply the final output projection.
|
||||
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
|
||||
|
||||
x = x + res_x
|
||||
return x
|
||||
return x + res_x
|
||||
|
||||
|
||||
class NerfFinalLayer(nn.Module):
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@ -37,6 +38,7 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
|
||||
nerf_max_freqs: int
|
||||
nerf_tile_size: int
|
||||
nerf_final_head_type: str
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
|
||||
|
||||
class ChromaRadiance(chroma_model.Chroma):
|
||||
@ -101,7 +103,12 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
@ -114,6 +121,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
embedder_dtype=params.nerf_embedder_dtype,
|
||||
)
|
||||
|
||||
self.nerf_blocks = nn.ModuleList([
|
||||
@ -135,7 +143,6 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self._nerf_final_layer = self.nerf_final_layer
|
||||
elif params.nerf_final_head_type == "conv":
|
||||
self.nerf_final_layer_conv = NerfFinalLayerConv(
|
||||
params.nerf_hidden_size,
|
||||
@ -144,14 +151,23 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self._nerf_final_layer = self.nerf_final_layer_conv
|
||||
else:
|
||||
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
|
||||
raise ValueError(errstr)
|
||||
|
||||
self.skip_mmdit = []
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
return self.nerf_final_layer
|
||||
if self.params.nerf_final_head_type == "conv":
|
||||
return self.nerf_final_layer_conv
|
||||
# Impossible to get here as we raise an error on unexpected types on initialization.
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_tiled_nerf(
|
||||
self,
|
||||
nerf_hidden: Tensor,
|
||||
@ -337,7 +353,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
)
|
||||
return self._nerf_final_layer(img_dct)
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
|
||||
@ -213,8 +213,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_mlp_ratio"] = 4
|
||||
dit_config["nerf_depth"] = 4
|
||||
dit_config["nerf_max_freqs"] = 8
|
||||
dit_config["nerf_tile_size"] = 16
|
||||
dit_config["nerf_tile_size"] = 32
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
Loading…
Reference in New Issue
Block a user