Compare commits

...

4 Commits

Author SHA1 Message Date
Jukka Seppänen
83413bc06c
Merge c03a90ecfc into dcff27fe3f 2026-01-27 18:50:11 -05:00
guill
dcff27fe3f
Add support for dev-only nodes. (#12106)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
When a node is declared as dev-only, it doesn't show in the default UI
unless the dev mode is enabled in the settings. The intention is to
allow nodes related to unit testing to be included in ComfyUI
distributions without confusing the average user.
2026-01-27 13:03:29 -08:00
kijai
c03a90ecfc Fix encode 2026-01-24 11:52:00 +02:00
kijai
570b11198b Flux2: Support Tiny VAE (taef2) 2026-01-23 15:43:29 +02:00
7 changed files with 74 additions and 15 deletions

View File

@ -236,6 +236,8 @@ class ComfyNodeABC(ABC):
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
DEV_ONLY: bool
"""Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled."""
API_NODE: Optional[bool]
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""

View File

@ -222,6 +222,7 @@ class Flux2(LatentFormat):
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
self.taesd_decoder_name = "taef2_decoder"
def process_in(self, latent):
return latent

View File

@ -451,7 +451,7 @@ class VAE:
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels, use_midblock_gn = True if "taesd_decoder.3.pool.0.weight" in sd else False)
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA()
self.downscale_ratio = 4

View File

