mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-02 20:37:35 +08:00
Radiance: support variant with nonzero txt_ids (#14206)
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This commit is contained in:
parent
e88a81d316
commit
c96fcddb81
@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams):
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
use_x0: bool
|
||||
# Use sequential txt_ids instead of zeros
|
||||
use_sequential_txt_ids: bool
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
@ -162,6 +164,9 @@ class ChromaRadiance(Chroma):
|
||||
if params.use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
if params.use_sequential_txt_ids:
|
||||
self.register_buffer("__sequential__", torch.tensor([]))
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
@ -313,6 +318,9 @@ class ChromaRadiance(Chroma):
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
# Radiance after 2026-05-22 uses sequential txt_ids instead of zeros
|
||||
if params.use_sequential_txt_ids:
|
||||
txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1)
|
||||
|
||||
img_out = self.forward_orig(
|
||||
img,
|
||||
|
||||
@ -313,6 +313,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["use_x0"] = True
|
||||
else:
|
||||
dit_config["use_x0"] = False
|
||||
if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids
|
||||
dit_config["use_sequential_txt_ids"] = True
|
||||
else:
|
||||
dit_config["use_sequential_txt_ids"] = False
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
|
||||
@ -65,6 +65,12 @@ class ChromaRadianceOptions(io.ComfyNode):
|
||||
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
|
||||
advanced=True,
|
||||
),
|
||||
io.Boolean.Input(
|
||||
id="force_sequential_txt_ids",
|
||||
default=False,
|
||||
tooltip="Force usage of sequential text token IDs instead of zeroes. Should be used for checkpoints from 2026-05-22 to 2026-06-01 that are trained in this way but do not contain the __sequential__ key in the state dict.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
@ -78,11 +84,15 @@ class ChromaRadianceOptions(io.ComfyNode):
|
||||
start_sigma: float,
|
||||
end_sigma: float,
|
||||
nerf_tile_size: int,
|
||||
force_sequential_txt_ids: bool,
|
||||
) -> io.NodeOutput:
|
||||
radiance_options = {}
|
||||
if nerf_tile_size >= 0:
|
||||
radiance_options["nerf_tile_size"] = nerf_tile_size
|
||||
|
||||
if force_sequential_txt_ids:
|
||||
radiance_options["use_sequential_txt_ids"] = True
|
||||
|
||||
if not radiance_options:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user