diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py index 3888dc481..4773fff74 100644 --- a/comfy_extras/nodes_chroma_radiance.py +++ b/comfy_extras/nodes_chroma_radiance.py @@ -10,8 +10,8 @@ class EmptyChromaRadianceLatentImage: @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 1024, "min": 2, "max": nodes.MAX_RESOLUTION}), - "height": ("INT", {"default": 1024, "min": 2, "max": nodes.MAX_RESOLUTION}), + return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} RETURN_TYPES = ("LATENT",) FUNCTION = "go" @@ -57,6 +57,19 @@ class ChromaRadianceImageToLatent: CATEGORY = "latent/chroma_radiance" def go(self, *, image): + if image.ndim == 3: + image = image.unsqueeze(0) + elif image.ndim != 4: + raise ValueError("Unexpected input image shape") + h, w, c = image.shape[1:] + if h < 16 or w < 16 or not (h / 16).is_integer() or not (w / 16).is_integer(): + raise ValueError("Chroma Radiance image inputs must have sizes that are multiples of 16.") + if c > 3: + image = image[..., :3] + elif c == 1: + image = image.expand(-1, -1, -1, 3) + elif c != 3: + raise ValueError("Unexpected number of channels in input image") latent = image.to(device=self.device, dtype=torch.float32, copy=True) latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous() latent -= 0.5