mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-04 05:17:50 +08:00
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
64 lines
2.5 KiB
Python
64 lines
2.5 KiB
Python
"""PiD (Pixel Diffusion Decoder) node"""
|
|
|
|
import torch
|
|
from typing_extensions import override
|
|
|
|
import node_helpers
|
|
import comfy.latent_formats
|
|
from comfy_api.latest import ComfyExtension, io
|
|
|
|
|
|
class PiDConditioning(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="PiDConditioning",
|
|
display_name="PiD Conditioning",
|
|
category="advanced/conditioning",
|
|
description=(
|
|
"Attaches a latent and a degrade_sigma scalar to a CONDITIONING for PiD decoding/upscaling"
|
|
),
|
|
inputs=[
|
|
io.Conditioning.Input("positive"),
|
|
io.Latent.Input("latent", tooltip="latent (from VAEEncode or a KSampler)."),
|
|
io.Combo.Input("latent_format", options=["flux", "sd3", "sdxl", "qwenimage"], default="flux",
|
|
tooltip="Flux1 (16-ch) and Flux2 (128-ch) latents are auto-detected from channel dim under 'flux'. For SD3 (16-ch), SDXL (4-ch), or QwenImage (16-ch), select manually."),
|
|
io.Float.Input(
|
|
"degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01,
|
|
tooltip="0 = clean latent. Increase to denoise corrupted latent outputs.",
|
|
),
|
|
],
|
|
outputs=[io.Conditioning.Output()],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput:
|
|
samples = latent["samples"]
|
|
if latent_format == "flux":
|
|
fmt_cls = comfy.latent_formats.Flux2 if samples.shape[1] == 128 else comfy.latent_formats.Flux
|
|
elif latent_format == "sd3":
|
|
fmt_cls = comfy.latent_formats.SD3
|
|
elif latent_format == "sdxl":
|
|
fmt_cls = comfy.latent_formats.SDXL
|
|
elif latent_format == "qwenimage":
|
|
fmt_cls = comfy.latent_formats.Wan21
|
|
else:
|
|
raise ValueError(f"Unknown latent_format: {latent_format}")
|
|
lq_latent = fmt_cls().process_in(samples)
|
|
if lq_latent.ndim == 5:
|
|
lq_latent = lq_latent[:, :, 0]
|
|
sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32)
|
|
return io.NodeOutput(node_helpers.conditioning_set_values(
|
|
positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t},
|
|
))
|
|
|
|
|
|
class PiDExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [PiDConditioning]
|
|
|
|
|
|
async def comfy_entrypoint() -> PiDExtension:
|
|
return PiDExtension()
|