This commit is contained in:
Jukka Seppänen 2026-05-10 04:15:31 -04:00 committed by GitHub
commit 5fea7f7fa8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1270 additions and 11 deletions

View File

@ -264,6 +264,51 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x return x
@torch.no_grad()
def sample_euler_flash_flowmatch(model, x, sigmas, extra_args=None, callback=None, disable=None,
s_noise=7.5, s_noise_end=None, noise_clip_std=2.5,
noise_sampler=None):
"""HiDream-O1-Image-Dev "flash" sampler.
Step: x_next = sigma_next * noise * s_noise_i + (1 - sigma_next) * denoised,
with noise clamped to noise_clip_std stddevs and s_noise_i linearly
interpolated from s_noise to s_noise_end across steps. Equivalent to
sample_lcm + CONST_SCALED_NOISE when s_noise_end is None and noise_clip_std
is 0.
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
in_dtype = x.dtype
n_steps = max(1, len(sigmas) - 1)
s_start = float(s_noise)
s_end = float(s_noise if s_noise_end is None else s_noise_end)
for i in trange(n_steps, disable=disable):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
denoised = model(x, sigma * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised})
if sigma_next == 0:
x = denoised.to(in_dtype)
continue
noise = noise_sampler(sigma, sigma_next)
if noise_clip_std > 0:
clip_val = noise_clip_std * noise.std()
noise = noise.clamp(min=-clip_val, max=clip_val)
# Linear interpolation start -> end across steps, matching upstream
# pipeline.py's noise_scale_schedule construction.
t = (i / (n_steps - 1)) if n_steps > 1 else 0.0
s_noise_i = s_start + (s_end - s_start) * t
# Match upstream FlashFlowMatchEulerDiscreteScheduler.step: do the step
# math in fp32 to avoid bf16 accumulation drift across 28 steps.
x = (sigma_next * noise.float() * s_noise_i
+ (1.0 - sigma_next) * denoised.float()).to(in_dtype)
return x
@torch.no_grad() @torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""

View File

@ -792,6 +792,13 @@ class ZImagePixelSpace(ChromaRadiance):
""" """
pass pass
class HiDreamO1Pixel(ChromaRadiance):
"""Pixel-space latent format for HiDream-O1.
No VAE model patches/unpatches raw RGB internally with patch_size=32.
"""
pass
class CogVideoX(LatentFormat): class CogVideoX(LatentFormat):
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b). """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).

View File

@ -0,0 +1,46 @@
"""HiDream-O1 two-pass attention: tokens [0, ar_len) are causal, [ar_len, T)
attend full K/V. Splitting Q at the boundary avoids the (B, 1, T, T) additive
mask the general-purpose path would build (~500 MB at T~16K) and lets the
gen half hit the user's preferred backend via optimized_attention.
"""
import torch
import comfy.ops
from comfy.ldm.modules.attention import optimized_attention
def make_two_pass_attention(ar_len: int):
"""Build a two-pass attention callable. AR pass uses SDPA-causal directly, gen pass routes through optimized_attention.
"""
def two_pass_attention(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
B, H, T, D = q.shape
else:
B, T, total_dim = q.shape
D = total_dim // heads
H = heads
q = q.view(B, T, H, D).transpose(1, 2)
k = k.view(B, T, H, D).transpose(1, 2)
v = v.view(B, T, H, D).transpose(1, 2)
if ar_len >= T:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
elif ar_len <= 0:
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True)
else:
out_ar = comfy.ops.scaled_dot_product_attention(
q[:, :, :ar_len], k[:, :, :ar_len], v[:, :, :ar_len],
attn_mask=None, dropout_p=0.0, is_causal=True,
)
out_gen = optimized_attention(
q[:, :, ar_len:], k, v, heads,
mask=None, skip_reshape=True, skip_output_reshape=True,
)
out = torch.cat([out_ar, out_gen], dim=2)
if skip_output_reshape:
return out
return out.transpose(1, 2).reshape(B, T, H * D)
return two_pass_attention

View File

