diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c53ac4b2b..11db46d94 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -242,6 +242,7 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -373,6 +374,7 @@ def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -686,6 +688,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1 lambda_fn = lambda sigma: ((1-sigma)/sigma).log() @@ -747,6 +750,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -832,6 +836,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) old_denoised = None h, h_last = None, None @@ -889,6 +894,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) denoised_1, denoised_2 = None, None h, h_1, h_2 = None, None, None @@ -1006,23 +1012,39 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) @torch.no_grad() -def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): +def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, s_noise=1.0, s_noise_end=None, noise_clip_std=0.0): + + # s_noise / s_noise_end: per-step noise multiplier, linearly interpolated across steps + # noise_clip_std: clamp injected noise to +/- N stddevs (0 disables). + extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - for i in trange(len(sigmas) - 1, disable=disable): + n_steps = max(1, len(sigmas) - 1) + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + + s_start = float(s_noise) + s_end = s_start if s_noise_end is None else float(s_noise_end) + for i in trange(n_steps, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) x = denoised if sigmas[i + 1] > 0: - x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x) + noise = noise_sampler(sigmas[i], sigmas[i + 1]) + if noise_clip_std > 0: + clip_val = noise_clip_std * noise.std() + noise = noise.clamp(min=-clip_val, max=clip_val) + t = (i / (n_steps - 1)) if n_steps > 1 else 0.0 + s_noise_i = s_start + (s_end - s_start) * t + if s_noise_i != 1.0: + noise = noise * s_noise_i + x = model_sampling.noise_scaling(sigmas[i + 1], noise, x) return x - @torch.no_grad() def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ @@ -1249,6 +1271,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) uncond_denoised = None @@ -1296,6 +1319,7 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) temp = [0] def post_cfg_function(args): @@ -1371,6 +1395,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() @@ -1504,6 +1529,7 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) s_in = x.new_ones([x.shape[0]]) def default_er_sde_noise_scaler(x): @@ -1574,9 +1600,10 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - inject_noise = eta > 0 and s_noise > 0 model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) + inject_noise = eta > 0 and s_noise > 0 sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) @@ -1645,9 +1672,10 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - inject_noise = eta > 0 and s_noise > 0 model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) + inject_noise = eta > 0 and s_noise > 0 sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) @@ -1713,6 +1741,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F s_in = x.new_ones([x.shape[0]]) model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 91bebed3d..d527eec4a 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -792,6 +792,13 @@ class ZImagePixelSpace(ChromaRadiance): """ pass + +class HiDreamO1Pixel(ChromaRadiance): + """Pixel-space latent format for HiDream-O1. + No VAE — model patches/unpatches raw RGB internally with patch_size=32. + """ + pass + class CogVideoX(LatentFormat): """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b). diff --git a/comfy/ldm/hidream_o1/attention.py b/comfy/ldm/hidream_o1/attention.py new file mode 100644 index 000000000..1b68f1771 --- /dev/null +++ b/comfy/ldm/hidream_o1/attention.py @@ -0,0 +1,41 @@ +"""HiDream-O1 two-pass attention: tokens [0, ar_len) are causal, [ar_len, T) +attend full K/V. Splitting Q at the boundary avoids the (B, 1, T, T) additive +mask the general-purpose path would build (~500 MB at T~16K) and lets the +gen half hit the user's preferred backend via optimized_attention. +""" + +import torch + +import comfy.ops +from comfy.ldm.modules.attention import optimized_attention + + +def make_two_pass_attention(ar_len: int, transformer_options=None): + """Build a two-pass attention callable. AR pass uses SDPA-causal directly, gen pass routes through optimized_attention. + The AR pass goes through SDPA directand bypasses wrappers, it is only ~1% of T at typical edit sizes. + """ + + def two_pass_attention(q, k, v, heads, **kwargs): + B, H, T, D = q.shape + + if T < k.shape[2]: # KV-cache hot path: Q is shorter than K/V (cached AR prefix is in K/V only), all fresh Q positions are in the gen region, single full-attention call + out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options) + elif ar_len >= T: + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) + elif ar_len <= 0: + out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options) + else: + out_ar = comfy.ops.scaled_dot_product_attention( + q[:, :, :ar_len], k[:, :, :ar_len], v[:, :, :ar_len], + attn_mask=None, dropout_p=0.0, is_causal=True, + ) + out_gen = optimized_attention( + q[:, :, ar_len:], k, v, heads, + mask=None, skip_reshape=True, skip_output_reshape=True, + transformer_options=transformer_options, + ) + out = torch.cat([out_ar, out_gen], dim=2) + + return out.transpose(1, 2).reshape(B, T, H * D) + + return two_pass_attention diff --git a/comfy/ldm/hidream_o1/conditioning.py b/comfy/ldm/hidream_o1/conditioning.py new file mode 100644 index 000000000..7496f0035 --- /dev/null +++ b/comfy/ldm/hidream_o1/conditioning.py @@ -0,0 +1,230 @@ +"""HiDream-O1 conditioning prep — ref-image dual path + extra_conds assembly. + +Each ref image goes through two paths: a 32x32 patchified stream concatenated +to the noised target, and a Qwen3-VL ViT path producing tokens that scatter +into input_ids at <|image_pad|> positions. +""" + +from typing import List + +import torch + +import comfy.utils +from comfy.text_encoders.qwen_vl import process_qwen2vl_images + +from .utils import (PATCH_SIZE, calculate_dimensions, cond_image_size, ref_max_size, resize_tensor) + +# Qwen3-VL ViT preprocessing constants (preprocessor_config.json). +VIT_PATCH = 16 +VIT_MERGE = 2 +VIT_IMAGE_MEAN = [0.5, 0.5, 0.5] +VIT_IMAGE_STD = [0.5, 0.5, 0.5] + + +def prepare_ref_images( + ref_images: List[torch.Tensor], + target_h: int, + target_w: int, + device: torch.device, + dtype: torch.dtype, +): + """Build the dual-path tensors for K reference images at (target_h, target_w). + + Returns None for K=0, else a dict with ref_patches, ref_pixel_values, + ref_image_grid_thw, per_ref_vit_tokens, per_ref_patch_grids. + """ + K = len(ref_images) + if K == 0: + return None + max_size = ref_max_size(max(target_h, target_w), K) + cis = cond_image_size(K) + + refs_t = [img[0].clamp(0, 1).permute(2, 0, 1).unsqueeze(0).contiguous().float() for img in ref_images] + refs_t = [resize_tensor(t, max_size, PATCH_SIZE) for t in refs_t] + + # 32-patch path. + ref_patches_per = [] + per_ref_patch_grids = [] + for t in refs_t: + t_norm = (t.squeeze(0) - 0.5) / 0.5 # (3, H, W) in [-1, 1] + h_p, w_p = t_norm.shape[-2] // PATCH_SIZE, t_norm.shape[-1] // PATCH_SIZE + per_ref_patch_grids.append((h_p, w_p)) + patches = ( + t_norm.reshape(3, h_p, PATCH_SIZE, w_p, PATCH_SIZE) + .permute(1, 3, 0, 2, 4) + .reshape(h_p * w_p, 3 * PATCH_SIZE * PATCH_SIZE) + ) + ref_patches_per.append(patches) + ref_patches = torch.cat(ref_patches_per, dim=0).unsqueeze(0).to(device=device, dtype=dtype) + + # ViT path. + refs_vlm_t = [] + for t in refs_t: + _, _, h, w = t.shape + cond_w, cond_h = calculate_dimensions(cis, w / h) + cond_w = max(cond_w, VIT_PATCH * VIT_MERGE) + cond_h = max(cond_h, VIT_PATCH * VIT_MERGE) + refs_vlm_t.append(comfy.utils.common_upscale(t, cond_w, cond_h, "lanczos", "disabled")) + + pv_list, grid_list, per_ref_vit_tokens = [], [], [] + for t_v in refs_vlm_t: + pv, grid_thw = process_qwen2vl_images( + t_v.permute(0, 2, 3, 1), + min_pixels=0, max_pixels=10**12, + patch_size=VIT_PATCH, merge_size=VIT_MERGE, + image_mean=VIT_IMAGE_MEAN, image_std=VIT_IMAGE_STD, + ) + grid_thw = grid_thw[0] + pv_list.append(pv.to(device=device, dtype=dtype)) + grid_list.append(grid_thw.to(device=device)) + # Post-merge token count = number of <|image_pad|> tokens this image expands to in input_ids. + gh, gw = int(grid_thw[1].item()), int(grid_thw[2].item()) + per_ref_vit_tokens.append((gh // VIT_MERGE) * (gw // VIT_MERGE)) + + return { + "ref_patches": ref_patches, + "ref_pixel_values": torch.cat(pv_list, dim=0), + "ref_image_grid_thw": torch.stack(grid_list, dim=0), + "per_ref_vit_tokens": per_ref_vit_tokens, + "per_ref_patch_grids": per_ref_patch_grids, + } + + +def build_ref_input_ids( + text_input_ids: torch.Tensor, + per_ref_vit_tokens: List[int], + image_token_id: int, + vision_start_id: int, + vision_end_id: int, +): + """Splice [vision_start, image_pad*N, vision_end] blocks into input_ids + after the [im_start, user, \\n] prefix (matches original chat template). + """ + ids = text_input_ids[0].tolist() + inserted = [] + for n_pad in per_ref_vit_tokens: + inserted.extend([vision_start_id] + [image_token_id] * n_pad + [vision_end_id]) + new_ids = ids[:3] + inserted + ids[3:] # 3 = len([im_start, user, \n]) + return torch.tensor([new_ids], dtype=text_input_ids.dtype, device=text_input_ids.device) + + +def build_extra_conds( + text_input_ids: torch.Tensor, + noise: torch.Tensor, + ref_images: List[torch.Tensor] = None, + target_patch_size: int = 32, +): + """Assemble all conditioning tensors for HiDreamO1Transformer.forward: + input_ids (with ref-vision tokens spliced in for the edit/IP path), + position_ids (MRoPE), token_types, vinput_mask, plus the ref + dual-path tensors when refs are provided. + """ + from .utils import get_rope_index_fix_point + from comfy.text_encoders.hidream_o1 import ( + IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID, + ) + + if text_input_ids.dim() == 1: + text_input_ids = text_input_ids.unsqueeze(0) + text_input_ids = text_input_ids.long().to(noise.device) + B = noise.shape[0] + if text_input_ids.shape[0] == 1 and B > 1: + text_input_ids = text_input_ids.expand(B, -1) + + H, W = noise.shape[-2], noise.shape[-1] + h_p, w_p = H // target_patch_size, W // target_patch_size + image_len = h_p * w_p + image_grid_thw_tgt = torch.tensor( + [[1, h_p, w_p]], dtype=torch.long, device=text_input_ids.device, + ) + + out = {} + if ref_images: + ref = prepare_ref_images(ref_images, H, W, device=noise.device, dtype=noise.dtype) + text_input_ids = build_ref_input_ids( + text_input_ids, ref["per_ref_vit_tokens"], + IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID, + ) + new_txt_len = text_input_ids.shape[1] + + # Each ref's patchified stream gets a [vision_start, image_pad*N-1] + # block in the position-id stream after the noised target. + ref_grid_lengths = [hp * wp for (hp, wp) in ref["per_ref_patch_grids"]] + tgt_vision = torch.full((1, image_len), IMAGE_TOKEN_ID, + dtype=text_input_ids.dtype, device=text_input_ids.device) + tgt_vision[:, 0] = VISION_START_ID + ref_vision_blocks = [] + for rl in ref_grid_lengths: + blk = torch.full((1, rl), IMAGE_TOKEN_ID, + dtype=text_input_ids.dtype, device=text_input_ids.device) + blk[:, 0] = VISION_START_ID + ref_vision_blocks.append(blk) + ref_vision_cat = torch.cat([tgt_vision] + ref_vision_blocks, dim=1) + input_ids_pad = torch.cat([text_input_ids, ref_vision_cat], dim=-1) + total_ref_patches_len = sum(ref_grid_lengths) + total_len = new_txt_len + image_len + total_ref_patches_len + + # K (ViT, post-merge) + 1 (target) + K (ref-patches) image grids. + K = len(ref_images) + igthw_cond = ref["ref_image_grid_thw"].clone() + igthw_cond[:, 1] //= 2 + igthw_cond[:, 2] //= 2 + image_grid_thw_ref = torch.tensor( + [[1, hp, wp] for (hp, wp) in ref["per_ref_patch_grids"]], + dtype=torch.long, device=text_input_ids.device, + ) + igthw_all = torch.cat([ + igthw_cond.to(text_input_ids.device), + image_grid_thw_tgt, + image_grid_thw_ref, + ], dim=0) + position_ids, _ = get_rope_index_fix_point( + spatial_merge_size=1, + image_token_id=IMAGE_TOKEN_ID, + vision_start_token_id=VISION_START_ID, + input_ids=input_ids_pad, image_grid_thw=igthw_all, + attention_mask=None, + skip_vision_start_token=[0] * K + [1] + [1] * K, + fix_point=4096, + ) + + # tms + target_image + ref_patches are all gen. + tms_pos = new_txt_len - 1 + ar_len = tms_pos + token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device) + token_types[:, tms_pos:] = 1 + vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device) + vinput_mask[:, new_txt_len:] = True + + # Leading batch dim sidesteps CONDRegular.process_cond's repeat_to_batch_size truncation + out["ref_pixel_values"] = ref["ref_pixel_values"].unsqueeze(0) + out["ref_image_grid_thw"] = ref["ref_image_grid_thw"].unsqueeze(0) + out["ref_patches"] = ref["ref_patches"] + else: + # T2I: text + noised target only, vision_start replaces the first image token + txt_len = text_input_ids.shape[1] + total_len = txt_len + image_len + vision_tokens = torch.full((B, image_len), IMAGE_TOKEN_ID, + dtype=text_input_ids.dtype, device=text_input_ids.device) + vision_tokens[:, 0] = VISION_START_ID + input_ids_pad = torch.cat([text_input_ids, vision_tokens], dim=-1) + position_ids, _ = get_rope_index_fix_point( + spatial_merge_size=1, + image_token_id=IMAGE_TOKEN_ID, + vision_start_token_id=VISION_START_ID, + input_ids=input_ids_pad, image_grid_thw=image_grid_thw_tgt, + attention_mask=None, + skip_vision_start_token=[1], + ) + ar_len = txt_len - 1 + token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device) + token_types[:, ar_len:] = 1 + vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device) + vinput_mask[:, txt_len:] = True + + out["input_ids"] = text_input_ids + out["position_ids"] = position_ids[:, 0].unsqueeze(0) # Collapse position_ids batch and add a leading dim so CONDRegular's batch-resize doesn't truncate the 3-axis MRoPE dim + out["token_types"] = token_types + out["vinput_mask"] = vinput_mask + out["ar_len"] = ar_len + return out diff --git a/comfy/ldm/hidream_o1/model.py b/comfy/ldm/hidream_o1/model.py new file mode 100644 index 000000000..a223e706f --- /dev/null +++ b/comfy/ldm/hidream_o1/model.py @@ -0,0 +1,306 @@ +"""HiDream-O1-Image transformer. + +Pixel-space DiT built on Qwen3-VL: the vision tower (Qwen35VisionModel) +encodes ref images, the Qwen3-VL-8B decoder (Llama2_ with interleaved MRoPE) +processes a unified text+image sequence, and 32x32 patch embed/unembed +shims map raw RGB in and out of LLM hidden space. The Qwen3-VL deepstack +mergers go unused — their weights are dropped at load. +""" + +from dataclasses import dataclass, field +from typing import List, Optional + +import einops +import torch +import torch.nn as nn + +import comfy.patcher_extension +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder +from comfy.text_encoders.llama import Llama2_ +from comfy.text_encoders.qwen35 import Qwen35VisionModel + +from .attention import make_two_pass_attention + + +IMAGE_TOKEN_ID = 151655 # Qwen3-VL <|image_pad|> +TMS_TOKEN_ID = 151673 # HiDream-O1 <|tms_token|> +PATCH_SIZE = 32 + + +@dataclass +class HiDreamO1TextConfig: + """Qwen3-VL-8B text-decoder dims (matches public Qwen3-VL-8B-Instruct).""" + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 12288 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + head_dim: int = 128 + max_position_embeddings: int = 128000 + rms_norm_eps: float = 1e-6 + rope_theta: float = 5000000.0 + rope_scale: Optional[float] = None + rope_dims: List[int] = field(default_factory=lambda: [24, 20, 20]) + interleaved_mrope: bool = True + transformer_type: str = "llama" + rms_norm_add: bool = False + mlp_activation: str = "silu" + qkv_bias: bool = False + q_norm: str = "gemma3" + k_norm: str = "gemma3" + final_norm: bool = True + lm_head: bool = False + stop_tokens: List[int] = field(default_factory=lambda: [151643, 151645]) + + +QWEN3VL_VISION_DEFAULTS = dict( + hidden_size=1152, + num_heads=16, + intermediate_size=4304, + depth=27, + patch_size=16, + temporal_patch_size=2, + in_channels=3, + spatial_merge_size=2, + num_position_embeddings=2304, + deepstack_visual_indexes=(8, 16, 24), + out_hidden_size=4096, # final merger projects directly into LLM hidden +) + + +class BottleneckPatchEmbed(nn.Module): + # 3072 -> 1024 -> 4096 (raw 32x32 RGB patch -> bottleneck -> LLM hidden). + def __init__(self, patch_size=32, in_chans=3, pca_dim=1024, embed_dim=4096, bias=True, device=None, dtype=None, ops=None): + super().__init__() + self.proj1 = ops.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False, device=device, dtype=dtype) + self.proj2 = ops.Linear(pca_dim, embed_dim, bias=bias, device=device, dtype=dtype) + + def forward(self, x): + return self.proj2(self.proj1(x)) + + +class FinalLayer(nn.Module): + # 4096 -> 3072 (LLM hidden -> flat pixel patch). + def __init__(self, hidden_size, patch_size=32, out_channels=3, device=None, dtype=None, ops=None): + super().__init__() + self.linear = ops.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, device=device, dtype=dtype) + + def forward(self, x): + return self.linear(x) + + +class HiDreamO1Transformer(nn.Module): + """HiDream-O1 unified pixel-level transformer.""" + + def __init__(self, image_model=None, dtype=None, device=None, operations=None, + text_config_overrides=None, vision_config_overrides=None, **kwargs): + super().__init__() + self.dtype = dtype + + text_cfg = HiDreamO1TextConfig(**(text_config_overrides or {})) + vision_cfg = dict(QWEN3VL_VISION_DEFAULTS) + if vision_config_overrides: + vision_cfg.update(vision_config_overrides) + vision_cfg["out_hidden_size"] = text_cfg.hidden_size + + self.text_config = text_cfg + self.vision_config = vision_cfg + self.hidden_size = text_cfg.hidden_size + self.patch_size = PATCH_SIZE + self.in_channels = 3 + self.tms_token_id = TMS_TOKEN_ID + + self.visual = Qwen35VisionModel(vision_cfg, device=device, dtype=dtype, ops=operations) + self.language_model = Llama2_(text_cfg, device=device, dtype=dtype, ops=operations) + self.t_embedder1 = TimestepEmbedder( + text_cfg.hidden_size, device=device, dtype=dtype, operations=operations, + ) + self.x_embedder = BottleneckPatchEmbed( + patch_size=self.patch_size, in_chans=self.in_channels, + pca_dim=text_cfg.hidden_size // 4, embed_dim=text_cfg.hidden_size, + bias=True, device=device, dtype=dtype, ops=operations, + ) + self.final_layer2 = FinalLayer( + text_cfg.hidden_size, patch_size=self.patch_size, + out_channels=self.in_channels, device=device, dtype=dtype, ops=operations, + ) + + self._visual_cache = None + self._kv_cache_entries = [] + + def clear_kv_cache(self): + self._kv_cache_entries = [] + self._visual_cache = None + + def forward(self, x, timesteps, context=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timesteps, context, transformer_options, **kwargs) + + def _forward(self, x, timesteps, context=None, transformer_options={}, input_ids=None, attention_mask=None, position_ids=None, + vinput_mask=None, ar_len=None, ref_pixel_values=None, ref_image_grid_thw=None, ref_patches=None, **kwargs): + """Returns flow-match velocity (x - x_pred) / sigma""" + + if input_ids is None or position_ids is None: + raise ValueError("HiDreamO1Transformer requires input_ids and position_ids in conditioning") + + B, _, H, W = x.shape + h_p, w_p = H // self.patch_size, W // self.patch_size + tgt_image_len = h_p * w_p + + z = einops.rearrange( + x, 'B C (H p1) (W p2) -> B (H W) (C p1 p2)', + p1=self.patch_size, p2=self.patch_size, + ) + vinputs = torch.cat([z, ref_patches.to(z.dtype)], dim=1) if ref_patches is not None else z + + inputs_embeds = self.language_model.embed_tokens(input_ids).to(x.dtype) + + if ref_pixel_values is not None and ref_image_grid_thw is not None: + # ViT output is constant across sampling steps within a generation + # identity-key by the input tensor so refs don't recompute every step. + cached = self._visual_cache + if cached is not None and cached[0] is ref_pixel_values: + image_embeds = cached[1] + else: + ref_pv = ref_pixel_values.to(inputs_embeds.device) + ref_grid = ref_image_grid_thw.to(inputs_embeds.device).long() + # extra_conds wraps with a leading batch dim; refs are model-level so [0] always recovers them. + if ref_pv.dim() == 3: + ref_pv = ref_pv[0] + if ref_grid.dim() == 3: + ref_grid = ref_grid[0] + image_embeds = self.visual(ref_pv, ref_grid).to(inputs_embeds.dtype) + self._visual_cache = (ref_pixel_values, image_embeds) + # image_pad positions identical across batch (input_ids shared cond/uncond). + image_idx = (input_ids[0] == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0] + if image_idx.shape[0] != image_embeds.shape[0]: + raise ValueError( + f"Image-token count {image_idx.shape[0]} != ViT output count " + f"{image_embeds.shape[0]}; check tokenizer/processor alignment." + ) + inputs_embeds[:, image_idx] = image_embeds.unsqueeze(0).expand(B, -1, -1) + + sigma = timesteps.float() / 1000.0 + t_pixeldit = 1.0 - sigma + t_emb = self.t_embedder1(t_pixeldit * 1000, inputs_embeds.dtype) + tms_mask_3d = (input_ids == self.tms_token_id).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = torch.where(tms_mask_3d, t_emb.unsqueeze(1).expand_as(inputs_embeds), inputs_embeds) + + vinputs_embedded = self.x_embedder(vinputs.to(inputs_embeds.dtype)) + inputs_embeds = torch.cat([inputs_embeds, vinputs_embedded], dim=1) + + # extra_conds stores position_ids as (1, 3, T); process_cond repeats dim 0 to B. Take row 0. + freqs_cis = self.language_model.compute_freqs_cis(position_ids[0].to(x.device), x.device) + freqs_cis = tuple(t.to(x.dtype) for t in freqs_cis) + + two_pass_attn = make_two_pass_attention(ar_len, transformer_options=transformer_options) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.language_model.layers) + transformer_options["block_type"] = "double" + + # Cache prefix K/V across steps. Key includes input_ids (prompt), ref_id + # (refs scatter into inputs_embeds), and position_ids (RoPE baked into cached K). + can_cache = not blocks_replace and ar_len > 0 + cache_len = ar_len if can_cache else 0 + ref_id = id(ref_pixel_values) if ref_pixel_values is not None else None + pos_ids_key = position_ids[..., :cache_len] if can_cache else position_ids + cache_entries = self._kv_cache_entries + # Drop stale entries from a previous device (model was unloaded and reloaded). + if cache_entries and cache_entries[0]["input_ids"].device != input_ids.device: + cache_entries = [] + self._kv_cache_entries = [] + kv_cache = None + if can_cache: + for entry in cache_entries: + ck = entry["input_ids"] + ep = entry["position_ids"] + if (entry["cache_len"] == cache_len + and ck.shape == input_ids.shape and torch.equal(ck, input_ids) + and entry["ref_id"] == ref_id + and ep.shape == pos_ids_key.shape and torch.equal(ep, pos_ids_key)): + kv_cache = entry + break + + if kv_cache is not None: + # Hot path: project Q/K/V only for fresh positions; past_key_value prepends cached AR K/V. + hidden_states = inputs_embeds[:, cache_len:] + sliced_freqs = tuple(t[..., cache_len:, :] for t in freqs_cis) + for i, layer in enumerate(self.language_model.layers): + transformer_options["block_index"] = i + K_i, V_i = kv_cache["kv"][i] + hidden_states, _ = layer( + x=hidden_states, attention_mask=None, freqs_cis=sliced_freqs, optimized_attention=two_pass_attn, + past_key_value=(K_i, V_i, cache_len), + ) + else: + # Cold path: run full sequence; if cacheable, snapshot K/V at AR positions. + snapshots = [] if can_cache else None + past_kv_cold = () if can_cache else None + hidden_states = inputs_embeds + for i, layer in enumerate(self.language_model.layers): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args, _layer=layer): + out = {} + out["x"], _ = _layer( + x=args["x"], attention_mask=args.get("attention_mask"), + freqs_cis=args["freqs_cis"], optimized_attention=args["optimized_attention"], + past_key_value=None, + ) + return out + out = blocks_replace[("double_block", i)]( + {"x": hidden_states, "attention_mask": None, + "freqs_cis": freqs_cis, "optimized_attention": two_pass_attn, + "transformer_options": transformer_options}, + {"original_block": block_wrap}, + ) + hidden_states = out["x"] + else: + hidden_states, present_kv = layer( + x=hidden_states, attention_mask=None, + freqs_cis=freqs_cis, optimized_attention=two_pass_attn, + past_key_value=past_kv_cold, + ) + if snapshots is not None: + K, V, _ = present_kv + snapshots.append((K[:, :, :cache_len].contiguous(), + V[:, :, :cache_len].contiguous())) + if snapshots is not None: + # Cap at 2 entries (cond + uncond). Multi-cond workflows LRU-evict. + new_entry = { + "input_ids": input_ids.clone(), + "cache_len": cache_len, + "kv": snapshots, + "ref_id": ref_id, + "position_ids": pos_ids_key.clone(), + } + self._kv_cache_entries = (cache_entries + [new_entry])[-2:] + + if self.language_model.norm is not None: + hidden_states = self.language_model.norm(hidden_states) + + # Slice target-image positions before the final projection so the Linear only runs on tgt_image_len tokens. + # In the hot path hidden_states starts at original position cache_len, so masks/indices shift by cache_len. + sliced_offset = cache_len if kv_cache is not None else 0 + if vinput_mask is not None: + vmask = vinput_mask.to(x.device).bool() + if sliced_offset > 0: + vmask = vmask[:, sliced_offset:] + target_hidden = hidden_states[vmask].view(B, -1, hidden_states.shape[-1])[:, :tgt_image_len] + else: + txt_seq_len = input_ids.shape[1] + start = txt_seq_len - sliced_offset + target_hidden = hidden_states[:, start:start + tgt_image_len] + x_pred_tgt = self.final_layer2(target_hidden) + + # fp32 final subtraction, bf16 here noticeably degrades samples. + x_pred_img = einops.rearrange( + x_pred_tgt, 'B (H W) (C p1 p2) -> B C (H p1) (W p2)', + H=h_p, W=w_p, p1=self.patch_size, p2=self.patch_size, + ) + return (x.float() - x_pred_img.float()) / sigma.view(B, 1, 1, 1).clamp_min(1e-3) diff --git a/comfy/ldm/hidream_o1/utils.py b/comfy/ldm/hidream_o1/utils.py new file mode 100644 index 000000000..5a1249c72 --- /dev/null +++ b/comfy/ldm/hidream_o1/utils.py @@ -0,0 +1,173 @@ +"""HiDream-O1 input-prep helpers: image/resolution math and unified-sequence +RoPE position-id assembly. The fix_point offset in get_rope_index_fix_point +lets the target image and patchified ref images share spatial RoPE positions +despite living at different sequence indices — same 2D image plane. +""" + +import math +from typing import Optional + +import torch + + +PATCH_SIZE = 32 +CONDITION_IMAGE_SIZE = 384 # ViT-side base size for ref images + + +def resize_tensor(img_t, image_size, patch_size=16): + """img_t: (1, 3, H, W) float [0, 1]. Fit to image_size**2 area, patch-aligned, center-cropped.""" + + while min(img_t.shape[-2], img_t.shape[-1]) >= 2 * image_size: # Pre-halves with 2x2 box averaging while the image is still very large + img_t = torch.nn.functional.avg_pool2d(img_t, kernel_size=2, stride=2) + + _, _, height, width = img_t.shape + m = patch_size + s_max = image_size * image_size + scale = math.sqrt(s_max / (width * height)) + + candidates = [ + (round(width * scale) // m * m, round(height * scale) // m * m), + (round(width * scale) // m * m, math.floor(height * scale) // m * m), + (math.floor(width * scale) // m * m, round(height * scale) // m * m), + (math.floor(width * scale) // m * m, math.floor(height * scale) // m * m), + ] + candidates = sorted(candidates, key=lambda x: x[0] * x[1], reverse=True) + new_size = candidates[-1] + for c in candidates: + if c[0] * c[1] <= s_max: + new_size = c + break + + new_w, new_h = new_size + s1 = width / new_w + s2 = height / new_h + if s1 < s2: + resize_w, resize_h = new_w, round(height / s1) + else: + resize_w, resize_h = round(width / s2), new_h + img_t = torch.nn.functional.interpolate(img_t, size=(resize_h, resize_w), mode="bicubic") + top = (resize_h - new_h) // 2 + left = (resize_w - new_w) // 2 + return img_t[..., top:top + new_h, left:left + new_w] + + +def calculate_dimensions(max_size, ratio): + """(W, H) for an aspect ratio fitting in max_size**2 area, 32-aligned.""" + width = math.sqrt(max_size * max_size * ratio) + height = width / ratio + width = int(width / 32) * 32 + height = int(height / 32) * 32 + return width, height + + +def ref_max_size(target_max_dim, k): + """K-dependent ref-image max dim before patchifying.""" + if k == 1: + return target_max_dim + if k == 2: + return target_max_dim * 48 // 64 + if k <= 4: + return target_max_dim // 2 + if k <= 8: + return target_max_dim * 24 // 64 + return target_max_dim // 4 + + +def cond_image_size(k): + """K-dependent ViT-side image size.""" + if k <= 4: + return CONDITION_IMAGE_SIZE + if k <= 8: + return CONDITION_IMAGE_SIZE * 48 // 64 + return CONDITION_IMAGE_SIZE // 2 + + +def get_rope_index_fix_point( + spatial_merge_size: int, + image_token_id: int, + vision_start_token_id: int, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + skip_vision_start_token=None, + fix_point: int = 4096, +): + mrope_position_deltas = [] + if input_ids is not None and image_grid_thw is not None: + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], + dtype=input_ids.dtype, device=input_ids.device, + ) + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids_b in enumerate(total_input_ids): + fp = fix_point + image_index = 0 + input_ids_b = input_ids_b[attention_mask[i] == 1] + vision_start_indices = torch.argwhere(input_ids_b == vision_start_token_id).squeeze(1) + vision_tokens = input_ids_b[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + input_tokens = input_ids_b.tolist() + llm_pos_ids_list = [] + st = 0 + remain_images = image_nums + for _ in range(image_nums): + if image_token_id in input_tokens and remain_images > 0: + ed = input_tokens.index(image_token_id, st) + else: + ed = len(input_tokens) + 1 + t = image_grid_thw[image_index][0] + h = image_grid_thw[image_index][1] + w = image_grid_thw[image_index][2] + image_index += 1 + remain_images -= 1 + llm_grid_t = t.item() + llm_grid_h = h.item() // spatial_merge_size + llm_grid_w = w.item() // spatial_merge_size + text_len = ed - st + text_len -= skip_vision_start_token[image_index - 1] + text_len = max(0, text_len) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + + if skip_vision_start_token[image_index - 1]: + if fp > 0: + fp = fp - st_idx + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fp + st_idx) + fp = 0 + else: + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1).expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype, + ) + return position_ids, mrope_position_deltas diff --git a/comfy/model_base.py b/comfy/model_base.py index dbed239e5..0736321b3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -58,6 +58,8 @@ import comfy.ldm.cogvideo.model import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.ernie.model import comfy.ldm.sam3.detector +import comfy.ldm.hidream_o1.model +from comfy.ldm.hidream_o1.conditioning import build_extra_conds import comfy.model_management import comfy.patcher_extension @@ -1674,6 +1676,32 @@ class HiDream(BaseModel): out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond)) return out +class HiDreamO1(BaseModel): + """HiDream-O1-Image: pixel-space DiT (no VAE). Refs from HiDreamO1ReferenceImages and tokens from the stub TE flow through + extra_conds; the heavy preprocessing lives in comfy.ldm.hidream_o1.conditioning.""" + PATCH_SIZE = 32 + + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream_o1.model.HiDreamO1Transformer) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + text_input_ids = kwargs.get("text_input_ids", None) + noise = kwargs.get("noise", None) + if text_input_ids is None or noise is None: + return out + + conds = build_extra_conds( + text_input_ids, noise, + ref_images=kwargs.get("reference_latents", None), + target_patch_size=self.PATCH_SIZE, + ) + for k, v in conds.items(): + # ar_len is a Python int (precomputed to avoid a GPU sync in forward). + cls = comfy.conds.CONDConstant if k == "ar_len" else comfy.conds.CONDRegular + out[k] = cls(v) + return out + class Chroma(Flux): def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma): super().__init__(model_config, model_type, device=device, unet_model=unet_model) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8ae456481..bc0b933bc 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys return dit_config + if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1 + return {"image_model": "hidream_o1"} + if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream dit_config = {} dit_config["image_model"] = "hidream" diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index cf2b5db5f..5af336e76 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -93,7 +93,8 @@ class CONST: def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): sigma = reshape_sigma(sigma, noise.ndim) - return sigma * noise + (1.0 - sigma) * latent_image + s = getattr(self, "noise_scale", 1.0) + return sigma * (s * noise) + (1.0 - sigma) * latent_image def inverse_noise_scaling(self, sigma, latent): sigma = reshape_sigma(sigma, latent.ndim) @@ -288,7 +289,11 @@ class ModelSamplingDiscreteFlow(torch.nn.Module): else: sampling_settings = {} - self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000)) + self.set_noise_scale(sampling_settings.get("noise_scale", 1.0)) + self.set_parameters( + shift=sampling_settings.get("shift", 1.0), + multiplier=sampling_settings.get("multiplier", 1000), + ) def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000): self.shift = shift @@ -296,6 +301,9 @@ class ModelSamplingDiscreteFlow(torch.nn.Module): ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier) self.register_buffer('sigmas', ts) + def set_noise_scale(self, noise_scale): + self.noise_scale = float(noise_scale) + @property def sigma_min(self): return self.sigmas[0] diff --git a/comfy/ops.py b/comfy/ops.py index 77ad1d527..117cdd327 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1285,7 +1285,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict: self.quant_format = quant_format qconfig = QUANT_ALGOS[quant_format] - layout_cls = get_layout_class(qconfig["comfy_tensor_layout"]) + self.layout_type = qconfig["comfy_tensor_layout"] + layout_cls = get_layout_class(self.layout_type) weight = state_dict.pop(weight_key) manually_loaded_keys = [weight_key] diff --git a/comfy/sd.py b/comfy/sd.py index 749bdd710..ab2718892 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -239,7 +239,8 @@ class CLIP: model_management.archive_model_dtypes(self.cond_stage_model) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + te_disable_dynamic = disable_dynamic or getattr(self.cond_stage_model, "disable_offload", False) + ModelPatcher = comfy.model_patcher.ModelPatcher if te_disable_dynamic else comfy.model_patcher.CoreModelPatcher self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) @@ -776,6 +777,7 @@ class VAE: self.latent_channels = 3 self.latent_dim = 2 self.output_channels = 3 + self.disable_offload = True elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE sample_rate = 16000 if sample_rate == 16000: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 40417f922..8d2e02f68 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -28,6 +28,7 @@ import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.ernie import comfy.text_encoders.cogvideo +import comfy.text_encoders.hidream_o1 from . import supported_models_base from . import latent_formats @@ -1431,6 +1432,50 @@ class HiDream(supported_models_base.BASE): def clip_target(self, state_dict={}): return None # TODO +class HiDreamO1(supported_models_base.BASE): + unet_config = { + "image_model": "hidream_o1", + } + + sampling_settings = { + "shift": 3.0, + "noise_scale": 8.0, + } + + latent_format = latent_formats.HiDreamO1Pixel + memory_usage_factor = 0.6 + # fp16 not supported: LM MLP down_proj activations fp16 overflow, causing NaNs + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + optimizations = {"fp8": False} + + def get_model(self, state_dict, prefix="", device=None): + return model_base.HiDreamO1(self, device=device) + + def process_unet_state_dict(self, state_dict): + # Drop unused Qwen3-VL deepstack merger weights; upstream discards them at inference. + for key in list(state_dict.keys()): + if "visual.deepstack_merger_list" in key: + del state_dict[key] + return state_dict + + def process_vae_state_dict(self, state_dict): + # Pixel-space model: inject sentinel so VAE construction picks PixelspaceConversionVAE. + return {"pixel_space_vae": torch.tensor(1.0)} + + def process_clip_state_dict(self, state_dict): + # Tokenizer-only TE: inject sentinel so load_state_dict_guess_config triggers CLIP init. + return {"_hidream_o1_te_sentinel": torch.zeros(1)} + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget( + comfy.text_encoders.hidream_o1.HiDreamO1Tokenizer, + comfy.text_encoders.hidream_o1.HiDreamO1TE, + ) + class Chroma(supported_models_base.BASE): unet_config = { "image_model": "chroma", @@ -2018,6 +2063,7 @@ models = [ Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, + HiDreamO1, Chroma, ChromaRadiance, ACEStep, diff --git a/comfy/text_encoders/hidream_o1.py b/comfy/text_encoders/hidream_o1.py new file mode 100644 index 000000000..5d287b784 --- /dev/null +++ b/comfy/text_encoders/hidream_o1.py @@ -0,0 +1,119 @@ +"""HiDream-O1-Image tokenizer-only text encoder. + +The real Qwen3-VL backbone runs inside diffusion_model.* every step, so this +module just tokenizes the prompt into text_input_ids and emits them as +conditioning. Position ids / token_types / vinput_mask depend on target H/W +and are built later in model_base.HiDreamO1.extra_conds. +""" + +import os + +import torch +from transformers import Qwen2Tokenizer + +from comfy import sd1_clip + + +# Qwen3-VL special tokens +IM_START_ID = 151644 +IM_END_ID = 151645 +ASSISTANT_ID = 77091 +USER_ID = 872 +NEWLINE_ID = 198 +VISION_START_ID = 151652 +VISION_END_ID = 151653 +IMAGE_TOKEN_ID = 151655 +VIDEO_TOKEN_ID = 151656 +# HiDream-O1-specific tokens +BOI_TOKEN_ID = 151669 +BOR_TOKEN_ID = 151670 +EOR_TOKEN_ID = 151671 +BOT_TOKEN_ID = 151672 +TMS_TOKEN_ID = 151673 + + +class HiDreamO1QwenTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer" + ) + super().__init__( + tokenizer_path, + pad_with_end=False, + embedding_size=4096, + embedding_key="hidream_o1", + tokenizer_class=Qwen2Tokenizer, + has_start_token=False, + has_end_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=1, + pad_token=151643, + tokenizer_data=tokenizer_data, + ) + + +class HiDreamO1Tokenizer(sd1_clip.SD1Tokenizer): + """Wraps prompt in the upstream chat template ending with boi/tms markers. + Image tokens get spliced in at sample time once target H/W is known. + """ + + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, + name="hidream_o1", + tokenizer=HiDreamO1QwenTokenizer, + ) + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + text_tokens_dict = super().tokenize_with_weights( + text, return_word_ids=return_word_ids, disable_weights=True, **kwargs + ) + text_tuples = text_tokens_dict["hidream_o1"][0] + text_tuples = [t for t in text_tuples if int(t[0]) != 151643] # strip pad + + # <|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|boi|><|tms|> + def tok(tid): + return (tid, 1.0) if not return_word_ids else (tid, 1.0, 0) + + prefix = [tok(IM_START_ID), tok(USER_ID), tok(NEWLINE_ID)] + suffix = [ + tok(IM_END_ID), tok(NEWLINE_ID), + tok(IM_START_ID), tok(ASSISTANT_ID), tok(NEWLINE_ID), + tok(BOI_TOKEN_ID), tok(TMS_TOKEN_ID), + ] + full = prefix + list(text_tuples) + suffix + return {"hidream_o1": [full]} + + +class HiDreamO1TE(torch.nn.Module): + """Passthrough TE: emits int token ids; the Qwen3-VL backbone in diffusion_model does the actual encoding.""" + + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__() + self.dtypes = {torch.float32} + self.disable_offload = True # skips dynamic VRAM management for this zero-parameter module + self.device = torch.device("cpu") if device is None else torch.device(device) + + def encode_token_weights(self, token_weight_pairs): + tok_pairs = token_weight_pairs["hidream_o1"][0] + ids = [int(t[0]) for t in tok_pairs] + input_ids = torch.tensor([ids], dtype=torch.long) + # Surrogate keeps the cross_attn slot non-empty for CONDITIONING + # plumbing; the model reads text_input_ids out of `extra` instead. + cross_attn = input_ids.unsqueeze(-1).to(torch.float32) + extra = {"text_input_ids": input_ids} + return cross_attn, None, extra + + def load_sd(self, sd): + return [] + + def get_sd(self): + return {} + + def reset_clip_options(self): + pass + + def set_clip_options(self, options): + pass diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index a34c41144..5087228ca 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -397,7 +397,7 @@ class RMSNorm(nn.Module): -def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None): +def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None, interleaved_mrope=False): if not isinstance(theta, list): theta = [theta] @@ -415,16 +415,27 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - if rope_dims is not None and position_ids.shape[0] > 1: - mrope_section = rope_dims * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + if rope_dims is not None and position_ids.shape[0] > 1 and interleaved_mrope: + # Qwen3-VL interleaved MRoPE: T-freqs by default, H/W replace every 3rd dim. + freqs_inter = freqs[0].clone() + for axis_idx, offset in ((1, 1), (2, 2)): + length = rope_dims[axis_idx] * 3 + idx = slice(offset, length, 3) + freqs_inter[..., idx] = freqs[axis_idx, ..., idx] + emb = torch.cat((freqs_inter, freqs_inter), dim=-1) + cos = emb.cos().unsqueeze(0) + sin = emb.sin().unsqueeze(0) else: - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + if rope_dims is not None and position_ids.shape[0] > 1: + mrope_section = rope_dims * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + else: + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) sin_split = sin.shape[-1] // 2 out.append((cos, sin[..., : sin_split], -sin[..., sin_split :])) @@ -689,6 +700,7 @@ class Llama2_(nn.Module): self.config.rope_theta, self.config.rope_scale, self.config.rope_dims, + interleaved_mrope=getattr(self.config, "interleaved_mrope", False), device=device) def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None): diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index d8ed9cd32..d0bbec3e6 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -651,7 +651,7 @@ class Qwen35VisionModel(nn.Module): x = self.patch_embed(x) pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device) x = x + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rot_pos_emb(grid_thw).to(x.device) seq_len = x.shape[0] x = x.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py index 7e8411fa4..567c37be0 100644 --- a/comfy_extras/nodes_advanced_samplers.py +++ b/comfy_extras/nodes_advanced_samplers.py @@ -86,6 +86,37 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No return x +class SamplerLCM(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerLCM", + category="sampling/samplers", + description=("LCM sampler with tunable per-step noise. s_noise is a multiplier on the model's training noise scale"), + inputs=[ + io.Float.Input("s_noise", default=1.0, min=0.0, max=64.0, step=0.01, + tooltip="Per-step noise multiplier at the first step (1.0 = match training)."), + io.Float.Input("s_noise_end", default=1.0, min=0.0, max=64.0, step=0.01, + tooltip="Per-step noise multiplier at the last step. Set equal to s_noise for a constant schedule."), + io.Float.Input("noise_clip_std", default=0.0, min=0.0, max=10.0, step=0.01, + tooltip="Clamp per-step noise to +/- N*std. 0 disables."), + ], + outputs=[io.Sampler.Output()], + ) + + @classmethod + def execute(cls, s_noise, s_noise_end, noise_clip_std) -> io.NodeOutput: + sampler = comfy.samplers.ksampler( + "lcm", + { + "s_noise": float(s_noise), + "s_noise_end": float(s_noise_end), + "noise_clip_std": float(noise_clip_std), + }, + ) + return io.NodeOutput(sampler) + + class SamplerEulerCFGpp(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: @@ -114,6 +145,7 @@ class AdvancedSamplersExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ SamplerLCMUpscale, + SamplerLCM, SamplerEulerCFGpp, ] diff --git a/comfy_extras/nodes_hidream_o1.py b/comfy_extras/nodes_hidream_o1.py new file mode 100644 index 000000000..f393745f6 --- /dev/null +++ b/comfy_extras/nodes_hidream_o1.py @@ -0,0 +1,256 @@ +from typing_extensions import override + +import torch + +import comfy.model_management +import comfy.patcher_extension +import node_helpers +from comfy_api.latest import ComfyExtension, io + + +class EmptyHiDreamO1LatentImage(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyHiDreamO1LatentImage", + display_name="Empty HiDream-O1 Latent Image", + category="latent/image", + description=( + "Empty pixel-space latent for HiDream-O1-Image. The model was " + "trained at ~4 megapixels; lower resolutions go off-distribution " + "and quality regresses noticeably. Trained resolutions: " + "2048x2048, 2304x1728, 1728x2304, 2560x1440, 1440x2560, " + "2496x1664, 1664x2496, 3104x1312, 1312x3104, 2304x1792, 1792x2304." + ), + inputs=[ + io.Int.Input(id="width", default=2048, min=64, max=4096, step=32), + io.Int.Input(id="height", default=2048, min=64, max=4096, step=32), + io.Int.Input(id="batch_size", default=1, min=1, max=64), + ], + outputs=[io.Latent().Output()], + ) + + @classmethod + def execute(cls, *, width: int, height: int, batch_size: int = 1) -> io.NodeOutput: + latent = torch.zeros( + (batch_size, 3, height, width), + device=comfy.model_management.intermediate_device(), + ) + return io.NodeOutput({"samples": latent}) + + +class HiDreamO1ReferenceImages(io.ComfyNode): + """Attach reference images to both positive and negative conditioning.""" + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="HiDreamO1ReferenceImages", + display_name="HiDream-O1 Reference Images", + category="conditioning/image", + description=( + "Attach 1-10 reference images to conditioning, one for edit instruction" + "or multiple for subject-driven personalization." + ), + inputs=[ + io.Conditioning.Input(id="positive"), + io.Conditioning.Input(id="negative"), + io.Autogrow.Input( + "images", + template=io.Autogrow.TemplateNames( + io.Image.Input("image"), + names=[f"image_{i}" for i in range(1, 11)], + min=1, + ), + tooltip=("Reference images. 1 image = instruction edit; 2-10 images = multi reference." + ), + ), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, *, positive, negative, images: io.Autogrow.Type) -> io.NodeOutput: + refs = [images[f"image_{i}"] for i in range(1, 11) if f"image_{i}" in images] + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": refs}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": refs}, append=True) + return io.NodeOutput(positive, negative) + + +class HiDreamO1PatchSeamSmoothing(io.ComfyNode): + PATCH_SIZE = 32 + EDGE_FEATHER = 4 + + # Shift presets per (pattern, N). 8-pass = 4-quadrant + 4 quarter-patch offsets. + SHIFTS_BY_PATTERN = { + ("single_shift", 2): [(0, 0), (16, 16)], + ("single_shift", 4): [(0, 0), (16, 0), (0, 16), (16, 16)], + ("single_shift", 8): [(0, 0), (16, 0), (0, 16), (16, 16), + (8, 8), (24, 8), (8, 24), (24, 24)], + ("symmetric", 2): [(-8, -8), (8, 8)], + ("symmetric", 4): [(-8, -8), (8, -8), (-8, 8), (8, 8)], + ("symmetric", 8): [(-12, -12), (4, -12), (-12, 4), (4, 4), + (-4, -4), (12, -4), (-4, 12), (12, 12)], + } + RAMP_LEVELS = { + "2": [2], + "4": [4], + "ramp_2_4": [2, 4], + "ramp_2_4_8": [2, 4, 8], + } + + @staticmethod + def _hann_tile(cy: int, cx: int, size: int = 32) -> torch.Tensor: + """size x size Hann tile peaking at (cy, cx) within a patch.""" + half = size // 2 + yy = torch.arange(size).view(size, 1) + xx = torch.arange(size).view(1, size) + dy = ((yy - cy + half) % size) - half + dx = ((xx - cx + half) % size) - half + return 0.25 * (1 + torch.cos(torch.pi * dy / half)) * (1 + torch.cos(torch.pi * dx / half)) + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="HiDreamO1PatchSeamSmoothing", + display_name="HiDream-O1 Patch Seam Smoothing", + category="advanced/model", + is_experimental=True, + description=( + "Average the model output across multiple shifted patch-grid " + "positions during the late portion of sampling. Cancels seams." + ), + inputs=[ + io.Model.Input(id="model"), + io.Float.Input(id="start_percent", default=0.8, min=0.0, max=1.0, step=0.01, + tooltip="Sampling progress (0=start, 1=end) at which the blend turns ON.", + ), + io.Float.Input(id="end_percent", default=1.0, min=0.0, max=1.0, step=0.01, + tooltip="Sampling progress at which the blend turns OFF.", + ), + io.Combo.Input( + id="pattern", + options=["single_shift", "symmetric"], + default="single_shift", + tooltip="Shift layout. single_shift: one pass at the natural patch grid + others offset. symmetric: all passes off-grid, shifts split around origin.", + ), + io.Combo.Input( + id="passes", + options=["2", "4", "ramp_2_4", "ramp_2_4_8"], + default="2", + tooltip="Number of passes per gated step. 2/4 = fixed. ramp_*: pass count increases as sampling approaches end (more smoothing where seams are most visible).", + ), + io.Combo.Input( + id="blend", + options=["average", "window", "median"], + default="average", + tooltip="average: equal-weight mean. window: Hann-windowed weighting favoring each pass away from its patch boundaries. median: per-pixel median, rejects wraparound-outlier passes.", + ), + io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01, + tooltip="Interpolation between the natural-grid pred (0) and the averaged result (1).", + ), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def execute(cls, *, model, start_percent: float, end_percent: float, pattern: str, passes: str, blend: str, strength: float) -> io.NodeOutput: + if strength <= 0.0 or end_percent <= start_percent: + return io.NodeOutput(model) + + P = cls.PATCH_SIZE + half = P // 2 + shift_levels = [cls.SHIFTS_BY_PATTERN[(pattern, n)] for n in cls.RAMP_LEVELS[passes]] + + if blend == "window": + window_tile_levels = [ + torch.stack([cls._hann_tile((half - sy) % P, (half - sx) % P, P) for sy, sx in lst], dim=0) + for lst in shift_levels + ] + else: + window_tile_levels = [None] * len(shift_levels) + + m = model.clone() + model_sampling = m.get_model_object("model_sampling") + multiplier = float(model_sampling.multiplier) + start_t = float(model_sampling.percent_to_sigma(start_percent)) * multiplier + end_t = float(model_sampling.percent_to_sigma(end_percent)) * multiplier + + edge_ramp_cache: dict = {} + + def get_edge_ramp(H: int, W: int, device, dtype) -> torch.Tensor: + key = (H, W, device, dtype) + cached = edge_ramp_cache.get(key) + if cached is not None: + return cached + feather = cls.EDGE_FEATHER + ys = torch.minimum(torch.arange(H, device=device, dtype=torch.float32), + (H - 1) - torch.arange(H, device=device, dtype=torch.float32)) + xs = torch.minimum(torch.arange(W, device=device, dtype=torch.float32), + (W - 1) - torch.arange(W, device=device, dtype=torch.float32)) + y_mask = ((ys - P) / feather).clamp(0, 1) + x_mask = ((xs - P) / feather).clamp(0, 1) + ramp = (y_mask[:, None] * x_mask[None, :]).to(dtype) + edge_ramp_cache[key] = ramp + return ramp + + def smoothing_wrapper(executor, *args, **kwargs): + x = args[0] + t = float(args[1][0]) + pred = executor(*args, **kwargs) + if not (end_t <= t <= start_t): + return pred + # Pick shift-level by sigma phase across the gated range. + if len(shift_levels) == 1: + level_idx = 0 + else: + phase = (start_t - t) / max(start_t - end_t, 1e-8) + level_idx = min(int(phase * len(shift_levels)), len(shift_levels) - 1) + shifts = shift_levels[level_idx] + window_tiles = window_tile_levels[level_idx] + + preds = [] + for sy, sx in shifts: + if sy == 0 and sx == 0: + preds.append(pred) + continue + x_rolled = torch.roll(x, shifts=(sy, sx), dims=(-2, -1)) + pred_rolled = executor(x_rolled, *args[1:], **kwargs) + preds.append(torch.roll(pred_rolled, shifts=(-sy, -sx), dims=(-2, -1))) + stacked = torch.stack(preds, dim=0) # (N, B, C, H, W) + _, _, _, H, W = stacked.shape + if blend == "window": + N = stacked.shape[0] + tiles = window_tiles.to(device=stacked.device, dtype=stacked.dtype) + w = tiles.repeat(1, H // P, W // P)[:, :H, :W] + sum_w = w.sum(dim=0, keepdim=True) + w = torch.where(sum_w < 1e-3, torch.full_like(w, 1.0 / N), w / sum_w.clamp(min=1e-8)) + avg = (stacked * w[:, None, None, :, :]).sum(dim=0) + elif blend == "median": + avg = torch.median(stacked, dim=0).values + else: + avg = stacked.mean(dim=0) + + # Mask out the P-px wraparound contamination strip at each edge. + mask = get_edge_ramp(H, W, pred.device, pred.dtype) + return pred * (1.0 - mask * strength) + avg * (mask * strength) + + m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "hidream_o1_patch_seam_smoothing", smoothing_wrapper) + return io.NodeOutput(m) + + +class HiDreamO1Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyHiDreamO1LatentImage, + HiDreamO1ReferenceImages, + HiDreamO1PatchSeamSmoothing, + ] + + +async def comfy_entrypoint() -> HiDreamO1Extension: + return HiDreamO1Extension() diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 8bf6a1afa..33b940a0f 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -300,6 +300,29 @@ class RescaleCFG: m.set_model_sampler_cfg_function(rescale_cfg) return (m, ) +class ModelNoiseScale: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "noise_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 64.0, "step": 0.01, + "tooltip": "Absolute training noise scale. For example HiDream-O1 base: 8.0, dev: 7.5."}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, noise_scale): + m = model.clone() + original = m.model.model_sampling + ms = type(original)(m.model.model_config) + ms.set_parameters(shift=original.shift, multiplier=original.multiplier) + ms.set_noise_scale(noise_scale) + m.add_object_patch("model_sampling", ms) + return (m, ) + + class ModelComputeDtype: SEARCH_ALIASES = ["model precision", "change dtype"] @classmethod @@ -327,6 +350,7 @@ NODE_CLASS_MAPPINGS = { "ModelSamplingSD3": ModelSamplingSD3, "ModelSamplingAuraFlow": ModelSamplingAuraFlow, "ModelSamplingFlux": ModelSamplingFlux, + "ModelNoiseScale": ModelNoiseScale, "RescaleCFG": RescaleCFG, "ModelComputeDtype": ModelComputeDtype, } diff --git a/nodes.py b/nodes.py index ec66e54d7..78aaaef74 100644 --- a/nodes.py +++ b/nodes.py @@ -2435,6 +2435,7 @@ async def init_builtin_extra_nodes(): "nodes_sam3.py", "nodes_void.py", "nodes_wandancer.py", + "nodes_hidream_o1.py", ] import_failed = []