Cleanup degrade_sigma passthrough

This commit is contained in:
kijai 2026-05-26 10:34:16 +03:00
parent 28c44cb2d7
commit 700ad0c4dc
3 changed files with 4 additions and 4 deletions

View File

@ -240,7 +240,7 @@ class PidNet(PixDiT_T2I):
Hs = x.shape[2] // self.patch_size Hs = x.shape[2] // self.patch_size
Ws = x.shape[3] // self.patch_size Ws = x.shape[3] // self.patch_size
degrade_sigma = torch.as_tensor(degrade_sigma if degrade_sigma is not None else 0.0, device=x.device, dtype=torch.float32).reshape(-1) degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
if degrade_sigma.numel() == 1 and B > 1: if degrade_sigma.numel() == 1 and B > 1:
degrade_sigma = degrade_sigma.expand(B).contiguous() degrade_sigma = degrade_sigma.expand(B).contiguous()

View File

@ -1428,8 +1428,6 @@ class PiD(BaseModel):
out["lq_latent"] = comfy.conds.CONDRegular(lq_latent) out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
degrade_sigma = kwargs.get("degrade_sigma", None) degrade_sigma = kwargs.get("degrade_sigma", None)
if degrade_sigma is not None: if degrade_sigma is not None:
if not isinstance(degrade_sigma, torch.Tensor):
degrade_sigma = torch.tensor([float(degrade_sigma)], dtype=torch.float32)
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma) out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
return out return out

View File

@ -1,5 +1,6 @@
"""PiD (Pixel Diffusion Decoder) node""" """PiD (Pixel Diffusion Decoder) node"""
import torch
from typing_extensions import override from typing_extensions import override
import node_helpers import node_helpers
@ -46,8 +47,9 @@ class PiDConditioning(io.ComfyNode):
@classmethod @classmethod
def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput: def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput:
lq_latent = _LATENT_FORMAT_CLASSES[latent_format]().process_in(latent["samples"]) lq_latent = _LATENT_FORMAT_CLASSES[latent_format]().process_in(latent["samples"])
sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32)
return io.NodeOutput(node_helpers.conditioning_set_values( return io.NodeOutput(node_helpers.conditioning_set_values(
positive, {"lq_latent": lq_latent, "degrade_sigma": float(degrade_sigma)}, positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t},
)) ))