mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 08:12:34 +08:00
wan: vae: restyle a little to match LTX
This commit is contained in:
parent
8658c570f9
commit
0f5621ee32
@ -380,28 +380,25 @@ class Decoder3d(nn.Module):
|
||||
|
||||
out_chunks = []
|
||||
|
||||
def run_head(x, feat_idx):
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = layer(x, feat_cache[feat_idx[0]])
|
||||
feat_cache[feat_idx[0]] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
out_chunks.append(x)
|
||||
|
||||
def run_upsamples(layer_idx, x_ref, feat_idx):
|
||||
def run_up(layer_idx, x_ref, feat_idx):
|
||||
x = x_ref[0]
|
||||
x_ref[0] = None
|
||||
if layer_idx >= len(self.upsamples):
|
||||
run_head(x, feat_idx)
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = layer(x, feat_cache[feat_idx[0]])
|
||||
feat_cache[feat_idx[0]] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
out_chunks.append(x)
|
||||
return
|
||||
|
||||
layer = self.upsamples[layer_idx]
|
||||
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
|
||||
for frame_idx in range(x.shape[2]):
|
||||
run_upsamples(
|
||||
run_up(
|
||||
layer_idx,
|
||||
[x[:, :, frame_idx:frame_idx + 1, :, :]],
|
||||
feat_idx.copy(),
|
||||
@ -416,9 +413,9 @@ class Decoder3d(nn.Module):
|
||||
|
||||
next_x_ref = [x]
|
||||
del x
|
||||
run_upsamples(layer_idx + 1, next_x_ref, feat_idx)
|
||||
run_up(layer_idx + 1, next_x_ref, feat_idx)
|
||||
|
||||
run_upsamples(0, [x], feat_idx)
|
||||
run_up(0, [x], feat_idx)
|
||||
return out_chunks
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user