@ -17,28 +17,36 @@ class Clamp(nn.Module):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
def __init__(self, n_in, n_out):
def __init__(self, n_in, n_out, use_midblock_gn=False):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
self.pool = None
if use_midblock_gn:
conv1x1, n_gn = lambda n_in, n_out: comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False), n_in*4
self.pool = nn.Sequential(conv1x1(n_in, n_gn), comfy.ops.disable_weight_init.GroupNorm(4, n_gn), nn.ReLU(inplace=True), conv1x1(n_gn, n_in))
def forward(self, x):
if self.pool is not None:
x = x + self.pool(x)
return self.fuse(self.conv(x) + self.skip(x))
def Encoder(latent_channels=4):
def Encoder(latent_channels=4, use_midblock_gn=False):
mb_kw = dict(use_midblock_gn=use_midblock_gn)
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64, **mb_kw), Block(64, 64, **mb_kw), Block(64, 64, **mb_kw),
conv(64, latent_channels),
)
def Decoder(latent_channels=4):
def Decoder(latent_channels=4, use_midblock_gn=False):
mb_kw = dict(use_midblock_gn=use_midblock_gn)
return nn.Sequential(
Clamp(), conv(latent_channels, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64, **mb_kw), Block(64, 64, **mb_kw), Block(64, 64, **mb_kw), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
@ -48,17 +56,30 @@ class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4, use_midblock_gn=False):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.taesd_encoder = Encoder(latent_channels=latent_channels)
self.taesd_decoder = Decoder(latent_channels=latent_channels)
self.latent_channels = latent_channels
self.use_midblock_gn = use_midblock_gn
self.taesd_encoder = Encoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn)
self.taesd_decoder = Decoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn)
if encoder_path is not None:
self.taesd_encoder, self.latent_channels = self._load_model(encoder_path, Encoder)
if decoder_path is not None:
self.taesd_decoder, self.latent_channels = self._load_model(decoder_path, Decoder)
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None:
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None:
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
def _load_model(self, path, model_class):
"""Load a TAESD encoder or decoder from a file."""
sd = comfy.utils.load_torch_file(path, safe_load=True)
latent_channels = sd["1.weight"].shape[1]
model = model_class(latent_channels=latent_channels, use_midblock_gn="3.pool.0.weight" in sd)
model.load_state_dict(sd)
return model, latent_channels
@staticmethod
def scale_latents(x):
@ -71,9 +92,15 @@ class TAESD(nn.Module):
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode(self, x):
if x.shape[1] == self.latent_channels * 4:
x = x.reshape(x.shape[0], self.latent_channels, 2, 2, x.shape[-2], x.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(x.shape[0], self.latent_channels, x.shape[-2] * 2, x.shape[-1] * 2)
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample
def encode(self, x):
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
x_sample = (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
if self.latent_channels == 32 and self.use_midblock_gn: # Only taef2 for Flux2 currently, pack latents: [B, C, H, W] -> [B, C*4, H//2, W//2]
x_sample = x_sample.reshape(x_sample.shape[0], self.latent_channels, x_sample.shape[-2] // 2, 2, x_sample.shape[-1] // 2, 2).permute(0, 1, 3, 5, 2, 4).reshape(x_sample.shape[0], self.latent_channels * 4, x_sample.shape[-2] // 2, x_sample.shape[-1] // 2)
return x_sample

View File

@ -1247,6 +1247,7 @@ class NodeInfoV1:
output_node: bool=None
deprecated: bool=None
experimental: bool=None
dev_only: bool=None
api_node: bool=None
price_badge: dict | None = None
search_aliases: list[str]=None
@ -1264,6 +1265,7 @@ class NodeInfoV3:
output_node: bool=None
deprecated: bool=None
experimental: bool=None
dev_only: bool=None
api_node: bool=None
price_badge: dict | None = None
@ -1375,6 +1377,8 @@ class Schema:
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
is_experimental: bool=False
"""Flags a node as experimental, informing users that it may change or not work as expected."""
is_dev_only: bool=False
"""Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled."""
is_api_node: bool=False
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
price_badge: PriceBadge | None = None
@ -1485,6 +1489,7 @@ class Schema:
output_node=self.is_output_node,
deprecated=self.is_deprecated,
experimental=self.is_experimental,
dev_only=self.is_dev_only,
api_node=self.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
@ -1519,6 +1524,7 @@ class Schema:
output_node=self.is_output_node,
deprecated=self.is_deprecated,
experimental=self.is_experimental,
dev_only=self.is_dev_only,
api_node=self.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
@ -1791,6 +1797,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA()
return cls._DEPRECATED
_DEV_ONLY = None
@final
@classproperty
def DEV_ONLY(cls): # noqa
if cls._DEV_ONLY is None:
cls.GET_SCHEMA()
return cls._DEV_ONLY
_API_NODE = None
@final
@classproperty
@ -1893,6 +1907,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._EXPERIMENTAL = schema.is_experimental
if cls._DEPRECATED is None:
cls._DEPRECATED = schema.is_deprecated
if cls._DEV_ONLY is None:
cls._DEV_ONLY = schema.is_dev_only
if cls._API_NODE is None:
cls._API_NODE = schema.is_api_node
if cls._OUTPUT_NODE is None:

View File

@ -724,7 +724,7 @@ class LoraLoaderModelOnly(LoraLoader):
class VAELoader:
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
image_taes = ["taesd", "taesdxl", "taesd3", "taef1", "taef2"]
@staticmethod
def vae_list(s):
vaes = folder_paths.get_filename_list("vae")
@ -737,6 +737,8 @@ class VAELoader:
sd3_taesd_dec = False
f1_taesd_enc = False
f1_taesd_dec = False
f2_taesd_enc = False
f2_taesd_dec = False
for v in approx_vaes:
if v.startswith("taesd_decoder."):
@ -755,6 +757,10 @@ class VAELoader:
f1_taesd_dec = True
elif v.startswith("taef1_decoder."):
f1_taesd_enc = True
elif v.startswith("taef2_encoder."):
f2_taesd_dec = True
elif v.startswith("taef2_decoder."):
f2_taesd_enc = True
else:
for tae in s.video_taes:
if v.startswith(tae):
@ -768,6 +774,8 @@ class VAELoader:
vaes.append("taesd3")
if f1_taesd_dec and f1_taesd_enc:
vaes.append("taef1")
if f2_taesd_dec and f2_taesd_enc:
vaes.append("taef2")
vaes.append("pixel_space")
return vaes
@ -799,6 +807,9 @@ class VAELoader:
elif name == "taef1":
sd["vae_scale"] = torch.tensor(0.3611)
sd["vae_shift"] = torch.tensor(0.1159)
elif name == "taef2":
sd["vae_scale"] = torch.tensor(1.0)
sd["vae_shift"] = torch.tensor(0.0)
return sd
@classmethod

View File

@ -679,6 +679,8 @@ class PromptServer():
info['deprecated'] = True
if getattr(obj_class, "EXPERIMENTAL", False):
info['experimental'] = True
if getattr(obj_class, "DEV_ONLY", False):
info['dev_only'] = True
if hasattr(obj_class, 'API_NODE'):
info['api_node'] = obj_class.API_NODE