Radiance: support variant with nonzero txt_ids (#14206)
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

This commit is contained in:
person4268 2026-06-02 01:07:48 -04:00 committed by GitHub
parent e88a81d316
commit c96fcddb81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 0 deletions

View File

@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams):
# None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool
# Use sequential txt_ids instead of zeros
use_sequential_txt_ids: bool
class ChromaRadiance(Chroma):
"""
@ -162,6 +164,9 @@ class ChromaRadiance(Chroma):
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
if params.use_sequential_txt_ids:
self.register_buffer("__sequential__", torch.tensor([]))
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
@ -313,6 +318,9 @@ class ChromaRadiance(Chroma):
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
# Radiance after 2026-05-22 uses sequential txt_ids instead of zeros
if params.use_sequential_txt_ids:
txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1)
img_out = self.forward_orig(
img,

View File

@ -313,6 +313,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids
dit_config["use_sequential_txt_ids"] = True
else:
dit_config["use_sequential_txt_ids"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys

View File

@ -65,6 +65,12 @@ class ChromaRadianceOptions(io.ComfyNode):
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
advanced=True,
),
io.Boolean.Input(
id="force_sequential_txt_ids",
default=False,
tooltip="Force usage of sequential text token IDs instead of zeroes. Should be used for checkpoints from 2026-05-22 to 2026-06-01 that are trained in this way but do not contain the __sequential__ key in the state dict.",
advanced=True,
),
],
outputs=[io.Model.Output()],
)
@ -78,11 +84,15 @@ class ChromaRadianceOptions(io.ComfyNode):
start_sigma: float,
end_sigma: float,
nerf_tile_size: int,
force_sequential_txt_ids: bool,
) -> io.NodeOutput:
radiance_options = {}
if nerf_tile_size >= 0:
radiance_options["nerf_tile_size"] = nerf_tile_size
if force_sequential_txt_ids:
radiance_options["use_sequential_txt_ids"] = True
if not radiance_options:
return io.NodeOutput(model)