Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-10-02 12:49:48 +09:00
commit 28092933c1
2 changed files with 30 additions and 23 deletions

View File

@ -468,55 +468,46 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
self._enc_conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
# z: [b,c,t,h,w]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -652,6 +652,7 @@ class VAE:
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -667,6 +668,13 @@ class VAE:
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@ -713,6 +721,7 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
do_tile = False
if self.latent_dim == 3 and pixel_samples.ndim < 5:
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
@ -734,6 +743,13 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4