Add validation for invalid HiDream input dimensions

This commit is contained in:
unknown 2026-05-14 13:36:08 +05:30
parent 1f28908d6e
commit cbcadb3e61

View File

@ -148,6 +148,12 @@ class HiDreamO1Transformer(nn.Module):
raise ValueError("HiDreamO1Transformer requires input_ids and position_ids in conditioning")
B, _, H, W = x.shape
if H % self.patch_size!=0 or w% self.patch_size!=0:
raise ValueError(
f"Input dimensions ({H},{W}) must be divisible"
f"by patch size{self.patch_size}"
)
h_p, w_p = H // self.patch_size, W // self.patch_size
tgt_image_len = h_p * w_p