mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
fix
This commit is contained in:
parent
18ae40065a
commit
d423272754
@ -10,7 +10,7 @@ ops = comfy.ops.disable_weight_init
|
||||
class SpatialPadConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, padding=(0, padding, padding), dilation=dilation, **kwargs)
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
@ -98,18 +98,19 @@ def conv_carry(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||
x = xl[0]
|
||||
xl.clear()
|
||||
|
||||
if isinstance(op, SpatialPadConv3d):
|
||||
if conv_carry_in is None:
|
||||
x = torch.nn.functional.pad(x, (0, 0, 0, 0, 2, 0), mode = 'replicate')
|
||||
else:
|
||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||
|
||||
out = op(x)
|
||||
|
||||
if conv_carry_out is not None:
|
||||
to_push = x[:, :, -2:, :, :].clone()
|
||||
conv_carry_out.append(to_push)
|
||||
|
||||
if isinstance(op, SpatialPadConv3d):
|
||||
if conv_carry_in is None:
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||
else:
|
||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 0, 0), mode = 'replicate')
|
||||
|
||||
out = op(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user