Use intermediate state

This commit is contained in:
Alexis Rolland 2026-05-08 22:57:31 +08:00
parent 75fbd9c933
commit de5dec4e74

View File

@ -47,7 +47,7 @@ class BackgroundRemovalModel():
out = self.model(pixel_values=pixel_values)
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().float().cpu()
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1: