Compare commits

...

9 Commits

Author SHA1 Message Date
Alexander Piskun
e6f06b994e
Merge f1c07a72c4 into e84a200a3c 2026-03-16 04:00:10 +09:00
rattus
e84a200a3c
ops: opt out of deferred weight init if subclassed (#12967)
If a subclass BYO _load_from_state_dict and doesnt call the super() the
needed default init of these weights is missed and can lead to problems
for uninitialized weights.
2026-03-15 11:49:49 -07:00
Dr.Lt.Data
192cb8eeb9
bump manager version to 4.1b5 (#12957) 2026-03-15 11:48:56 -07:00
Jukka Seppänen
0904cc3fe5
LTXV: Accumulate VAE decode results on intermediate_device (#12955)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
2026-03-14 18:09:09 -07:00
comfyanonymous
4941cd046e
Update comfyui-frontend-package to version 1.41.20 (#12954) 2026-03-14 19:53:31 -04:00
comfyanonymous
c711b8f437
Add --fp16-intermediates to use fp16 for intermediate values between nodes (#12953)
This is an experimental WIP option that might not work in your workflow but
should lower memory usage if it does.

Currently only the VAE and the load image node will output in fp16 when
this option is turned on.
2026-03-14 19:18:19 -04:00
Jukka Seppänen
1c5db7397d
feat: Support mxfp8 (#12907) 2026-03-14 18:36:29 -04:00
Christian Byrne
e0982a7174
fix: use no-store cache headers to prevent stale frontend chunks (#12911)
After a frontend update (e.g. nightly build), browsers could load
outdated cached index.html and JS/CSS chunks, causing dynamically
imported modules to fail with MIME type errors and vite:preloadError.

Hard refresh (Ctrl+Shift+R) was insufficient to fix the issue because
Cache-Control: no-cache still allows the browser to cache and
revalidate via ETags. aiohttp's FileResponse auto-generates ETags
based on file mtime+size, which may not change after pip reinstall,
so the browser gets 304 Not Modified and serves stale content.

Clearing ALL site data in DevTools did fix it, confirming the HTTP
cache was the root cause.

The fix changes:
- index.html: no-cache -> no-store, must-revalidate
- JS/CSS/JSON entry points: no-cache -> no-store

no-store instructs browsers to never cache these responses, ensuring
every page load fetches the current index.html with correct chunk
references. This is a small tradeoff (~5KB re-download per page load)
for guaranteed correctness after updates.
2026-03-14 18:25:09 -04:00
bigcat88
f1c07a72c4
convert model_merging and video_model nodes to V3 schema
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-03-11 12:25:59 +02:00
16 changed files with 879 additions and 512 deletions

View File

@ -83,6 +83,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")

View File

@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
output_block[i:i + slice_size].copy_(block)
return output_fp4, to_blocked(output_block, flatten=False)
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
def roundup(x_val, multiple):
return ((x_val + multiple - 1) // multiple) * multiple
if pad_32x:
rows, cols = x.shape
padded_rows = roundup(rows, 32)
padded_cols = roundup(cols, 32)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
F8_E4M3_MAX = 448.0
E8M0_BIAS = 127
BLOCK_SIZE = 32
rows, cols = x.shape
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
# E8M0 block scales (power-of-2 exponents)
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
block_scales_e8m0 = exp_biased.to(torch.uint8)
zero_mask = (max_abs == 0)
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
# Scale per-block then stochastic round
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)

View File

@ -11,6 +11,7 @@ from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
import comfy.model_management
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
@ -536,7 +537,7 @@ class Decoder(nn.Module):
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
output.append(sample.to(comfy.model_management.intermediate_device()))
return
up_block = self.up_blocks[idx]

View File

@ -1050,6 +1050,12 @@ def intermediate_device():
else:
return torch.device("cpu")
def intermediate_dtype():
if args.fp16_intermediates:
return torch.float16
else:
return torch.float32
def vae_device():
if args.cpu_vae:
return torch.device("cpu")
@ -1712,6 +1718,19 @@ def supports_nvfp4_compute(device=None):
return True
def supports_mxfp8_compute(device=None):
if not is_nvidia():
return False
if torch_version_numeric < (2, 10):
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):

View File

@ -336,7 +336,10 @@ class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
# don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
super().__init__(in_features, out_features, bias, device, dtype)
return
@ -357,7 +360,9 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
@ -564,7 +569,10 @@ class disable_weight_init:
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
_freeze=False, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
# don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight,
_freeze, device, dtype)
@ -590,7 +598,9 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
@ -857,6 +867,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
@ -1006,12 +1032,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
if not mxfp8_compute:
disabled.add("mxfp8")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")

View File

@ -43,6 +43,18 @@ except ImportError as e:
def get_layout_class(name):
return None
_CK_MXFP8_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
_CK_MXFP8_AVAILABLE = True
except ImportError:
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
if not _CK_MXFP8_AVAILABLE:
class _CKMxfp8Layout:
pass
import comfy.float
# ==============================================================================
@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
params = cls.Params(
scale=block_scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
)
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
QUANT_ALGOS = {
"float8_e4m3fn": {
@ -157,6 +196,14 @@ QUANT_ALGOS = {
},
}
if _CK_MXFP8_AVAILABLE:
QUANT_ALGOS["mxfp8"] = {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
"group_size": 32,
}
# ==============================================================================
# Re-exports for backward compatibility

View File

@ -871,13 +871,16 @@ class VAE:
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels
def vae_output_dtype(self):
return model_management.intermediate_dtype()
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@ -887,16 +890,16 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
else:
og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@ -905,7 +908,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@ -914,7 +917,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
@ -923,7 +926,7 @@ class VAE:
tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
@ -932,7 +935,7 @@ class VAE:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}):
@ -950,9 +953,9 @@ class VAE:
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
@ -1025,9 +1028,9 @@ class VAE:
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
except Exception as e:

View File

@ -10,146 +10,198 @@ import json
import os
from comfy.cli_args import args
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ModelMergeSimple:
class ModelMergeSimple(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeSimple",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, ratio):
@classmethod
def execute(cls, model1, model2, ratio) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
class ModelSubtract:
merge = execute # TODO: remove
class ModelSubtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeSubtract",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, multiplier):
@classmethod
def execute(cls, model1, model2, multiplier) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return (m, )
return io.NodeOutput(m)
class ModelAdd:
merge = execute # TODO: remove
class ModelAdd(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeAdd",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2):
@classmethod
def execute(cls, model1, model2) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, 1.0, 1.0)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPMergeSimple:
class CLIPMergeSimple(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeSimple",
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, ratio):
@classmethod
def execute(cls, clip1, clip2, ratio) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPSubtract:
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
class CLIPSubtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeSubtract",
search_aliases=["clip difference", "text encoder subtract"],
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, multiplier):
@classmethod
def execute(cls, clip1, clip2, multiplier) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class CLIPAdd:
SEARCH_ALIASES = ["combine clip"]
class CLIPAdd(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeAdd",
search_aliases=["combine clip"],
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2):
@classmethod
def execute(cls, clip1, clip2) -> io.NodeOutput:
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0, 1.0)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
class ModelMergeBlocks:
class ModelMergeBlocks(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
def define_schema(cls):
return io.Schema(
node_id="ModelMergeBlocks",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("input", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("middle", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("out", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, **kwargs):
@classmethod
def execute(cls, model1, model2, **kwargs) -> io.NodeOutput:
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
default_ratio = next(iter(kwargs.values()))
@ -165,7 +217,10 @@ class ModelMergeBlocks:
last_arg_size = len(arg)
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )
return io.NodeOutput(m)
merge = execute # TODO: remove
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
@ -226,59 +281,65 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave:
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
class CheckpointSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CheckpointSave",
display_name="Save Checkpoint",
search_aliases=["save model", "export checkpoint", "merge save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.Clip.Input("clip"),
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
def execute(cls, model, clip, vae, filename_prefix) -> io.NodeOutput:
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
CATEGORY = "advanced/model_merging"
save = execute # TODO: remove
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
class CLIPSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
class CLIPSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPSave",
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip"),
io.String.Input("filename_prefix", default="clip/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip": ("CLIP",),
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
def execute(cls, clip, filename_prefix) -> io.NodeOutput:
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
if cls.hidden.prompt is not None:
prompt_info = json.dumps(cls.hidden.prompt)
metadata = {}
if not args.disable_metadata:
metadata["format"] = "pt"
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd()
output_dir = folder_paths.get_output_directory()
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
current_clip_sd = {}
@ -295,7 +356,7 @@ class CLIPSave:
replace_prefix[prefix] = ""
replace_prefix["transformer."] = ""
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, output_dir)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
@ -303,76 +364,88 @@ class CLIPSave:
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
return {}
return io.NodeOutput()
class VAESave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
save = execute # TODO: remove
class VAESave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VAESave",
category="advanced/model_merging",
inputs=[
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="vae/ComfyUI_vae"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
def execute(cls, vae, filename_prefix) -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
if cls.hidden.prompt is not None:
prompt_info = json.dumps(cls.hidden.prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {}
return io.NodeOutput()
class ModelSave:
SEARCH_ALIASES = ["export model", "checkpoint save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
save = execute # TODO: remove
class ModelSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ModelSave",
search_aliases=["export model", "checkpoint save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.String.Input("filename_prefix", default="diffusion_models/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
def execute(cls, model, filename_prefix) -> io.NodeOutput:
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
CATEGORY = "advanced/model_merging"
save = execute # TODO: remove
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks,
"ModelMergeSubtract": ModelSubtract,
"ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple,
"CLIPMergeSubtract": CLIPSubtract,
"CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
"ModelSave": ModelSave,
}
class ModelMergingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ModelMergeSimple,
ModelMergeBlocks,
ModelSubtract,
ModelAdd,
CheckpointSave,
CLIPMergeSimple,
CLIPSubtract,
CLIPAdd,
CLIPSave,
VAESave,
ModelSave,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointSave": "Save Checkpoint",
}
async def comfy_entrypoint() -> ModelMergingExtension:
return ModelMergingExtension()

View File

@ -1,356 +1,455 @@
import comfy_extras.nodes_model_merging
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
inputs.append(io.Float.Input("time_embed.", **argument))
inputs.append(io.Float.Input("label_emb.", **argument))
for i in range(12):
arg_dict["input_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("input_blocks.{}.".format(i), **argument))
for i in range(3):
arg_dict["middle_block.{}.".format(i)] = argument
inputs.append(io.Float.Input("middle_block.{}.".format(i), **argument))
for i in range(12):
arg_dict["output_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("output_blocks.{}.".format(i), **argument))
arg_dict["out."] = argument
inputs.append(io.Float.Input("out.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeSD1",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
class ModelMergeSD2(ModelMergeSD1):
# SD1 and SD2 have the same blocks
@classmethod
def define_schema(cls):
schema = ModelMergeSD1.define_schema()
schema.node_id = "ModelMergeSD2"
return schema
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
inputs.append(io.Float.Input("time_embed.", **argument))
inputs.append(io.Float.Input("label_emb.", **argument))
for i in range(9):
arg_dict["input_blocks.{}".format(i)] = argument
inputs.append(io.Float.Input("input_blocks.{}".format(i), **argument))
for i in range(3):
arg_dict["middle_block.{}".format(i)] = argument
inputs.append(io.Float.Input("middle_block.{}".format(i), **argument))
for i in range(9):
arg_dict["output_blocks.{}".format(i)] = argument
inputs.append(io.Float.Input("output_blocks.{}".format(i), **argument))
arg_dict["out."] = argument
inputs.append(io.Float.Input("out.", **argument))
return io.Schema(
node_id="ModelMergeSDXL",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embed."] = argument
arg_dict["x_embedder."] = argument
arg_dict["context_embedder."] = argument
arg_dict["y_embedder."] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("pos_embed.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("context_embedder.", **argument))
inputs.append(io.Float.Input("y_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(24):
arg_dict["joint_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeSD3_2B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["init_x_linear."] = argument
arg_dict["positional_encoding"] = argument
arg_dict["cond_seq_linear."] = argument
arg_dict["register_tokens"] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("init_x_linear.", **argument))
inputs.append(io.Float.Input("positional_encoding", **argument))
inputs.append(io.Float.Input("cond_seq_linear.", **argument))
inputs.append(io.Float.Input("register_tokens", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(4):
arg_dict["double_layers.{}.".format(i)] = argument
inputs.append(io.Float.Input("double_layers.{}.".format(i), **argument))
for i in range(32):
arg_dict["single_layers.{}.".format(i)] = argument
inputs.append(io.Float.Input("single_layers.{}.".format(i), **argument))
arg_dict["modF."] = argument
arg_dict["final_linear."] = argument
inputs.append(io.Float.Input("modF.", **argument))
inputs.append(io.Float.Input("final_linear.", **argument))
return io.Schema(
node_id="ModelMergeAuraflow",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["img_in."] = argument
arg_dict["time_in."] = argument
arg_dict["guidance_in"] = argument
arg_dict["vector_in."] = argument
arg_dict["txt_in."] = argument
inputs.append(io.Float.Input("img_in.", **argument))
inputs.append(io.Float.Input("time_in.", **argument))
inputs.append(io.Float.Input("guidance_in", **argument))
inputs.append(io.Float.Input("vector_in.", **argument))
inputs.append(io.Float.Input("txt_in.", **argument))
for i in range(19):
arg_dict["double_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("double_blocks.{}.".format(i), **argument))
for i in range(38):
arg_dict["single_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("single_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeFlux1",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embed."] = argument
arg_dict["x_embedder."] = argument
arg_dict["context_embedder."] = argument
arg_dict["y_embedder."] = argument
arg_dict["t_embedder."] = argument
inputs.append(io.Float.Input("pos_embed.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("context_embedder.", **argument))
inputs.append(io.Float.Input("y_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
for i in range(38):
arg_dict["joint_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("joint_blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeSD35_Large",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_frequencies."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t5_y_embedder."] = argument
arg_dict["t5_yproj."] = argument
inputs.append(io.Float.Input("pos_frequencies.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t5_y_embedder.", **argument))
inputs.append(io.Float.Input("t5_yproj.", **argument))
for i in range(48):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeMochiPreview",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["patchify_proj."] = argument
arg_dict["adaln_single."] = argument
arg_dict["caption_projection."] = argument
inputs.append(io.Float.Input("patchify_proj.", **argument))
inputs.append(io.Float.Input("adaln_single.", **argument))
inputs.append(io.Float.Input("caption_projection.", **argument))
for i in range(28):
arg_dict["transformer_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
arg_dict["scale_shift_table"] = argument
arg_dict["proj_out."] = argument
inputs.append(io.Float.Input("scale_shift_table", **argument))
inputs.append(io.Float.Input("proj_out.", **argument))
return io.Schema(
node_id="ModelMergeLTXV",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmos7B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["extra_pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["affline_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("affline_norm.", **argument))
for i in range(28):
arg_dict["blocks.block{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmos7B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["extra_pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["affline_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("extra_pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("affline_norm.", **argument))
for i in range(36):
arg_dict["blocks.block{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.block{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmos14B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["patch_embedding."] = argument
arg_dict["time_embedding."] = argument
arg_dict["time_projection."] = argument
arg_dict["text_embedding."] = argument
arg_dict["img_emb."] = argument
inputs.append(io.Float.Input("patch_embedding.", **argument))
inputs.append(io.Float.Input("time_embedding.", **argument))
inputs.append(io.Float.Input("time_projection.", **argument))
inputs.append(io.Float.Input("text_embedding.", **argument))
inputs.append(io.Float.Input("img_emb.", **argument))
for i in range(40):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["head."] = argument
inputs.append(io.Float.Input("head.", **argument))
return io.Schema(
node_id="ModelMergeWAN2_1",
category="advanced/model_merging/model_specific",
description="1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb.",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
for i in range(28):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmosPredict2_2B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
inputs.append(io.Float.Input("pos_embedder.", **argument))
inputs.append(io.Float.Input("x_embedder.", **argument))
inputs.append(io.Float.Input("t_embedder.", **argument))
inputs.append(io.Float.Input("t_embedding_norm.", **argument))
for i in range(36):
arg_dict["blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("blocks.{}.".format(i), **argument))
arg_dict["final_layer."] = argument
inputs.append(io.Float.Input("final_layer.", **argument))
return io.Schema(
node_id="ModelMergeCosmosPredict2_14B",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
return {"required": arg_dict}
class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
]
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
argument = dict(default=1.0, min=0.0, max=1.0, step=0.01)
arg_dict["pos_embeds."] = argument
arg_dict["img_in."] = argument
arg_dict["txt_norm."] = argument
arg_dict["txt_in."] = argument
arg_dict["time_text_embed."] = argument
inputs.append(io.Float.Input("pos_embeds.", **argument))
inputs.append(io.Float.Input("img_in.", **argument))
inputs.append(io.Float.Input("txt_norm.", **argument))
inputs.append(io.Float.Input("txt_in.", **argument))
inputs.append(io.Float.Input("time_text_embed.", **argument))
for i in range(60):
arg_dict["transformer_blocks.{}.".format(i)] = argument
inputs.append(io.Float.Input("transformer_blocks.{}.".format(i), **argument))
arg_dict["proj_out."] = argument
inputs.append(io.Float.Input("proj_out.", **argument))
return {"required": arg_dict}
return io.Schema(
node_id="ModelMergeQwenImage",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[io.Model.Output()],
)
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSD3_2B": ModelMergeSD3_2B,
"ModelMergeAuraflow": ModelMergeAuraflow,
"ModelMergeFlux1": ModelMergeFlux1,
"ModelMergeSD35_Large": ModelMergeSD35_Large,
"ModelMergeMochiPreview": ModelMergeMochiPreview,
"ModelMergeLTXV": ModelMergeLTXV,
"ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
"ModelMergeQwenImage": ModelMergeQwenImage,
}
class ModelMergingModelSpecificExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ModelMergeSD1,
ModelMergeSD2,
ModelMergeSDXL,
ModelMergeSD3_2B,
ModelMergeAuraflow,
ModelMergeFlux1,
ModelMergeSD35_Large,
ModelMergeMochiPreview,
ModelMergeLTXV,
ModelMergeCosmos7B,
ModelMergeCosmos14B,
ModelMergeWAN2_1,
ModelMergeCosmosPredict2_2B,
ModelMergeCosmosPredict2_14B,
ModelMergeQwenImage,
]
async def comfy_entrypoint() -> ModelMergingModelSpecificExtension:
return ModelMergingModelSpecificExtension()

View File

@ -6,44 +6,62 @@ import folder_paths
import comfy_extras.nodes_model_merging
import node_helpers
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
class ImageOnlyCheckpointLoader:
class ImageOnlyCheckpointLoader(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
def define_schema(cls):
return io.Schema(
node_id="ImageOnlyCheckpointLoader",
display_name="Image Only Checkpoint Loader (img2vid model)",
category="loaders/video_models",
inputs=[
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("checkpoints")),
],
outputs=[
io.Model.Output(),
io.ClipVision.Output(),
io.Vae.Output(),
],
)
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
@classmethod
def execute(cls, ckpt_name, output_vae=True, output_clip=True) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
return io.NodeOutput(out[0], out[3], out[2])
load_checkpoint = execute # TODO: remove
class SVD_img2vid_Conditioning:
class SVD_img2vid_Conditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023, "advanced": True}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01, "advanced": True})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
def define_schema(cls):
return io.Schema(
node_id="SVD_img2vid_Conditioning",
category="conditioning/video_models",
inputs=[
io.ClipVision.Input("clip_vision"),
io.Image.Input("init_image"),
io.Vae.Input("vae"),
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("video_frames", default=14, min=1, max=4096),
io.Int.Input("motion_bucket_id", default=127, min=1, max=1023, advanced=True),
io.Int.Input("fps", default=6, min=1, max=1024),
io.Float.Input("augmentation_level", default=0.0, min=0.0, max=10.0, step=0.01, advanced=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
@classmethod
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level) -> io.NodeOutput:
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@ -54,20 +72,28 @@ class SVD_img2vid_Conditioning:
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
return io.NodeOutput(positive, negative, {"samples":latent})
class VideoLinearCFGGuidance:
encode = execute # TODO: remove
class VideoLinearCFGGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01, "advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="VideoLinearCFGGuidance",
category="sampling/video_models",
inputs=[
io.Model.Input("model"),
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01, advanced=True),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
@classmethod
def execute(cls, model, min_cfg) -> io.NodeOutput:
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
@ -78,20 +104,28 @@ class VideoLinearCFGGuidance:
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
return io.NodeOutput(m)
class VideoTriangleCFGGuidance:
patch = execute # TODO: remove
class VideoTriangleCFGGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01, "advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="VideoTriangleCFGGuidance",
category="sampling/video_models",
inputs=[
io.Model.Input("model"),
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01, advanced=True),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
@classmethod
def execute(cls, model, min_cfg) -> io.NodeOutput:
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
@ -105,57 +139,79 @@ class VideoTriangleCFGGuidance:
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
return io.NodeOutput(m)
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "advanced/model_merging"
patch = execute # TODO: remove
class ImageOnlyCheckpointSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageOnlyCheckpointSave",
search_aliases=["save model", "export checkpoint", "merge save"],
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.ClipVision.Input("clip_vision"),
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
def execute(cls, model, clip_vision, vae, filename_prefix) -> io.NodeOutput:
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
save = execute # TODO: remove
class ConditioningSetAreaPercentageVideo:
class ConditioningSetAreaPercentageVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"temporal": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"z": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
def define_schema(cls):
return io.Schema(
node_id="ConditioningSetAreaPercentageVideo",
category="conditioning",
inputs=[
io.Conditioning.Input("conditioning"),
io.Float.Input("width", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("height", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("temporal", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("x", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("y", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("z", default=0.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "conditioning"
def append(self, conditioning, width, height, temporal, x, y, z, strength):
@classmethod
def execute(cls, conditioning, width, height, temporal, x, y, z, strength) -> io.NodeOutput:
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x),
"strength": strength,
"set_area_to_bounds": False})
return (c, )
return io.NodeOutput(c)
append = execute # TODO: remove
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
"ConditioningSetAreaPercentageVideo": ConditioningSetAreaPercentageVideo,
}
class VideoModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ImageOnlyCheckpointLoader,
SVD_img2vid_Conditioning,
VideoLinearCFGGuidance,
VideoTriangleCFGGuidance,
ImageOnlyCheckpointSave,
ConditioningSetAreaPercentageVideo,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
}
async def comfy_entrypoint() -> VideoModelExtension:
return VideoModelExtension()

View File

@ -1 +1 @@
comfyui_manager==4.1b4
comfyui_manager==4.1b5

View File

@ -32,7 +32,7 @@ async def cache_control(
)
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
response.headers.setdefault("Cache-Control", "no-cache")
response.headers.setdefault("Cache-Control", "no-store")
return response
# Early return for non-image files - no cache headers needed

View File

@ -1724,6 +1724,8 @@ class LoadImage:
output_masks = []
w, h = None, None
dtype = comfy.model_management.intermediate_dtype()
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
@ -1748,8 +1750,8 @@ class LoadImage:
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
output_images.append(image.to(dtype=dtype))
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
if img.format == "MPO":
break # ignore all frames except the first one for MPO format

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.41.19
comfyui-frontend-package==1.41.20
comfyui-workflow-templates==0.9.21
comfyui-embedded-docs==0.4.3
torch

View File

@ -310,7 +310,7 @@ class PromptServer():
@routes.get("/")
async def get_root(request):
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
response.headers['Cache-Control'] = 'no-cache'
response.headers['Cache-Control'] = 'no-store, must-revalidate'
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response

View File

@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
},
# JavaScript/CSS scenarios
{
"name": "js_no_cache",
"name": "js_no_store",
"path": "/script.js",
"status": 200,
"expected_cache": "no-cache",
"expected_cache": "no-store",
"should_have_header": True,
},
{
"name": "css_no_cache",
"name": "css_no_store",
"path": "/styles.css",
"status": 200,
"expected_cache": "no-cache",
"expected_cache": "no-store",
"should_have_header": True,
},
{
"name": "index_json_no_cache",
"name": "index_json_no_store",
"path": "/api/index.json",
"status": 200,
"expected_cache": "no-cache",
"expected_cache": "no-store",
"should_have_header": True,
},
{
"name": "localized_index_json_no_cache",
"name": "localized_index_json_no_store",
"path": "/templates/index.zh.json",
"status": 200,
"expected_cache": "no-cache",
"expected_cache": "no-store",
"should_have_header": True,
},
# Non-matching files