diff --git a/comfy/bg_removal_model.py b/comfy/bg_removal_model.py index 6dec65e63..c772c5f6a 100644 --- a/comfy/bg_removal_model.py +++ b/comfy/bg_removal_model.py @@ -55,12 +55,7 @@ class BackgroundRemovalModel(): out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False) mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) - if mask.ndim == 3: - mask = mask.unsqueeze(0) - if mask.shape[1] != 1: - mask = mask.movedim(-1, 1) - - return mask + return mask.squeeze(1) # (B, 1, H, W) -> (B, H, W) def load_background_removal_model(sd):