diff --git a/.github/workflows/test-execution.yml b/.github/workflows/test-execution.yml
new file mode 100644
index 000000000..00ef07ebf
--- /dev/null
+++ b/.github/workflows/test-execution.yml
@@ -0,0 +1,30 @@
+name: Execution Tests
+
+on:
+ push:
+ branches: [ main, master ]
+ pull_request:
+ branches: [ main, master ]
+
+jobs:
+ test:
+ strategy:
+ matrix:
+ os: [ubuntu-latest, windows-latest, macos-latest]
+ runs-on: ${{ matrix.os }}
+ continue-on-error: true
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.12'
+ - name: Install requirements
+ run: |
+ python -m pip install --upgrade pip
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
+ pip install -r requirements.txt
+ pip install -r tests-unit/requirements.txt
+ - name: Run Execution Tests
+ run: |
+ python -m pytest tests/execution -v --skip-timing-checks
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 638d01119..fa12cd875 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -143,8 +143,9 @@ class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
+ AutoTune = "autotune"
-parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
+parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
diff --git a/comfy/clip_model.py b/comfy/clip_model.py
index 7e47d8a55..7c0cadab5 100644
--- a/comfy/clip_model.py
+++ b/comfy/clip_model.py
@@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module):
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
+ all_intermediate = None
if intermediate_output is not None:
- if intermediate_output < 0:
+ if intermediate_output == "all":
+ all_intermediate = []
+ intermediate_output = None
+ elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
@@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
+ if all_intermediate is not None:
+ all_intermediate.append(x.unsqueeze(1).clone())
+
+ if all_intermediate is not None:
+ intermediate = torch.cat(all_intermediate, dim=1)
+
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py
index 00aab9164..447b1ce4a 100644
--- a/comfy/clip_vision.py
+++ b/comfy/clip_vision.py
@@ -50,7 +50,13 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
- model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
+ model_type = config.get("model_type", "clip_vision_model")
+ model_class = IMAGE_ENCODERS.get(model_type)
+ if model_type == "siglip_vision_model":
+ self.return_all_hidden_states = True
+ else:
+ self.return_all_hidden_states = False
+
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@@ -68,12 +74,18 @@ class ClipVisionModel():
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
- out = self.model(pixel_values=pixel_values, intermediate_output=-2)
+ out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
- outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
+ if self.return_all_hidden_states:
+ all_hs = out[1].to(comfy.model_management.intermediate_device())
+ outputs["penultimate_hidden_states"] = all_hs[:, -2]
+ outputs["all_hidden_states"] = all_hs
+ else:
+ outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
+
outputs["mm_projected"] = out[3]
return outputs
@@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
- elif "embeddings.patch_embeddings.projection.weight" in sd:
+
+ # Dinov2
+ elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
+ elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
+ json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
else:
return None
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 54d83c069..11273e32f 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -288,7 +288,10 @@ class ControlNet(ControlBase):
to_concat = []
for c in self.extra_concat_orig:
c = c.to(self.cond_hint.device)
- c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
+ c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
+ if c.ndim < self.cond_hint.ndim:
+ c = c.unsqueeze(2)
+ c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
@@ -628,11 +631,18 @@ def load_controlnet_flux_instantx(sd, model_options={}):
def load_controlnet_qwen_instantx(sd, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
- control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
+ control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
+
+ extra_condition_channels = 0
+ concat_mask = False
+ if control_latent_channels == 68: #inpaint controlnet
+ extra_condition_channels = control_latent_channels - 64
+ concat_mask = True
+ control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.Wan21()
extra_conds = []
- control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
+ control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):
diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py
index 976f98c65..9b6dace9d 100644
--- a/comfy/image_encoders/dino2.py
+++ b/comfy/image_encoders/dino2.py
@@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
+class Dinov2MLP(torch.nn.Module):
+ def __init__(self, hidden_size: int, dtype, device, operations):
+ super().__init__()
+
+ mlp_ratio = 4
+ hidden_features = int(hidden_size * mlp_ratio)
+ self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
+ self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = torch.nn.functional.gelu(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
@@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
class Dino2Block(torch.nn.Module):
- def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
+ def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
- self.mlp = SwiGLUFFN(dim, dtype, device, operations)
+ if use_swiglu_ffn:
+ self.mlp = SwiGLUFFN(dim, dtype, device, operations)
+ else:
+ self.mlp = Dinov2MLP(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
@@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
class Dino2Encoder(torch.nn.Module):
- def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
+ def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
super().__init__()
- self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
+ self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
+ for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
@@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
- for i, l in enumerate(self.layer):
- x = l(x, optimized_attention)
+ for i, layer in enumerate(self.layer):
+ x = layer(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
@@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
+ use_swiglu_ffn = config_dict["use_swiglu_ffn"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
- self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
+ self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
diff --git a/comfy/image_encoders/dino2_large.json b/comfy/image_encoders/dino2_large.json
new file mode 100644
index 000000000..43fbb58ff
--- /dev/null
+++ b/comfy/image_encoders/dino2_large.json
@@ -0,0 +1,22 @@
+{
+ "hidden_size": 1024,
+ "use_mask_token": true,
+ "patch_size": 14,
+ "image_size": 518,
+ "num_channels": 3,
+ "num_attention_heads": 16,
+ "initializer_range": 0.02,
+ "attention_probs_dropout_prob": 0.0,
+ "hidden_dropout_prob": 0.0,
+ "hidden_act": "gelu",
+ "mlp_ratio": 4,
+ "model_type": "dinov2",
+ "num_hidden_layers": 24,
+ "layer_norm_eps": 1e-6,
+ "qkv_bias": true,
+ "use_swiglu_ffn": false,
+ "layerscale_value": 1.0,
+ "drop_path_rate": 0.0,
+ "image_mean": [0.485, 0.456, 0.406],
+ "image_std": [0.229, 0.224, 0.225]
+}
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index fe6844b17..2d7e09838 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
return sigmas
+def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
+ """Compute the result of h*phi_1(h) in exponential integrator methods."""
+ return torch.expm1(h)
+
+
+def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
+ """Compute the result of h*phi_2(h) in exponential integrator methods."""
+ return (torch.expm1(h) - h) / h
+
+
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -1550,13 +1560,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
- arXiv: https://arxiv.org/abs/2305.14267
+ arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
-
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1564,55 +1573,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
+ fac = 1 / (2 * r)
+
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
if sigmas[i + 1] == 0:
x = denoised
- else:
- lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
- h = lambda_t - lambda_s
- h_eta = h * (eta + 1)
- lambda_s_1 = lambda_s + r * h
- fac = 1 / (2 * r)
- sigma_s_1 = sigma_fn(lambda_s_1)
+ continue
- # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
- alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
- alpha_t = sigmas[i + 1] * lambda_t.exp()
+ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
+ h = lambda_t - lambda_s
+ h_eta = h * (eta + 1)
+ lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
+ sigma_s_1 = sigma_fn(lambda_s_1)
- coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
- if inject_noise:
- # 0 < r < 1
- noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
- noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
- noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
+ alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
+ alpha_t = sigmas[i + 1] * lambda_t.exp()
- # Step 1
- x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
- if inject_noise:
- x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
- denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
+ # Step 1
+ x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
+ if inject_noise:
+ sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
+ x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
+ denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
- # Step 2
- denoised_d = (1 - fac) * denoised + fac * denoised_2
- x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
- if inject_noise:
- x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
+ # Step 2
+ denoised_d = torch.lerp(denoised, denoised_2, fac)
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
+ if inject_noise:
+ segment_factor = (r - 1) * h * eta
+ sde_noise = sde_noise * segment_factor.exp()
+ sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
+ x = x + sde_noise * sigmas[i + 1] * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
- arXiv: https://arxiv.org/abs/2305.14267
+ arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
-
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1624,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
if sigmas[i + 1] == 0:
x = denoised
- else:
- lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
- h = lambda_t - lambda_s
- h_eta = h * (eta + 1)
- lambda_s_1 = lambda_s + r_1 * h
- lambda_s_2 = lambda_s + r_2 * h
- sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
+ continue
- # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
- alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
- alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
- alpha_t = sigmas[i + 1] * lambda_t.exp()
+ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
+ h = lambda_t - lambda_s
+ h_eta = h * (eta + 1)
+ lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
+ lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
+ sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
- coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
- if inject_noise:
- # 0 < r_1 < r_2 < 1
- noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
- noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
- noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
- noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
+ alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
+ alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
+ alpha_t = sigmas[i + 1] * lambda_t.exp()
- # Step 1
- x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
- if inject_noise:
- x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
- denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
+ # Step 1
+ x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
+ if inject_noise:
+ sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
+ x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
+ denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
- # Step 2
- x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
- if inject_noise:
- x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
- denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
+ # Step 2
+ a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
+ a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
+ x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
+ if inject_noise:
+ segment_factor = (r_1 - r_2) * h * eta
+ sde_noise = sde_noise * segment_factor.exp()
+ sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
+ x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
+ denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
- # Step 3
- x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
- if inject_noise:
- x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
+ # Step 3
+ b3 = ei_h_phi_2(-h_eta) / r_2
+ b1 = ei_h_phi_1(-h_eta) - b3
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
+ if inject_noise:
+ segment_factor = (r_2 - 1) * h * eta
+ sde_noise = sde_noise * segment_factor.exp()
+ sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
+ x = x + sde_noise * sigmas[i + 1] * s_noise
return x
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index caf4991fc..f975b5e11 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -533,11 +533,89 @@ class Wan22(Wan21):
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
+class HunyuanImage21(LatentFormat):
+ latent_channels = 64
+ latent_dimensions = 2
+ scale_factor = 0.75289
+
+ latent_rgb_factors = [
+ [-0.0154, -0.0397, -0.0521],
+ [ 0.0005, 0.0093, 0.0006],
+ [-0.0805, -0.0773, -0.0586],
+ [-0.0494, -0.0487, -0.0498],
+ [-0.0212, -0.0076, -0.0261],
+ [-0.0179, -0.0417, -0.0505],
+ [ 0.0158, 0.0310, 0.0239],
+ [ 0.0409, 0.0516, 0.0201],
+ [ 0.0350, 0.0553, 0.0036],
+ [-0.0447, -0.0327, -0.0479],
+ [-0.0038, -0.0221, -0.0365],
+ [-0.0423, -0.0718, -0.0654],
+ [ 0.0039, 0.0368, 0.0104],
+ [ 0.0655, 0.0217, 0.0122],
+ [ 0.0490, 0.1638, 0.2053],
+ [ 0.0932, 0.0829, 0.0650],
+ [-0.0186, -0.0209, -0.0135],
+ [-0.0080, -0.0076, -0.0148],
+ [-0.0284, -0.0201, 0.0011],
+ [-0.0642, -0.0294, -0.0777],
+ [-0.0035, 0.0076, -0.0140],
+ [ 0.0519, 0.0731, 0.0887],
+ [-0.0102, 0.0095, 0.0704],
+ [ 0.0068, 0.0218, -0.0023],
+ [-0.0726, -0.0486, -0.0519],
+ [ 0.0260, 0.0295, 0.0263],
+ [ 0.0250, 0.0333, 0.0341],
+ [ 0.0168, -0.0120, -0.0174],
+ [ 0.0226, 0.1037, 0.0114],
+ [ 0.2577, 0.1906, 0.1604],
+ [-0.0646, -0.0137, -0.0018],
+ [-0.0112, 0.0309, 0.0358],
+ [-0.0347, 0.0146, -0.0481],
+ [ 0.0234, 0.0179, 0.0201],
+ [ 0.0157, 0.0313, 0.0225],
+ [ 0.0423, 0.0675, 0.0524],
+ [-0.0031, 0.0027, -0.0255],
+ [ 0.0447, 0.0555, 0.0330],
+ [-0.0152, 0.0103, 0.0299],
+ [-0.0755, -0.0489, -0.0635],
+ [ 0.0853, 0.0788, 0.1017],
+ [-0.0272, -0.0294, -0.0471],
+ [ 0.0440, 0.0400, -0.0137],
+ [ 0.0335, 0.0317, -0.0036],
+ [-0.0344, -0.0621, -0.0984],
+ [-0.0127, -0.0630, -0.0620],
+ [-0.0648, 0.0360, 0.0924],
+ [-0.0781, -0.0801, -0.0409],
+ [ 0.0363, 0.0613, 0.0499],
+ [ 0.0238, 0.0034, 0.0041],
+ [-0.0135, 0.0258, 0.0310],
+ [ 0.0614, 0.1086, 0.0589],
+ [ 0.0428, 0.0350, 0.0205],
+ [ 0.0153, 0.0173, -0.0018],
+ [-0.0288, -0.0455, -0.0091],
+ [ 0.0344, 0.0109, -0.0157],
+ [-0.0205, -0.0247, -0.0187],
+ [ 0.0487, 0.0126, 0.0064],
+ [-0.0220, -0.0013, 0.0074],
+ [-0.0203, -0.0094, -0.0048],
+ [-0.0719, 0.0429, -0.0442],
+ [ 0.1042, 0.0497, 0.0356],
+ [-0.0659, -0.0578, -0.0280],
+ [-0.0060, -0.0322, -0.0234]]
+
+ latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
+
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 0.9990943042622529
+class Hunyuan3Dv2_1(LatentFormat):
+ scale_factor = 1.0039506158752403
+ latent_channels = 64
+ latent_dimensions = 1
+
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py
index 179c5b67e..d0d69bbdc 100644
--- a/comfy/ldm/audio/dit.py
+++ b/comfy/ldm/audio/dit.py
@@ -632,7 +632,7 @@ class ContinuousTransformer(nn.Module):
# Attention layers
if self.rotary_pos_emb is not None:
- rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
else:
rotary_pos_emb = None
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index 1344c3a57..8ea7d4f57 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -106,6 +106,7 @@ class Flux(nn.Module):
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
+ patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -117,9 +118,17 @@ class Flux(nn.Module):
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
- vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
+ vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
txt = self.txt_in(txt)
+ if "post_input" in patches:
+ for p in patches["post_input"]:
+ out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
+ img = out["img"]
+ txt = out["txt"]
+ img_ids = out["img_ids"]
+ txt_ids = out["txt_ids"]
+
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -233,12 +242,18 @@ class Flux(nn.Module):
h = 0
w = 0
index = 0
- index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
+ ref_latents_method = kwargs.get("ref_latents_method", "offset")
for ref in ref_latents:
- if index_ref_method:
+ if ref_latents_method == "index":
index += 1
h_offset = 0
w_offset = 0
+ elif ref_latents_method == "uxo":
+ index = 0
+ h_offset = h_len * patch_size + h
+ w_offset = w_len * patch_size + w
+ h += ref.shape[-2]
+ w += ref.shape[-1]
else:
index = 1
h_offset = 0
diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py
index 6e8cbf1d9..760944827 100644
--- a/comfy/ldm/hunyuan3d/vae.py
+++ b/comfy/ldm/hunyuan3d/vae.py
@@ -4,81 +4,458 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
-
-from typing import Union, Tuple, List, Callable, Optional
-
import numpy as np
-from einops import repeat, rearrange
+import math
from tqdm import tqdm
+
+from typing import Optional
+
import logging
import comfy.ops
ops = comfy.ops.disable_weight_init
-def generate_dense_grid_points(
- bbox_min: np.ndarray,
- bbox_max: np.ndarray,
- octree_resolution: int,
- indexing: str = "ij",
-):
- length = bbox_max - bbox_min
- num_cells = octree_resolution
+def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
- x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
- y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
- z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
- [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
- xyz = np.stack((xs, ys, zs), axis=-1)
- grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
+ # manually create the pointer vector
+ assert src.size(0) == batch.numel()
- return xyz, grid_size, length
+ batch_size = int(batch.max()) + 1
+ deg = src.new_zeros(batch_size, dtype = torch.long)
+
+ deg.scatter_add_(0, batch, torch.ones_like(batch))
+
+ ptr_vec = deg.new_zeros(batch_size + 1)
+ torch.cumsum(deg, 0, out=ptr_vec[1:])
+
+ #return fps_sampling(src, ptr_vec, ratio)
+ sampled_indicies = []
+
+ for b in range(batch_size):
+ # start and the end of each batch
+ start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
+ # points from the point cloud
+ points = src[start:end]
+
+ num_points = points.size(0)
+ num_samples = max(1, math.ceil(num_points * sampling_ratio))
+
+ selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
+ distances = torch.full((num_points,), float("inf"), device = src.device)
+
+ # select a random start point
+ if start_random:
+ farthest = torch.randint(0, num_points, (1,), device = src.device)
+ else:
+ farthest = torch.tensor([0], device = src.device, dtype = torch.long)
+
+ for i in range(num_samples):
+ selected[i] = farthest
+ centroid = points[farthest].squeeze(0)
+ dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
+ distances = torch.minimum(distances, dist)
+ farthest = torch.argmax(distances)
+
+ sampled_indicies.append(torch.arange(start, end)[selected])
+
+ return torch.cat(sampled_indicies, dim = 0)
+class PointCrossAttention(nn.Module):
+ def __init__(self,
+ num_latents: int,
+ downsample_ratio: float,
+ pc_size: int,
+ pc_sharpedge_size: int,
+ point_feats: int,
+ width: int,
+ heads: int,
+ layers: int,
+ fourier_embedder,
+ normal_pe: bool = False,
+ qkv_bias: bool = False,
+ use_ln_post: bool = True,
+ qk_norm: bool = True):
+
+ super().__init__()
+
+ self.fourier_embedder = fourier_embedder
+
+ self.pc_size = pc_size
+ self.normal_pe = normal_pe
+ self.downsample_ratio = downsample_ratio
+ self.pc_sharpedge_size = pc_sharpedge_size
+ self.num_latents = num_latents
+ self.point_feats = point_feats
+
+ self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
+
+ self.cross_attn = ResidualCrossAttentionBlock(
+ width = width,
+ heads = heads,
+ qkv_bias = qkv_bias,
+ qk_norm = qk_norm
+ )
+
+ self.self_attn = None
+ if layers > 0:
+ self.self_attn = Transformer(
+ width = width,
+ heads = heads,
+ qkv_bias = qkv_bias,
+ qk_norm = qk_norm,
+ layers = layers
+ )
+
+ if use_ln_post:
+ self.ln_post = nn.LayerNorm(width)
+ else:
+ self.ln_post = None
+
+ def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
+
+ """
+ Subsample points randomly from the point cloud (input_pc)
+ Further sample the subsampled points to get query_pc
+ take the fourier embeddings for both input and query pc
+
+ Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
+ Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
+ More computationally efficient.
+
+ Features are additional information for each point in the cloud
+ """
+
+ B, _, D = point_cloud.shape
+
+ num_latents = int(self.num_latents)
+
+ num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
+ num_sharpedge_query = num_latents - num_random_query
+
+ # Split random and sharpedge surface points
+ random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
+
+ # assert statements
+ assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
+ assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
+
+ input_random_pc_size = int(num_random_query * self.downsample_ratio)
+ random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
+ self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
+
+ input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
+
+ if input_sharpedge_pc_size == 0:
+ sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
+ sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
+
+ else:
+ sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
+ self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
+
+ # concat the random and sharpedges
+ query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
+ input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
+
+ query = self.fourier_embedder(query_pc)
+ data = self.fourier_embedder(input_pc)
+
+ if self.point_feats > 0:
+ random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
+
+ input_random_surface_features, query_random_features = \
+ self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
+ input_pc_size = input_random_pc_size, idx_query = random_idx_query)
+
+ if input_sharpedge_pc_size == 0:
+ input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
+ dtype = input_random_surface_features.dtype, device = point_cloud.device)
+
+ query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
+ dtype = query_random_features.dtype, device = point_cloud.device)
+ else:
+
+ input_sharpedge_surface_features, query_sharpedge_features = \
+ self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
+ batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
+
+ query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
+ input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
+
+ if self.normal_pe:
+ # apply the fourier embeddings on the first 3 dims (xyz)
+ input_features_pe = self.fourier_embedder(input_features[..., :3])
+ query_features_pe = self.fourier_embedder(query_features[..., :3])
+ # replace the first 3 dims with the new PE ones
+ input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
+ query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
+
+ # concat at the channels dim
+ query = torch.cat([query, query_features], dim = -1)
+ data = torch.cat([data, input_features], dim = -1)
+
+ # don't return pc_info to avoid unnecessary memory usuage
+ return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
+
+ def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
+
+ query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
+
+ # apply projections
+ query = self.input_proj(query)
+ data = self.input_proj(data)
+
+ # apply cross attention between query and data
+ latents = self.cross_attn(query, data)
+
+ if self.self_attn is not None:
+ latents = self.self_attn(latents)
+
+ if self.ln_post is not None:
+ latents = self.ln_post(latents)
+
+ return latents
-class VanillaVolumeDecoder:
+ def subsample(self, pc, num_query, input_pc_size: int):
+
+ """
+ num_query: number of points to keep after FPS
+ input_pc_size: number of points to select before FPS
+ """
+
+ B, _, D = pc.shape
+ query_ratio = num_query / input_pc_size
+
+ # random subsampling of points inside the point cloud
+ idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
+ input_pc = pc[:, idx_pc, :]
+
+ # flatten to allow applying fps across the whole batch
+ flattent_input_pc = input_pc.view(B * input_pc_size, D)
+
+ # construct a batch_down tensor to tell fps
+ # which points belong to which batch
+ N_down = int(flattent_input_pc.shape[0] / B)
+ batch_down = torch.arange(B).to(pc.device)
+ batch_down = torch.repeat_interleave(batch_down, N_down)
+
+ idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
+ query_pc = flattent_input_pc[idx_query].view(B, -1, D)
+
+ return query_pc, input_pc, idx_pc, idx_query
+
+ def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
+
+ B = batch_size
+
+ input_surface_features = features[:, idx_pc, :]
+ flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
+ query_features = flattent_input_features[idx_query].view(B, -1,
+ flattent_input_features.shape[-1])
+
+ return input_surface_features, query_features
+
+def normalize_mesh(mesh, scale = 0.9999):
+ """Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
+
+ bbox = mesh.bounds
+ center = (bbox[1] + bbox[0]) / 2
+
+ max_extent = (bbox[1] - bbox[0]).max()
+ mesh.apply_translation(-center)
+ mesh.apply_scale((2 * scale) / max_extent)
+
+ return mesh
+
+def sample_pointcloud(mesh, num = 200000):
+ """ Uniformly sample points from the surface of the mesh """
+
+ points, face_idx = mesh.sample(num, return_index = True)
+ normals = mesh.face_normals[face_idx]
+ return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
+
+def detect_sharp_edges(mesh, threshold=0.985):
+ """Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
+
+ V, F = mesh.vertices, mesh.faces
+ VN, FN = mesh.vertex_normals, mesh.face_normals
+
+ sharp_mask = np.ones(V.shape[0])
+ for i in range(3):
+ indices = F[:, i]
+ alignment = np.einsum('ij,ij->i', VN[indices], FN)
+ dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
+ sharp_mask[indices] = np.min(dot_stack, axis=-1)
+
+ edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
+ edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
+ sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
+
+ return edge_a[sharp_edges], edge_b[sharp_edges]
+
+
+def sharp_sample_pointcloud(mesh, num = 16384):
+ """ Sample points preferentially from sharp edges in the mesh. """
+
+ edge_a, edge_b = detect_sharp_edges(mesh)
+ V, VN = mesh.vertices, mesh.vertex_normals
+
+ va, vb = V[edge_a], V[edge_b]
+ na, nb = VN[edge_a], VN[edge_b]
+
+ edge_lengths = np.linalg.norm(vb - va, axis=-1)
+ weights = edge_lengths / edge_lengths.sum()
+
+ indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
+ t = np.random.rand(num, 1)
+
+ samples = t * va[indices] + (1 - t) * vb[indices]
+ normals = t * na[indices] + (1 - t) * nb[indices]
+
+ return samples.astype(np.float32), normals.astype(np.float32)
+
+def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
+ """Load a surface with optional sharp-edge annotations from a trimesh mesh."""
+
+ import trimesh
+
+ try:
+ mesh_full = trimesh.util.concatenate(mesh.dump())
+ except Exception:
+ mesh_full = trimesh.util.concatenate(mesh)
+
+ mesh_full = normalize_mesh(mesh_full)
+
+ faces = mesh_full.faces
+ vertices = mesh_full.vertices
+ origin_face_count = faces.shape[0]
+
+ mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
+ mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
+
+ area_surface = mesh_surface.area
+ area_fill = mesh_fill.area
+ total_area = area_surface + area_fill
+
+ sample_num = 499712 // 2
+ fill_ratio = area_fill / total_area if total_area > 0 else 0
+
+ num_fill = int(sample_num * fill_ratio)
+ num_surface = sample_num - num_fill
+
+ surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
+ fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
+
+ sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
+
+ def assemble_tensor(points, normals, label=None):
+
+ data = torch.cat([points, normals], dim=1).half().to(device)
+
+ if label is not None:
+ label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
+ data = torch.cat([data, label_tensor], dim=1)
+
+ return data
+
+ surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
+ torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
+ label = 0 if sharpedge_flag else None)
+
+ sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
+ label = 1 if sharpedge_flag else None)
+
+ rng = np.random.default_rng()
+
+ surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
+ sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
+
+ full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
+
+ return full
+
+class SharpEdgeSurfaceLoader:
+ """ Load mesh surface and sharp edge samples. """
+
+ def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
+
+ self.num_uniform_points = num_uniform_points
+ self.num_sharp_points = num_sharp_points
+ self.total_points = num_uniform_points + num_sharp_points
+
+ def __call__(self, mesh_input, device = "cuda"):
+ mesh = self._load_mesh(mesh_input)
+ return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
+
+ @staticmethod
+ def _load_mesh(mesh_input):
+ import trimesh
+
+ if isinstance(mesh_input, str):
+ mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
+ else:
+ mesh = mesh_input
+
+ if isinstance(mesh, trimesh.Scene):
+ combined = None
+ for obj in mesh.geometry.values():
+ combined = obj if combined is None else combined + obj
+ return combined
+
+ return mesh
+
+class DiagonalGaussianDistribution:
+ def __init__(self, params: torch.Tensor, feature_dim: int = -1):
+
+ # divide quant channels (8) into mean and log variance
+ self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
+
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.std = torch.exp(0.5 * self.logvar)
+
+ def sample(self):
+
+ eps = torch.randn_like(self.std)
+ z = self.mean + eps * self.std
+
+ return z
+
+################################################
+# Volume Decoder
+################################################
+
+class VanillaVolumeDecoder():
@torch.no_grad()
- def __call__(
- self,
- latents: torch.FloatTensor,
- geo_decoder: Callable,
- bounds: Union[Tuple[float], List[float], float] = 1.01,
- num_chunks: int = 10000,
- octree_resolution: int = None,
- enable_pbar: bool = True,
- **kwargs,
- ):
- device = latents.device
- dtype = latents.dtype
- batch_size = latents.shape[0]
+ def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
+ num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
- # 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
- bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
- xyz_samples, grid_size, length = generate_dense_grid_points(
- bbox_min=bbox_min,
- bbox_max=bbox_max,
- octree_resolution=octree_resolution,
- indexing="ij"
- )
- xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
+ bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
+
+ x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
+ y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
+ z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
+
+ [xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
+ xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
+ grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
- # 2. latents to 3d volume
batch_logits = []
- for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
+ for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
disable=not enable_pbar):
- chunk_queries = xyz_samples[start: start + num_chunks, :]
- chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
- logits = geo_decoder(queries=chunk_queries, latents=latents)
+
+ chunk_queries = xyz[start: start + num_chunks, :]
+ chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
+ logits = geo_decoder(queries = chunk_queries, latents = latents)
batch_logits.append(logits)
- grid_logits = torch.cat(batch_logits, dim=1)
- grid_logits = grid_logits.view((batch_size, *grid_size)).float()
+ grid_logits = torch.cat(batch_logits, dim = 1)
+ grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
return grid_logits
-
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
@@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
else:
return x
-
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = comfy.ops.scaled_dot_product_attention(q, k, v)
return out
-
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
@@ -232,38 +607,41 @@ class MLP(nn.Module):
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
-
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
- *,
heads: int,
+ n_data = None,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
+ self.n_data = n_data
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
- self.attn_processor = CrossAttentionProcessor()
-
def forward(self, q, kv):
+
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
+
attn_ch = width // self.heads // 2
q = q.view(bs, n_ctx, self.heads, -1)
+
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
- q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
- out = self.attn_processor(self, q, k, v)
- out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
- return out
+ q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
+ out = F.scaled_dot_product_attention(q, k, v)
+
+ out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
+
+ return out
class MultiheadCrossAttention(nn.Module):
def __init__(
@@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
x = self.c_proj(x)
return x
-
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
@@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
- q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
+ q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
@@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
drop_path_rate: float = 0.0
):
super().__init__()
- self.width = width
- self.heads = heads
+
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadAttention(
@@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
if self.downsample_ratio != 1:
self.latents_proj = ops.Linear(width * downsample_ratio, width)
- if self.enable_ln_post == False:
+ if not self.enable_ln_post:
qk_norm = False
self.cross_attn_decoder = ResidualCrossAttentionBlock(
width=width,
@@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
class ShapeVAE(nn.Module):
def __init__(
- self,
- *,
- embed_dim: int,
- width: int,
- heads: int,
- num_decoder_layers: int,
- geo_decoder_downsample_ratio: int = 1,
- geo_decoder_mlp_expand_ratio: int = 4,
- geo_decoder_ln_post: bool = True,
- num_freqs: int = 8,
- include_pi: bool = True,
- qkv_bias: bool = True,
- qk_norm: bool = False,
- label_type: str = "binary",
- drop_path_rate: float = 0.0,
- scale_factor: float = 1.0,
+ self,
+ *,
+ num_latents: int = 4096,
+ embed_dim: int = 64,
+ width: int = 1024,
+ heads: int = 16,
+ num_decoder_layers: int = 16,
+ num_encoder_layers: int = 8,
+ pc_size: int = 81920,
+ pc_sharpedge_size: int = 0,
+ point_feats: int = 4,
+ downsample_ratio: int = 20,
+ geo_decoder_downsample_ratio: int = 1,
+ geo_decoder_mlp_expand_ratio: int = 4,
+ geo_decoder_ln_post: bool = True,
+ num_freqs: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ drop_path_rate: float = 0.0,
+ include_pi: bool = False,
+ scale_factor: float = 1.0039506158752403,
+ label_type: str = "binary",
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+ self.encoder = PointCrossAttention(layers = num_encoder_layers,
+ num_latents = num_latents,
+ downsample_ratio = downsample_ratio,
+ heads = heads,
+ pc_size = pc_size,
+ width = width,
+ point_feats = point_feats,
+ fourier_embedder = self.fourier_embedder,
+ pc_sharpedge_size = pc_sharpedge_size)
+
self.post_kl = ops.Linear(embed_dim, width)
self.transformer = Transformer(
@@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
return grid_logits.movedim(-2, -1)
- def encode(self, x):
- return None
+ def encode(self, surface):
+
+ pc, feats = surface[:, :, :3], surface[:, :, 3:]
+ latents = self.encoder(pc, feats)
+
+ moments = self.pre_kl(latents)
+ posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
+
+ latents = posterior.sample()
+
+ return latents
diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
new file mode 100644
index 000000000..d48d9d642
--- /dev/null
+++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
@@ -0,0 +1,659 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from comfy.ldm.modules.attention import optimized_attention
+import comfy.model_management
+
+class GELU(nn.Module):
+
+ def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
+ super().__init__()
+ self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+
+ if gate.device.type == "mps":
+ return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
+
+ return F.gelu(gate)
+
+ def forward(self, hidden_states):
+
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+
+ return hidden_states
+
+class FeedForward(nn.Module):
+
+ def __init__(self, dim: int, dim_out = None, mult: int = 4,
+ dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
+
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+
+ dim_out = dim_out if dim_out is not None else dim
+
+ act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
+
+ self.net = nn.ModuleList([])
+ self.net.append(act_fn)
+
+ self.net.append(nn.Dropout(dropout))
+ self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+class AddAuxLoss(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, loss):
+ # do nothing in forward (no computation)
+ ctx.requires_aux_loss = loss.requires_grad
+ ctx.dtype = loss.dtype
+
+ return x
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # add the aux loss gradients
+ grad_loss = None
+ # put the aux grad the same as the main grad loss
+ # aux grad contributes equally
+ if ctx.requires_aux_loss:
+ grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
+
+ return grad_output, grad_loss
+
+class MoEGate(nn.Module):
+
+ def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
+
+ super().__init__()
+ self.top_k = num_experts_per_tok
+ self.n_routed_experts = num_experts
+
+ self.alpha = aux_loss_alpha
+
+ self.gating_dim = embed_dim
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+
+ # flatten hidden states
+ hidden_states = hidden_states.view(-1, hidden_states.size(-1))
+
+ # get logits and pass it to softmax
+ logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
+ scores = logits.softmax(dim = -1)
+
+ topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
+
+ if self.training and self.alpha > 0.0:
+ scores_for_aux = scores
+
+ # used bincount instead of one hot encoding
+ counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
+ ce = counts / topk_idx.numel() # normalized expert usage
+
+ # mean expert score
+ Pi = scores_for_aux.mean(0)
+
+ # expert balance loss
+ aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
+ else:
+ aux_loss = None
+
+ return topk_idx, topk_weight, aux_loss
+
+class MoEBlock(nn.Module):
+ def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
+ ff_inner_dim: int = None, operations = None, device = None, dtype = None):
+ super().__init__()
+
+ self.moe_top_k = moe_top_k
+ self.num_experts = num_experts
+
+ self.experts = nn.ModuleList([
+ FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
+ for _ in range(num_experts)
+ ])
+
+ self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
+ self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
+
+ def forward(self, hidden_states) -> torch.Tensor:
+
+ identity = hidden_states
+ orig_shape = hidden_states.shape
+ topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
+
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ flat_topk_idx = topk_idx.view(-1)
+
+ if self.training:
+
+ hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
+ y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
+
+ for i, expert in enumerate(self.experts):
+ tmp = expert(hidden_states[flat_topk_idx == i])
+ y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
+
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
+ y = y.view(*orig_shape)
+
+ y = AddAuxLoss.apply(y, aux_loss)
+ else:
+ y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
+
+ y = y + self.shared_experts(identity)
+
+ return y
+
+ @torch.no_grad()
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
+
+ expert_cache = torch.zeros_like(x)
+ idxs = flat_expert_indices.argsort()
+
+ # no need for .numpy().cpu() here
+ tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
+ token_idxs = idxs // self.moe_top_k
+
+ for i, end_idx in enumerate(tokens_per_expert):
+
+ start_idx = 0 if i == 0 else tokens_per_expert[i-1]
+
+ if start_idx == end_idx:
+ continue
+
+ expert = self.experts[i]
+ exp_token_idx = token_idxs[start_idx:end_idx]
+
+ expert_tokens = x[exp_token_idx]
+ expert_out = expert(expert_tokens)
+
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
+
+ # use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
+ # + avoid dtype conversion
+ expert_cache.index_add_(0, exp_token_idx, expert_out)
+
+ return expert_cache
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
+ scale: float = 1.0, max_period: int = 10000):
+ super().__init__()
+
+ self.num_channels = num_channels
+ half_dim = num_channels // 2
+
+ # precompute the “inv_freq” vector once
+ exponent = -math.log(max_period) * torch.arange(
+ half_dim, dtype=torch.float32
+ ) / (half_dim - downscale_freq_shift)
+
+ inv_freq = torch.exp(exponent)
+
+ # pad
+ if num_channels % 2 == 1:
+ # we’ll pad a zero at the end of the cos-half
+ inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
+
+ # register to buffer so it moves with the device
+ self.register_buffer("inv_freq", inv_freq, persistent = False)
+ self.scale = scale
+
+ def forward(self, timesteps: torch.Tensor):
+
+ x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
+
+
+ # fused CUDA kernels for sin and cos
+ sin_emb = x.sin()
+ cos_emb = x.cos()
+
+ emb = torch.cat([sin_emb, cos_emb], dim = 1)
+
+ # scale factor
+ if self.scale != 1.0:
+ emb = emb * self.scale
+
+ # If we padded inv_freq for odd, emb is already wide enough; otherwise:
+ if emb.shape[1] > self.num_channels:
+ emb = emb[:, :self.num_channels]
+
+ return emb
+
+class TimestepEmbedder(nn.Module):
+ def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
+ super().__init__()
+
+ self.mlp = nn.Sequential(
+ operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
+ nn.GELU(),
+ operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ if cond_proj_dim is not None:
+ self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
+
+ self.time_embed = Timesteps(hidden_size)
+
+ def forward(self, timesteps, condition):
+
+ timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
+
+ if condition is not None:
+ cond_embed = self.cond_proj(condition)
+ timestep_embed = timestep_embed + cond_embed
+
+ time_conditioned = self.mlp(timestep_embed)
+
+ # for broadcasting with image tokens
+ return time_conditioned.unsqueeze(1)
+
+class MLP(nn.Module):
+ def __init__(self, *, width: int, operations = None, device = None, dtype = None):
+ super().__init__()
+ self.width = width
+ self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
+ self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ return self.fc2(self.gelu(self.fc1(x)))
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ qdim,
+ kdim,
+ num_heads,
+ qkv_bias=True,
+ qk_norm=False,
+ norm_layer=nn.LayerNorm,
+ use_fp16: bool = False,
+ operations = None,
+ dtype = None,
+ device = None,
+ **kwargs,
+ ):
+ super().__init__()
+ self.qdim = qdim
+ self.kdim = kdim
+
+ self.num_heads = num_heads
+ self.head_dim = self.qdim // num_heads
+
+ self.scale = self.head_dim ** -0.5
+
+ self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
+ self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
+ self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
+
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ if norm_layer == nn.LayerNorm:
+ norm_layer = operations.LayerNorm
+ else:
+ norm_layer = operations.RMSNorm
+
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
+
+ def forward(self, x, y):
+
+ b, s1, _ = x.shape
+ _, s2, _ = y.shape
+
+ y = y.to(next(self.to_k.parameters()).dtype)
+
+ q = self.to_q(x)
+ k = self.to_k(y)
+ v = self.to_v(y)
+
+ kv = torch.cat((k, v), dim=-1)
+ split_size = kv.shape[-1] // self.num_heads // 2
+
+ kv = kv.view(1, -1, self.num_heads, split_size * 2)
+ k, v = torch.split(kv, split_size, dim=-1)
+
+ q = q.view(b, s1, self.num_heads, self.head_dim)
+ k = k.view(b, s2, self.num_heads, self.head_dim)
+ v = v.reshape(b, s2, self.num_heads * self.head_dim)
+
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ x = optimized_attention(
+ q.reshape(b, s1, self.num_heads * self.head_dim),
+ k.reshape(b, s2, self.num_heads * self.head_dim),
+ v,
+ heads=self.num_heads,
+ )
+
+ out = self.out_proj(x)
+
+ return out
+
+class Attention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ qkv_bias = True,
+ qk_norm = False,
+ norm_layer = nn.LayerNorm,
+ use_fp16: bool = False,
+ operations = None,
+ device = None,
+ dtype = None
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = self.dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
+ self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
+ self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
+
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ if norm_layer == nn.LayerNorm:
+ norm_layer = operations.LayerNorm
+ else:
+ norm_layer = operations.RMSNorm
+
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
+
+ def forward(self, x):
+ B, N, _ = x.shape
+
+ query = self.to_q(x)
+ key = self.to_k(x)
+ value = self.to_v(x)
+
+ qkv_combined = torch.cat((query, key, value), dim=-1)
+ split_size = qkv_combined.shape[-1] // self.num_heads // 3
+
+ qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ query = query.reshape(B, N, self.num_heads, self.head_dim)
+ key = key.reshape(B, N, self.num_heads, self.head_dim)
+ value = value.reshape(B, N, self.num_heads * self.head_dim)
+
+ query = self.q_norm(query)
+ key = self.k_norm(key)
+
+ x = optimized_attention(
+ query.reshape(B, N, self.num_heads * self.head_dim),
+ key.reshape(B, N, self.num_heads * self.head_dim),
+ value,
+ heads=self.num_heads,
+ )
+
+ x = self.out_proj(x)
+ return x
+
+class HunYuanDiTBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ c_emb_size,
+ num_heads,
+ text_states_dim=1024,
+ qk_norm=False,
+ norm_layer=nn.LayerNorm,
+ qk_norm_layer=True,
+ qkv_bias=True,
+ skip_connection=True,
+ timested_modulate=False,
+ use_moe: bool = False,
+ num_experts: int = 8,
+ moe_top_k: int = 2,
+ use_fp16: bool = False,
+ operations = None,
+ device = None, dtype = None
+ ):
+ super().__init__()
+
+ # eps can't be 1e-6 in fp16 mode because of numerical stability issues
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+
+ self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
+ norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
+
+ self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+
+ self.timested_modulate = timested_modulate
+ if self.timested_modulate:
+ self.default_modulation = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
+ )
+
+ self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
+ qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
+ device = device, dtype = dtype, operations = operations)
+
+ self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+
+ if skip_connection:
+ self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+ self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
+ else:
+ self.skip_linear = None
+
+ self.use_moe = use_moe
+
+ if self.use_moe:
+ self.moe = MoEBlock(
+ hidden_size,
+ num_experts = num_experts,
+ moe_top_k = moe_top_k,
+ dropout = 0.0,
+ ff_inner_dim = int(hidden_size * 4.0),
+ device = device, dtype = dtype,
+ operations = operations
+ )
+ else:
+ self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
+
+ def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
+
+ if self.skip_linear is not None:
+ combined = torch.cat([skip_tensor, hidden_states], dim=-1)
+ hidden_states = self.skip_linear(combined)
+ hidden_states = self.skip_norm(hidden_states)
+
+ # self attention
+ if self.timested_modulate:
+ modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
+ hidden_states = hidden_states + modulation_shift
+
+ self_attn_out = self.attn1(self.norm1(hidden_states))
+ hidden_states = hidden_states + self_attn_out
+
+ # cross attention
+ hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
+
+ # MLP Layer
+ mlp_input = self.norm3(hidden_states)
+
+ if self.use_moe:
+ hidden_states = hidden_states + self.moe(mlp_input)
+ else:
+ hidden_states = hidden_states + self.mlp(mlp_input)
+
+ return hidden_states
+
+class FinalLayer(nn.Module):
+
+ def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
+ super().__init__()
+
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+ self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
+
+ def forward(self, x):
+ x = self.norm_final(x)
+ x = x[:, 1:]
+ x = self.linear(x)
+ return x
+
+class HunYuanDiTPlain(nn.Module):
+
+ # init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
+ def __init__(
+ self,
+ in_channels: int = 64,
+ hidden_size: int = 2048,
+ context_dim: int = 1024,
+ depth: int = 21,
+ num_heads: int = 16,
+ qk_norm: bool = True,
+ qkv_bias: bool = False,
+ num_moe_layers: int = 6,
+ guidance_cond_proj_dim = 2048,
+ norm_type = 'layer',
+ num_experts: int = 8,
+ moe_top_k: int = 2,
+ use_fp16: bool = False,
+ dtype = None,
+ device = None,
+ operations = None,
+ **kwargs
+ ):
+
+ self.dtype = dtype
+
+ super().__init__()
+
+ self.depth = depth
+
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+
+ norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
+ qk_norm = operations.RMSNorm
+
+ self.context_dim = context_dim
+ self.guidance_cond_proj_dim = guidance_cond_proj_dim
+
+ self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
+ self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
+
+
+ # HUnYuanDiT Blocks
+ self.blocks = nn.ModuleList([
+ HunYuanDiTBlock(hidden_size=hidden_size,
+ c_emb_size=hidden_size,
+ num_heads=num_heads,
+ text_states_dim=context_dim,
+ qk_norm=qk_norm,
+ norm_layer = norm,
+ qk_norm_layer = qk_norm,
+ skip_connection=layer > depth // 2,
+ qkv_bias=qkv_bias,
+ use_moe=True if depth - layer <= num_moe_layers else False,
+ num_experts=num_experts,
+ moe_top_k=moe_top_k,
+ use_fp16 = use_fp16,
+ device = device, dtype = dtype, operations = operations)
+ for layer in range(depth)
+ ])
+
+ self.depth = depth
+
+ self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
+
+ def forward(self, x, t, context, transformer_options = {}, **kwargs):
+
+ x = x.movedim(-1, -2)
+ uncond_emb, cond_emb = context.chunk(2, dim = 0)
+
+ context = torch.cat([cond_emb, uncond_emb], dim = 0)
+ main_condition = context
+
+ t = 1.0 - t
+
+ time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
+
+ x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
+ x_embedded = self.x_embedder(x)
+
+ combined = torch.cat([time_embedded, x_embedded], dim=1)
+
+ def block_wrap(args):
+ return block(
+ args["x"],
+ args["t"],
+ args["cond"],
+ skip_tensor=args.get("skip"),)
+
+ skip_stack = []
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ for idx, block in enumerate(self.blocks):
+ if idx <= self.depth // 2:
+ skip_input = None
+ else:
+ skip_input = skip_stack.pop()
+
+ if ("block", idx) in blocks_replace:
+
+ combined = blocks_replace[("block", idx)](
+ {
+ "x": combined,
+ "t": time_embedded,
+ "cond": main_condition,
+ "skip": skip_input,
+ },
+ {"original_block": block_wrap},
+ )
+ else:
+ combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
+
+ if idx < self.depth // 2:
+ skip_stack.append(combined)
+
+ output = self.final_layer(combined)
+ output = output.movedim(-2, -1) * (-1.0)
+
+ cond_emb, uncond_emb = output.chunk(2, dim = 0)
+ return torch.cat([uncond_emb, cond_emb])
diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py
index da1011596..7732182a4 100644
--- a/comfy/ldm/hunyuan_video/model.py
+++ b/comfy/ldm/hunyuan_video/model.py
@@ -40,6 +40,8 @@ class HunyuanVideoParams:
patch_size: list
qkv_bias: bool
guidance_embed: bool
+ byt5: bool
+ meanflow: bool
class SelfAttentionRef(nn.Module):
@@ -161,6 +163,30 @@ class TokenRefiner(nn.Module):
x = self.individual_token_refiner(x, c, mask)
return x
+
+class ByT5Mapper(nn.Module):
+ def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
+ self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
+ self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
+ self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
+ self.use_res = use_res
+ self.act_fn = nn.GELU()
+
+ def forward(self, x):
+ if self.use_res:
+ res = x
+ x = self.layernorm(x)
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.fc2(x)
+ x2 = self.act_fn(x)
+ x2 = self.fc3(x2)
+ if self.use_res:
+ x2 = x2 + res
+ return x2
+
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
@@ -185,9 +211,13 @@ class HunyuanVideo(nn.Module):
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
- self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
+ self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
+ if params.vec_in_dim is not None:
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
+ else:
+ self.vector_in = None
+
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
@@ -215,6 +245,23 @@ class HunyuanVideo(nn.Module):
]
)
+ if params.byt5:
+ self.byt5_in = ByT5Mapper(
+ in_dim=1472,
+ out_dim=2048,
+ hidden_dim=2048,
+ out_dim1=self.hidden_size,
+ use_res=False,
+ dtype=dtype, device=device, operations=operations
+ )
+ else:
+ self.byt5_in = None
+
+ if params.meanflow:
+ self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
+ else:
+ self.time_r_in = None
+
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
@@ -226,7 +273,8 @@ class HunyuanVideo(nn.Module):
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
- y: Tensor,
+ y: Tensor = None,
+ txt_byt5=None,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
@@ -240,6 +288,14 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
+ if self.time_r_in is not None:
+ w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
+ if len(w) > 0:
+ timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
+ timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
+ vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
+ vec = (vec + vec_r) / 2
+
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
@@ -250,13 +306,17 @@ class HunyuanVideo(nn.Module):
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
- vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
- vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
+ if self.vector_in is not None:
+ vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
+ vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
+ else:
+ vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
- vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
+ if self.vector_in is not None:
+ vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None
@@ -269,6 +329,12 @@ class HunyuanVideo(nn.Module):
txt = self.txt_in(txt, timesteps, txt_mask)
+ if self.byt5_in is not None and txt_byt5 is not None:
+ txt_byt5 = self.byt5_in(txt_byt5)
+ txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
+ txt = torch.cat((txt, txt_byt5), dim=1)
+ txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
+
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -328,12 +394,16 @@ class HunyuanVideo(nn.Module):
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
- shape = initial_shape[-3:]
+ shape = initial_shape[-len(self.patch_size):]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
- img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
- img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
+ if img.ndim == 8:
+ img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
+ img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
+ else:
+ img = img.permute(0, 3, 1, 4, 2, 5)
+ img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
return img
def img_ids(self, x):
@@ -348,16 +418,30 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
- def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
+ def img_ids_2d(self, x):
+ bs, c, h, w = x.shape
+ patch_size = self.patch_size
+ h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
+ w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
+ img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
+ img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
+ img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
+ return repeat(img_ids, "h w c -> b (h w) c", b=bs)
+
+ def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
- ).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
+ ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
- def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
- bs, c, t, h, w = x.shape
- img_ids = self.img_ids(x)
- txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
- out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
+ def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
+ bs = x.shape[0]
+ if len(self.patch_size) == 3:
+ img_ids = self.img_ids(x)
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
+ else:
+ img_ids = self.img_ids_2d(x)
+ txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
+ out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out
diff --git a/comfy/ldm/hunyuan_video/vae.py b/comfy/ldm/hunyuan_video/vae.py
new file mode 100644
index 000000000..40c12b183
--- /dev/null
+++ b/comfy/ldm/hunyuan_video/vae.py
@@ -0,0 +1,136 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+
+class PixelShuffle2D(nn.Module):
+ def __init__(self, in_dim, out_dim, op=ops.Conv2d):
+ super().__init__()
+ self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
+ self.ratio = (in_dim << 2) // out_dim
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ h2, w2 = h >> 1, w >> 1
+ y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
+ r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
+ return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
+
+
+class PixelUnshuffle2D(nn.Module):
+ def __init__(self, in_dim, out_dim, op=ops.Conv2d):
+ super().__init__()
+ self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
+ self.scale = (out_dim << 2) // in_dim
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ h2, w2 = h << 1, w << 1
+ y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
+ r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
+ return y + r
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
+ ffactor_spatial, downsample_match_channel=True, **_):
+ super().__init__()
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+ self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
+
+ self.down = nn.ModuleList()
+ ch = block_out_channels[0]
+ depth = (ffactor_spatial >> 1).bit_length()
+
+ for i, tgt in enumerate(block_out_channels):
+ stage = nn.Module()
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=ops.Conv2d)
+ for j in range(num_res_blocks)])
+ ch = tgt
+ if i < depth:
+ nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
+ stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
+ ch = nxt
+ self.down.append(stage)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+ self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+
+ self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
+ self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
+
+ def forward(self, x):
+ x = self.conv_in(x)
+
+ for stage in self.down:
+ for blk in stage.block:
+ x = blk(x)
+ if hasattr(stage, 'downsample'):
+ x = stage.downsample(x)
+
+ x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
+
+ b, c, h, w = x.shape
+ grp = c // (self.z_channels << 1)
+ skip = x.view(b, c // grp, grp, h, w).mean(2)
+
+ return self.conv_out(F.silu(self.norm_out(x))) + skip
+
+
+class Decoder(nn.Module):
+ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
+ ffactor_spatial, upsample_match_channel=True, **_):
+ super().__init__()
+ block_out_channels = block_out_channels[::-1]
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+
+ ch = block_out_channels[0]
+ self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+ self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+
+ self.up = nn.ModuleList()
+ depth = (ffactor_spatial >> 1).bit_length()
+
+ for i, tgt in enumerate(block_out_channels):
+ stage = nn.Module()
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=ops.Conv2d)
+ for j in range(num_res_blocks + 1)])
+ ch = tgt
+ if i < depth:
+ nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
+ stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
+ ch = nxt
+ self.up.append(stage)
+
+ self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
+ self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
+
+ def forward(self, z):
+ x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
+ x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
+
+ for stage in self.up:
+ for blk in stage.block:
+ x = blk(x)
+ if hasattr(stage, 'upsample'):
+ x = stage.upsample(x)
+
+ return self.conv_out(F.silu(self.norm_out(x)))
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index 1fd12b35a..8f598a848 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -145,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
- dropout, temb_channels=512, conv_op=ops.Conv2d):
+ dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@@ -183,7 +183,7 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
- def forward(self, x, temb):
+ def forward(self, x, temb=None):
h = x
h = self.norm1(h)
h = self.swish(h)
diff --git a/comfy/lora.py b/comfy/lora.py
index 00358884b..4a44f1318 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
+ for k in sdk:
+ hidden_size = model.model_config.unet_config.get("hidden_size", 0)
+ if k.endswith(".weight") and ".linear1." in k:
+ key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:
diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py
index 3e00b63db..9d8d21efe 100644
--- a/comfy/lora_convert.py
+++ b/comfy/lora_convert.py
@@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
def convert_lora_wan_fun(sd): #Wan Fun loras
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
+def convert_uso_lora(sd):
+ sd_out = {}
+ for k in sd:
+ tensor = sd[k]
+ k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
+ .replace(".up.weight", ".lora_up.weight")
+ .replace(".qkv_lora2.", ".txt_attn.qkv.")
+ .replace(".qkv_lora1.", ".img_attn.qkv.")
+ .replace(".proj_lora1.", ".img_attn.proj.")
+ .replace(".proj_lora2.", ".txt_attn.proj.")
+ .replace(".qkv_lora.", ".linear1_qkv.")
+ .replace(".proj_lora.", ".linear2.")
+ .replace(".processor.", ".")
+ )
+ sd_out[k_to] = tensor
+ return sd_out
+
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
return convert_lora_wan_fun(sd)
+ if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
+ return convert_uso_lora(sd)
return sd
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 56a6798be..993ff65e6 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -16,6 +16,8 @@
along with this program. If not, see .
"""
+import comfy.ldm.hunyuan3dv2_1
+import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -1282,6 +1284,21 @@ class Hunyuan3Dv2(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
+class Hunyuan3Dv2_1(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ guidance = kwargs.get("guidance", 5.0)
+ if guidance is not None:
+ out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
+ return out
+
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
@@ -1391,3 +1408,27 @@ class QwenImage(BaseModel):
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
+
+class HunyuanImage21(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ attention_mask = kwargs.get("attention_mask", None)
+ if attention_mask is not None:
+ if torch.numel(attention_mask) != attention_mask.sum():
+ out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
+ if conditioning_byt5small is not None:
+ out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
+
+ guidance = kwargs.get("guidance", 6.0)
+ if guidance is not None:
+ out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
+
+ return out
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 9f3ab64df..fe983cede 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -136,20 +136,40 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
+ in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
+ out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
dit_config["image_model"] = "hunyuan_video"
- dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
- dit_config["patch_size"] = [1, 2, 2]
- dit_config["out_channels"] = 16
- dit_config["vec_in_dim"] = 768
- dit_config["context_in_dim"] = 4096
- dit_config["hidden_size"] = 3072
+ dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
+ dit_config["patch_size"] = list(in_w.shape[2:])
+ dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
+ if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
+ dit_config["vec_in_dim"] = 768
+ else:
+ dit_config["vec_in_dim"] = None
+
+ if len(dit_config["patch_size"]) == 2:
+ dit_config["axes_dim"] = [64, 64]
+ else:
+ dit_config["axes_dim"] = [16, 56, 56]
+
+ if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys):
+ dit_config["meanflow"] = True
+ else:
+ dit_config["meanflow"] = False
+
+ dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
+ dit_config["hidden_size"] = in_w.shape[0]
dit_config["mlp_ratio"] = 4.0
- dit_config["num_heads"] = 24
+ dit_config["num_heads"] = in_w.shape[0] // 128
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
- dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 256
dit_config["qkv_bias"] = True
+ if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict:
+ dit_config["byt5"] = True
+ else:
+ dit_config["byt5"] = False
+
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
@@ -400,6 +420,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
+ if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
+
+ dit_config = {}
+ dit_config["image_model"] = "hunyuan3d2_1"
+ dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
+ dit_config["context_dim"] = 1024
+ dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
+ dit_config["mlp_ratio"] = 4.0
+ dit_config["num_heads"] = 16
+ dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
+ dit_config["qkv_bias"] = False
+ dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
+ return dit_config
+
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 4ac04b8b9..3dfa122c8 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -23,6 +23,7 @@ from enum import Enum
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
+import importlib
import platform
import weakref
import gc
@@ -313,6 +314,24 @@ def is_amd():
return True
return False
+def amd_min_version(device=None, min_rdna_version=0):
+ if not is_amd():
+ return False
+
+ if is_device_cpu(device):
+ return False
+
+ arch = torch.cuda.get_device_properties(device).gcnArchName
+ if arch.startswith('gfx') and len(arch) == 7:
+ try:
+ cmp_rdna_version = int(arch[4]) + 2
+ except:
+ cmp_rdna_version = 0
+ if cmp_rdna_version >= min_rdna_version:
+ return True
+
+ return False
+
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.0
@@ -345,12 +364,13 @@ try:
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
- if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
- if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
- ENABLE_PYTORCH_ATTENTION = True
-# if torch_version_numeric >= (2, 8):
-# if any((a in arch) for a in ["gfx1201"]):
-# ENABLE_PYTORCH_ATTENTION = True
+ if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
+ if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
+ if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
+ ENABLE_PYTORCH_ATTENTION = True
+# if torch_version_numeric >= (2, 8):
+# if any((a in arch) for a in ["gfx1201"]):
+# ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
@@ -933,7 +953,9 @@ def vae_dtype(device=None, allowed_dtypes=[]):
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
- if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
+ # also a problem on RDNA4 except fp32 is also slow there.
+ # This is due to large bf16 convolutions being extremely slow.
+ if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
return d
return torch.float32
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index e3d0b4840..c64778da0 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -513,6 +513,9 @@ class ModelPatcher:
def set_model_double_block_patch(self, patch):
self.set_model_patch(patch, "double_block")
+ def set_model_post_input_patch(self, patch):
+ self.set_model_patch(patch, "post_input")
+
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
diff --git a/comfy/ops.py b/comfy/ops.py
index 18e7db705..55e958adb 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -52,6 +52,9 @@ except (ModuleNotFoundError, TypeError):
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
+if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
+ torch.backends.cudnn.benchmark = True
+
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
diff --git a/comfy/sd.py b/comfy/sd.py
index bb5d61fb3..9dd9a74d4 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -17,6 +17,7 @@ import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
+import comfy.ldm.hunyuan_video.vae
import yaml
import math
import os
@@ -48,6 +49,7 @@ import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
+import comfy.text_encoders.hunyuan_image
import comfy.model_patcher
import comfy.lora
@@ -328,6 +330,19 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
+ elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
+ ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
+ self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ self.downscale_ratio = 32
+ self.upscale_ratio = 32
+ self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+ self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
+ encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
+ decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
+
+ self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
+
elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -446,17 +461,29 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
+ # Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
+
self.latent_dim = 1
- ln_post = "geo_decoder.ln_post.weight" in sd
- inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
- downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
- mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
- self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
- self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
- ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
- self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
+
+ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
+ batch, num_tokens, hidden_dim = shape
+ dtype_size = model_management.dtype_size(dtype)
+
+ total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers)
+ return total_mem
+
+ # better memory estimations
+ self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\
+ estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
+
+ self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \
+ estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
+
+ self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
@@ -773,6 +800,7 @@ class CLIPType(Enum):
ACE = 16
OMNIGEN2 = 17
QWEN_IMAGE = 18
+ HUNYUAN_IMAGE = 19
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -794,6 +822,7 @@ class TEModel(Enum):
GEMMA_2_2B = 9
QWEN25_3B = 10
QWEN25_7B = 11
+ BYT5_SMALL_GLYPH = 12
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -811,6 +840,9 @@ def detect_te_model(sd):
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
+ weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
+ if weight.shape[0] == 384:
+ return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
@@ -925,8 +957,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
elif te_model == TEModel.QWEN25_7B:
- clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
- clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
+ if clip_type == CLIPType.HUNYUAN_IMAGE:
+ clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
+ else:
+ clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@@ -970,6 +1006,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
+ elif clip_type == CLIPType.HUNYUAN_IMAGE:
+ clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 76260de00..aa953b462 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -20,6 +20,7 @@ import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
+import comfy.text_encoders.hunyuan_image
from . import supported_models_base
from . import latent_formats
@@ -1128,6 +1129,17 @@ class Hunyuan3Dv2(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None
+class Hunyuan3Dv2_1(Hunyuan3Dv2):
+ unet_config = {
+ "image_model": "hunyuan3d2_1",
+ }
+
+ latent_format = latent_formats.Hunyuan3Dv2_1
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Hunyuan3Dv2_1(self, device = device)
+ return out
+
class Hunyuan3Dv2mini(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1284,7 +1296,31 @@ class QwenImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
+class HunyuanImage21(HunyuanVideo):
+ unet_config = {
+ "image_model": "hunyuan_video",
+ "vec_in_dim": None,
+ }
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
+ sampling_settings = {
+ "shift": 5.0,
+ }
+
+ latent_format = latent_formats.HunyuanImage21
+
+ memory_usage_factor = 7.7
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.HunyuanImage21(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
+
+models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]
diff --git a/comfy/text_encoders/byt5_config_small_glyph.json b/comfy/text_encoders/byt5_config_small_glyph.json
new file mode 100644
index 000000000..0239c7164
--- /dev/null
+++ b/comfy/text_encoders/byt5_config_small_glyph.json
@@ -0,0 +1,22 @@
+{
+ "d_ff": 3584,
+ "d_kv": 64,
+ "d_model": 1472,
+ "decoder_start_token_id": 0,
+ "dropout_rate": 0.1,
+ "eos_token_id": 1,
+ "dense_act_fn": "gelu_pytorch_tanh",
+ "initializer_factor": 1.0,
+ "is_encoder_decoder": true,
+ "is_gated_act": true,
+ "layer_norm_epsilon": 1e-06,
+ "model_type": "t5",
+ "num_decoder_layers": 4,
+ "num_heads": 6,
+ "num_layers": 12,
+ "output_past": true,
+ "pad_token_id": 0,
+ "relative_attention_num_buckets": 32,
+ "tie_word_embeddings": false,
+ "vocab_size": 1510
+}
diff --git a/comfy/text_encoders/byt5_tokenizer/added_tokens.json b/comfy/text_encoders/byt5_tokenizer/added_tokens.json
new file mode 100644
index 000000000..93c190b56
--- /dev/null
+++ b/comfy/text_encoders/byt5_tokenizer/added_tokens.json
@@ -0,0 +1,127 @@
+{
+ "": 259,
+ "": 359,
+ "": 360,
+ "": 361,
+ "": 362,
+ "": 363,
+ "": 364,
+ "": 365,
+ "": 366,
+ "": 367,
+ "": 368,
+ "": 269,
+ "": 369,
+ "": 370,
+ "": 371,
+ "": 372,
+ "": 373,
+ "": 374,
+ "": 375,
+ "": 376,
+ "": 377,
+ "": 378,
+ "": 270,
+ "": 379,
+ "": 380,
+ "": 381,
+ "": 382,
+ "": 383,
+ "": 271,
+ "": 272,
+ "": 273,
+ "": 274,
+ "": 275,
+ "": 276,
+ "": 277,
+ "": 278,
+ "": 260,
+ "": 279,
+ "": 280,
+ "": 281,
+ "": 282,
+ "": 283,
+ "": 284,
+ "": 285,
+ "": 286,
+ "": 287,
+ "": 288,
+ "": 261,
+ "": 289,
+ "": 290,
+ "": 291,
+ "": 292,
+ "": 293,
+ "": 294,
+ "": 295,
+ "": 296,
+ "": 297,
+ "": 298,
+ "": 262,
+ "": 299,
+ "": 300,
+ "": 301,
+ "": 302,
+ "": 303,
+ "": 304,
+ "": 305,
+ "": 306,
+ "": 307,
+ "": 308,
+ "": 263,
+ "": 309,
+ "": 310,
+ "": 311,
+ "": 312,
+ "": 313,
+ "": 314,
+ "": 315,
+ "": 316,
+ "": 317,
+ "": 318,
+ "": 264,
+ "": 319,
+ "": 320,
+ "": 321,
+ "": 322,
+ "": 323,
+ "": 324,
+ "": 325,
+ "": 326,
+ "": 327,
+ "": 328,
+ "": 265,
+ "": 329,
+ "": 330,
+ "": 331,
+ "": 332,
+ "": 333,
+ "": 334,
+ "": 335,
+ "": 336,
+ "": 337,
+ "": 338,
+ "": 266,
+ "": 339,
+ "": 340,
+ "": 341,
+ "": 342,
+ "": 343,
+ "": 344,
+ "": 345,
+ "": 346,
+ "": 347,
+ "": 348,
+ "": 267,
+ "": 349,
+ "": 350,
+ "": 351,
+ "": 352,
+ "": 353,
+ "": 354,
+ "": 355,
+ "": 356,
+ "": 357,
+ "": 358,
+ "": 268
+}
diff --git a/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json
new file mode 100644
index 000000000..04fd58b5f
--- /dev/null
+++ b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json
@@ -0,0 +1,150 @@
+{
+ "additional_special_tokens": [
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ""
+ ],
+ "eos_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false
+ },
+ "pad_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false
+ },
+ "unk_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false
+ }
+}
diff --git a/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json
new file mode 100644
index 000000000..5b1fe24c1
--- /dev/null
+++ b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json
@@ -0,0 +1,1163 @@
+{
+ "added_tokens_decoder": {
+ "0": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "1": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "2": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "259": {
+ "content": "