Fix tensor types

This commit is contained in:
doctorpangloss 2024-09-06 11:04:32 -07:00
parent a4fb34a0b8
commit e8eab4dbc6

View File

@ -1,8 +1,19 @@
from typing import TypedDict
from jaxtyping import Float
from torch import Tensor
from typing_extensions import NotRequired
ImageBatch = Float[Tensor, "batch height width channels"]
LatentBatch = Float[Tensor, "batch channels height width"]
SD1LatentBatch = Float[Tensor, "batch 8 height width"]
SD3LatentBatch = Float[Tensor, "batch 16 height width"]
MaskBatch = Float[Tensor, "batch height width"]
RGBImageBatch = Float[Tensor, "batch height width 3"]
RGBAImageBatch = Float[Tensor, "batch height width 4"]
RGBImage = Float[Tensor, "height width 3"]
class Latent(TypedDict):
samples: LatentBatch
noise_mask: NotRequired[LatentBatch]