diff --git a/comfy_extras/nodes_pid.py b/comfy_extras/nodes_pid.py index 811b9ae8e..71855254e 100644 --- a/comfy_extras/nodes_pid.py +++ b/comfy_extras/nodes_pid.py @@ -21,8 +21,8 @@ class PiDConditioning(io.ComfyNode): inputs=[ io.Conditioning.Input("positive"), io.Latent.Input("latent", tooltip="latent (from VAEEncode or a KSampler)."), - io.Combo.Input("latent_format", options=["flux", "sd3"], default="flux", - tooltip="Flux1 and Flux2 latents auto-detected from channel dim, sd3 has to be selected manually."), + io.Combo.Input("latent_format", options=["flux", "sd3", "sdxl", "qwenimage"], default="flux", + tooltip="Flux1 (16-ch) and Flux2 (128-ch) latents are auto-detected from channel dim under 'flux'. For SD3 (16-ch), SDXL (4-ch), or QwenImage (16-ch), select manually."), io.Float.Input( "degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="0 = clean latent. Increase to denoise corrupted latent outputs.", @@ -36,9 +36,17 @@ class PiDConditioning(io.ComfyNode): samples = latent["samples"] if latent_format == "flux": fmt_cls = comfy.latent_formats.Flux2 if samples.shape[1] == 128 else comfy.latent_formats.Flux - else: + elif latent_format == "sd3": fmt_cls = comfy.latent_formats.SD3 + elif latent_format == "sdxl": + fmt_cls = comfy.latent_formats.SDXL + elif latent_format == "qwenimage": + fmt_cls = comfy.latent_formats.Wan21 + else: + raise ValueError(f"Unknown latent_format: {latent_format}") lq_latent = fmt_cls().process_in(samples) + if lq_latent.ndim == 5: + lq_latent = lq_latent[:, :, 0] sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32) return io.NodeOutput(node_helpers.conditioning_set_values( positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t},