Radiance: support variant with nonzero txt_ids

This commit is contained in:
person4268 2026-06-01 01:51:22 -04:00
parent 70a2e1a851
commit 3b921372de
3 changed files with 22 additions and 0 deletions

View File

@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams):
# None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool
# Use sequential txt_ids instead of zeros
use_sequential_txt_ids: bool
class ChromaRadiance(Chroma):
"""
@ -162,6 +164,9 @@ class ChromaRadiance(Chroma):
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
if params.use_sequential_txt_ids:
self.register_buffer("__sequential__", torch.tensor([]))
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
@ -313,6 +318,9 @@ class ChromaRadiance(Chroma):
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
# Radiance after 2026-05-22 uses sequential txt_ids instead of zeros
if params.use_sequential_txt_ids:
txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1)
img_out = self.forward_orig(
img,

View File

@ -313,6 +313,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids
dit_config["use_sequential_txt_ids"] = True
else:
dit_config["use_sequential_txt_ids"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys

View File

@ -65,6 +65,12 @@ class ChromaRadianceOptions(io.ComfyNode):
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
advanced=True,
),
io.Boolean.Input(
id="force_sequential_txt_ids",
default=False,
tooltip="Force usage of sequential text token IDs instead of zeroes. Should be used for checkpoints from 2026-05-22 to 2026-06-01 that are trained in this way but do not contain the __sequential__ key in the state dict.",
advanced=True,
),
],
outputs=[io.Model.Output()],
)
@ -78,11 +84,15 @@ class ChromaRadianceOptions(io.ComfyNode):
start_sigma: float,
end_sigma: float,
nerf_tile_size: int,
force_sequential_txt_ids: bool,
) -> io.NodeOutput:
radiance_options = {}
if nerf_tile_size >= 0:
radiance_options["nerf_tile_size"] = nerf_tile_size
if force_sequential_txt_ids:
radiance_options["use_sequential_txt_ids"] = True
if not radiance_options:
return io.NodeOutput(model)