Compare commits

..

No commits in common. "3b418dab2c863605c7f4b3cac70692c2db57bd97" and "459150582ea97aca74dcb15161fae7955100520c" have entirely different histories.

3 changed files with 29 additions and 6 deletions

View File

@ -366,6 +366,8 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None):
"""Safe interpolate operation that handles Half precision for problematic modes"""
# Modes qui peuvent causer des problèmes avec Half precision
problematic_modes = ['bilinear', 'bicubic', 'trilinear']
if mode in problematic_modes:
@ -417,8 +419,10 @@ def extend_head(tensor, times: int = 2, memory = None):
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2)
def cache_send_recv(tensor, cache_size, times, memory=None):
# Single GPU inference - simplified cache handling
recv_buffer = None
# Handle memory buffer for single GPU case
if memory is not None:
recv_buffer = memory.to(tensor[0])
elif times > 0:
@ -2047,6 +2051,19 @@ class VideoAutoencoderKL(nn.Module):
h = self.decode(h.latent_dist.mode())
return h.sample
def load_state_dict(self, state_dict, strict=False):
# Newer version of diffusers changed the model keys,
# causing incompatibility with old checkpoints.
# They provided a method for conversion.
# We call conversion before loading state_dict.
convert_deprecated_attention_blocks = getattr(
self, "_convert_deprecated_attention_blocks", None
)
if callable(convert_deprecated_attention_blocks):
convert_deprecated_attention_blocks(state_dict)
return super().load_state_dict(state_dict, strict)
class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
def __init__(
self,
@ -2060,7 +2077,6 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
self.temporal_downsample_factor = temporal_downsample_factor
self.freeze_encoder = freeze_encoder
self.original_image_video = None
self.enable_tiling = False
super().__init__(*args, **kwargs)
self.set_memory_limit(0.5, 0.5)
@ -2083,17 +2099,22 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
z = p.squeeze(2)
return z, p
def decode(self, z):
def decode(self, z: torch.FloatTensor):
b, tc, h, w = z.shape
latent = z.view(b, 16, -1, h, w)
z = z.view(b, 16, -1, h, w)
z = z.movedim(1, -1)
latent = z.unsqueeze(0)
scale = 0.9152
shift = 0
latent = latent / scale + shift
latent = rearrange(latent, "b ... c -> b c ...")
latent = latent.squeeze(2)
if latent.ndim == 4:
latent = latent.unsqueeze(2)
self.enable_tiling = self.tiled_args.get("enable_tiling", False)
if self.tiled_args.get("enable_tiling", None) is not None:
self.enable_tiling = self.tiled_args.pop("enable_tiling", False)
if self.enable_tiling:
x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2)

View File

@ -440,7 +440,9 @@ class CLIP:
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
if metadata is None or metadata.get("keep_diffusers_format") != "true":
if (metadata is not None and metadata["keep_diffusers_format"] == "true"):
pass
else:
sd = diffusers_convert.convert_vae_state_dict(sd)
if model_management.is_amd():

View File

@ -337,7 +337,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
)
@classmethod
def execute(cls, images, vae, resolution, spatial_tile_size, spatial_overlap, temporal_tile_size, enable_tiling):
def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
comfy.model_management.load_models_gpu([vae.patcher])
vae_model = vae.first_stage_model