This commit is contained in:
Reiner "Tiles" Prokein 2026-03-15 20:03:10 +00:00 committed by GitHub
commit 65db7325d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 15 deletions

View File

@ -98,8 +98,10 @@ class Resample(nn.Module):
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Fix: mutable default argument feat_idx=[0] would persist between calls
def forward(self, x, feat_cache=None, feat_idx=None):
if feat_idx is None:
feat_idx = [0]
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
@ -176,8 +178,10 @@ class ResidualBlock(nn.Module):
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Fix: mutable default argument feat_idx=[0] would persist between calls
def forward(self, x, feat_cache=None, feat_idx=None):
if feat_idx is None:
feat_idx = [0]
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
@ -282,8 +286,10 @@ class Encoder3d(nn.Module):
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Fix: mutable default argument feat_idx=[0] would persist between calls
def forward(self, x, feat_cache=None, feat_idx=None):
if feat_idx is None:
feat_idx = [0]
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
@ -388,8 +394,10 @@ class Decoder3d(nn.Module):
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, output_channels, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Fix: mutable default argument feat_idx=[0] would persist between calls
def forward(self, x, feat_cache=None, feat_idx=None):
if feat_idx is None:
feat_idx = [0]
## conv1
if feat_cache is not None:
idx = feat_idx[0]

View File

@ -54,7 +54,10 @@ class Resample(nn.Module):
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Fix: mutable default argument feat_idx=[0] would persist between calls
def forward(self, x, feat_cache=None, feat_idx=None):
if feat_idx is None:
feat_idx = [0]
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
@ -135,7 +138,10 @@ class ResidualBlock(nn.Module):
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim else nn.Identity())
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Fix: mutable default argument feat_idx=[0] would persist between calls
def forward(self, x, feat_cache=None, feat_idx=None):
if feat_idx is None:
feat_idx = [0]
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
@ -326,7 +332,10 @@ class Down_ResidualBlock(nn.Module):
self.downsamples = nn.Sequential(*downsamples)
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Fix: mutable default argument feat_idx=[0] would persist between calls
def forward(self, x, feat_cache=None, feat_idx=None):
if feat_idx is None:
feat_idx = [0]
x_copy = x
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
@ -367,8 +376,10 @@ class Up_ResidualBlock(nn.Module):
upsamples.append(Resample(out_dim, mode=mode))
self.upsamples = nn.Sequential(*upsamples)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
# Fix: mutable default argument feat_idx=[0] would persist between calls
def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False):
if feat_idx is None:
feat_idx = [0]
x_main = x
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx)
@ -438,7 +449,10 @@ class Encoder3d(nn.Module):
CausalConv3d(out_dim, z_dim, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Fix: mutable default argument feat_idx=[0] would persist between calls
def forward(self, x, feat_cache=None, feat_idx=None):
if feat_idx is None:
feat_idx = [0]
if feat_cache is not None:
idx = feat_idx[0]
@ -550,7 +564,10 @@ class Decoder3d(nn.Module):
CausalConv3d(out_dim, 12, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
# Fix: mutable default argument feat_idx=[0] would persist between calls
def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False):
if feat_idx is None:
feat_idx = [0]
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()