mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 05:23:37 +08:00
tested with a modern version of comfyui
This commit is contained in:
parent
459150582e
commit
af6c5d6de9
@ -366,8 +366,6 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None):
|
def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None):
|
||||||
"""Safe interpolate operation that handles Half precision for problematic modes"""
|
|
||||||
# Modes qui peuvent causer des problèmes avec Half precision
|
|
||||||
problematic_modes = ['bilinear', 'bicubic', 'trilinear']
|
problematic_modes = ['bilinear', 'bicubic', 'trilinear']
|
||||||
|
|
||||||
if mode in problematic_modes:
|
if mode in problematic_modes:
|
||||||
@ -419,10 +417,8 @@ def extend_head(tensor, times: int = 2, memory = None):
|
|||||||
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2)
|
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2)
|
||||||
|
|
||||||
def cache_send_recv(tensor, cache_size, times, memory=None):
|
def cache_send_recv(tensor, cache_size, times, memory=None):
|
||||||
# Single GPU inference - simplified cache handling
|
|
||||||
recv_buffer = None
|
recv_buffer = None
|
||||||
|
|
||||||
# Handle memory buffer for single GPU case
|
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
recv_buffer = memory.to(tensor[0])
|
recv_buffer = memory.to(tensor[0])
|
||||||
elif times > 0:
|
elif times > 0:
|
||||||
@ -2051,19 +2047,6 @@ class VideoAutoencoderKL(nn.Module):
|
|||||||
h = self.decode(h.latent_dist.mode())
|
h = self.decode(h.latent_dist.mode())
|
||||||
return h.sample
|
return h.sample
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=False):
|
|
||||||
# Newer version of diffusers changed the model keys,
|
|
||||||
# causing incompatibility with old checkpoints.
|
|
||||||
# They provided a method for conversion.
|
|
||||||
# We call conversion before loading state_dict.
|
|
||||||
convert_deprecated_attention_blocks = getattr(
|
|
||||||
self, "_convert_deprecated_attention_blocks", None
|
|
||||||
)
|
|
||||||
if callable(convert_deprecated_attention_blocks):
|
|
||||||
convert_deprecated_attention_blocks(state_dict)
|
|
||||||
return super().load_state_dict(state_dict, strict)
|
|
||||||
|
|
||||||
|
|
||||||
class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -2099,16 +2082,12 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
z = p.squeeze(2)
|
z = p.squeeze(2)
|
||||||
return z, p
|
return z, p
|
||||||
|
|
||||||
def decode(self, z: torch.FloatTensor):
|
def decode(self, z):
|
||||||
b, tc, h, w = z.shape
|
b, tc, h, w = z.shape
|
||||||
z = z.view(b, 16, -1, h, w)
|
latent = z.view(b, 16, -1, h, w)
|
||||||
z = z.movedim(1, -1)
|
|
||||||
latent = z.unsqueeze(0)
|
|
||||||
scale = 0.9152
|
scale = 0.9152
|
||||||
shift = 0
|
shift = 0
|
||||||
latent = latent / scale + shift
|
latent = latent / scale + shift
|
||||||
latent = rearrange(latent, "b ... c -> b c ...")
|
|
||||||
latent = latent.squeeze(2)
|
|
||||||
|
|
||||||
if latent.ndim == 4:
|
if latent.ndim == 4:
|
||||||
latent = latent.unsqueeze(2)
|
latent = latent.unsqueeze(2)
|
||||||
|
|||||||
@ -440,9 +440,7 @@ class CLIP:
|
|||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
if (metadata is not None and metadata["keep_diffusers_format"] == "true"):
|
if metadata is None or metadata.get("keep_diffusers_format") != "true":
|
||||||
pass
|
|
||||||
else:
|
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
if model_management.is_amd():
|
if model_management.is_amd():
|
||||||
|
|||||||
@ -337,7 +337,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
|
def execute(cls, images, vae, resolution, spatial_tile_size, spatial_overlap, temporal_tile_size, enable_tiling):
|
||||||
|
|
||||||
comfy.model_management.load_models_gpu([vae.patcher])
|
comfy.model_management.load_models_gpu([vae.patcher])
|
||||||
vae_model = vae.first_stage_model
|
vae_model = vae.first_stage_model
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user