Move mask to CPU

This commit is contained in:
Alexis Rolland 2026-05-08 18:25:36 +08:00
parent b42f8a7ba7
commit 75fbd9c933
2 changed files with 2 additions and 2 deletions

View File

@ -47,7 +47,7 @@ class BackgroundRemovalModel():
out = self.model(pixel_values=pixel_values) out = self.model(pixel_values=pixel_values)
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False) out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid() mask = out.sigmoid().float().cpu()
if mask.ndim == 3: if mask.ndim == 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
if mask.shape[1] != 1: if mask.shape[1] != 1:

View File

@ -203,7 +203,7 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod @classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = max(len(image), len(alpha)) batch_size = max(len(image), len(alpha))
alpha = 1.0 - resize_mask(alpha.to(image.device), image.shape[1:]) alpha = 1.0 - resize_mask(alpha.to(image), image.shape[1:])
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size) alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
image = comfy.utils.repeat_to_batch_size(image, batch_size) image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1)) return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))