@ -0,0 +1,269 @@
"""HiDream-O1 conditioning prep — ref-image dual path + extra_conds assembly.
Each ref image goes through two paths: a 32x32 patchified stream concatenated
to the noised target, and a Qwen3-VL ViT path producing tokens that scatter
into input_ids at <|image_pad|> positions.
"""
from typing import List, Tuple
import einops
import numpy as np
import torch
from PIL import Image
from .utils import (PATCH_SIZE, calculate_dimensions, cond_image_size, ref_max_size, resize_pilimage)
# Qwen3-VL ViT preprocessing constants (preprocessor_config.json).
VIT_PATCH = 16
VIT_MERGE = 2
VIT_TEMPORAL_PATCH = 2
VIT_IMAGE_MEAN = [0.5, 0.5, 0.5]
VIT_IMAGE_STD = [0.5, 0.5, 0.5]
def _process_vit_image(pil: Image.Image, device, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
"""Qwen3-VL ViT preprocessing: returns (flatten_patches, image_grid_thw)."""
arr = np.asarray(pil, dtype=np.float32) / 255.0
img_t = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
h, w = img_t.shape[-2:]
# H/W must be multiples of patch*merge.
factor = VIT_PATCH * VIT_MERGE
h_bar = max(round(h / factor) * factor, factor)
w_bar = max(round(w / factor) * factor, factor)
if (h, w) != (h_bar, w_bar):
img_t = torch.nn.functional.interpolate(
img_t.unsqueeze(0), size=(h_bar, w_bar), mode="bilinear", align_corners=False,
).squeeze(0)
mean = torch.tensor(VIT_IMAGE_MEAN).view(3, 1, 1)
std = torch.tensor(VIT_IMAGE_STD).view(3, 1, 1)
normalized = (img_t - mean) / std
grid_h = h_bar // VIT_PATCH
grid_w = w_bar // VIT_PATCH
grid_thw = torch.tensor([1, grid_h, grid_w], dtype=torch.long)
# Stack 2 copies for the temporal_patch dim, then patchify.
pixel_values = normalized.unsqueeze(0).repeat(VIT_TEMPORAL_PATCH, 1, 1, 1)
patches = pixel_values.reshape(
1, VIT_TEMPORAL_PATCH, 3,
grid_h // VIT_MERGE, VIT_MERGE, VIT_PATCH,
grid_w // VIT_MERGE, VIT_MERGE, VIT_PATCH,
)
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_h * grid_w,
3 * VIT_TEMPORAL_PATCH * VIT_PATCH * VIT_PATCH,
)
return flatten_patches.to(device=device, dtype=dtype), grid_thw.to(device=device)
def prepare_ref_images(
ref_images: List[torch.Tensor],
target_h: int,
target_w: int,
device: torch.device,
dtype: torch.dtype,
):
"""Build the dual-path tensors for K reference images at (target_h, target_w).
Returns None for K=0, else a dict with ref_patches,
ref_pixel_values, ref_image_grid_thw, per_ref_vit_tokens,
per_ref_patch_grids.
"""
K = len(ref_images)
if K == 0:
return None
max_size = ref_max_size(max(target_h, target_w), K)
cis = cond_image_size(K)
pils = []
for img in ref_images:
arr = np.round(img[0].clamp(0, 1).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8)
pils.append(Image.fromarray(arr, "RGB"))
pils_resized = [resize_pilimage(p, max_size, PATCH_SIZE) for p in pils]
# 32-patch path.
ref_patches_per = []
per_ref_patch_grids = []
for pil_r in pils_resized:
arr = np.asarray(pil_r, dtype=np.float32) / 255.0
t = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
t = (t - 0.5) / 0.5 # -> [-1, 1]
h_p, w_p = pil_r.height // PATCH_SIZE, pil_r.width // PATCH_SIZE
per_ref_patch_grids.append((h_p, w_p))
patches = einops.rearrange(
t, "C (H p1) (W p2) -> (H W) (C p1 p2)",
p1=PATCH_SIZE, p2=PATCH_SIZE,
)
ref_patches_per.append(patches)
ref_patches = torch.cat(ref_patches_per, dim=0).unsqueeze(0).to(device=device, dtype=dtype)
# ViT path.
pils_vlm = []
for pil_r in pils_resized:
cond_w, cond_h = calculate_dimensions(cis, pil_r.width / pil_r.height)
cond_w = max(cond_w, VIT_PATCH * VIT_MERGE)
cond_h = max(cond_h, VIT_PATCH * VIT_MERGE)
pils_vlm.append(pil_r.resize((cond_w, cond_h), resample=Image.LANCZOS))
pv_list, grid_list, per_ref_vit_tokens = [], [], []
for pil_v in pils_vlm:
pv, grid_thw = _process_vit_image(pil_v, device, dtype)
pv_list.append(pv)
grid_list.append(grid_thw)
# Post-merge token count = number of <|image_pad|> tokens this image
# expands to in input_ids.
gh, gw = int(grid_thw[1].item()), int(grid_thw[2].item())
per_ref_vit_tokens.append((gh // VIT_MERGE) * (gw // VIT_MERGE))
return {
"ref_patches": ref_patches,
"ref_pixel_values": torch.cat(pv_list, dim=0),
"ref_image_grid_thw": torch.stack(grid_list, dim=0),
"per_ref_vit_tokens": per_ref_vit_tokens,
"per_ref_patch_grids": per_ref_patch_grids,
}
def build_ref_input_ids(
text_input_ids: torch.Tensor,
per_ref_vit_tokens: List[int],
image_token_id: int,
vision_start_id: int,
vision_end_id: int,
):
"""Splice [vision_start, image_pad*N, vision_end] blocks into input_ids
after the [im_start, user, \\n] prefix (matches upstream chat template).
"""
ids = text_input_ids[0].tolist()
inserted = []
for n_pad in per_ref_vit_tokens:
inserted.extend([vision_start_id] + [image_token_id] * n_pad + [vision_end_id])
new_ids = ids[:3] + inserted + ids[3:] # 3 = len([im_start, user, \n])
return torch.tensor([new_ids], dtype=text_input_ids.dtype, device=text_input_ids.device)
def build_extra_conds(
text_input_ids: torch.Tensor,
noise: torch.Tensor,
ref_images: List[torch.Tensor] = None,
target_patch_size: int = 32,
):
"""Assemble all conditioning tensors for HiDreamO1Transformer.forward:
input_ids (with ref-vision tokens spliced in for the edit/IP path),
position_ids (MRoPE), token_types, vinput_mask, plus the ref
dual-path tensors when refs are provided.
"""
from .utils import get_rope_index_fix_point
from comfy.text_encoders.hidream_o1 import (
IMAGE_TOKEN_ID, VIDEO_TOKEN_ID, VISION_START_ID, VISION_END_ID,
)
if text_input_ids.dim() == 1:
text_input_ids = text_input_ids.unsqueeze(0)
text_input_ids = text_input_ids.long().to(noise.device)
B = noise.shape[0]
if text_input_ids.shape[0] == 1 and B > 1:
text_input_ids = text_input_ids.expand(B, -1)
H, W = noise.shape[-2], noise.shape[-1]
h_p, w_p = H // target_patch_size, W // target_patch_size
image_len = h_p * w_p
image_grid_thw_tgt = torch.tensor(
[[1, h_p, w_p]], dtype=torch.long, device=text_input_ids.device,
)
out = {}
if ref_images:
ref = prepare_ref_images(ref_images, H, W, device=noise.device, dtype=noise.dtype)
text_input_ids = build_ref_input_ids(
text_input_ids, ref["per_ref_vit_tokens"],
IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID,
)
new_txt_len = text_input_ids.shape[1]
# Each ref's patchified stream gets a [vision_start, image_pad*N-1]
# block in the position-id stream after the noised target.
ref_grid_lengths = [hp * wp for (hp, wp) in ref["per_ref_patch_grids"]]
tgt_vision = torch.full((1, image_len), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
tgt_vision[:, 0] = VISION_START_ID
ref_vision_blocks = []
for rl in ref_grid_lengths:
blk = torch.full((1, rl), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
blk[:, 0] = VISION_START_ID
ref_vision_blocks.append(blk)
ref_vision_cat = torch.cat([tgt_vision] + ref_vision_blocks, dim=1)
input_ids_pad = torch.cat([text_input_ids, ref_vision_cat], dim=-1)
total_ref_patches_len = sum(ref_grid_lengths)
total_len = new_txt_len + image_len + total_ref_patches_len
# K (ViT, post-merge) + 1 (target) + K (ref-patches) image grids.
K = len(ref_images)
igthw_cond = ref["ref_image_grid_thw"].clone()
igthw_cond[:, 1] //= 2
igthw_cond[:, 2] //= 2
image_grid_thw_ref = torch.tensor(
[[1, hp, wp] for (hp, wp) in ref["per_ref_patch_grids"]],
dtype=torch.long, device=text_input_ids.device,
)
igthw_all = torch.cat([
igthw_cond.to(text_input_ids.device),
image_grid_thw_tgt,
image_grid_thw_ref,
], dim=0)
position_ids, _ = get_rope_index_fix_point(
spatial_merge_size=1,
image_token_id=IMAGE_TOKEN_ID, video_token_id=VIDEO_TOKEN_ID,
vision_start_token_id=VISION_START_ID,
input_ids=input_ids_pad, image_grid_thw=igthw_all,
video_grid_thw=None, attention_mask=None,
skip_vision_start_token=[0] * K + [1] + [1] * K,
fix_point=4096,
)
# tms + target_image + ref_patches are all gen.
tms_pos = new_txt_len - 1
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
token_types[:, tms_pos:] = 1
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
vinput_mask[:, new_txt_len:] = True
# Leading batch dim sidesteps CONDRegular.process_cond's
# repeat_to_batch_size truncation (which narrows dim 0 to B).
out["ref_pixel_values"] = ref["ref_pixel_values"].unsqueeze(0)
out["ref_image_grid_thw"] = ref["ref_image_grid_thw"].unsqueeze(0)
out["ref_patches"] = ref["ref_patches"]
else:
# T2I: text + noised target only. vision_start replaces the first
# image token (upstream pipeline.py:51).
txt_len = text_input_ids.shape[1]
total_len = txt_len + image_len
vision_tokens = torch.full((B, image_len), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
vision_tokens[:, 0] = VISION_START_ID
input_ids_pad = torch.cat([text_input_ids, vision_tokens], dim=-1)
position_ids, _ = get_rope_index_fix_point(
spatial_merge_size=1,
image_token_id=IMAGE_TOKEN_ID, video_token_id=VIDEO_TOKEN_ID,
vision_start_token_id=VISION_START_ID,
input_ids=input_ids_pad, image_grid_thw=image_grid_thw_tgt,
video_grid_thw=None, attention_mask=None,
skip_vision_start_token=[1],
)
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
token_types[:, txt_len - 1:] = 1
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
vinput_mask[:, txt_len:] = True
# Collapse position_ids batch and add a leading dim so CONDRegular's
# batch-resize doesn't truncate the 3-axis MRoPE dim.
out["input_ids"] = text_input_ids
out["position_ids"] = position_ids[:, 0].unsqueeze(0)
out["token_types"] = token_types
out["vinput_mask"] = vinput_mask
return out

View File

@ -0,0 +1,231 @@
"""HiDream-O1-Image transformer.
Pixel-space DiT built on Qwen3-VL: the vision tower (Qwen35VisionModel)
encodes ref images, the Qwen3-VL-8B decoder (Llama2_ with interleaved MRoPE)
processes a unified text+image sequence, and 32x32 patch embed/unembed
shims map raw RGB in and out of LLM hidden space. The Qwen3-VL deepstack
mergers go unused their weights are dropped at load.
"""
from dataclasses import dataclass, field
from typing import List, Optional
import einops
import torch
import torch.nn as nn
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.text_encoders.llama import Llama2_
from comfy.text_encoders.qwen35 import Qwen35VisionModel
from .attention import make_two_pass_attention
IMAGE_TOKEN_ID = 151655 # Qwen3-VL <|image_pad|>
TMS_TOKEN_ID = 151673 # HiDream-O1 <|tms_token|>
PATCH_SIZE = 32
@dataclass
class HiDreamO1TextConfig:
"""Qwen3-VL-8B text-decoder dims (matches public Qwen3-VL-8B-Instruct)."""
vocab_size: int = 151936
hidden_size: int = 4096
intermediate_size: int = 12288
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
head_dim: int = 128
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 5000000.0
rope_scale: Optional[float] = None
rope_dims: List[int] = field(default_factory=lambda: [24, 20, 20])
interleaved_mrope: bool = True
transformer_type: str = "llama"
rms_norm_add: bool = False
mlp_activation: str = "silu"
qkv_bias: bool = False
q_norm: str = "gemma3"
k_norm: str = "gemma3"
final_norm: bool = True
lm_head: bool = False
stop_tokens: List[int] = field(default_factory=lambda: [151643, 151645])
QWEN3VL_VISION_DEFAULTS = dict(
hidden_size=1152,
num_heads=16,
intermediate_size=4304,
depth=27,
patch_size=16,
temporal_patch_size=2,
in_channels=3,
spatial_merge_size=2,
num_position_embeddings=2304,
deepstack_visual_indexes=(8, 16, 24),
out_hidden_size=4096, # final merger projects directly into LLM hidden
)
class BottleneckPatchEmbed(nn.Module):
# 3072 -> 1024 -> 4096 (raw 32x32 RGB patch -> bottleneck -> LLM hidden).
def __init__(self, patch_size=32, in_chans=3, pca_dim=1024, embed_dim=4096, bias=True, device=None, dtype=None, ops=None):
super().__init__()
self.proj1 = ops.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False, device=device, dtype=dtype)
self.proj2 = ops.Linear(pca_dim, embed_dim, bias=bias, device=device, dtype=dtype)
def forward(self, x):
return self.proj2(self.proj1(x))
class FinalLayer(nn.Module):
# 4096 -> 3072 (LLM hidden -> flat pixel patch).
def __init__(self, hidden_size, patch_size=32, out_channels=3, device=None, dtype=None, ops=None):
super().__init__()
self.linear = ops.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True,
device=device, dtype=dtype)
def forward(self, x):
return self.linear(x)
class HiDreamO1Transformer(nn.Module):
"""HiDream-O1 unified pixel-level transformer."""
def __init__(self, image_model=None, dtype=None, device=None, operations=None,
text_config_overrides=None, vision_config_overrides=None, **kwargs):
super().__init__()
self.dtype = dtype
text_cfg = HiDreamO1TextConfig(**(text_config_overrides or {}))
vision_cfg = dict(QWEN3VL_VISION_DEFAULTS)
if vision_config_overrides:
vision_cfg.update(vision_config_overrides)
vision_cfg["out_hidden_size"] = text_cfg.hidden_size
self.text_config = text_cfg
self.vision_config = vision_cfg
self.hidden_size = text_cfg.hidden_size
self.patch_size = PATCH_SIZE
self.in_channels = 3
self.tms_token_id = TMS_TOKEN_ID
self.visual = Qwen35VisionModel(vision_cfg, device=device, dtype=dtype, ops=operations)
self.language_model = Llama2_(text_cfg, device=device, dtype=dtype, ops=operations)
self.t_embedder1 = TimestepEmbedder(
text_cfg.hidden_size, device=device, dtype=dtype, operations=operations,
)
self.x_embedder = BottleneckPatchEmbed(
patch_size=self.patch_size, in_chans=self.in_channels,
pca_dim=text_cfg.hidden_size // 4, embed_dim=text_cfg.hidden_size,
bias=True, device=device, dtype=dtype, ops=operations,
)
self.final_layer2 = FinalLayer(
text_cfg.hidden_size, patch_size=self.patch_size,
out_channels=self.in_channels, device=device, dtype=dtype, ops=operations,
)
def forward(self, x, timesteps, context=None, transformer_options={},
input_ids=None, attention_mask=None, position_ids=None,
token_types=None, vinput_mask=None,
ref_pixel_values=None, ref_image_grid_thw=None, ref_patches=None,
**kwargs):
"""Returns flow-match velocity (x - x_pred) / sigma"""
if input_ids is None or position_ids is None:
raise ValueError("HiDreamO1Transformer requires input_ids and position_ids in conditioning")
B, _, H, W = x.shape
h_p, w_p = H // self.patch_size, W // self.patch_size
tgt_image_len = h_p * w_p
z = einops.rearrange(
x, 'B C (H p1) (W p2) -> B (H W) (C p1 p2)',
p1=self.patch_size, p2=self.patch_size,
)
vinputs = torch.cat([z, ref_patches.to(z.dtype)], dim=1) if ref_patches is not None else z
if input_ids.dim() == 3:
input_ids = input_ids.squeeze(-1)
input_ids = input_ids.long()
inputs_embeds = self.language_model.embed_tokens(input_ids).to(x.dtype)
if ref_pixel_values is not None and ref_image_grid_thw is not None:
ref_pv = ref_pixel_values.to(inputs_embeds.device)
ref_grid = ref_image_grid_thw.to(inputs_embeds.device).long()
# Refs are model-level (same for cond/uncond), wrapped with a leading batch dim by extra_conds; [0] always recovers them.
if ref_pv.dim() == 3:
ref_pv = ref_pv[0]
if ref_grid.dim() == 3:
ref_grid = ref_grid[0]
image_embeds = self.visual(ref_pv, ref_grid).to(inputs_embeds.dtype)
image_mask = (input_ids == IMAGE_TOKEN_ID)
if image_mask[0].sum().item() != image_embeds.shape[0]:
raise ValueError(
f"Image-token count {image_mask[0].sum().item()} != ViT output count "
f"{image_embeds.shape[0]}; check tokenizer/processor alignment."
)
image_embeds_b = image_embeds.unsqueeze(0).expand(B, -1, -1).reshape(-1, image_embeds.shape[-1])
inputs_embeds = inputs_embeds.masked_scatter(
image_mask.unsqueeze(-1).expand_as(inputs_embeds), image_embeds_b,
)
sigma = timesteps.float() / 1000.0
t_pixeldit = 1.0 - sigma
t_emb = self.t_embedder1(t_pixeldit * 1000, inputs_embeds.dtype)
tms_mask_3d = (input_ids == self.tms_token_id).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = torch.where(tms_mask_3d, t_emb.unsqueeze(1).expand_as(inputs_embeds), inputs_embeds)
vinputs_embedded = self.x_embedder(vinputs.to(inputs_embeds.dtype))
inputs_embeds = torch.cat([inputs_embeds, vinputs_embedded], dim=1)
total_seq_len = inputs_embeds.shape[1]
# AR (text) tokens are contiguous at the start, so (==0).sum() gives ar_len.
if token_types is None:
txt_seq_len = input_ids.shape[1]
token_types = torch.zeros(B, total_seq_len, dtype=torch.long, device=x.device)
token_types[:, txt_seq_len:] = 1
else:
token_types = token_types.to(x.device)
if token_types.dim() == 1:
token_types = token_types.unsqueeze(0)
if token_types.shape[0] == 1 and B > 1:
token_types = token_types.expand(B, -1)
ar_len = int((token_types[0] == 0).sum().item())
# position_ids may arrive as (3, T) or wrapped (1, 3, T) / (3, 1, T) by CONDRegular.
position_ids = position_ids.to(x.device).long()
if position_ids.dim() == 3:
position_ids = position_ids[0] if position_ids.shape[1] == 3 else position_ids[:, 0]
freqs_cis = self.language_model.compute_freqs_cis(position_ids, x.device)
freqs_cis = tuple(t.to(x.dtype) for t in freqs_cis)
two_pass_attn = make_two_pass_attention(ar_len)
hidden_states = inputs_embeds
for layer in self.language_model.layers:
hidden_states, _ = layer(
x=hidden_states, attention_mask=None,
freqs_cis=freqs_cis, optimized_attention=two_pass_attn,
past_key_value=None,
)
if self.language_model.norm is not None:
hidden_states = self.language_model.norm(hidden_states)
x_pred = self.final_layer2(hidden_states)
if vinput_mask is not None:
vmask = vinput_mask.to(x.device).bool()
if vmask.dim() == 1:
vmask = vmask.unsqueeze(0)
if vmask.shape[0] == 1 and B > 1:
vmask = vmask.expand(B, -1)
x_pred_tgt = x_pred[vmask].view(B, -1, x_pred.shape[-1])[:, :tgt_image_len]
else:
txt_seq_len = input_ids.shape[1]
x_pred_tgt = x_pred[:, txt_seq_len:txt_seq_len + tgt_image_len]
# fp32 final subtraction, bf16 here noticeably degrades samples.
x_pred_img = einops.rearrange(
x_pred_tgt, 'B (H W) (C p1 p2) -> B C (H p1) (W p2)',
H=h_p, W=w_p, p1=self.patch_size, p2=self.patch_size,
)
return (x.float() - x_pred_img.float()) / sigma.view(B, 1, 1, 1).clamp_min(1e-3)

View File

@ -0,0 +1,223 @@
"""HiDream-O1 input-prep helpers: image/resolution math and unified-sequence
RoPE position-id assembly. The fix_point offset in get_rope_index_fix_point
lets the target image and patchified ref images share spatial RoPE positions
despite living at different sequence indices same 2D image plane.
"""
import math
from typing import Optional
import torch
from PIL import Image
PREDEFINED_RESOLUTIONS = [
(2048, 2048),
(2304, 1728),
(1728, 2304),
(2560, 1440),
(1440, 2560),
(2496, 1664),
(1664, 2496),
(3104, 1312),
(1312, 3104),
(2304, 1792),
(1792, 2304),
]
PATCH_SIZE = 32
CONDITION_IMAGE_SIZE = 384 # ViT-side base size for ref images
def find_closest_resolution(width, height):
"""Closest (W, H) in PREDEFINED_RESOLUTIONS by aspect ratio."""
img_ratio = width / height
best = None
min_diff = float("inf")
for w, h in PREDEFINED_RESOLUTIONS:
diff = abs(w / h - img_ratio)
if diff < min_diff:
min_diff = diff
best = (w, h)
return best
def resize_pilimage(pil_image, image_size, patch_size=16, resampler=Image.BICUBIC):
"""Resize to fit image_size**2 area, patch-aligned, center-cropped. Pre-halves
with BOX filter while the image is still very large.
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX,
)
m = patch_size
width, height = pil_image.width, pil_image.height
s_max = image_size * image_size
scale = math.sqrt(s_max / (width * height))
candidates = [
(round(width * scale) // m * m, round(height * scale) // m * m),
(round(width * scale) // m * m, math.floor(height * scale) // m * m),
(math.floor(width * scale) // m * m, round(height * scale) // m * m),
(math.floor(width * scale) // m * m, math.floor(height * scale) // m * m),
]
candidates = sorted(candidates, key=lambda x: x[0] * x[1], reverse=True)
new_size = candidates[-1]
for c in candidates:
if c[0] * c[1] <= s_max:
new_size = c
break
s1 = width / new_size[0]
s2 = height / new_size[1]
if s1 < s2:
pil_image = pil_image.resize([new_size[0], round(height / s1)], resample=resampler)
top = (round(height / s1) - new_size[1]) // 2
pil_image = pil_image.crop((0, top, new_size[0], top + new_size[1]))
else:
pil_image = pil_image.resize([round(width / s2), new_size[1]], resample=resampler)
left = (round(width / s2) - new_size[0]) // 2
pil_image = pil_image.crop((left, 0, left + new_size[0], new_size[1]))
return pil_image
def calculate_dimensions(max_size, ratio):
"""(W, H) for an aspect ratio fitting in max_size**2 area, 32-aligned."""
width = math.sqrt(max_size * max_size * ratio)
height = width / ratio
width = int(width / 32) * 32
height = int(height / 32) * 32
return width, height
def ref_max_size(target_max_dim, k):
"""K-dependent ref-image max dim before patchifying."""
if k == 1:
return target_max_dim
if k == 2:
return target_max_dim * 48 // 64
if k <= 4:
return target_max_dim // 2
if k <= 8:
return target_max_dim * 24 // 64
return target_max_dim // 4
def cond_image_size(k):
"""K-dependent ViT-side image size."""
if k <= 4:
return CONDITION_IMAGE_SIZE
if k <= 8:
return CONDITION_IMAGE_SIZE * 48 // 64
return CONDITION_IMAGE_SIZE // 2
def get_rope_index_fix_point(
spatial_merge_size: int,
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
skip_vision_start_token=None,
fix_point: int = 4096,
):
if video_grid_thw is not None:
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
video_grid_thw[:, 0] = 1
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3, input_ids.shape[0], input_ids.shape[1],
dtype=input_ids.dtype, device=input_ids.device,
)
image_index, video_index = 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids_b in enumerate(total_input_ids):
input_ids_b = input_ids_b[attention_mask[i] == 1]
vision_start_indices = torch.argwhere(input_ids_b == vision_start_token_id).squeeze(1)
vision_tokens = input_ids_b[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids_b.tolist()
llm_pos_ids_list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t = image_grid_thw[image_index][0]
h = image_grid_thw[image_index][1]
w = image_grid_thw[image_index][2]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t = video_grid_thw[video_index][0]
h = video_grid_thw[video_index][1]
w = video_grid_thw[video_index][2]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t = t.item()
llm_grid_h = h.item() // spatial_merge_size
llm_grid_w = w.item() // spatial_merge_size
text_len = ed - st
text_len -= skip_vision_start_token[image_index - 1]
text_len = max(0, text_len)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
if skip_vision_start_token[image_index - 1]:
if fix_point > 0:
fix_point = fix_point - st_idx
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fix_point + st_idx)
fix_point = 0
else:
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1).expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas

View File

@ -58,6 +58,7 @@ import comfy.ldm.cogvideo.model
import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector import comfy.ldm.sam3.detector
import comfy.ldm.hidream_o1.model
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
@ -1690,6 +1691,43 @@ class ChromaRadiance(Chroma):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None): def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
class HiDreamO1(BaseModel):
"""HiDream-O1-Image: pixel-space DiT (no VAE). Refs from HiDreamO1ReferenceImages and tokens from the stub TE flow through
extra_conds; the heavy preprocessing lives in comfy.ldm.hidream_o1.conditioning."""
PATCH_SIZE = 32
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.hidream_o1.model.HiDreamO1Transformer)
# HiDream-O1 trains with x_t = (1-t) x_clean + t * s_noise * noise
s_noise = float((model_config.sampling_settings or {}).get("s_noise", 8.0))
class _HiDreamO1Sampling(
comfy.model_sampling.ModelSamplingDiscreteFlow,
comfy.model_sampling.CONST_SCALED_NOISE,
):
pass
ms = _HiDreamO1Sampling(model_config)
ms._s_noise = s_noise
self.model_sampling = ms
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
text_input_ids = kwargs.get("text_input_ids", None)
noise = kwargs.get("noise", None)
if text_input_ids is None or noise is None:
return out
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
conds = build_extra_conds(
text_input_ids, noise,
ref_images=kwargs.get("hidream_o1_ref_images", None),
target_patch_size=self.PATCH_SIZE,
)
for k, v in conds.items():
out[k] = comfy.conds.CONDRegular(v)
return out
class ACEStep(BaseModel): class ACEStep(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)

View File

@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config return dit_config
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
return {"image_model": "hidream_o1"}
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {} dit_config = {}
dit_config["image_model"] = "hidream" dit_config["image_model"] = "hidream"

View File

@ -99,6 +99,18 @@ class CONST:
sigma = reshape_sigma(sigma, latent.ndim) sigma = reshape_sigma(sigma, latent.ndim)
return latent / (1.0 - sigma) return latent / (1.0 - sigma)
class CONST_SCALED_NOISE(CONST):
"""CONST variant for flow-match models trained with x_t = (1-t)*x_clean +
t*s_noise*noise. Set _s_noise to the recipe value; default 1.0 == plain CONST.
"""
_s_noise = 1.0
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
return sigma * (self._s_noise * noise) + (1.0 - sigma) * latent_image
class X0(EPS): class X0(EPS):
def calculate_denoised(self, sigma, model_output, model_input): def calculate_denoised(self, sigma, model_output, model_input):
return model_output return model_output

View File

@ -723,7 +723,8 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece",
"euler_flash_flowmatch"]
class KSAMPLER(Sampler): class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@ -28,6 +28,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo import comfy.text_encoders.cogvideo
import comfy.text_encoders.hidream_o1
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
@ -1480,6 +1481,49 @@ class ChromaRadiance(Chroma):
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device) return model_base.ChromaRadiance(self, device=device)
class HiDreamO1(supported_models_base.BASE):
unet_config = {
"image_model": "hidream_o1",
}
sampling_settings = {
"shift": 3.0,
"s_noise": 8.0,
}
latent_format = latent_formats.HiDreamO1Pixel
memory_usage_factor = 0.6
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
optimizations = {"fp8": False}
def get_model(self, state_dict, prefix="", device=None):
return model_base.HiDreamO1(self, device=device)
def process_unet_state_dict(self, state_dict):
# Drop unused Qwen3-VL deepstack merger weights; upstream discards them at inference.
for key in list(state_dict.keys()):
if "visual.deepstack_merger_list" in key:
del state_dict[key]
return state_dict
def process_vae_state_dict(self, state_dict):
# Pixel-space model: inject sentinel so VAE construction picks PixelspaceConversionVAE.
return {"pixel_space_vae": torch.tensor(1.0)}
def process_clip_state_dict(self, state_dict):
# Tokenizer-only TE: inject sentinel so load_state_dict_guess_config triggers CLIP init.
return {"_hidream_o1_te_sentinel": torch.zeros(1)}
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(
comfy.text_encoders.hidream_o1.HiDreamO1Tokenizer,
comfy.text_encoders.hidream_o1.HiDreamO1TE,
)
class ACEStep(supported_models_base.BASE): class ACEStep(supported_models_base.BASE):
unet_config = { unet_config = {
"audio_model": "ace", "audio_model": "ace",
@ -2018,6 +2062,7 @@ models = [
Hunyuan3Dv2, Hunyuan3Dv2,
Hunyuan3Dv2_1, Hunyuan3Dv2_1,
HiDream, HiDream,
HiDreamO1,
Chroma, Chroma,
ChromaRadiance, ChromaRadiance,
ACEStep, ACEStep,

View File

@ -0,0 +1,124 @@
"""HiDream-O1-Image tokenizer-only text encoder.
The real Qwen3-VL backbone runs inside diffusion_model.* every step, so this
module just tokenizes the prompt into text_input_ids and emits them as
conditioning. Position ids / token_types / vinput_mask depend on target H/W
and are built later in model_base.HiDreamO1.extra_conds.
"""
import os
import torch
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
# Qwen3-VL special tokens
IM_START_ID = 151644
IM_END_ID = 151645
ASSISTANT_ID = 77091
USER_ID = 872
NEWLINE_ID = 198
VISION_START_ID = 151652
VISION_END_ID = 151653
IMAGE_TOKEN_ID = 151655
VIDEO_TOKEN_ID = 151656
# HiDream-O1-specific tokens
BOI_TOKEN_ID = 151669
BOR_TOKEN_ID = 151670
EOR_TOKEN_ID = 151671
BOT_TOKEN_ID = 151672
TMS_TOKEN_ID = 151673
class HiDreamO1QwenTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer"
)
super().__init__(
tokenizer_path,
pad_with_end=False,
embedding_size=4096,
embedding_key="hidream_o1",
tokenizer_class=Qwen2Tokenizer,
has_start_token=False,
has_end_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=1,
pad_token=151643,
tokenizer_data=tokenizer_data,
)
class HiDreamO1Tokenizer(sd1_clip.SD1Tokenizer):
"""Wraps prompt in the upstream chat template ending with boi/tms markers.
Image tokens get spliced in at sample time once target H/W is known.
"""
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
tokenizer_data=tokenizer_data,
name="hidream_o1",
tokenizer=HiDreamO1QwenTokenizer,
)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
text_tokens_dict = super().tokenize_with_weights(
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
)
text_tuples = text_tokens_dict["hidream_o1"][0]
text_tuples = [t for t in text_tuples if int(t[0]) != 151643] # strip pad
# <|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|boi|><|tms|>
def tok(tid):
return (tid, 1.0) if not return_word_ids else (tid, 1.0, 0)
prefix = [tok(IM_START_ID), tok(USER_ID), tok(NEWLINE_ID)]
suffix = [
tok(IM_END_ID), tok(NEWLINE_ID),
tok(IM_START_ID), tok(ASSISTANT_ID), tok(NEWLINE_ID),
tok(BOI_TOKEN_ID), tok(TMS_TOKEN_ID),
]
full = prefix + list(text_tuples) + suffix
return {"hidream_o1": [full]}
class HiDreamO1TE(torch.nn.Module):
"""Passthrough TE: emits int token ids; the Qwen3-VL backbone in
diffusion_model.* does the actual encoding.
dtypes advertises uint8 as a routing hint: supports_cast(cuda, uint8)
is False, so CLIP.__init__ downgrades load_device to CPU, which makes
CoreModelPatcher skip the VBAR allocator (it would fail on a zero-param TE).
"""
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = {torch.uint8}
self.device = torch.device("cpu") if device is None else torch.device(device)
def encode_token_weights(self, token_weight_pairs):
tok_pairs = token_weight_pairs["hidream_o1"][0]
ids = [int(t[0]) for t in tok_pairs]
input_ids = torch.tensor([ids], dtype=torch.long)
# Surrogate keeps the cross_attn slot non-empty for CONDITIONING
# plumbing; the model reads text_input_ids out of `extra` instead.
cross_attn = input_ids.unsqueeze(-1).to(torch.float32)
extra = {"text_input_ids": input_ids}
return cross_attn, None, extra
def load_sd(self, sd):
return []
def get_sd(self):
return {}
def reset_clip_options(self):
pass
def set_clip_options(self, options):
pass

View File

@ -397,7 +397,7 @@ class RMSNorm(nn.Module):
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None): def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None, interleaved_mrope=False):
if not isinstance(theta, list): if not isinstance(theta, list):
theta = [theta] theta = [theta]
@ -415,6 +415,17 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
if rope_dims is not None and position_ids.shape[0] > 1 and interleaved_mrope:
# Qwen3-VL interleaved MRoPE: T-freqs by default, H/W replace every 3rd dim.
freqs_inter = freqs[0].clone()
for axis_idx, offset in ((1, 1), (2, 2)):
length = rope_dims[axis_idx] * 3
idx = slice(offset, length, 3)
freqs_inter[..., idx] = freqs[axis_idx, ..., idx]
emb = torch.cat((freqs_inter, freqs_inter), dim=-1)
cos = emb.cos().unsqueeze(0)
sin = emb.sin().unsqueeze(0)
else:
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() cos = emb.cos()
sin = emb.sin() sin = emb.sin()
@ -689,6 +700,7 @@ class Llama2_(nn.Module):
self.config.rope_theta, self.config.rope_theta,
self.config.rope_scale, self.config.rope_scale,
self.config.rope_dims, self.config.rope_dims,
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
device=device) device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None): def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):

View File

@ -0,0 +1,202 @@
from typing_extensions import override
import torch
import comfy.model_management
import node_helpers
from comfy_api.latest import ComfyExtension, io
from comfy.ldm.hidream_o1.utils import find_closest_resolution
class EmptyHiDreamO1LatentImage(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EmptyHiDreamO1LatentImage",
display_name="Empty HiDream-O1 Latent Image",
category="latent/image",
description=(
"Empty pixel-space latent for HiDream-O1-Image. When "
"snap_to_predefined is on, dimensions are matched (by aspect "
"ratio) to the upstream HiDream-O1 PREDEFINED_RESOLUTIONS list."
),
inputs=[
io.Int.Input(id="width", default=2048, min=64, max=4096, step=32),
io.Int.Input(id="height", default=2048, min=64, max=4096, step=32),
io.Int.Input(id="batch_size", default=1, min=1, max=64),
io.Boolean.Input(
id="snap_to_predefined",
default=True,
tooltip=(
"Snap (W, H) to the closest aspect ratio in HiDream-O1's "
"PREDEFINED_RESOLUTIONS table for best parity with the "
"upstream CLI. Disable for arbitrary 32-aligned sizes."
),
),
],
outputs=[io.Latent().Output()],
)
@classmethod
def execute(cls, *, width: int, height: int, batch_size: int = 1,
snap_to_predefined: bool = True) -> io.NodeOutput:
if snap_to_predefined: #TODO: better way to handle this
sw, sh = find_closest_resolution(width, height)
width, height = sw, sh
width = (width // 32) * 32
height = (height // 32) * 32
latent = torch.zeros(
(batch_size, 3, height, width),
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput({"samples": latent})
class HiDreamO1ReferenceImages(io.ComfyNode):
"""Attach reference images to both positive and negative conditioning."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="HiDreamO1ReferenceImages",
display_name="HiDream-O1 Reference Images",
category="conditioning/image",
description=(
"Attach 1-10 reference images to conditioning, one for edit instruction"
"or multiple for subject-driven personalization."
),
inputs=[
io.Conditioning.Input(id="positive"),
io.Conditioning.Input(id="negative"),
io.Autogrow.Input(
"images",
template=io.Autogrow.TemplateNames(
io.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 11)],
min=1,
),
tooltip=(
"Reference images. K=1 -> instruction edit; "
"K=2..10 -> subject-driven personalization."
),
),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
@classmethod
def execute(cls, *, positive, negative, images: io.Autogrow.Type) -> io.NodeOutput:
refs = [images[f"image_{i}"] for i in range(1, 11) if f"image_{i}" in images]
positive = node_helpers.conditioning_set_values(positive, {"hidream_o1_ref_images": refs})
negative = node_helpers.conditioning_set_values(negative, {"hidream_o1_ref_images": refs})
return io.NodeOutput(positive, negative)
class HiDreamO1Sampling(io.ComfyNode):
"""Adjust HiDream-O1's flow-match sigma shift and noise scale together."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="HiDreamO1Sampling",
display_name="HiDream-O1 Sampling",
category="advanced/model",
description=(
"Patch HiDream-O1's sigma shift and noise scaling factor. "
"Base model defaults: shift=3.0, s_noise=8.0. "
"Dev/flash sampler defaults: shift=1.0, s_noise=7.5."
),
inputs=[
io.Model.Input(id="model"),
io.Float.Input(
id="shift", default=3.0, min=0.0, max=100.0, step=0.01,
tooltip="Flow-match sigma shift. Defaults: 3.0 for base, 1.0 for dev.",
),
io.Float.Input(
id="s_noise", default=8.0, min=0.0, max=64.0, step=0.1,
tooltip=("HiDream-O1 noise scale (CONST_SCALED_NOISE). Defaults: 8.0 for base, 7.5 for dev/flash."
),
),
],
outputs=[io.Model.Output()],
)
@classmethod
def execute(cls, *, model, shift: float, s_noise: float) -> io.NodeOutput:
import comfy.model_sampling
m = model.clone()
class _HiDreamO1SamplingPatched(
comfy.model_sampling.ModelSamplingDiscreteFlow,
comfy.model_sampling.CONST_SCALED_NOISE,
):
pass
ms = _HiDreamO1SamplingPatched(m.model.model_config)
ms.set_parameters(shift=float(shift), multiplier=1000)
ms._s_noise = float(s_noise)
m.add_object_patch("model_sampling", ms)
return io.NodeOutput(m)
class SamplerEulerFlashFlowmatch(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SamplerEulerFlashFlowmatch",
display_name="Sampler Euler Flash Flowmatch",
category="sampling/custom_sampling/samplers",
description=("HiDream-O1 dev/flash sampler with tunable per-step noise"),
inputs=[
io.Float.Input(
id="s_noise_start", default=7.5, min=0.0, max=64.0, step=0.1,
tooltip="Per-step noise scale at the first sampling step.",
),
io.Float.Input(
id="s_noise_end", default=7.5, min=0.0, max=64.0, step=0.1,
tooltip=(
"Per-step noise scale at the last step. Default: 7.5 for dev/flash. "
"Differ from s_noise_start to linearly ramp noise across steps."
),
),
io.Float.Input(
id="noise_clip_std", default=2.5, min=0.0, max=10.0, step=0.1,
tooltip=("Clamp per-step noise to +/- N*std. 0 disables.")
),
],
outputs=[io.Sampler.Output()],
)
@classmethod
def execute(cls, *, s_noise_start: float, s_noise_end: float,
noise_clip_std: float) -> io.NodeOutput:
import comfy.samplers
import comfy.k_diffusion.sampling
sampler = comfy.samplers.KSAMPLER(
comfy.k_diffusion.sampling.sample_euler_flash_flowmatch,
extra_options={
"s_noise": float(s_noise_start),
"s_noise_end": float(s_noise_end),
"noise_clip_std": float(noise_clip_std),
},
)
return io.NodeOutput(sampler)
class HiDreamO1Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyHiDreamO1LatentImage,
HiDreamO1ReferenceImages,
HiDreamO1Sampling,
SamplerEulerFlashFlowmatch,
]
async def comfy_entrypoint() -> HiDreamO1Extension:
return HiDreamO1Extension()

View File

@ -2435,6 +2435,7 @@ async def init_builtin_extra_nodes():
"nodes_sam3.py", "nodes_sam3.py",
"nodes_void.py", "nodes_void.py",
"nodes_wandancer.py", "nodes_wandancer.py",
"nodes_hidream_o1.py",
] ]
import_failed = [] import_failed = []