From f6aeec66f7901fbf0d6229a1a57360fe68d52dea Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 16 Apr 2026 00:44:37 +0200 Subject: [PATCH] . --- comfy/image_encoders/birefnet.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/comfy/image_encoders/birefnet.py b/comfy/image_encoders/birefnet.py index edb72eeee..08abcd1a2 100644 --- a/comfy/image_encoders/birefnet.py +++ b/comfy/image_encoders/birefnet.py @@ -433,9 +433,10 @@ class DeformableConv2d(nn.Module): offset = self.offset_conv(x) modulator = 2. * torch.sigmoid(self.modulator_conv(x)) dtype = self.regular_conv.weight.dtype - x = x.to(dtype) - offset = offset.to(dtype) - modulator = modulator.to(dtype) + device = self.regular_conv.weight.device + x = x.to(dtype).to(device) + offset = offset.to(dtype).to(device) + modulator = modulator.to(dtype).to(device) x = deform_conv2d( input=x, offset=offset, @@ -516,6 +517,8 @@ class ASPPDeformable(nn.Module): self.relu = nn.ReLU(inplace=True) def forward(self, x): + device = self.conv1.weight.device + x = x.to(device) x1 = self.aspp1(x) x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms] x5 = self.global_avg_pool(x) @@ -634,7 +637,8 @@ class Decoder(nn.Module): return torch.cat(patches_batch, dim=0) def forward(self, features): - x, x1, x2, x3, x4 = features + device = next(self.ipt_blk5.parameters()).device + x, x1, x2, x3, x4 = [t.to(device) for t in features] patches_batch = self.get_patches_batch(x, x4) if self.split else x x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)