Minor Chroma Radiance cleanups

This commit is contained in:
blepping 2025-09-02 08:19:30 -06:00
parent 50f3b65a48
commit a46078afe7
2 changed files with 5 additions and 4 deletions

View File

@ -6,6 +6,7 @@ from torch import nn
from comfy.ldm.flux.layers import RMSNorm from comfy.ldm.flux.layers import RMSNorm
class NerfEmbedder(nn.Module): class NerfEmbedder(nn.Module):
""" """
An embedder module that combines input features with a 2D positional An embedder module that combines input features with a 2D positional
@ -130,6 +131,7 @@ class NerfEmbedder(nn.Module):
# No-op if already the same dtype. # No-op if already the same dtype.
return inputs.to(dtype=orig_dtype) return inputs.to(dtype=orig_dtype)
class NerfGLUBlock(nn.Module): class NerfGLUBlock(nn.Module):
""" """
A NerfBlock using a Gated Linear Unit (GLU) like MLP. A NerfBlock using a Gated Linear Unit (GLU) like MLP.
@ -182,6 +184,7 @@ class NerfFinalLayer(nn.Module):
# So we temporarily move the channel dimension to the end for the norm operation. # So we temporarily move the channel dimension to the end for the norm operation.
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1) return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
class NerfFinalLayerConv(nn.Module): class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__() super().__init__()

View File

@ -10,10 +10,7 @@ from torch import Tensor, nn
from einops import repeat from einops import repeat
import comfy.ldm.common_dit import comfy.ldm.common_dit
from comfy.ldm.flux.layers import ( from comfy.ldm.flux.layers import EmbedND
EmbedND,
timestep_embedding,
)
from .layers import ( from .layers import (
DoubleStreamBlock, DoubleStreamBlock,
@ -29,6 +26,7 @@ from .layers_dct import (
from . import model as chroma_model from . import model as chroma_model
@dataclass @dataclass
class ChromaRadianceParams(chroma_model.ChromaParams): class ChromaRadianceParams(chroma_model.ChromaParams):
patch_size: int patch_size: int