mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-03 13:52:31 +08:00
Add high quality preview support for Flux2 latents (#13496)
This commit is contained in:
parent
5eeae3f1d8
commit
a164c82913
@ -224,6 +224,7 @@ class Flux2(LatentFormat):
|
|||||||
|
|
||||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||||
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
|
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
|
||||||
|
self.taesd_decoder_name = "taef2_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
return latent
|
return latent
|
||||||
|
|||||||
@ -479,7 +479,10 @@ class VAE:
|
|||||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||||
elif "taesd_decoder.1.weight" in sd:
|
elif "taesd_decoder.1.weight" in sd:
|
||||||
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
if isinstance(metadata, dict) and "tae_latent_channels" in metadata:
|
||||||
|
self.latent_channels = metadata["tae_latent_channels"]
|
||||||
|
else:
|
||||||
|
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
||||||
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
|
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
|
||||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||||
self.first_stage_model = StageA()
|
self.first_stage_model = StageA()
|
||||||
|
|||||||
@ -17,32 +17,79 @@ class Clamp(nn.Module):
|
|||||||
return torch.tanh(x / 3) * 3
|
return torch.tanh(x / 3) * 3
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, n_in, n_out):
|
def __init__(self, n_in: int, n_out: int, use_midblock_gn: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||||
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||||
self.fuse = nn.ReLU()
|
self.fuse = nn.ReLU()
|
||||||
def forward(self, x):
|
if not use_midblock_gn:
|
||||||
|
self.pool = None
|
||||||
|
return
|
||||||
|
n_gn = n_in * 4
|
||||||
|
self.pool = nn.Sequential(
|
||||||
|
comfy.ops.disable_weight_init.Conv2d(n_in, n_gn, 1, bias=False),
|
||||||
|
comfy.ops.disable_weight_init.GroupNorm(4, n_gn),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
comfy.ops.disable_weight_init.Conv2d(n_gn, n_in, 1, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.pool is not None:
|
||||||
|
x = x + self.pool(x)
|
||||||
return self.fuse(self.conv(x) + self.skip(x))
|
return self.fuse(self.conv(x) + self.skip(x))
|
||||||
|
|
||||||
def Encoder(latent_channels=4):
|
class Encoder(nn.Sequential):
|
||||||
return nn.Sequential(
|
def __init__(self, latent_channels: int = 4, use_gn: bool = False):
|
||||||
conv(3, 64), Block(64, 64),
|
super().__init__(
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(3, 64), Block(64, 64),
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
conv(64, latent_channels),
|
conv(64, 64, stride=2, bias=False), Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn),
|
||||||
)
|
conv(64, latent_channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
class Decoder(nn.Sequential):
|
||||||
|
def __init__(self, latent_channels: int = 4, use_gn: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
||||||
|
Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
|
Block(64, 64), conv(64, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
class DecoderFlux2(Decoder):
|
||||||
|
def __init__(self, latent_channels: int = 128, use_gn: bool = True):
|
||||||
|
if latent_channels != 128 or not use_gn:
|
||||||
|
raise ValueError("Unexpected parameters for Flux2 TAE module")
|
||||||
|
super().__init__(latent_channels=32, use_gn=True)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
x = (
|
||||||
|
x
|
||||||
|
.reshape(B, 32, 2, 2, H, W)
|
||||||
|
.permute(0, 1, 4, 2, 5, 3)
|
||||||
|
.reshape(B, 32, H * 2, W * 2)
|
||||||
|
)
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
class EncoderFlux2(Encoder):
|
||||||
|
def __init__(self, latent_channels: int = 128, use_gn: bool = True):
|
||||||
|
if latent_channels != 128 or not use_gn:
|
||||||
|
raise ValueError("Unexpected parameters for Flux2 TAE module")
|
||||||
|
super().__init__(latent_channels=32, use_gn=True)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
result = super().forward(x)
|
||||||
|
B, C, H, W = result.shape
|
||||||
|
return (
|
||||||
|
result
|
||||||
|
.reshape(B, C, H // 2, 2, W // 2, 2)
|
||||||
|
.permute(0, 1, 3, 5, 2, 4)
|
||||||
|
.reshape(B, 128, H // 2, W // 2)
|
||||||
|
)
|
||||||
|
|
||||||
def Decoder(latent_channels=4):
|
|
||||||
return nn.Sequential(
|
|
||||||
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
|
||||||
Block(64, 64), conv(64, 3),
|
|
||||||
)
|
|
||||||
|
|
||||||
class TAESD(nn.Module):
|
class TAESD(nn.Module):
|
||||||
latent_magnitude = 3
|
latent_magnitude = 3
|
||||||
@ -51,8 +98,15 @@ class TAESD(nn.Module):
|
|||||||
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
|
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
|
||||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.taesd_encoder = Encoder(latent_channels=latent_channels)
|
if latent_channels == 128:
|
||||||
self.taesd_decoder = Decoder(latent_channels=latent_channels)
|
encoder_class = EncoderFlux2
|
||||||
|
decoder_class = DecoderFlux2
|
||||||
|
else:
|
||||||
|
encoder_class = Encoder
|
||||||
|
decoder_class = Decoder
|
||||||
|
self.taesd_encoder = encoder_class(latent_channels=latent_channels)
|
||||||
|
self.taesd_decoder = decoder_class(latent_channels=latent_channels)
|
||||||
|
|
||||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||||
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
||||||
if encoder_path is not None:
|
if encoder_path is not None:
|
||||||
@ -61,19 +115,19 @@ class TAESD(nn.Module):
|
|||||||
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def scale_latents(x):
|
def scale_latents(x: torch.Tensor) -> torch.Tensor:
|
||||||
"""raw latents -> [0, 1]"""
|
"""raw latents -> [0, 1]"""
|
||||||
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
|
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unscale_latents(x):
|
def unscale_latents(x: torch.Tensor) -> torch.Tensor:
|
||||||
"""[0, 1] -> raw latents"""
|
"""[0, 1] -> raw latents"""
|
||||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||||
|
|
||||||
def decode(self, x):
|
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
|
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
|
||||||
x_sample = x_sample.sub(0.5).mul(2)
|
x_sample = x_sample.sub(0.5).mul(2)
|
||||||
return x_sample
|
return x_sample
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
||||||
|
|||||||
53
nodes.py
53
nodes.py
@ -728,50 +728,26 @@ class LoraLoaderModelOnly(LoraLoader):
|
|||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
||||||
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
|
image_taes = ["taesd", "taesdxl", "taesd3", "taef1", "taef2"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def vae_list(s):
|
def vae_list(s):
|
||||||
vaes = folder_paths.get_filename_list("vae")
|
vaes = folder_paths.get_filename_list("vae")
|
||||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||||
sdxl_taesd_enc = False
|
have_img_encoder, have_img_decoder = set(), set()
|
||||||
sdxl_taesd_dec = False
|
|
||||||
sd1_taesd_enc = False
|
|
||||||
sd1_taesd_dec = False
|
|
||||||
sd3_taesd_enc = False
|
|
||||||
sd3_taesd_dec = False
|
|
||||||
f1_taesd_enc = False
|
|
||||||
f1_taesd_dec = False
|
|
||||||
|
|
||||||
for v in approx_vaes:
|
for v in approx_vaes:
|
||||||
if v.startswith("taesd_decoder."):
|
parts = v.split("_", 1)
|
||||||
sd1_taesd_dec = True
|
if len(parts) != 2 or parts[0] not in s.image_taes:
|
||||||
elif v.startswith("taesd_encoder."):
|
|
||||||
sd1_taesd_enc = True
|
|
||||||
elif v.startswith("taesdxl_decoder."):
|
|
||||||
sdxl_taesd_dec = True
|
|
||||||
elif v.startswith("taesdxl_encoder."):
|
|
||||||
sdxl_taesd_enc = True
|
|
||||||
elif v.startswith("taesd3_decoder."):
|
|
||||||
sd3_taesd_dec = True
|
|
||||||
elif v.startswith("taesd3_encoder."):
|
|
||||||
sd3_taesd_enc = True
|
|
||||||
elif v.startswith("taef1_encoder."):
|
|
||||||
f1_taesd_dec = True
|
|
||||||
elif v.startswith("taef1_decoder."):
|
|
||||||
f1_taesd_enc = True
|
|
||||||
else:
|
|
||||||
for tae in s.video_taes:
|
for tae in s.video_taes:
|
||||||
if v.startswith(tae):
|
if v.startswith(tae):
|
||||||
vaes.append(v)
|
vaes.append(v)
|
||||||
|
break
|
||||||
if sd1_taesd_dec and sd1_taesd_enc:
|
continue
|
||||||
vaes.append("taesd")
|
if parts[1].startswith("encoder."):
|
||||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
have_img_encoder.add(parts[0])
|
||||||
vaes.append("taesdxl")
|
elif parts[1].startswith("decoder."):
|
||||||
if sd3_taesd_dec and sd3_taesd_enc:
|
have_img_decoder.add(parts[0])
|
||||||
vaes.append("taesd3")
|
vaes += [k for k in have_img_decoder if k in have_img_encoder]
|
||||||
if f1_taesd_dec and f1_taesd_enc:
|
|
||||||
vaes.append("taef1")
|
|
||||||
vaes.append("pixel_space")
|
vaes.append("pixel_space")
|
||||||
return vaes
|
return vaes
|
||||||
|
|
||||||
@ -827,6 +803,11 @@ class VAELoader:
|
|||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||||
|
if vae_name == "taef2":
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {"tae_latent_channels": 128}
|
||||||
|
else:
|
||||||
|
metadata["tae_latent_channels"] = 128
|
||||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||||
vae.throw_exception_if_invalid()
|
vae.throw_exception_if_invalid()
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user