Fix background removal mask output shape (#14171)

This commit is contained in:
Jukka Seppänen 2026-05-29 19:14:32 +03:00 committed by GitHub
parent ea5b092576
commit 54d5be4a8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -55,12 +55,7 @@ class BackgroundRemovalModel():
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.movedim(-1, 1)
return mask
return mask.squeeze(1) # (B, 1, H, W) -> (B, H, W)
def load_background_removal_model(sd):