feat: Support HiDream-O1-Image (CORE-187) (#13817)

* Initial HiDream01-image support

* Cleanup nodes

* Cleaner handling of empty placeholder models

* Remove snap_to_predefined, prefer tooltip for the trained resolutions

* Add model and block wrappers

* Fix shift tooltip

* Add node to work around the patch tile issue

Experimental, runs multiple passes with the patch grid offset and blends with various different methods.

* Qwen35 vision rotary_pos_emb cast fix

* Fix embedding layout type

* Some small optimizations

* Cleanup, don't need this fallback

* Prefix KV cache, cleanup

Bit of speed, reduce redundant code

* Get rid of redundant custom sampler, refactor noise scaling

Our existing lcm sampler is mathematically same, just added the missing options to it instead and a node to control them. Refactored the noise scaling and fix it for the stochastic samplers, add a generic node to control the initial noise scale.

* Update nodes_hidream_o1.py

* Fix some cache validation cases

* Keep existing sampling params

* Remove redundant video vision path

* Replace some numpy ops with torch

* Fx RoPE index for batch size > 1

* Prefer torch preprocessing

* Rename block_type to be compatible with existing patch nodes

* Fixes and tweaks
This commit is contained in:
Jukka Seppänen 2026-05-12 06:35:53 +03:00 committed by GitHub
parent 0a7d2ffd68
commit 8e53f001a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1339 additions and 21 deletions

View File

@ -242,6 +242,7 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -373,6 +374,7 @@ def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -686,6 +688,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
@ -747,6 +750,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -832,6 +836,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
old_denoised = None
h, h_last = None, None
@ -889,6 +894,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
denoised_1, denoised_2 = None, None
h, h_1, h_2 = None, None, None
@ -1006,23 +1012,39 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None,
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, s_noise=1.0, s_noise_end=None, noise_clip_std=0.0):
# s_noise / s_noise_end: per-step noise multiplier, linearly interpolated across steps
# noise_clip_std: clamp injected noise to +/- N stddevs (0 disables).
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
n_steps = max(1, len(sigmas) - 1)
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
s_start = float(s_noise)
s_end = s_start if s_noise_end is None else float(s_noise_end)
for i in trange(n_steps, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
x = denoised
if sigmas[i + 1] > 0:
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
noise = noise_sampler(sigmas[i], sigmas[i + 1])
if noise_clip_std > 0:
clip_val = noise_clip_std * noise.std()
noise = noise.clamp(min=-clip_val, max=clip_val)
t = (i / (n_steps - 1)) if n_steps > 1 else 0.0
s_noise_i = s_start + (s_end - s_start) * t
if s_noise_i != 1.0:
noise = noise * s_noise_i
x = model_sampling.noise_scaling(sigmas[i + 1], noise, x)
return x
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
@ -1249,6 +1271,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
uncond_denoised = None
@ -1296,6 +1319,7 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
temp = [0]
def post_cfg_function(args):
@ -1371,6 +1395,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
@ -1504,6 +1529,7 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
def default_er_sde_noise_scaler(x):
@ -1574,9 +1600,10 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
inject_noise = eta > 0 and s_noise > 0
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
@ -1645,9 +1672,10 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
inject_noise = eta > 0 and s_noise > 0
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
@ -1713,6 +1741,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)

View File

