Fix background removal mask output shape

Should be B H W like everything else
This commit is contained in:
kijai 2026-05-29 19:07:42 +03:00
parent e7214d78ee
commit 2dcd375fa0

View File

@ -55,12 +55,7 @@ class BackgroundRemovalModel():
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False) out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3: return mask.squeeze(1) # (B, 1, H, W) -> (B, H, W)
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.movedim(-1, 1)
return mask
def load_background_removal_model(sd): def load_background_removal_model(sd):