diff --git a/comfy/component_model/tensor_types.py b/comfy/component_model/tensor_types.py index ed88a6750..a6e533875 100644 --- a/comfy/component_model/tensor_types.py +++ b/comfy/component_model/tensor_types.py @@ -1,8 +1,19 @@ +from typing import TypedDict + from jaxtyping import Float from torch import Tensor +from typing_extensions import NotRequired ImageBatch = Float[Tensor, "batch height width channels"] +LatentBatch = Float[Tensor, "batch channels height width"] +SD1LatentBatch = Float[Tensor, "batch 8 height width"] +SD3LatentBatch = Float[Tensor, "batch 16 height width"] MaskBatch = Float[Tensor, "batch height width"] RGBImageBatch = Float[Tensor, "batch height width 3"] RGBAImageBatch = Float[Tensor, "batch height width 4"] RGBImage = Float[Tensor, "height width 3"] + + +class Latent(TypedDict): + samples: LatentBatch + noise_mask: NotRequired[LatentBatch]