mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 20:32:45 +08:00
removed .tos
This commit is contained in:
parent
f6aeec66f7
commit
37f80907bc
@ -102,7 +102,7 @@ class WindowAttention(nn.Module):
|
|||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = (q @ k.transpose(-2, -1))
|
attn = (q @ k.transpose(-2, -1))
|
||||||
|
|
||||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
|
||||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||||
attn = attn + relative_position_bias.unsqueeze(0)
|
attn = attn + relative_position_bias.unsqueeze(0)
|
||||||
@ -422,7 +422,7 @@ class DeformableConv2d(nn.Module):
|
|||||||
padding=self.padding,
|
padding=self.padding,
|
||||||
bias=True, device=device, dtype=dtype)
|
bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.regular_conv = operations.Conv2d(in_channels,
|
self.regular_conv = torch.nn.Conv2d(in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
@ -432,11 +432,6 @@ class DeformableConv2d(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
offset = self.offset_conv(x)
|
offset = self.offset_conv(x)
|
||||||
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
||||||
dtype = self.regular_conv.weight.dtype
|
|
||||||
device = self.regular_conv.weight.device
|
|
||||||
x = x.to(dtype).to(device)
|
|
||||||
offset = offset.to(dtype).to(device)
|
|
||||||
modulator = modulator.to(dtype).to(device)
|
|
||||||
x = deform_conv2d(
|
x = deform_conv2d(
|
||||||
input=x,
|
input=x,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
@ -517,8 +512,6 @@ class ASPPDeformable(nn.Module):
|
|||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
device = self.conv1.weight.device
|
|
||||||
x = x.to(device)
|
|
||||||
x1 = self.aspp1(x)
|
x1 = self.aspp1(x)
|
||||||
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
|
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
|
||||||
x5 = self.global_avg_pool(x)
|
x5 = self.global_avg_pool(x)
|
||||||
@ -637,8 +630,7 @@ class Decoder(nn.Module):
|
|||||||
return torch.cat(patches_batch, dim=0)
|
return torch.cat(patches_batch, dim=0)
|
||||||
|
|
||||||
def forward(self, features):
|
def forward(self, features):
|
||||||
device = next(self.ipt_blk5.parameters()).device
|
x, x1, x2, x3, x4 = features
|
||||||
x, x1, x2, x3, x4 = [t.to(device) for t in features]
|
|
||||||
|
|
||||||
patches_batch = self.get_patches_batch(x, x4) if self.split else x
|
patches_batch = self.get_patches_batch(x, x4) if self.split else x
|
||||||
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
|
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user