From e8eab4dbc6487a14a0c548a806b0b772e274e3df Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 6 Sep 2024 11:04:32 -0700 Subject: [PATCH] Fix tensor types --- comfy/component_model/tensor_types.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/comfy/component_model/tensor_types.py b/comfy/component_model/tensor_types.py index ed88a6750..a6e533875 100644 --- a/comfy/component_model/tensor_types.py +++ b/comfy/component_model/tensor_types.py @@ -1,8 +1,19 @@ +from typing import TypedDict + from jaxtyping import Float from torch import Tensor +from typing_extensions import NotRequired ImageBatch = Float[Tensor, "batch height width channels"] +LatentBatch = Float[Tensor, "batch channels height width"] +SD1LatentBatch = Float[Tensor, "batch 8 height width"] +SD3LatentBatch = Float[Tensor, "batch 16 height width"] MaskBatch = Float[Tensor, "batch height width"] RGBImageBatch = Float[Tensor, "batch height width 3"] RGBAImageBatch = Float[Tensor, "batch height width 4"] RGBImage = Float[Tensor, "height width 3"] + + +class Latent(TypedDict): + samples: LatentBatch + noise_mask: NotRequired[LatentBatch]