@ -792,6 +792,13 @@ class ZImagePixelSpace(ChromaRadiance):
"""
pass
class HiDreamO1Pixel(ChromaRadiance):
"""Pixel-space latent format for HiDream-O1.
No VAE model patches/unpatches raw RGB internally with patch_size=32.
"""
pass
class CogVideoX(LatentFormat):
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).

View File

@ -0,0 +1,41 @@
"""HiDream-O1 two-pass attention: tokens [0, ar_len) are causal, [ar_len, T)
attend full K/V. Splitting Q at the boundary avoids the (B, 1, T, T) additive
mask the general-purpose path would build (~500 MB at T~16K) and lets the
gen half hit the user's preferred backend via optimized_attention.
"""
import torch
import comfy.ops
from comfy.ldm.modules.attention import optimized_attention
def make_two_pass_attention(ar_len: int, transformer_options=None):
"""Build a two-pass attention callable. AR pass uses SDPA-causal directly, gen pass routes through optimized_attention.
The AR pass goes through SDPA directand bypasses wrappers, it is only ~1% of T at typical edit sizes.
"""
def two_pass_attention(q, k, v, heads, **kwargs):
B, H, T, D = q.shape
if T < k.shape[2]: # KV-cache hot path: Q is shorter than K/V (cached AR prefix is in K/V only), all fresh Q positions are in the gen region, single full-attention call
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
elif ar_len >= T:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
elif ar_len <= 0:
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
else:
out_ar = comfy.ops.scaled_dot_product_attention(
q[:, :, :ar_len], k[:, :, :ar_len], v[:, :, :ar_len],
attn_mask=None, dropout_p=0.0, is_causal=True,
)
out_gen = optimized_attention(
q[:, :, ar_len:], k, v, heads,
mask=None, skip_reshape=True, skip_output_reshape=True,
transformer_options=transformer_options,
)
out = torch.cat([out_ar, out_gen], dim=2)
return out.transpose(1, 2).reshape(B, T, H * D)
return two_pass_attention

View File

@ -0,0 +1,230 @@
"""HiDream-O1 conditioning prep — ref-image dual path + extra_conds assembly.
Each ref image goes through two paths: a 32x32 patchified stream concatenated
to the noised target, and a Qwen3-VL ViT path producing tokens that scatter
into input_ids at <|image_pad|> positions.
"""
from typing import List
import torch
import comfy.utils
from comfy.text_encoders.qwen_vl import process_qwen2vl_images
from .utils import (PATCH_SIZE, calculate_dimensions, cond_image_size, ref_max_size, resize_tensor)
# Qwen3-VL ViT preprocessing constants (preprocessor_config.json).
VIT_PATCH = 16
VIT_MERGE = 2
VIT_IMAGE_MEAN = [0.5, 0.5, 0.5]
VIT_IMAGE_STD = [0.5, 0.5, 0.5]
def prepare_ref_images(
ref_images: List[torch.Tensor],
target_h: int,
target_w: int,
device: torch.device,
dtype: torch.dtype,
):
"""Build the dual-path tensors for K reference images at (target_h, target_w).
Returns None for K=0, else a dict with ref_patches, ref_pixel_values,
ref_image_grid_thw, per_ref_vit_tokens, per_ref_patch_grids.
"""
K = len(ref_images)
if K == 0:
return None
max_size = ref_max_size(max(target_h, target_w), K)
cis = cond_image_size(K)
refs_t = [img[0].clamp(0, 1).permute(2, 0, 1).unsqueeze(0).contiguous().float() for img in ref_images]
refs_t = [resize_tensor(t, max_size, PATCH_SIZE) for t in refs_t]
# 32-patch path.
ref_patches_per = []
per_ref_patch_grids = []
for t in refs_t:
t_norm = (t.squeeze(0) - 0.5) / 0.5 # (3, H, W) in [-1, 1]
h_p, w_p = t_norm.shape[-2] // PATCH_SIZE, t_norm.shape[-1] // PATCH_SIZE
per_ref_patch_grids.append((h_p, w_p))
patches = (
t_norm.reshape(3, h_p, PATCH_SIZE, w_p, PATCH_SIZE)
.permute(1, 3, 0, 2, 4)
.reshape(h_p * w_p, 3 * PATCH_SIZE * PATCH_SIZE)
)
ref_patches_per.append(patches)
ref_patches = torch.cat(ref_patches_per, dim=0).unsqueeze(0).to(device=device, dtype=dtype)
# ViT path.
refs_vlm_t = []
for t in refs_t:
_, _, h, w = t.shape
cond_w, cond_h = calculate_dimensions(cis, w / h)
cond_w = max(cond_w, VIT_PATCH * VIT_MERGE)
cond_h = max(cond_h, VIT_PATCH * VIT_MERGE)
refs_vlm_t.append(comfy.utils.common_upscale(t, cond_w, cond_h, "lanczos", "disabled"))
pv_list, grid_list, per_ref_vit_tokens = [], [], []
for t_v in refs_vlm_t:
pv, grid_thw = process_qwen2vl_images(
t_v.permute(0, 2, 3, 1),
min_pixels=0, max_pixels=10**12,
patch_size=VIT_PATCH, merge_size=VIT_MERGE,
image_mean=VIT_IMAGE_MEAN, image_std=VIT_IMAGE_STD,
)
grid_thw = grid_thw[0]
pv_list.append(pv.to(device=device, dtype=dtype))
grid_list.append(grid_thw.to(device=device))
# Post-merge token count = number of <|image_pad|> tokens this image expands to in input_ids.
gh, gw = int(grid_thw[1].item()), int(grid_thw[2].item())
per_ref_vit_tokens.append((gh // VIT_MERGE) * (gw // VIT_MERGE))
return {
"ref_patches": ref_patches,
"ref_pixel_values": torch.cat(pv_list, dim=0),
"ref_image_grid_thw": torch.stack(grid_list, dim=0),
"per_ref_vit_tokens": per_ref_vit_tokens,
"per_ref_patch_grids": per_ref_patch_grids,
}
def build_ref_input_ids(
text_input_ids: torch.Tensor,
per_ref_vit_tokens: List[int],
image_token_id: int,
vision_start_id: int,
vision_end_id: int,
):
"""Splice [vision_start, image_pad*N, vision_end] blocks into input_ids
after the [im_start, user, \\n] prefix (matches original chat template).
"""
ids = text_input_ids[0].tolist()
inserted = []
for n_pad in per_ref_vit_tokens:
inserted.extend([vision_start_id] + [image_token_id] * n_pad + [vision_end_id])
new_ids = ids[:3] + inserted + ids[3:] # 3 = len([im_start, user, \n])
return torch.tensor([new_ids], dtype=text_input_ids.dtype, device=text_input_ids.device)
def build_extra_conds(
text_input_ids: torch.Tensor,
noise: torch.Tensor,
ref_images: List[torch.Tensor] = None,
target_patch_size: int = 32,
):
"""Assemble all conditioning tensors for HiDreamO1Transformer.forward:
input_ids (with ref-vision tokens spliced in for the edit/IP path),
position_ids (MRoPE), token_types, vinput_mask, plus the ref
dual-path tensors when refs are provided.
"""
from .utils import get_rope_index_fix_point
from comfy.text_encoders.hidream_o1 import (
IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID,
)
if text_input_ids.dim() == 1:
text_input_ids = text_input_ids.unsqueeze(0)
text_input_ids = text_input_ids.long().to(noise.device)
B = noise.shape[0]
if text_input_ids.shape[0] == 1 and B > 1:
text_input_ids = text_input_ids.expand(B, -1)
H, W = noise.shape[-2], noise.shape[-1]
h_p, w_p = H // target_patch_size, W // target_patch_size
image_len = h_p * w_p
image_grid_thw_tgt = torch.tensor(
[[1, h_p, w_p]], dtype=torch.long, device=text_input_ids.device,
)
out = {}
if ref_images:
ref = prepare_ref_images(ref_images, H, W, device=noise.device, dtype=noise.dtype)
text_input_ids = build_ref_input_ids(
text_input_ids, ref["per_ref_vit_tokens"],
IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID,
)
new_txt_len = text_input_ids.shape[1]
# Each ref's patchified stream gets a [vision_start, image_pad*N-1]
# block in the position-id stream after the noised target.
ref_grid_lengths = [hp * wp for (hp, wp) in ref["per_ref_patch_grids"]]
tgt_vision = torch.full((1, image_len), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
tgt_vision[:, 0] = VISION_START_ID
ref_vision_blocks = []
for rl in ref_grid_lengths:
blk = torch.full((1, rl), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
blk[:, 0] = VISION_START_ID
ref_vision_blocks.append(blk)
ref_vision_cat = torch.cat([tgt_vision] + ref_vision_blocks, dim=1)
input_ids_pad = torch.cat([text_input_ids, ref_vision_cat], dim=-1)
total_ref_patches_len = sum(ref_grid_lengths)
total_len = new_txt_len + image_len + total_ref_patches_len
# K (ViT, post-merge) + 1 (target) + K (ref-patches) image grids.
K = len(ref_images)
igthw_cond = ref["ref_image_grid_thw"].clone()
igthw_cond[:, 1] //= 2
igthw_cond[:, 2] //= 2
image_grid_thw_ref = torch.tensor(
[[1, hp, wp] for (hp, wp) in ref["per_ref_patch_grids"]],
dtype=torch.long, device=text_input_ids.device,
)
igthw_all = torch.cat([
igthw_cond.to(text_input_ids.device),
image_grid_thw_tgt,
image_grid_thw_ref,
], dim=0)
position_ids, _ = get_rope_index_fix_point(
spatial_merge_size=1,
image_token_id=IMAGE_TOKEN_ID,
vision_start_token_id=VISION_START_ID,
input_ids=input_ids_pad, image_grid_thw=igthw_all,
attention_mask=None,
skip_vision_start_token=[0] * K + [1] + [1] * K,
fix_point=4096,
)
# tms + target_image + ref_patches are all gen.
tms_pos = new_txt_len - 1
ar_len = tms_pos
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
token_types[:, tms_pos:] = 1
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
vinput_mask[:, new_txt_len:] = True
# Leading batch dim sidesteps CONDRegular.process_cond's repeat_to_batch_size truncation
out["ref_pixel_values"] = ref["ref_pixel_values"].unsqueeze(0)
out["ref_image_grid_thw"] = ref["ref_image_grid_thw"].unsqueeze(0)
out["ref_patches"] = ref["ref_patches"]
else:
# T2I: text + noised target only, vision_start replaces the first image token
txt_len = text_input_ids.shape[1]
total_len = txt_len + image_len
vision_tokens = torch.full((B, image_len), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
vision_tokens[:, 0] = VISION_START_ID
input_ids_pad = torch.cat([text_input_ids, vision_tokens], dim=-1)
position_ids, _ = get_rope_index_fix_point(
spatial_merge_size=1,
image_token_id=IMAGE_TOKEN_ID,
vision_start_token_id=VISION_START_ID,
input_ids=input_ids_pad, image_grid_thw=image_grid_thw_tgt,
attention_mask=None,
skip_vision_start_token=[1],
)
ar_len = txt_len - 1
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
token_types[:, ar_len:] = 1
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
vinput_mask[:, txt_len:] = True
out["input_ids"] = text_input_ids
out["position_ids"] = position_ids[:, 0].unsqueeze(0) # Collapse position_ids batch and add a leading dim so CONDRegular's batch-resize doesn't truncate the 3-axis MRoPE dim
out["token_types"] = token_types
out["vinput_mask"] = vinput_mask
out["ar_len"] = ar_len
return out

View File

@ -0,0 +1,306 @@
"""HiDream-O1-Image transformer.
Pixel-space DiT built on Qwen3-VL: the vision tower (Qwen35VisionModel)
encodes ref images, the Qwen3-VL-8B decoder (Llama2_ with interleaved MRoPE)
processes a unified text+image sequence, and 32x32 patch embed/unembed
shims map raw RGB in and out of LLM hidden space. The Qwen3-VL deepstack
mergers go unused their weights are dropped at load.
"""
from dataclasses import dataclass, field
from typing import List, Optional
import einops
import torch
import torch.nn as nn
import comfy.patcher_extension
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.text_encoders.llama import Llama2_
from comfy.text_encoders.qwen35 import Qwen35VisionModel
from .attention import make_two_pass_attention
IMAGE_TOKEN_ID = 151655 # Qwen3-VL <|image_pad|>
TMS_TOKEN_ID = 151673 # HiDream-O1 <|tms_token|>
PATCH_SIZE = 32
@dataclass
class HiDreamO1TextConfig:
"""Qwen3-VL-8B text-decoder dims (matches public Qwen3-VL-8B-Instruct)."""
vocab_size: int = 151936
hidden_size: int = 4096
intermediate_size: int = 12288
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
head_dim: int = 128
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 5000000.0
rope_scale: Optional[float] = None
rope_dims: List[int] = field(default_factory=lambda: [24, 20, 20])
interleaved_mrope: bool = True
transformer_type: str = "llama"
rms_norm_add: bool = False
mlp_activation: str = "silu"
qkv_bias: bool = False
q_norm: str = "gemma3"
k_norm: str = "gemma3"
final_norm: bool = True
lm_head: bool = False
stop_tokens: List[int] = field(default_factory=lambda: [151643, 151645])
QWEN3VL_VISION_DEFAULTS = dict(
hidden_size=1152,
num_heads=16,
intermediate_size=4304,
depth=27,
patch_size=16,
temporal_patch_size=2,
in_channels=3,
spatial_merge_size=2,
num_position_embeddings=2304,
deepstack_visual_indexes=(8, 16, 24),
out_hidden_size=4096, # final merger projects directly into LLM hidden
)
class BottleneckPatchEmbed(nn.Module):
# 3072 -> 1024 -> 4096 (raw 32x32 RGB patch -> bottleneck -> LLM hidden).
def __init__(self, patch_size=32, in_chans=3, pca_dim=1024, embed_dim=4096, bias=True, device=None, dtype=None, ops=None):
super().__init__()
self.proj1 = ops.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False, device=device, dtype=dtype)
self.proj2 = ops.Linear(pca_dim, embed_dim, bias=bias, device=device, dtype=dtype)
def forward(self, x):
return self.proj2(self.proj1(x))
class FinalLayer(nn.Module):
# 4096 -> 3072 (LLM hidden -> flat pixel patch).
def __init__(self, hidden_size, patch_size=32, out_channels=3, device=None, dtype=None, ops=None):
super().__init__()
self.linear = ops.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.linear(x)
class HiDreamO1Transformer(nn.Module):
"""HiDream-O1 unified pixel-level transformer."""
def __init__(self, image_model=None, dtype=None, device=None, operations=None,
text_config_overrides=None, vision_config_overrides=None, **kwargs):
super().__init__()
self.dtype = dtype
text_cfg = HiDreamO1TextConfig(**(text_config_overrides or {}))
vision_cfg = dict(QWEN3VL_VISION_DEFAULTS)
if vision_config_overrides:
vision_cfg.update(vision_config_overrides)
vision_cfg["out_hidden_size"] = text_cfg.hidden_size
self.text_config = text_cfg
self.vision_config = vision_cfg
self.hidden_size = text_cfg.hidden_size
self.patch_size = PATCH_SIZE
self.in_channels = 3
self.tms_token_id = TMS_TOKEN_ID
self.visual = Qwen35VisionModel(vision_cfg, device=device, dtype=dtype, ops=operations)
self.language_model = Llama2_(text_cfg, device=device, dtype=dtype, ops=operations)
self.t_embedder1 = TimestepEmbedder(
text_cfg.hidden_size, device=device, dtype=dtype, operations=operations,
)
self.x_embedder = BottleneckPatchEmbed(
patch_size=self.patch_size, in_chans=self.in_channels,
pca_dim=text_cfg.hidden_size // 4, embed_dim=text_cfg.hidden_size,
bias=True, device=device, dtype=dtype, ops=operations,
)
self.final_layer2 = FinalLayer(
text_cfg.hidden_size, patch_size=self.patch_size,
out_channels=self.in_channels, device=device, dtype=dtype, ops=operations,
)
self._visual_cache = None
self._kv_cache_entries = []
def clear_kv_cache(self):
self._kv_cache_entries = []
self._visual_cache = None
def forward(self, x, timesteps, context=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timesteps, context, transformer_options, **kwargs)
def _forward(self, x, timesteps, context=None, transformer_options={}, input_ids=None, attention_mask=None, position_ids=None,
vinput_mask=None, ar_len=None, ref_pixel_values=None, ref_image_grid_thw=None, ref_patches=None, **kwargs):
"""Returns flow-match velocity (x - x_pred) / sigma"""
if input_ids is None or position_ids is None:
raise ValueError("HiDreamO1Transformer requires input_ids and position_ids in conditioning")
B, _, H, W = x.shape
h_p, w_p = H // self.patch_size, W // self.patch_size
tgt_image_len = h_p * w_p
z = einops.rearrange(
x, 'B C (H p1) (W p2) -> B (H W) (C p1 p2)',
p1=self.patch_size, p2=self.patch_size,
)
vinputs = torch.cat([z, ref_patches.to(z.dtype)], dim=1) if ref_patches is not None else z
inputs_embeds = self.language_model.embed_tokens(input_ids).to(x.dtype)
if ref_pixel_values is not None and ref_image_grid_thw is not None:
# ViT output is constant across sampling steps within a generation
# identity-key by the input tensor so refs don't recompute every step.
cached = self._visual_cache
if cached is not None and cached[0] is ref_pixel_values:
image_embeds = cached[1]
else:
ref_pv = ref_pixel_values.to(inputs_embeds.device)
ref_grid = ref_image_grid_thw.to(inputs_embeds.device).long()
# extra_conds wraps with a leading batch dim; refs are model-level so [0] always recovers them.
if ref_pv.dim() == 3:
ref_pv = ref_pv[0]
if ref_grid.dim() == 3:
ref_grid = ref_grid[0]
image_embeds = self.visual(ref_pv, ref_grid).to(inputs_embeds.dtype)
self._visual_cache = (ref_pixel_values, image_embeds)
# image_pad positions identical across batch (input_ids shared cond/uncond).
image_idx = (input_ids[0] == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]
if image_idx.shape[0] != image_embeds.shape[0]:
raise ValueError(
f"Image-token count {image_idx.shape[0]} != ViT output count "
f"{image_embeds.shape[0]}; check tokenizer/processor alignment."
)
inputs_embeds[:, image_idx] = image_embeds.unsqueeze(0).expand(B, -1, -1)
sigma = timesteps.float() / 1000.0
t_pixeldit = 1.0 - sigma
t_emb = self.t_embedder1(t_pixeldit * 1000, inputs_embeds.dtype)
tms_mask_3d = (input_ids == self.tms_token_id).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = torch.where(tms_mask_3d, t_emb.unsqueeze(1).expand_as(inputs_embeds), inputs_embeds)
vinputs_embedded = self.x_embedder(vinputs.to(inputs_embeds.dtype))
inputs_embeds = torch.cat([inputs_embeds, vinputs_embedded], dim=1)
# extra_conds stores position_ids as (1, 3, T); process_cond repeats dim 0 to B. Take row 0.
freqs_cis = self.language_model.compute_freqs_cis(position_ids[0].to(x.device), x.device)
freqs_cis = tuple(t.to(x.dtype) for t in freqs_cis)
two_pass_attn = make_two_pass_attention(ar_len, transformer_options=transformer_options)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.language_model.layers)
transformer_options["block_type"] = "double"
# Cache prefix K/V across steps. Key includes input_ids (prompt), ref_id
# (refs scatter into inputs_embeds), and position_ids (RoPE baked into cached K).
can_cache = not blocks_replace and ar_len > 0
cache_len = ar_len if can_cache else 0
ref_id = id(ref_pixel_values) if ref_pixel_values is not None else None
pos_ids_key = position_ids[..., :cache_len] if can_cache else position_ids
cache_entries = self._kv_cache_entries
# Drop stale entries from a previous device (model was unloaded and reloaded).
if cache_entries and cache_entries[0]["input_ids"].device != input_ids.device:
cache_entries = []
self._kv_cache_entries = []
kv_cache = None
if can_cache:
for entry in cache_entries:
ck = entry["input_ids"]
ep = entry["position_ids"]
if (entry["cache_len"] == cache_len
and ck.shape == input_ids.shape and torch.equal(ck, input_ids)
and entry["ref_id"] == ref_id
and ep.shape == pos_ids_key.shape and torch.equal(ep, pos_ids_key)):
kv_cache = entry
break
if kv_cache is not None:
# Hot path: project Q/K/V only for fresh positions; past_key_value prepends cached AR K/V.
hidden_states = inputs_embeds[:, cache_len:]
sliced_freqs = tuple(t[..., cache_len:, :] for t in freqs_cis)
for i, layer in enumerate(self.language_model.layers):
transformer_options["block_index"] = i
K_i, V_i = kv_cache["kv"][i]
hidden_states, _ = layer(
x=hidden_states, attention_mask=None, freqs_cis=sliced_freqs, optimized_attention=two_pass_attn,
past_key_value=(K_i, V_i, cache_len),
)
else:
# Cold path: run full sequence; if cacheable, snapshot K/V at AR positions.
snapshots = [] if can_cache else None
past_kv_cold = () if can_cache else None
hidden_states = inputs_embeds
for i, layer in enumerate(self.language_model.layers):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args, _layer=layer):
out = {}
out["x"], _ = _layer(
x=args["x"], attention_mask=args.get("attention_mask"),
freqs_cis=args["freqs_cis"], optimized_attention=args["optimized_attention"],
past_key_value=None,
)
return out
out = blocks_replace[("double_block", i)](
{"x": hidden_states, "attention_mask": None,
"freqs_cis": freqs_cis, "optimized_attention": two_pass_attn,
"transformer_options": transformer_options},
{"original_block": block_wrap},
)
hidden_states = out["x"]
else:
hidden_states, present_kv = layer(
x=hidden_states, attention_mask=None,
freqs_cis=freqs_cis, optimized_attention=two_pass_attn,
past_key_value=past_kv_cold,
)
if snapshots is not None:
K, V, _ = present_kv
snapshots.append((K[:, :, :cache_len].contiguous(),
V[:, :, :cache_len].contiguous()))
if snapshots is not None:
# Cap at 2 entries (cond + uncond). Multi-cond workflows LRU-evict.
new_entry = {
"input_ids": input_ids.clone(),
"cache_len": cache_len,
"kv": snapshots,
"ref_id": ref_id,
"position_ids": pos_ids_key.clone(),
}
self._kv_cache_entries = (cache_entries + [new_entry])[-2:]
if self.language_model.norm is not None:
hidden_states = self.language_model.norm(hidden_states)
# Slice target-image positions before the final projection so the Linear only runs on tgt_image_len tokens.
# In the hot path hidden_states starts at original position cache_len, so masks/indices shift by cache_len.
sliced_offset = cache_len if kv_cache is not None else 0
if vinput_mask is not None:
vmask = vinput_mask.to(x.device).bool()
if sliced_offset > 0:
vmask = vmask[:, sliced_offset:]
target_hidden = hidden_states[vmask].view(B, -1, hidden_states.shape[-1])[:, :tgt_image_len]
else:
txt_seq_len = input_ids.shape[1]
start = txt_seq_len - sliced_offset
target_hidden = hidden_states[:, start:start + tgt_image_len]
x_pred_tgt = self.final_layer2(target_hidden)
# fp32 final subtraction, bf16 here noticeably degrades samples.
x_pred_img = einops.rearrange(
x_pred_tgt, 'B (H W) (C p1 p2) -> B C (H p1) (W p2)',
H=h_p, W=w_p, p1=self.patch_size, p2=self.patch_size,
)
return (x.float() - x_pred_img.float()) / sigma.view(B, 1, 1, 1).clamp_min(1e-3)

View File

@ -0,0 +1,173 @@
"""HiDream-O1 input-prep helpers: image/resolution math and unified-sequence
RoPE position-id assembly. The fix_point offset in get_rope_index_fix_point
lets the target image and patchified ref images share spatial RoPE positions
despite living at different sequence indices same 2D image plane.
"""
import math
from typing import Optional
import torch
PATCH_SIZE = 32
CONDITION_IMAGE_SIZE = 384 # ViT-side base size for ref images
def resize_tensor(img_t, image_size, patch_size=16):
"""img_t: (1, 3, H, W) float [0, 1]. Fit to image_size**2 area, patch-aligned, center-cropped."""
while min(img_t.shape[-2], img_t.shape[-1]) >= 2 * image_size: # Pre-halves with 2x2 box averaging while the image is still very large
img_t = torch.nn.functional.avg_pool2d(img_t, kernel_size=2, stride=2)
_, _, height, width = img_t.shape
m = patch_size
s_max = image_size * image_size
scale = math.sqrt(s_max / (width * height))
candidates = [
(round(width * scale) // m * m, round(height * scale) // m * m),
(round(width * scale) // m * m, math.floor(height * scale) // m * m),
(math.floor(width * scale) // m * m, round(height * scale) // m * m),
(math.floor(width * scale) // m * m, math.floor(height * scale) // m * m),
]
candidates = sorted(candidates, key=lambda x: x[0] * x[1], reverse=True)
new_size = candidates[-1]
for c in candidates:
if c[0] * c[1] <= s_max:
new_size = c
break
new_w, new_h = new_size
s1 = width / new_w
s2 = height / new_h
if s1 < s2:
resize_w, resize_h = new_w, round(height / s1)
else:
resize_w, resize_h = round(width / s2), new_h
img_t = torch.nn.functional.interpolate(img_t, size=(resize_h, resize_w), mode="bicubic")
top = (resize_h - new_h) // 2
left = (resize_w - new_w) // 2
return img_t[..., top:top + new_h, left:left + new_w]
def calculate_dimensions(max_size, ratio):
"""(W, H) for an aspect ratio fitting in max_size**2 area, 32-aligned."""
width = math.sqrt(max_size * max_size * ratio)
height = width / ratio
width = int(width / 32) * 32
height = int(height / 32) * 32
return width, height
def ref_max_size(target_max_dim, k):
"""K-dependent ref-image max dim before patchifying."""
if k == 1:
return target_max_dim
if k == 2:
return target_max_dim * 48 // 64
if k <= 4:
return target_max_dim // 2
if k <= 8:
return target_max_dim * 24 // 64
return target_max_dim // 4
def cond_image_size(k):
"""K-dependent ViT-side image size."""
if k <= 4:
return CONDITION_IMAGE_SIZE
if k <= 8:
return CONDITION_IMAGE_SIZE * 48 // 64
return CONDITION_IMAGE_SIZE // 2
def get_rope_index_fix_point(
spatial_merge_size: int,
image_token_id: int,
vision_start_token_id: int,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
skip_vision_start_token=None,
fix_point: int = 4096,
):
mrope_position_deltas = []
if input_ids is not None and image_grid_thw is not None:
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3, input_ids.shape[0], input_ids.shape[1],
dtype=input_ids.dtype, device=input_ids.device,
)
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids_b in enumerate(total_input_ids):
fp = fix_point
image_index = 0
input_ids_b = input_ids_b[attention_mask[i] == 1]
vision_start_indices = torch.argwhere(input_ids_b == vision_start_token_id).squeeze(1)
vision_tokens = input_ids_b[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
input_tokens = input_ids_b.tolist()
llm_pos_ids_list = []
st = 0
remain_images = image_nums
for _ in range(image_nums):
if image_token_id in input_tokens and remain_images > 0:
ed = input_tokens.index(image_token_id, st)
else:
ed = len(input_tokens) + 1
t = image_grid_thw[image_index][0]
h = image_grid_thw[image_index][1]
w = image_grid_thw[image_index][2]
image_index += 1
remain_images -= 1
llm_grid_t = t.item()
llm_grid_h = h.item() // spatial_merge_size
llm_grid_w = w.item() // spatial_merge_size
text_len = ed - st
text_len -= skip_vision_start_token[image_index - 1]
text_len = max(0, text_len)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
if skip_vision_start_token[image_index - 1]:
if fp > 0:
fp = fp - st_idx
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fp + st_idx)
fp = 0
else:
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1).expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas

View File

@ -58,6 +58,8 @@ import comfy.ldm.cogvideo.model
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
import comfy.ldm.hidream_o1.model
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
import comfy.model_management
import comfy.patcher_extension
@ -1674,6 +1676,32 @@ class HiDream(BaseModel):
out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
return out
class HiDreamO1(BaseModel):
"""HiDream-O1-Image: pixel-space DiT (no VAE). Refs from HiDreamO1ReferenceImages and tokens from the stub TE flow through
extra_conds; the heavy preprocessing lives in comfy.ldm.hidream_o1.conditioning."""
PATCH_SIZE = 32
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream_o1.model.HiDreamO1Transformer)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
text_input_ids = kwargs.get("text_input_ids", None)
noise = kwargs.get("noise", None)
if text_input_ids is None or noise is None:
return out
conds = build_extra_conds(
text_input_ids, noise,
ref_images=kwargs.get("reference_latents", None),
target_patch_size=self.PATCH_SIZE,
)
for k, v in conds.items():
# ar_len is a Python int (precomputed to avoid a GPU sync in forward).
cls = comfy.conds.CONDConstant if k == "ar_len" else comfy.conds.CONDRegular
out[k] = cls(v)
return out
class Chroma(Flux):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma):
super().__init__(model_config, model_type, device=device, unet_model=unet_model)

View File

@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
return {"image_model": "hidream_o1"}
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"

View File

@ -93,7 +93,8 @@ class CONST:
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
return sigma * noise + (1.0 - sigma) * latent_image
s = getattr(self, "noise_scale", 1.0)
return sigma * (s * noise) + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
sigma = reshape_sigma(sigma, latent.ndim)
@ -288,7 +289,11 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000))
self.set_noise_scale(sampling_settings.get("noise_scale", 1.0))
self.set_parameters(
shift=sampling_settings.get("shift", 1.0),
multiplier=sampling_settings.get("multiplier", 1000),
)
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
self.shift = shift
@ -296,6 +301,9 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier)
self.register_buffer('sigmas', ts)
def set_noise_scale(self, noise_scale):
self.noise_scale = float(noise_scale)
@property
def sigma_min(self):
return self.sigmas[0]

View File

@ -1285,7 +1285,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]

View File

@ -239,7 +239,8 @@ class CLIP:
model_management.archive_model_dtypes(self.cond_stage_model)
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
te_disable_dynamic = disable_dynamic or getattr(self.cond_stage_model, "disable_offload", False)
ModelPatcher = comfy.model_patcher.ModelPatcher if te_disable_dynamic else comfy.model_patcher.CoreModelPatcher
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
@ -776,6 +777,7 @@ class VAE:
self.latent_channels = 3
self.latent_dim = 2
self.output_channels = 3
self.disable_offload = True
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
sample_rate = 16000
if sample_rate == 16000:

View File

@ -28,6 +28,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo
import comfy.text_encoders.hidream_o1
from . import supported_models_base
from . import latent_formats
@ -1431,6 +1432,50 @@ class HiDream(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None # TODO
class HiDreamO1(supported_models_base.BASE):
unet_config = {
"image_model": "hidream_o1",
}
sampling_settings = {
"shift": 3.0,
"noise_scale": 8.0,
}
latent_format = latent_formats.HiDreamO1Pixel
memory_usage_factor = 0.6
# fp16 not supported: LM MLP down_proj activations fp16 overflow, causing NaNs
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
optimizations = {"fp8": False}
def get_model(self, state_dict, prefix="", device=None):
return model_base.HiDreamO1(self, device=device)
def process_unet_state_dict(self, state_dict):
# Drop unused Qwen3-VL deepstack merger weights; upstream discards them at inference.
for key in list(state_dict.keys()):
if "visual.deepstack_merger_list" in key:
del state_dict[key]
return state_dict
def process_vae_state_dict(self, state_dict):
# Pixel-space model: inject sentinel so VAE construction picks PixelspaceConversionVAE.
return {"pixel_space_vae": torch.tensor(1.0)}
def process_clip_state_dict(self, state_dict):
# Tokenizer-only TE: inject sentinel so load_state_dict_guess_config triggers CLIP init.
return {"_hidream_o1_te_sentinel": torch.zeros(1)}
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(
comfy.text_encoders.hidream_o1.HiDreamO1Tokenizer,
comfy.text_encoders.hidream_o1.HiDreamO1TE,
)
class Chroma(supported_models_base.BASE):
unet_config = {
"image_model": "chroma",
@ -2018,6 +2063,7 @@ models = [
Hunyuan3Dv2,
Hunyuan3Dv2_1,
HiDream,
HiDreamO1,
Chroma,
ChromaRadiance,
ACEStep,

View File

@ -0,0 +1,119 @@
"""HiDream-O1-Image tokenizer-only text encoder.
The real Qwen3-VL backbone runs inside diffusion_model.* every step, so this
module just tokenizes the prompt into text_input_ids and emits them as
conditioning. Position ids / token_types / vinput_mask depend on target H/W
and are built later in model_base.HiDreamO1.extra_conds.
"""
import os
import torch
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
# Qwen3-VL special tokens
IM_START_ID = 151644
IM_END_ID = 151645
ASSISTANT_ID = 77091
USER_ID = 872
NEWLINE_ID = 198
VISION_START_ID = 151652
VISION_END_ID = 151653
IMAGE_TOKEN_ID = 151655
VIDEO_TOKEN_ID = 151656
# HiDream-O1-specific tokens
BOI_TOKEN_ID = 151669
BOR_TOKEN_ID = 151670
EOR_TOKEN_ID = 151671
BOT_TOKEN_ID = 151672
TMS_TOKEN_ID = 151673
class HiDreamO1QwenTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer"
)
super().__init__(
tokenizer_path,
pad_with_end=False,
embedding_size=4096,
embedding_key="hidream_o1",
tokenizer_class=Qwen2Tokenizer,
has_start_token=False,
has_end_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=1,
pad_token=151643,
tokenizer_data=tokenizer_data,
)
class HiDreamO1Tokenizer(sd1_clip.SD1Tokenizer):
"""Wraps prompt in the upstream chat template ending with boi/tms markers.
Image tokens get spliced in at sample time once target H/W is known.
"""
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
tokenizer_data=tokenizer_data,
name="hidream_o1",
tokenizer=HiDreamO1QwenTokenizer,
)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
text_tokens_dict = super().tokenize_with_weights(
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
)
text_tuples = text_tokens_dict["hidream_o1"][0]
text_tuples = [t for t in text_tuples if int(t[0]) != 151643] # strip pad
# <|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|boi|><|tms|>
def tok(tid):
return (tid, 1.0) if not return_word_ids else (tid, 1.0, 0)
prefix = [tok(IM_START_ID), tok(USER_ID), tok(NEWLINE_ID)]
suffix = [
tok(IM_END_ID), tok(NEWLINE_ID),
tok(IM_START_ID), tok(ASSISTANT_ID), tok(NEWLINE_ID),
tok(BOI_TOKEN_ID), tok(TMS_TOKEN_ID),
]
full = prefix + list(text_tuples) + suffix
return {"hidream_o1": [full]}
class HiDreamO1TE(torch.nn.Module):
"""Passthrough TE: emits int token ids; the Qwen3-VL backbone in diffusion_model does the actual encoding."""
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = {torch.float32}
self.disable_offload = True # skips dynamic VRAM management for this zero-parameter module
self.device = torch.device("cpu") if device is None else torch.device(device)
def encode_token_weights(self, token_weight_pairs):
tok_pairs = token_weight_pairs["hidream_o1"][0]
ids = [int(t[0]) for t in tok_pairs]
input_ids = torch.tensor([ids], dtype=torch.long)
# Surrogate keeps the cross_attn slot non-empty for CONDITIONING
# plumbing; the model reads text_input_ids out of `extra` instead.
cross_attn = input_ids.unsqueeze(-1).to(torch.float32)
extra = {"text_input_ids": input_ids}
return cross_attn, None, extra
def load_sd(self, sd):
return []
def get_sd(self):
return {}
def reset_clip_options(self):
pass
def set_clip_options(self, options):
pass

View File

@ -397,7 +397,7 @@ class RMSNorm(nn.Module):
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None, interleaved_mrope=False):
if not isinstance(theta, list):
theta = [theta]
@ -415,16 +415,27 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
if rope_dims is not None and position_ids.shape[0] > 1 and interleaved_mrope:
# Qwen3-VL interleaved MRoPE: T-freqs by default, H/W replace every 3rd dim.
freqs_inter = freqs[0].clone()
for axis_idx, offset in ((1, 1), (2, 2)):
length = rope_dims[axis_idx] * 3
idx = slice(offset, length, 3)
freqs_inter[..., idx] = freqs[axis_idx, ..., idx]
emb = torch.cat((freqs_inter, freqs_inter), dim=-1)
cos = emb.cos().unsqueeze(0)
sin = emb.sin().unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
sin_split = sin.shape[-1] // 2
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
@ -689,6 +700,7 @@ class Llama2_(nn.Module):
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):

View File

@ -651,7 +651,7 @@ class Qwen35VisionModel(nn.Module):
x = self.patch_embed(x)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
x = x + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
rotary_pos_emb = self.rot_pos_emb(grid_thw).to(x.device)
seq_len = x.shape[0]
x = x.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)

View File

@ -86,6 +86,37 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No
return x
class SamplerLCM(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SamplerLCM",
category="sampling/samplers",
description=("LCM sampler with tunable per-step noise. s_noise is a multiplier on the model's training noise scale"),
inputs=[
io.Float.Input("s_noise", default=1.0, min=0.0, max=64.0, step=0.01,
tooltip="Per-step noise multiplier at the first step (1.0 = match training)."),
io.Float.Input("s_noise_end", default=1.0, min=0.0, max=64.0, step=0.01,
tooltip="Per-step noise multiplier at the last step. Set equal to s_noise for a constant schedule."),
io.Float.Input("noise_clip_std", default=0.0, min=0.0, max=10.0, step=0.01,
tooltip="Clamp per-step noise to +/- N*std. 0 disables."),
],
outputs=[io.Sampler.Output()],
)
@classmethod
def execute(cls, s_noise, s_noise_end, noise_clip_std) -> io.NodeOutput:
sampler = comfy.samplers.ksampler(
"lcm",
{
"s_noise": float(s_noise),
"s_noise_end": float(s_noise_end),
"noise_clip_std": float(noise_clip_std),
},
)
return io.NodeOutput(sampler)
class SamplerEulerCFGpp(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
@ -114,6 +145,7 @@ class AdvancedSamplersExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SamplerLCMUpscale,
SamplerLCM,
SamplerEulerCFGpp,
]

View File

@ -0,0 +1,256 @@
from typing_extensions import override
import torch
import comfy.model_management
import comfy.patcher_extension
import node_helpers
from comfy_api.latest import ComfyExtension, io
class EmptyHiDreamO1LatentImage(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EmptyHiDreamO1LatentImage",
display_name="Empty HiDream-O1 Latent Image",
category="latent/image",
description=(
"Empty pixel-space latent for HiDream-O1-Image. The model was "
"trained at ~4 megapixels; lower resolutions go off-distribution "
"and quality regresses noticeably. Trained resolutions: "
"2048x2048, 2304x1728, 1728x2304, 2560x1440, 1440x2560, "
"2496x1664, 1664x2496, 3104x1312, 1312x3104, 2304x1792, 1792x2304."
),
inputs=[
io.Int.Input(id="width", default=2048, min=64, max=4096, step=32),
io.Int.Input(id="height", default=2048, min=64, max=4096, step=32),
io.Int.Input(id="batch_size", default=1, min=1, max=64),
],
outputs=[io.Latent().Output()],
)
@classmethod
def execute(cls, *, width: int, height: int, batch_size: int = 1) -> io.NodeOutput:
latent = torch.zeros(
(batch_size, 3, height, width),
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput({"samples": latent})
class HiDreamO1ReferenceImages(io.ComfyNode):
"""Attach reference images to both positive and negative conditioning."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="HiDreamO1ReferenceImages",
display_name="HiDream-O1 Reference Images",
category="conditioning/image",
description=(
"Attach 1-10 reference images to conditioning, one for edit instruction"
"or multiple for subject-driven personalization."
),
inputs=[
io.Conditioning.Input(id="positive"),
io.Conditioning.Input(id="negative"),
io.Autogrow.Input(
"images",
template=io.Autogrow.TemplateNames(
io.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 11)],
min=1,
),
tooltip=("Reference images. 1 image = instruction edit; 2-10 images = multi reference."
),
),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
@classmethod
def execute(cls, *, positive, negative, images: io.Autogrow.Type) -> io.NodeOutput:
refs = [images[f"image_{i}"] for i in range(1, 11) if f"image_{i}" in images]
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": refs}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": refs}, append=True)
return io.NodeOutput(positive, negative)
class HiDreamO1PatchSeamSmoothing(io.ComfyNode):
PATCH_SIZE = 32
EDGE_FEATHER = 4
# Shift presets per (pattern, N). 8-pass = 4-quadrant + 4 quarter-patch offsets.
SHIFTS_BY_PATTERN = {
("single_shift", 2): [(0, 0), (16, 16)],
("single_shift", 4): [(0, 0), (16, 0), (0, 16), (16, 16)],
("single_shift", 8): [(0, 0), (16, 0), (0, 16), (16, 16),
(8, 8), (24, 8), (8, 24), (24, 24)],
("symmetric", 2): [(-8, -8), (8, 8)],
("symmetric", 4): [(-8, -8), (8, -8), (-8, 8), (8, 8)],
("symmetric", 8): [(-12, -12), (4, -12), (-12, 4), (4, 4),
(-4, -4), (12, -4), (-4, 12), (12, 12)],
}
RAMP_LEVELS = {
"2": [2],
"4": [4],
"ramp_2_4": [2, 4],
"ramp_2_4_8": [2, 4, 8],
}
@staticmethod
def _hann_tile(cy: int, cx: int, size: int = 32) -> torch.Tensor:
"""size x size Hann tile peaking at (cy, cx) within a patch."""
half = size // 2
yy = torch.arange(size).view(size, 1)
xx = torch.arange(size).view(1, size)
dy = ((yy - cy + half) % size) - half
dx = ((xx - cx + half) % size) - half
return 0.25 * (1 + torch.cos(torch.pi * dy / half)) * (1 + torch.cos(torch.pi * dx / half))
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="HiDreamO1PatchSeamSmoothing",
display_name="HiDream-O1 Patch Seam Smoothing",
category="advanced/model",
is_experimental=True,
description=(
"Average the model output across multiple shifted patch-grid "
"positions during the late portion of sampling. Cancels seams."
),
inputs=[
io.Model.Input(id="model"),
io.Float.Input(id="start_percent", default=0.8, min=0.0, max=1.0, step=0.01,
tooltip="Sampling progress (0=start, 1=end) at which the blend turns ON.",
),
io.Float.Input(id="end_percent", default=1.0, min=0.0, max=1.0, step=0.01,
tooltip="Sampling progress at which the blend turns OFF.",
),
io.Combo.Input(
id="pattern",
options=["single_shift", "symmetric"],
default="single_shift",
tooltip="Shift layout. single_shift: one pass at the natural patch grid + others offset. symmetric: all passes off-grid, shifts split around origin.",
),
io.Combo.Input(
id="passes",
options=["2", "4", "ramp_2_4", "ramp_2_4_8"],
default="2",
tooltip="Number of passes per gated step. 2/4 = fixed. ramp_*: pass count increases as sampling approaches end (more smoothing where seams are most visible).",
),
io.Combo.Input(
id="blend",
options=["average", "window", "median"],
default="average",
tooltip="average: equal-weight mean. window: Hann-windowed weighting favoring each pass away from its patch boundaries. median: per-pixel median, rejects wraparound-outlier passes.",
),
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01,
tooltip="Interpolation between the natural-grid pred (0) and the averaged result (1).",
),
],
outputs=[io.Model.Output()],
)
@classmethod
def execute(cls, *, model, start_percent: float, end_percent: float, pattern: str, passes: str, blend: str, strength: float) -> io.NodeOutput:
if strength <= 0.0 or end_percent <= start_percent:
return io.NodeOutput(model)
P = cls.PATCH_SIZE
half = P // 2
shift_levels = [cls.SHIFTS_BY_PATTERN[(pattern, n)] for n in cls.RAMP_LEVELS[passes]]
if blend == "window":
window_tile_levels = [
torch.stack([cls._hann_tile((half - sy) % P, (half - sx) % P, P) for sy, sx in lst], dim=0)
for lst in shift_levels
]
else:
window_tile_levels = [None] * len(shift_levels)
m = model.clone()
model_sampling = m.get_model_object("model_sampling")
multiplier = float(model_sampling.multiplier)
start_t = float(model_sampling.percent_to_sigma(start_percent)) * multiplier
end_t = float(model_sampling.percent_to_sigma(end_percent)) * multiplier
edge_ramp_cache: dict = {}
def get_edge_ramp(H: int, W: int, device, dtype) -> torch.Tensor:
key = (H, W, device, dtype)
cached = edge_ramp_cache.get(key)
if cached is not None:
return cached
feather = cls.EDGE_FEATHER
ys = torch.minimum(torch.arange(H, device=device, dtype=torch.float32),
(H - 1) - torch.arange(H, device=device, dtype=torch.float32))
xs = torch.minimum(torch.arange(W, device=device, dtype=torch.float32),
(W - 1) - torch.arange(W, device=device, dtype=torch.float32))
y_mask = ((ys - P) / feather).clamp(0, 1)
x_mask = ((xs - P) / feather).clamp(0, 1)
ramp = (y_mask[:, None] * x_mask[None, :]).to(dtype)
edge_ramp_cache[key] = ramp
return ramp
def smoothing_wrapper(executor, *args, **kwargs):
x = args[0]
t = float(args[1][0])
pred = executor(*args, **kwargs)
if not (end_t <= t <= start_t):
return pred
# Pick shift-level by sigma phase across the gated range.
if len(shift_levels) == 1:
level_idx = 0
else:
phase = (start_t - t) / max(start_t - end_t, 1e-8)
level_idx = min(int(phase * len(shift_levels)), len(shift_levels) - 1)
shifts = shift_levels[level_idx]
window_tiles = window_tile_levels[level_idx]
preds = []
for sy, sx in shifts:
if sy == 0 and sx == 0:
preds.append(pred)
continue
x_rolled = torch.roll(x, shifts=(sy, sx), dims=(-2, -1))
pred_rolled = executor(x_rolled, *args[1:], **kwargs)
preds.append(torch.roll(pred_rolled, shifts=(-sy, -sx), dims=(-2, -1)))
stacked = torch.stack(preds, dim=0) # (N, B, C, H, W)
_, _, _, H, W = stacked.shape
if blend == "window":
N = stacked.shape[0]
tiles = window_tiles.to(device=stacked.device, dtype=stacked.dtype)
w = tiles.repeat(1, H // P, W // P)[:, :H, :W]
sum_w = w.sum(dim=0, keepdim=True)
w = torch.where(sum_w < 1e-3, torch.full_like(w, 1.0 / N), w / sum_w.clamp(min=1e-8))
avg = (stacked * w[:, None, None, :, :]).sum(dim=0)
elif blend == "median":
avg = torch.median(stacked, dim=0).values
else:
avg = stacked.mean(dim=0)
# Mask out the P-px wraparound contamination strip at each edge.
mask = get_edge_ramp(H, W, pred.device, pred.dtype)
return pred * (1.0 - mask * strength) + avg * (mask * strength)
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "hidream_o1_patch_seam_smoothing", smoothing_wrapper)
return io.NodeOutput(m)
class HiDreamO1Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyHiDreamO1LatentImage,
HiDreamO1ReferenceImages,
HiDreamO1PatchSeamSmoothing,
]
async def comfy_entrypoint() -> HiDreamO1Extension:
return HiDreamO1Extension()

View File

@ -300,6 +300,29 @@ class RescaleCFG:
m.set_model_sampler_cfg_function(rescale_cfg)
return (m, )
class ModelNoiseScale:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"noise_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 64.0, "step": 0.01,
"tooltip": "Absolute training noise scale. For example HiDream-O1 base: 8.0, dev: 7.5."}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, noise_scale):
m = model.clone()
original = m.model.model_sampling
ms = type(original)(m.model.model_config)
ms.set_parameters(shift=original.shift, multiplier=original.multiplier)
ms.set_noise_scale(noise_scale)
m.add_object_patch("model_sampling", ms)
return (m, )
class ModelComputeDtype:
SEARCH_ALIASES = ["model precision", "change dtype"]
@classmethod
@ -327,6 +350,7 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingSD3": ModelSamplingSD3,
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
"ModelSamplingFlux": ModelSamplingFlux,
"ModelNoiseScale": ModelNoiseScale,
"RescaleCFG": RescaleCFG,
"ModelComputeDtype": ModelComputeDtype,
}

View File

@ -2435,6 +2435,7 @@ async def init_builtin_extra_nodes():
"nodes_sam3.py",
"nodes_void.py",
"nodes_wandancer.py",
"nodes_hidream_o1.py",
]
import_failed = []