From 700ad0c4dcb437ab45abe71d8cdbd2ea9ca5ab81 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 26 May 2026 10:34:16 +0300 Subject: [PATCH] Cleanup degrade_sigma passthrough --- comfy/ldm/pixeldit/pid.py | 2 +- comfy/model_base.py | 2 -- comfy_extras/nodes_pid.py | 4 +++- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/pixeldit/pid.py b/comfy/ldm/pixeldit/pid.py index 7283d9788..be5ef7a0e 100644 --- a/comfy/ldm/pixeldit/pid.py +++ b/comfy/ldm/pixeldit/pid.py @@ -240,7 +240,7 @@ class PidNet(PixDiT_T2I): Hs = x.shape[2] // self.patch_size Ws = x.shape[3] // self.patch_size - degrade_sigma = torch.as_tensor(degrade_sigma if degrade_sigma is not None else 0.0, device=x.device, dtype=torch.float32).reshape(-1) + degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1) if degrade_sigma.numel() == 1 and B > 1: degrade_sigma = degrade_sigma.expand(B).contiguous() diff --git a/comfy/model_base.py b/comfy/model_base.py index 1f4d6ebc1..db476c1f2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1428,8 +1428,6 @@ class PiD(BaseModel): out["lq_latent"] = comfy.conds.CONDRegular(lq_latent) degrade_sigma = kwargs.get("degrade_sigma", None) if degrade_sigma is not None: - if not isinstance(degrade_sigma, torch.Tensor): - degrade_sigma = torch.tensor([float(degrade_sigma)], dtype=torch.float32) out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma) return out diff --git a/comfy_extras/nodes_pid.py b/comfy_extras/nodes_pid.py index 3fecd4bfa..eaeb3531b 100644 --- a/comfy_extras/nodes_pid.py +++ b/comfy_extras/nodes_pid.py @@ -1,5 +1,6 @@ """PiD (Pixel Diffusion Decoder) node""" +import torch from typing_extensions import override import node_helpers @@ -46,8 +47,9 @@ class PiDConditioning(io.ComfyNode): @classmethod def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput: lq_latent = _LATENT_FORMAT_CLASSES[latent_format]().process_in(latent["samples"]) + sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32) return io.NodeOutput(node_helpers.conditioning_set_values( - positive, {"lq_latent": lq_latent, "degrade_sigma": float(degrade_sigma)}, + positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t}, ))