This commit is contained in:
Yousef Rafat 2026-04-16 00:44:37 +02:00
parent 2932059118
commit f6aeec66f7

View File

@ -433,9 +433,10 @@ class DeformableConv2d(nn.Module):
offset = self.offset_conv(x) offset = self.offset_conv(x)
modulator = 2. * torch.sigmoid(self.modulator_conv(x)) modulator = 2. * torch.sigmoid(self.modulator_conv(x))
dtype = self.regular_conv.weight.dtype dtype = self.regular_conv.weight.dtype
x = x.to(dtype) device = self.regular_conv.weight.device
offset = offset.to(dtype) x = x.to(dtype).to(device)
modulator = modulator.to(dtype) offset = offset.to(dtype).to(device)
modulator = modulator.to(dtype).to(device)
x = deform_conv2d( x = deform_conv2d(
input=x, input=x,
offset=offset, offset=offset,
@ -516,6 +517,8 @@ class ASPPDeformable(nn.Module):
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
def forward(self, x): def forward(self, x):
device = self.conv1.weight.device
x = x.to(device)
x1 = self.aspp1(x) x1 = self.aspp1(x)
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms] x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
x5 = self.global_avg_pool(x) x5 = self.global_avg_pool(x)
@ -634,7 +637,8 @@ class Decoder(nn.Module):
return torch.cat(patches_batch, dim=0) return torch.cat(patches_batch, dim=0)
def forward(self, features): def forward(self, features):
x, x1, x2, x3, x4 = features device = next(self.ipt_blk5.parameters()).device
x, x1, x2, x3, x4 = [t.to(device) for t in features]
patches_batch = self.get_patches_batch(x, x4) if self.split else x patches_batch = self.get_patches_batch(x, x4) if self.split else x
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1) x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)