mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 03:23:00 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
28092933c1
@ -468,55 +468,46 @@ class WanVAE(nn.Module):
|
|||||||
attn_scales, self.temperal_upsample, dropout)
|
attn_scales, self.temperal_upsample, dropout)
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
self.clear_cache()
|
conv_idx = [0]
|
||||||
|
feat_map = [None] * count_conv3d(self.decoder)
|
||||||
## cache
|
## cache
|
||||||
t = x.shape[2]
|
t = x.shape[2]
|
||||||
iter_ = 1 + (t - 1) // 4
|
iter_ = 1 + (t - 1) // 4
|
||||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._enc_conv_idx = [0]
|
conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.encoder(
|
out = self.encoder(
|
||||||
x[:, :, :1, :, :],
|
x[:, :, :1, :, :],
|
||||||
feat_cache=self._enc_feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=self._enc_conv_idx)
|
feat_idx=conv_idx)
|
||||||
else:
|
else:
|
||||||
out_ = self.encoder(
|
out_ = self.encoder(
|
||||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||||
feat_cache=self._enc_feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=self._enc_conv_idx)
|
feat_idx=conv_idx)
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||||
self.clear_cache()
|
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
def decode(self, z):
|
def decode(self, z):
|
||||||
self.clear_cache()
|
conv_idx = [0]
|
||||||
|
feat_map = [None] * count_conv3d(self.decoder)
|
||||||
# z: [b,c,t,h,w]
|
# z: [b,c,t,h,w]
|
||||||
|
|
||||||
iter_ = z.shape[2]
|
iter_ = z.shape[2]
|
||||||
x = self.conv2(z)
|
x = self.conv2(z)
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._conv_idx = [0]
|
conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.decoder(
|
out = self.decoder(
|
||||||
x[:, :, i:i + 1, :, :],
|
x[:, :, i:i + 1, :, :],
|
||||||
feat_cache=self._feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=self._conv_idx)
|
feat_idx=conv_idx)
|
||||||
else:
|
else:
|
||||||
out_ = self.decoder(
|
out_ = self.decoder(
|
||||||
x[:, :, i:i + 1, :, :],
|
x[:, :, i:i + 1, :, :],
|
||||||
feat_cache=self._feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=self._conv_idx)
|
feat_idx=conv_idx)
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
self.clear_cache()
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
self._conv_num = count_conv3d(self.decoder)
|
|
||||||
self._conv_idx = [0]
|
|
||||||
self._feat_map = [None] * self._conv_num
|
|
||||||
#cache encode
|
|
||||||
self._enc_conv_num = count_conv3d(self.encoder)
|
|
||||||
self._enc_conv_idx = [0]
|
|
||||||
self._enc_feat_map = [None] * self._enc_conv_num
|
|
||||||
|
|||||||
16
comfy/sd.py
16
comfy/sd.py
@ -652,6 +652,7 @@ class VAE:
|
|||||||
def decode(self, samples_in, vae_options={}):
|
def decode(self, samples_in, vae_options={}):
|
||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
|
do_tile = False
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
@ -667,6 +668,13 @@ class VAE:
|
|||||||
pixel_samples[x:x+batch_number] = out
|
pixel_samples[x:x+batch_number] = out
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||||
|
#exception is fully off the books.
|
||||||
|
do_tile = True
|
||||||
|
|
||||||
|
if do_tile:
|
||||||
dims = samples_in.ndim - 2
|
dims = samples_in.ndim - 2
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
@ -713,6 +721,7 @@ class VAE:
|
|||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
|
do_tile = False
|
||||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||||
if not self.not_video:
|
if not self.not_video:
|
||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
@ -734,6 +743,13 @@ class VAE:
|
|||||||
|
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||||
|
#exception is fully off the books.
|
||||||
|
do_tile = True
|
||||||
|
|
||||||
|
if do_tile:
|
||||||
if self.latent_dim == 3:
|
if self.latent_dim == 3:
|
||||||
tile = 256
|
tile = 256
|
||||||
overlap = tile // 4
|
overlap = tile // 4
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user