mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-03 12:57:25 +08:00
Radiance: support variant with nonzero txt_ids (#14206)
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This commit is contained in:
parent
e88a81d316
commit
c96fcddb81
@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams):
|
|||||||
# None means use the same dtype as the model.
|
# None means use the same dtype as the model.
|
||||||
nerf_embedder_dtype: Optional[torch.dtype]
|
nerf_embedder_dtype: Optional[torch.dtype]
|
||||||
use_x0: bool
|
use_x0: bool
|
||||||
|
# Use sequential txt_ids instead of zeros
|
||||||
|
use_sequential_txt_ids: bool
|
||||||
|
|
||||||
class ChromaRadiance(Chroma):
|
class ChromaRadiance(Chroma):
|
||||||
"""
|
"""
|
||||||
@ -162,6 +164,9 @@ class ChromaRadiance(Chroma):
|
|||||||
if params.use_x0:
|
if params.use_x0:
|
||||||
self.register_buffer("__x0__", torch.tensor([]))
|
self.register_buffer("__x0__", torch.tensor([]))
|
||||||
|
|
||||||
|
if params.use_sequential_txt_ids:
|
||||||
|
self.register_buffer("__sequential__", torch.tensor([]))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _nerf_final_layer(self) -> nn.Module:
|
def _nerf_final_layer(self) -> nn.Module:
|
||||||
if self.params.nerf_final_head_type == "linear":
|
if self.params.nerf_final_head_type == "linear":
|
||||||
@ -313,6 +318,9 @@ class ChromaRadiance(Chroma):
|
|||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
# Radiance after 2026-05-22 uses sequential txt_ids instead of zeros
|
||||||
|
if params.use_sequential_txt_ids:
|
||||||
|
txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1)
|
||||||
|
|
||||||
img_out = self.forward_orig(
|
img_out = self.forward_orig(
|
||||||
img,
|
img,
|
||||||
|
|||||||
@ -313,6 +313,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["use_x0"] = True
|
dit_config["use_x0"] = True
|
||||||
else:
|
else:
|
||||||
dit_config["use_x0"] = False
|
dit_config["use_x0"] = False
|
||||||
|
if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids
|
||||||
|
dit_config["use_sequential_txt_ids"] = True
|
||||||
|
else:
|
||||||
|
dit_config["use_sequential_txt_ids"] = False
|
||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
|
|||||||
@ -65,6 +65,12 @@ class ChromaRadianceOptions(io.ComfyNode):
|
|||||||
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
|
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
|
||||||
advanced=True,
|
advanced=True,
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
id="force_sequential_txt_ids",
|
||||||
|
default=False,
|
||||||
|
tooltip="Force usage of sequential text token IDs instead of zeroes. Should be used for checkpoints from 2026-05-22 to 2026-06-01 that are trained in this way but do not contain the __sequential__ key in the state dict.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[io.Model.Output()],
|
outputs=[io.Model.Output()],
|
||||||
)
|
)
|
||||||
@ -78,11 +84,15 @@ class ChromaRadianceOptions(io.ComfyNode):
|
|||||||
start_sigma: float,
|
start_sigma: float,
|
||||||
end_sigma: float,
|
end_sigma: float,
|
||||||
nerf_tile_size: int,
|
nerf_tile_size: int,
|
||||||
|
force_sequential_txt_ids: bool,
|
||||||
) -> io.NodeOutput:
|
) -> io.NodeOutput:
|
||||||
radiance_options = {}
|
radiance_options = {}
|
||||||
if nerf_tile_size >= 0:
|
if nerf_tile_size >= 0:
|
||||||
radiance_options["nerf_tile_size"] = nerf_tile_size
|
radiance_options["nerf_tile_size"] = nerf_tile_size
|
||||||
|
|
||||||
|
if force_sequential_txt_ids:
|
||||||
|
radiance_options["use_sequential_txt_ids"] = True
|
||||||
|
|
||||||
if not radiance_options:
|
if not radiance_options:
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user