This commit is contained in:
Rattus 2025-11-20 04:26:51 +10:00 committed by comfyanonymous
parent 18ae40065a
commit d423272754

View File

@ -10,7 +10,7 @@ ops = comfy.ops.disable_weight_init
class SpatialPadConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, padding=(0, padding, padding), dilation=dilation, **kwargs)
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
@ -98,18 +98,19 @@ def conv_carry(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if isinstance(op, SpatialPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (0, 0, 0, 0, 2, 0), mode = 'replicate')
else:
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
out = op(x)
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
if isinstance(op, SpatialPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 0, 0), mode = 'replicate')
out = op(x)
return out