From 37f80907bc8254cf827abfd21673ee2045aab9dd Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 16 Apr 2026 23:27:18 +0200 Subject: [PATCH] removed .tos --- comfy/image_encoders/birefnet.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/comfy/image_encoders/birefnet.py b/comfy/image_encoders/birefnet.py index 08abcd1a2..167c8c067 100644 --- a/comfy/image_encoders/birefnet.py +++ b/comfy/image_encoders/birefnet.py @@ -102,7 +102,7 @@ class WindowAttention(nn.Module): q = q * self.scale attn = (q @ k.transpose(-2, -1)) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) @@ -422,7 +422,7 @@ class DeformableConv2d(nn.Module): padding=self.padding, bias=True, device=device, dtype=dtype) - self.regular_conv = operations.Conv2d(in_channels, + self.regular_conv = torch.nn.Conv2d(in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, @@ -432,11 +432,6 @@ class DeformableConv2d(nn.Module): def forward(self, x): offset = self.offset_conv(x) modulator = 2. * torch.sigmoid(self.modulator_conv(x)) - dtype = self.regular_conv.weight.dtype - device = self.regular_conv.weight.device - x = x.to(dtype).to(device) - offset = offset.to(dtype).to(device) - modulator = modulator.to(dtype).to(device) x = deform_conv2d( input=x, offset=offset, @@ -517,8 +512,6 @@ class ASPPDeformable(nn.Module): self.relu = nn.ReLU(inplace=True) def forward(self, x): - device = self.conv1.weight.device - x = x.to(device) x1 = self.aspp1(x) x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms] x5 = self.global_avg_pool(x) @@ -637,8 +630,7 @@ class Decoder(nn.Module): return torch.cat(patches_batch, dim=0) def forward(self, features): - device = next(self.ipt_blk5.parameters()).device - x, x1, x2, x3, x4 = [t.to(device) for t in features] + x, x1, x2, x3, x4 = features patches_batch = self.get_patches_batch(x, x4) if self.split else x x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)