From 974aab796d1dcd170bbac95ff3cc9b98feab058f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 10 May 2026 03:21:31 +0300 Subject: [PATCH 1/2] Initial HiDream01-image support --- comfy/k_diffusion/sampling.py | 45 +++++ comfy/latent_formats.py | 7 + comfy/ldm/hidream_o1/attention.py | 46 +++++ comfy/ldm/hidream_o1/conditioning.py | 269 +++++++++++++++++++++++++++ comfy/ldm/hidream_o1/model.py | 231 +++++++++++++++++++++++ comfy/ldm/hidream_o1/utils.py | 223 ++++++++++++++++++++++ comfy/model_base.py | 38 ++++ comfy/model_detection.py | 3 + comfy/model_sampling.py | 12 ++ comfy/samplers.py | 3 +- comfy/supported_models.py | 45 +++++ comfy/text_encoders/hidream_o1.py | 124 ++++++++++++ comfy/text_encoders/llama.py | 32 +++- comfy_extras/nodes_hidream_o1.py | 220 ++++++++++++++++++++++ nodes.py | 1 + 15 files changed, 1288 insertions(+), 11 deletions(-) create mode 100644 comfy/ldm/hidream_o1/attention.py create mode 100644 comfy/ldm/hidream_o1/conditioning.py create mode 100644 comfy/ldm/hidream_o1/model.py create mode 100644 comfy/ldm/hidream_o1/utils.py create mode 100644 comfy/text_encoders/hidream_o1.py create mode 100644 comfy_extras/nodes_hidream_o1.py diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c53ac4b2b..715c14fc0 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -264,6 +264,51 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff return x + +@torch.no_grad() +def sample_euler_flash_flowmatch(model, x, sigmas, extra_args=None, callback=None, disable=None, + s_noise=7.5, s_noise_end=None, noise_clip_std=2.5, + noise_sampler=None): + """HiDream-O1-Image-Dev "flash" sampler. + + Step: x_next = sigma_next * noise * s_noise_i + (1 - sigma_next) * denoised, + with noise clamped to noise_clip_std stddevs and s_noise_i linearly + interpolated from s_noise to s_noise_end across steps. Equivalent to + sample_lcm + CONST_SCALED_NOISE when s_noise_end is None and noise_clip_std + is 0. + """ + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + in_dtype = x.dtype + n_steps = max(1, len(sigmas) - 1) + s_start = float(s_noise) + s_end = float(s_noise if s_noise_end is None else s_noise_end) + for i in trange(n_steps, disable=disable): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + denoised = model(x, sigma * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised}) + if sigma_next == 0: + x = denoised.to(in_dtype) + continue + noise = noise_sampler(sigma, sigma_next) + if noise_clip_std > 0: + clip_val = noise_clip_std * noise.std() + noise = noise.clamp(min=-clip_val, max=clip_val) + # Linear interpolation start -> end across steps, matching upstream + # pipeline.py's noise_scale_schedule construction. + t = (i / (n_steps - 1)) if n_steps > 1 else 0.0 + s_noise_i = s_start + (s_end - s_start) * t + # Match upstream FlashFlowMatchEulerDiscreteScheduler.step: do the step + # math in fp32 to avoid bf16 accumulation drift across 28 steps. + x = (sigma_next * noise.float() * s_noise_i + + (1.0 - sigma_next) * denoised.float()).to(in_dtype) + return x + + @torch.no_grad() def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 91bebed3d..d527eec4a 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -792,6 +792,13 @@ class ZImagePixelSpace(ChromaRadiance): """ pass + +class HiDreamO1Pixel(ChromaRadiance): + """Pixel-space latent format for HiDream-O1. + No VAE — model patches/unpatches raw RGB internally with patch_size=32. + """ + pass + class CogVideoX(LatentFormat): """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b). diff --git a/comfy/ldm/hidream_o1/attention.py b/comfy/ldm/hidream_o1/attention.py new file mode 100644 index 000000000..8b2bb64e5 --- /dev/null +++ b/comfy/ldm/hidream_o1/attention.py @@ -0,0 +1,46 @@ +"""HiDream-O1 two-pass attention: tokens [0, ar_len) are causal, [ar_len, T) +attend full K/V. Splitting Q at the boundary avoids the (B, 1, T, T) additive +mask the general-purpose path would build (~500 MB at T~16K) and lets the +gen half hit the user's preferred backend via optimized_attention. +""" + +import torch + +import comfy.ops +from comfy.ldm.modules.attention import optimized_attention + + +def make_two_pass_attention(ar_len: int): + """Build a two-pass attention callable. AR pass uses SDPA-causal directly, gen pass routes through optimized_attention. + """ + def two_pass_attention(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if skip_reshape: + B, H, T, D = q.shape + else: + B, T, total_dim = q.shape + D = total_dim // heads + H = heads + q = q.view(B, T, H, D).transpose(1, 2) + k = k.view(B, T, H, D).transpose(1, 2) + v = v.view(B, T, H, D).transpose(1, 2) + + if ar_len >= T: + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) + elif ar_len <= 0: + out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True) + else: + out_ar = comfy.ops.scaled_dot_product_attention( + q[:, :, :ar_len], k[:, :, :ar_len], v[:, :, :ar_len], + attn_mask=None, dropout_p=0.0, is_causal=True, + ) + out_gen = optimized_attention( + q[:, :, ar_len:], k, v, heads, + mask=None, skip_reshape=True, skip_output_reshape=True, + ) + out = torch.cat([out_ar, out_gen], dim=2) + + if skip_output_reshape: + return out + return out.transpose(1, 2).reshape(B, T, H * D) + + return two_pass_attention diff --git a/comfy/ldm/hidream_o1/conditioning.py b/comfy/ldm/hidream_o1/conditioning.py new file mode 100644 index 000000000..dc9b18dd6 --- /dev/null +++ b/comfy/ldm/hidream_o1/conditioning.py @@ -0,0 +1,269 @@ +"""HiDream-O1 conditioning prep — ref-image dual path + extra_conds assembly. + +Each ref image goes through two paths: a 32x32 patchified stream concatenated +to the noised target, and a Qwen3-VL ViT path producing tokens that scatter +into input_ids at <|image_pad|> positions. +""" + +from typing import List, Tuple + +import einops +import numpy as np +import torch +from PIL import Image + +from .utils import (PATCH_SIZE, calculate_dimensions, cond_image_size, ref_max_size, resize_pilimage) + +# Qwen3-VL ViT preprocessing constants (preprocessor_config.json). +VIT_PATCH = 16 +VIT_MERGE = 2 +VIT_TEMPORAL_PATCH = 2 +VIT_IMAGE_MEAN = [0.5, 0.5, 0.5] +VIT_IMAGE_STD = [0.5, 0.5, 0.5] + + +def _process_vit_image(pil: Image.Image, device, dtype) -> Tuple[torch.Tensor, torch.Tensor]: + """Qwen3-VL ViT preprocessing: returns (flatten_patches, image_grid_thw).""" + arr = np.asarray(pil, dtype=np.float32) / 255.0 + img_t = torch.from_numpy(arr).permute(2, 0, 1).contiguous() + h, w = img_t.shape[-2:] + + # H/W must be multiples of patch*merge. + factor = VIT_PATCH * VIT_MERGE + h_bar = max(round(h / factor) * factor, factor) + w_bar = max(round(w / factor) * factor, factor) + if (h, w) != (h_bar, w_bar): + img_t = torch.nn.functional.interpolate( + img_t.unsqueeze(0), size=(h_bar, w_bar), mode="bilinear", align_corners=False, + ).squeeze(0) + + mean = torch.tensor(VIT_IMAGE_MEAN).view(3, 1, 1) + std = torch.tensor(VIT_IMAGE_STD).view(3, 1, 1) + normalized = (img_t - mean) / std + + grid_h = h_bar // VIT_PATCH + grid_w = w_bar // VIT_PATCH + grid_thw = torch.tensor([1, grid_h, grid_w], dtype=torch.long) + + # Stack 2 copies for the temporal_patch dim, then patchify. + pixel_values = normalized.unsqueeze(0).repeat(VIT_TEMPORAL_PATCH, 1, 1, 1) + patches = pixel_values.reshape( + 1, VIT_TEMPORAL_PATCH, 3, + grid_h // VIT_MERGE, VIT_MERGE, VIT_PATCH, + grid_w // VIT_MERGE, VIT_MERGE, VIT_PATCH, + ) + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_h * grid_w, + 3 * VIT_TEMPORAL_PATCH * VIT_PATCH * VIT_PATCH, + ) + return flatten_patches.to(device=device, dtype=dtype), grid_thw.to(device=device) + + +def prepare_ref_images( + ref_images: List[torch.Tensor], + target_h: int, + target_w: int, + device: torch.device, + dtype: torch.dtype, +): + """Build the dual-path tensors for K reference images at (target_h, target_w). + + Returns None for K=0, else a dict with ref_patches, + ref_pixel_values, ref_image_grid_thw, per_ref_vit_tokens, + per_ref_patch_grids. + """ + K = len(ref_images) + if K == 0: + return None + max_size = ref_max_size(max(target_h, target_w), K) + cis = cond_image_size(K) + + pils = [] + for img in ref_images: + arr = np.round(img[0].clamp(0, 1).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8) + pils.append(Image.fromarray(arr, "RGB")) + pils_resized = [resize_pilimage(p, max_size, PATCH_SIZE) for p in pils] + + # 32-patch path. + ref_patches_per = [] + per_ref_patch_grids = [] + for pil_r in pils_resized: + arr = np.asarray(pil_r, dtype=np.float32) / 255.0 + t = torch.from_numpy(arr).permute(2, 0, 1).contiguous() + t = (t - 0.5) / 0.5 # -> [-1, 1] + h_p, w_p = pil_r.height // PATCH_SIZE, pil_r.width // PATCH_SIZE + per_ref_patch_grids.append((h_p, w_p)) + patches = einops.rearrange( + t, "C (H p1) (W p2) -> (H W) (C p1 p2)", + p1=PATCH_SIZE, p2=PATCH_SIZE, + ) + ref_patches_per.append(patches) + ref_patches = torch.cat(ref_patches_per, dim=0).unsqueeze(0).to(device=device, dtype=dtype) + + # ViT path. + pils_vlm = [] + for pil_r in pils_resized: + cond_w, cond_h = calculate_dimensions(cis, pil_r.width / pil_r.height) + cond_w = max(cond_w, VIT_PATCH * VIT_MERGE) + cond_h = max(cond_h, VIT_PATCH * VIT_MERGE) + pils_vlm.append(pil_r.resize((cond_w, cond_h), resample=Image.LANCZOS)) + + pv_list, grid_list, per_ref_vit_tokens = [], [], [] + for pil_v in pils_vlm: + pv, grid_thw = _process_vit_image(pil_v, device, dtype) + pv_list.append(pv) + grid_list.append(grid_thw) + # Post-merge token count = number of <|image_pad|> tokens this image + # expands to in input_ids. + gh, gw = int(grid_thw[1].item()), int(grid_thw[2].item()) + per_ref_vit_tokens.append((gh // VIT_MERGE) * (gw // VIT_MERGE)) + + return { + "ref_patches": ref_patches, + "ref_pixel_values": torch.cat(pv_list, dim=0), + "ref_image_grid_thw": torch.stack(grid_list, dim=0), + "per_ref_vit_tokens": per_ref_vit_tokens, + "per_ref_patch_grids": per_ref_patch_grids, + } + + +def build_ref_input_ids( + text_input_ids: torch.Tensor, + per_ref_vit_tokens: List[int], + image_token_id: int, + vision_start_id: int, + vision_end_id: int, +): + """Splice [vision_start, image_pad*N, vision_end] blocks into input_ids + after the [im_start, user, \\n] prefix (matches upstream chat template). + """ + ids = text_input_ids[0].tolist() + inserted = [] + for n_pad in per_ref_vit_tokens: + inserted.extend([vision_start_id] + [image_token_id] * n_pad + [vision_end_id]) + new_ids = ids[:3] + inserted + ids[3:] # 3 = len([im_start, user, \n]) + return torch.tensor([new_ids], dtype=text_input_ids.dtype, device=text_input_ids.device) + + +def build_extra_conds( + text_input_ids: torch.Tensor, + noise: torch.Tensor, + ref_images: List[torch.Tensor] = None, + target_patch_size: int = 32, +): + """Assemble all conditioning tensors for HiDreamO1Transformer.forward: + input_ids (with ref-vision tokens spliced in for the edit/IP path), + position_ids (MRoPE), token_types, vinput_mask, plus the ref + dual-path tensors when refs are provided. + """ + from .utils import get_rope_index_fix_point + from comfy.text_encoders.hidream_o1 import ( + IMAGE_TOKEN_ID, VIDEO_TOKEN_ID, VISION_START_ID, VISION_END_ID, + ) + + if text_input_ids.dim() == 1: + text_input_ids = text_input_ids.unsqueeze(0) + text_input_ids = text_input_ids.long().to(noise.device) + B = noise.shape[0] + if text_input_ids.shape[0] == 1 and B > 1: + text_input_ids = text_input_ids.expand(B, -1) + + H, W = noise.shape[-2], noise.shape[-1] + h_p, w_p = H // target_patch_size, W // target_patch_size + image_len = h_p * w_p + image_grid_thw_tgt = torch.tensor( + [[1, h_p, w_p]], dtype=torch.long, device=text_input_ids.device, + ) + + out = {} + if ref_images: + ref = prepare_ref_images(ref_images, H, W, device=noise.device, dtype=noise.dtype) + text_input_ids = build_ref_input_ids( + text_input_ids, ref["per_ref_vit_tokens"], + IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID, + ) + new_txt_len = text_input_ids.shape[1] + + # Each ref's patchified stream gets a [vision_start, image_pad*N-1] + # block in the position-id stream after the noised target. + ref_grid_lengths = [hp * wp for (hp, wp) in ref["per_ref_patch_grids"]] + tgt_vision = torch.full((1, image_len), IMAGE_TOKEN_ID, + dtype=text_input_ids.dtype, device=text_input_ids.device) + tgt_vision[:, 0] = VISION_START_ID + ref_vision_blocks = [] + for rl in ref_grid_lengths: + blk = torch.full((1, rl), IMAGE_TOKEN_ID, + dtype=text_input_ids.dtype, device=text_input_ids.device) + blk[:, 0] = VISION_START_ID + ref_vision_blocks.append(blk) + ref_vision_cat = torch.cat([tgt_vision] + ref_vision_blocks, dim=1) + input_ids_pad = torch.cat([text_input_ids, ref_vision_cat], dim=-1) + total_ref_patches_len = sum(ref_grid_lengths) + total_len = new_txt_len + image_len + total_ref_patches_len + + # K (ViT, post-merge) + 1 (target) + K (ref-patches) image grids. + K = len(ref_images) + igthw_cond = ref["ref_image_grid_thw"].clone() + igthw_cond[:, 1] //= 2 + igthw_cond[:, 2] //= 2 + image_grid_thw_ref = torch.tensor( + [[1, hp, wp] for (hp, wp) in ref["per_ref_patch_grids"]], + dtype=torch.long, device=text_input_ids.device, + ) + igthw_all = torch.cat([ + igthw_cond.to(text_input_ids.device), + image_grid_thw_tgt, + image_grid_thw_ref, + ], dim=0) + position_ids, _ = get_rope_index_fix_point( + spatial_merge_size=1, + image_token_id=IMAGE_TOKEN_ID, video_token_id=VIDEO_TOKEN_ID, + vision_start_token_id=VISION_START_ID, + input_ids=input_ids_pad, image_grid_thw=igthw_all, + video_grid_thw=None, attention_mask=None, + skip_vision_start_token=[0] * K + [1] + [1] * K, + fix_point=4096, + ) + + # tms + target_image + ref_patches are all gen. + tms_pos = new_txt_len - 1 + token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device) + token_types[:, tms_pos:] = 1 + vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device) + vinput_mask[:, new_txt_len:] = True + + # Leading batch dim sidesteps CONDRegular.process_cond's + # repeat_to_batch_size truncation (which narrows dim 0 to B). + out["ref_pixel_values"] = ref["ref_pixel_values"].unsqueeze(0) + out["ref_image_grid_thw"] = ref["ref_image_grid_thw"].unsqueeze(0) + out["ref_patches"] = ref["ref_patches"] + else: + # T2I: text + noised target only. vision_start replaces the first + # image token (upstream pipeline.py:51). + txt_len = text_input_ids.shape[1] + total_len = txt_len + image_len + vision_tokens = torch.full((B, image_len), IMAGE_TOKEN_ID, + dtype=text_input_ids.dtype, device=text_input_ids.device) + vision_tokens[:, 0] = VISION_START_ID + input_ids_pad = torch.cat([text_input_ids, vision_tokens], dim=-1) + position_ids, _ = get_rope_index_fix_point( + spatial_merge_size=1, + image_token_id=IMAGE_TOKEN_ID, video_token_id=VIDEO_TOKEN_ID, + vision_start_token_id=VISION_START_ID, + input_ids=input_ids_pad, image_grid_thw=image_grid_thw_tgt, + video_grid_thw=None, attention_mask=None, + skip_vision_start_token=[1], + ) + token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device) + token_types[:, txt_len - 1:] = 1 + vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device) + vinput_mask[:, txt_len:] = True + + # Collapse position_ids batch and add a leading dim so CONDRegular's + # batch-resize doesn't truncate the 3-axis MRoPE dim. + out["input_ids"] = text_input_ids + out["position_ids"] = position_ids[:, 0].unsqueeze(0) + out["token_types"] = token_types + out["vinput_mask"] = vinput_mask + return out diff --git a/comfy/ldm/hidream_o1/model.py b/comfy/ldm/hidream_o1/model.py new file mode 100644 index 000000000..caf83082c --- /dev/null +++ b/comfy/ldm/hidream_o1/model.py @@ -0,0 +1,231 @@ +"""HiDream-O1-Image transformer. + +Pixel-space DiT built on Qwen3-VL: the vision tower (Qwen35VisionModel) +encodes ref images, the Qwen3-VL-8B decoder (Llama2_ with interleaved MRoPE) +processes a unified text+image sequence, and 32x32 patch embed/unembed +shims map raw RGB in and out of LLM hidden space. The Qwen3-VL deepstack +mergers go unused — their weights are dropped at load. +""" + +from dataclasses import dataclass, field +from typing import List, Optional + +import einops +import torch +import torch.nn as nn + +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder +from comfy.text_encoders.llama import Llama2_ +from comfy.text_encoders.qwen35 import Qwen35VisionModel + +from .attention import make_two_pass_attention + + +IMAGE_TOKEN_ID = 151655 # Qwen3-VL <|image_pad|> +TMS_TOKEN_ID = 151673 # HiDream-O1 <|tms_token|> +PATCH_SIZE = 32 + + +@dataclass +class HiDreamO1TextConfig: + """Qwen3-VL-8B text-decoder dims (matches public Qwen3-VL-8B-Instruct).""" + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 12288 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + head_dim: int = 128 + max_position_embeddings: int = 128000 + rms_norm_eps: float = 1e-6 + rope_theta: float = 5000000.0 + rope_scale: Optional[float] = None + rope_dims: List[int] = field(default_factory=lambda: [24, 20, 20]) + interleaved_mrope: bool = True + transformer_type: str = "llama" + rms_norm_add: bool = False + mlp_activation: str = "silu" + qkv_bias: bool = False + q_norm: str = "gemma3" + k_norm: str = "gemma3" + final_norm: bool = True + lm_head: bool = False + stop_tokens: List[int] = field(default_factory=lambda: [151643, 151645]) + + +QWEN3VL_VISION_DEFAULTS = dict( + hidden_size=1152, + num_heads=16, + intermediate_size=4304, + depth=27, + patch_size=16, + temporal_patch_size=2, + in_channels=3, + spatial_merge_size=2, + num_position_embeddings=2304, + deepstack_visual_indexes=(8, 16, 24), + out_hidden_size=4096, # final merger projects directly into LLM hidden +) + + +class BottleneckPatchEmbed(nn.Module): + # 3072 -> 1024 -> 4096 (raw 32x32 RGB patch -> bottleneck -> LLM hidden). + def __init__(self, patch_size=32, in_chans=3, pca_dim=1024, embed_dim=4096, bias=True, device=None, dtype=None, ops=None): + super().__init__() + self.proj1 = ops.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False, device=device, dtype=dtype) + self.proj2 = ops.Linear(pca_dim, embed_dim, bias=bias, device=device, dtype=dtype) + + def forward(self, x): + return self.proj2(self.proj1(x)) + + +class FinalLayer(nn.Module): + # 4096 -> 3072 (LLM hidden -> flat pixel patch). + def __init__(self, hidden_size, patch_size=32, out_channels=3, device=None, dtype=None, ops=None): + super().__init__() + self.linear = ops.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, + device=device, dtype=dtype) + + def forward(self, x): + return self.linear(x) + + +class HiDreamO1Transformer(nn.Module): + """HiDream-O1 unified pixel-level transformer.""" + + def __init__(self, image_model=None, dtype=None, device=None, operations=None, + text_config_overrides=None, vision_config_overrides=None, **kwargs): + super().__init__() + self.dtype = dtype + + text_cfg = HiDreamO1TextConfig(**(text_config_overrides or {})) + vision_cfg = dict(QWEN3VL_VISION_DEFAULTS) + if vision_config_overrides: + vision_cfg.update(vision_config_overrides) + vision_cfg["out_hidden_size"] = text_cfg.hidden_size + + self.text_config = text_cfg + self.vision_config = vision_cfg + self.hidden_size = text_cfg.hidden_size + self.patch_size = PATCH_SIZE + self.in_channels = 3 + self.tms_token_id = TMS_TOKEN_ID + + self.visual = Qwen35VisionModel(vision_cfg, device=device, dtype=dtype, ops=operations) + self.language_model = Llama2_(text_cfg, device=device, dtype=dtype, ops=operations) + self.t_embedder1 = TimestepEmbedder( + text_cfg.hidden_size, device=device, dtype=dtype, operations=operations, + ) + self.x_embedder = BottleneckPatchEmbed( + patch_size=self.patch_size, in_chans=self.in_channels, + pca_dim=text_cfg.hidden_size // 4, embed_dim=text_cfg.hidden_size, + bias=True, device=device, dtype=dtype, ops=operations, + ) + self.final_layer2 = FinalLayer( + text_cfg.hidden_size, patch_size=self.patch_size, + out_channels=self.in_channels, device=device, dtype=dtype, ops=operations, + ) + + def forward(self, x, timesteps, context=None, transformer_options={}, + input_ids=None, attention_mask=None, position_ids=None, + token_types=None, vinput_mask=None, + ref_pixel_values=None, ref_image_grid_thw=None, ref_patches=None, + **kwargs): + """Returns flow-match velocity (x - x_pred) / sigma""" + if input_ids is None or position_ids is None: + raise ValueError("HiDreamO1Transformer requires input_ids and position_ids in conditioning") + + B, _, H, W = x.shape + h_p, w_p = H // self.patch_size, W // self.patch_size + tgt_image_len = h_p * w_p + + z = einops.rearrange( + x, 'B C (H p1) (W p2) -> B (H W) (C p1 p2)', + p1=self.patch_size, p2=self.patch_size, + ) + vinputs = torch.cat([z, ref_patches.to(z.dtype)], dim=1) if ref_patches is not None else z + + if input_ids.dim() == 3: + input_ids = input_ids.squeeze(-1) + input_ids = input_ids.long() + inputs_embeds = self.language_model.embed_tokens(input_ids).to(x.dtype) + + if ref_pixel_values is not None and ref_image_grid_thw is not None: + ref_pv = ref_pixel_values.to(inputs_embeds.device) + ref_grid = ref_image_grid_thw.to(inputs_embeds.device).long() + # Refs are model-level (same for cond/uncond), wrapped with a leading batch dim by extra_conds; [0] always recovers them. + if ref_pv.dim() == 3: + ref_pv = ref_pv[0] + if ref_grid.dim() == 3: + ref_grid = ref_grid[0] + image_embeds = self.visual(ref_pv, ref_grid).to(inputs_embeds.dtype) + image_mask = (input_ids == IMAGE_TOKEN_ID) + if image_mask[0].sum().item() != image_embeds.shape[0]: + raise ValueError( + f"Image-token count {image_mask[0].sum().item()} != ViT output count " + f"{image_embeds.shape[0]}; check tokenizer/processor alignment." + ) + image_embeds_b = image_embeds.unsqueeze(0).expand(B, -1, -1).reshape(-1, image_embeds.shape[-1]) + inputs_embeds = inputs_embeds.masked_scatter( + image_mask.unsqueeze(-1).expand_as(inputs_embeds), image_embeds_b, + ) + + sigma = timesteps.float() / 1000.0 + t_pixeldit = 1.0 - sigma + t_emb = self.t_embedder1(t_pixeldit * 1000, inputs_embeds.dtype) + tms_mask_3d = (input_ids == self.tms_token_id).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = torch.where(tms_mask_3d, t_emb.unsqueeze(1).expand_as(inputs_embeds), inputs_embeds) + + vinputs_embedded = self.x_embedder(vinputs.to(inputs_embeds.dtype)) + inputs_embeds = torch.cat([inputs_embeds, vinputs_embedded], dim=1) + total_seq_len = inputs_embeds.shape[1] + + # AR (text) tokens are contiguous at the start, so (==0).sum() gives ar_len. + if token_types is None: + txt_seq_len = input_ids.shape[1] + token_types = torch.zeros(B, total_seq_len, dtype=torch.long, device=x.device) + token_types[:, txt_seq_len:] = 1 + else: + token_types = token_types.to(x.device) + if token_types.dim() == 1: + token_types = token_types.unsqueeze(0) + if token_types.shape[0] == 1 and B > 1: + token_types = token_types.expand(B, -1) + ar_len = int((token_types[0] == 0).sum().item()) + + # position_ids may arrive as (3, T) or wrapped (1, 3, T) / (3, 1, T) by CONDRegular. + position_ids = position_ids.to(x.device).long() + if position_ids.dim() == 3: + position_ids = position_ids[0] if position_ids.shape[1] == 3 else position_ids[:, 0] + freqs_cis = self.language_model.compute_freqs_cis(position_ids, x.device) + freqs_cis = tuple(t.to(x.dtype) for t in freqs_cis) + + two_pass_attn = make_two_pass_attention(ar_len) + hidden_states = inputs_embeds + for layer in self.language_model.layers: + hidden_states, _ = layer( + x=hidden_states, attention_mask=None, + freqs_cis=freqs_cis, optimized_attention=two_pass_attn, + past_key_value=None, + ) + if self.language_model.norm is not None: + hidden_states = self.language_model.norm(hidden_states) + + x_pred = self.final_layer2(hidden_states) + if vinput_mask is not None: + vmask = vinput_mask.to(x.device).bool() + if vmask.dim() == 1: + vmask = vmask.unsqueeze(0) + if vmask.shape[0] == 1 and B > 1: + vmask = vmask.expand(B, -1) + x_pred_tgt = x_pred[vmask].view(B, -1, x_pred.shape[-1])[:, :tgt_image_len] + else: + txt_seq_len = input_ids.shape[1] + x_pred_tgt = x_pred[:, txt_seq_len:txt_seq_len + tgt_image_len] + + # fp32 final subtraction, bf16 here noticeably degrades samples. + x_pred_img = einops.rearrange( + x_pred_tgt, 'B (H W) (C p1 p2) -> B C (H p1) (W p2)', + H=h_p, W=w_p, p1=self.patch_size, p2=self.patch_size, + ) + return (x.float() - x_pred_img.float()) / sigma.view(B, 1, 1, 1).clamp_min(1e-3) diff --git a/comfy/ldm/hidream_o1/utils.py b/comfy/ldm/hidream_o1/utils.py new file mode 100644 index 000000000..a6f13587e --- /dev/null +++ b/comfy/ldm/hidream_o1/utils.py @@ -0,0 +1,223 @@ +"""HiDream-O1 input-prep helpers: image/resolution math and unified-sequence +RoPE position-id assembly. The fix_point offset in get_rope_index_fix_point +lets the target image and patchified ref images share spatial RoPE positions +despite living at different sequence indices — same 2D image plane. +""" + +import math +from typing import Optional + +import torch +from PIL import Image + + +PREDEFINED_RESOLUTIONS = [ + (2048, 2048), + (2304, 1728), + (1728, 2304), + (2560, 1440), + (1440, 2560), + (2496, 1664), + (1664, 2496), + (3104, 1312), + (1312, 3104), + (2304, 1792), + (1792, 2304), +] + +PATCH_SIZE = 32 +CONDITION_IMAGE_SIZE = 384 # ViT-side base size for ref images + + +def find_closest_resolution(width, height): + """Closest (W, H) in PREDEFINED_RESOLUTIONS by aspect ratio.""" + img_ratio = width / height + best = None + min_diff = float("inf") + for w, h in PREDEFINED_RESOLUTIONS: + diff = abs(w / h - img_ratio) + if diff < min_diff: + min_diff = diff + best = (w, h) + return best + + +def resize_pilimage(pil_image, image_size, patch_size=16, resampler=Image.BICUBIC): + """Resize to fit image_size**2 area, patch-aligned, center-cropped. Pre-halves + with BOX filter while the image is still very large. + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX, + ) + + m = patch_size + width, height = pil_image.width, pil_image.height + s_max = image_size * image_size + scale = math.sqrt(s_max / (width * height)) + + candidates = [ + (round(width * scale) // m * m, round(height * scale) // m * m), + (round(width * scale) // m * m, math.floor(height * scale) // m * m), + (math.floor(width * scale) // m * m, round(height * scale) // m * m), + (math.floor(width * scale) // m * m, math.floor(height * scale) // m * m), + ] + candidates = sorted(candidates, key=lambda x: x[0] * x[1], reverse=True) + new_size = candidates[-1] + for c in candidates: + if c[0] * c[1] <= s_max: + new_size = c + break + + s1 = width / new_size[0] + s2 = height / new_size[1] + if s1 < s2: + pil_image = pil_image.resize([new_size[0], round(height / s1)], resample=resampler) + top = (round(height / s1) - new_size[1]) // 2 + pil_image = pil_image.crop((0, top, new_size[0], top + new_size[1])) + else: + pil_image = pil_image.resize([round(width / s2), new_size[1]], resample=resampler) + left = (round(width / s2) - new_size[0]) // 2 + pil_image = pil_image.crop((left, 0, left + new_size[0], new_size[1])) + return pil_image + + +def calculate_dimensions(max_size, ratio): + """(W, H) for an aspect ratio fitting in max_size**2 area, 32-aligned.""" + width = math.sqrt(max_size * max_size * ratio) + height = width / ratio + width = int(width / 32) * 32 + height = int(height / 32) * 32 + return width, height + + +def ref_max_size(target_max_dim, k): + """K-dependent ref-image max dim before patchifying.""" + if k == 1: + return target_max_dim + if k == 2: + return target_max_dim * 48 // 64 + if k <= 4: + return target_max_dim // 2 + if k <= 8: + return target_max_dim * 24 // 64 + return target_max_dim // 4 + + +def cond_image_size(k): + """K-dependent ViT-side image size.""" + if k <= 4: + return CONDITION_IMAGE_SIZE + if k <= 8: + return CONDITION_IMAGE_SIZE * 48 // 64 + return CONDITION_IMAGE_SIZE // 2 + + +def get_rope_index_fix_point( + spatial_merge_size: int, + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + skip_vision_start_token=None, + fix_point: int = 4096, +): + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], + dtype=input_ids.dtype, device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids_b in enumerate(total_input_ids): + input_ids_b = input_ids_b[attention_mask[i] == 1] + vision_start_indices = torch.argwhere(input_ids_b == vision_start_token_id).squeeze(1) + vision_tokens = input_ids_b[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids_b.tolist() + llm_pos_ids_list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t = image_grid_thw[image_index][0] + h = image_grid_thw[image_index][1] + w = image_grid_thw[image_index][2] + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t = video_grid_thw[video_index][0] + h = video_grid_thw[video_index][1] + w = video_grid_thw[video_index][2] + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t = t.item() + llm_grid_h = h.item() // spatial_merge_size + llm_grid_w = w.item() // spatial_merge_size + text_len = ed - st + text_len -= skip_vision_start_token[image_index - 1] + text_len = max(0, text_len) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + + if skip_vision_start_token[image_index - 1]: + if fix_point > 0: + fix_point = fix_point - st_idx + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fix_point + st_idx) + fix_point = 0 + else: + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1).expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype, + ) + return position_ids, mrope_position_deltas diff --git a/comfy/model_base.py b/comfy/model_base.py index 57a1e44d2..4d90e3d0e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -57,6 +57,7 @@ import comfy.ldm.cogvideo.model import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.ernie.model import comfy.ldm.sam3.detector +import comfy.ldm.hidream_o1.model import comfy.model_management import comfy.patcher_extension @@ -1665,6 +1666,43 @@ class ChromaRadiance(Chroma): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance) + +class HiDreamO1(BaseModel): + """HiDream-O1-Image: pixel-space DiT (no VAE). Refs from HiDreamO1ReferenceImages and tokens from the stub TE flow through + extra_conds; the heavy preprocessing lives in comfy.ldm.hidream_o1.conditioning.""" + PATCH_SIZE = 32 + + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.hidream_o1.model.HiDreamO1Transformer) + # HiDream-O1 trains with x_t = (1-t) x_clean + t * s_noise * noise + s_noise = float((model_config.sampling_settings or {}).get("s_noise", 8.0)) + + class _HiDreamO1Sampling( + comfy.model_sampling.ModelSamplingDiscreteFlow, + comfy.model_sampling.CONST_SCALED_NOISE, + ): + pass + ms = _HiDreamO1Sampling(model_config) + ms._s_noise = s_noise + self.model_sampling = ms + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + text_input_ids = kwargs.get("text_input_ids", None) + noise = kwargs.get("noise", None) + if text_input_ids is None or noise is None: + return out + from comfy.ldm.hidream_o1.conditioning import build_extra_conds + conds = build_extra_conds( + text_input_ids, noise, + ref_images=kwargs.get("hidream_o1_ref_images", None), + target_patch_size=self.PATCH_SIZE, + ) + for k, v in conds.items(): + out[k] = comfy.conds.CONDRegular(v) + return out + class ACEStep(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index d9b67dcdf..f7bbadf00 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -618,6 +618,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys return dit_config + if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1 + return {"image_model": "hidream_o1"} + if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream dit_config = {} dit_config["image_model"] = "hidream" diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index cf2b5db5f..1d27e763e 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -99,6 +99,18 @@ class CONST: sigma = reshape_sigma(sigma, latent.ndim) return latent / (1.0 - sigma) + +class CONST_SCALED_NOISE(CONST): + """CONST variant for flow-match models trained with x_t = (1-t)*x_clean + + t*s_noise*noise. Set _s_noise to the recipe value; default 1.0 == plain CONST. + """ + + _s_noise = 1.0 + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = reshape_sigma(sigma, noise.ndim) + return sigma * (self._s_noise * noise) + (1.0 - sigma) * latent_image + class X0(EPS): def calculate_denoised(self, sigma, model_output, model_input): return model_output diff --git a/comfy/samplers.py b/comfy/samplers.py index 0a4d062db..7d98c0aa7 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -723,7 +723,8 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] + "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece", + "euler_flash_flowmatch"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6a9613602..02bc50540 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -28,6 +28,7 @@ import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.ernie import comfy.text_encoders.cogvideo +import comfy.text_encoders.hidream_o1 from . import supported_models_base from . import latent_formats @@ -1449,6 +1450,49 @@ class ChromaRadiance(Chroma): def get_model(self, state_dict, prefix="", device=None): return model_base.ChromaRadiance(self, device=device) +class HiDreamO1(supported_models_base.BASE): + unet_config = { + "image_model": "hidream_o1", + } + + sampling_settings = { + "shift": 3.0, + "s_noise": 8.0, + } + + latent_format = latent_formats.HiDreamO1Pixel + memory_usage_factor = 0.6 + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + optimizations = {"fp8": False} + + def get_model(self, state_dict, prefix="", device=None): + return model_base.HiDreamO1(self, device=device) + + def process_unet_state_dict(self, state_dict): + # Drop unused Qwen3-VL deepstack merger weights; upstream discards them at inference. + for key in list(state_dict.keys()): + if "visual.deepstack_merger_list" in key: + del state_dict[key] + return state_dict + + def process_vae_state_dict(self, state_dict): + # Pixel-space model: inject sentinel so VAE construction picks PixelspaceConversionVAE. + return {"pixel_space_vae": torch.tensor(1.0)} + + def process_clip_state_dict(self, state_dict): + # Tokenizer-only TE: inject sentinel so load_state_dict_guess_config triggers CLIP init. + return {"_hidream_o1_te_sentinel": torch.zeros(1)} + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget( + comfy.text_encoders.hidream_o1.HiDreamO1Tokenizer, + comfy.text_encoders.hidream_o1.HiDreamO1TE, + ) + class ACEStep(supported_models_base.BASE): unet_config = { "audio_model": "ace", @@ -1986,6 +2030,7 @@ models = [ Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, + HiDreamO1, Chroma, ChromaRadiance, ACEStep, diff --git a/comfy/text_encoders/hidream_o1.py b/comfy/text_encoders/hidream_o1.py new file mode 100644 index 000000000..79b87e412 --- /dev/null +++ b/comfy/text_encoders/hidream_o1.py @@ -0,0 +1,124 @@ +"""HiDream-O1-Image tokenizer-only text encoder. + +The real Qwen3-VL backbone runs inside diffusion_model.* every step, so this +module just tokenizes the prompt into text_input_ids and emits them as +conditioning. Position ids / token_types / vinput_mask depend on target H/W +and are built later in model_base.HiDreamO1.extra_conds. +""" + +import os + +import torch +from transformers import Qwen2Tokenizer + +from comfy import sd1_clip + + +# Qwen3-VL special tokens +IM_START_ID = 151644 +IM_END_ID = 151645 +ASSISTANT_ID = 77091 +USER_ID = 872 +NEWLINE_ID = 198 +VISION_START_ID = 151652 +VISION_END_ID = 151653 +IMAGE_TOKEN_ID = 151655 +VIDEO_TOKEN_ID = 151656 +# HiDream-O1-specific tokens +BOI_TOKEN_ID = 151669 +BOR_TOKEN_ID = 151670 +EOR_TOKEN_ID = 151671 +BOT_TOKEN_ID = 151672 +TMS_TOKEN_ID = 151673 + + +class HiDreamO1QwenTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer" + ) + super().__init__( + tokenizer_path, + pad_with_end=False, + embedding_size=4096, + embedding_key="hidream_o1", + tokenizer_class=Qwen2Tokenizer, + has_start_token=False, + has_end_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=1, + pad_token=151643, + tokenizer_data=tokenizer_data, + ) + + +class HiDreamO1Tokenizer(sd1_clip.SD1Tokenizer): + """Wraps prompt in the upstream chat template ending with boi/tms markers. + Image tokens get spliced in at sample time once target H/W is known. + """ + + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, + name="hidream_o1", + tokenizer=HiDreamO1QwenTokenizer, + ) + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + text_tokens_dict = super().tokenize_with_weights( + text, return_word_ids=return_word_ids, disable_weights=True, **kwargs + ) + text_tuples = text_tokens_dict["hidream_o1"][0] + text_tuples = [t for t in text_tuples if int(t[0]) != 151643] # strip pad + + # <|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|boi|><|tms|> + def tok(tid): + return (tid, 1.0) if not return_word_ids else (tid, 1.0, 0) + + prefix = [tok(IM_START_ID), tok(USER_ID), tok(NEWLINE_ID)] + suffix = [ + tok(IM_END_ID), tok(NEWLINE_ID), + tok(IM_START_ID), tok(ASSISTANT_ID), tok(NEWLINE_ID), + tok(BOI_TOKEN_ID), tok(TMS_TOKEN_ID), + ] + full = prefix + list(text_tuples) + suffix + return {"hidream_o1": [full]} + + +class HiDreamO1TE(torch.nn.Module): + """Passthrough TE: emits int token ids; the Qwen3-VL backbone in + diffusion_model.* does the actual encoding. + + dtypes advertises uint8 as a routing hint: supports_cast(cuda, uint8) + is False, so CLIP.__init__ downgrades load_device to CPU, which makes + CoreModelPatcher skip the VBAR allocator (it would fail on a zero-param TE). + """ + + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__() + self.dtypes = {torch.uint8} + self.device = torch.device("cpu") if device is None else torch.device(device) + + def encode_token_weights(self, token_weight_pairs): + tok_pairs = token_weight_pairs["hidream_o1"][0] + ids = [int(t[0]) for t in tok_pairs] + input_ids = torch.tensor([ids], dtype=torch.long) + # Surrogate keeps the cross_attn slot non-empty for CONDITIONING + # plumbing; the model reads text_input_ids out of `extra` instead. + cross_attn = input_ids.unsqueeze(-1).to(torch.float32) + extra = {"text_input_ids": input_ids} + return cross_attn, None, extra + + def load_sd(self, sd): + return [] + + def get_sd(self): + return {} + + def reset_clip_options(self): + pass + + def set_clip_options(self, options): + pass diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index a34c41144..5087228ca 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -397,7 +397,7 @@ class RMSNorm(nn.Module): -def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None): +def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None, interleaved_mrope=False): if not isinstance(theta, list): theta = [theta] @@ -415,16 +415,27 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - if rope_dims is not None and position_ids.shape[0] > 1: - mrope_section = rope_dims * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + if rope_dims is not None and position_ids.shape[0] > 1 and interleaved_mrope: + # Qwen3-VL interleaved MRoPE: T-freqs by default, H/W replace every 3rd dim. + freqs_inter = freqs[0].clone() + for axis_idx, offset in ((1, 1), (2, 2)): + length = rope_dims[axis_idx] * 3 + idx = slice(offset, length, 3) + freqs_inter[..., idx] = freqs[axis_idx, ..., idx] + emb = torch.cat((freqs_inter, freqs_inter), dim=-1) + cos = emb.cos().unsqueeze(0) + sin = emb.sin().unsqueeze(0) else: - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + if rope_dims is not None and position_ids.shape[0] > 1: + mrope_section = rope_dims * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + else: + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) sin_split = sin.shape[-1] // 2 out.append((cos, sin[..., : sin_split], -sin[..., sin_split :])) @@ -689,6 +700,7 @@ class Llama2_(nn.Module): self.config.rope_theta, self.config.rope_scale, self.config.rope_dims, + interleaved_mrope=getattr(self.config, "interleaved_mrope", False), device=device) def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None): diff --git a/comfy_extras/nodes_hidream_o1.py b/comfy_extras/nodes_hidream_o1.py new file mode 100644 index 000000000..389e6d3fc --- /dev/null +++ b/comfy_extras/nodes_hidream_o1.py @@ -0,0 +1,220 @@ +from typing_extensions import override + +import torch + +import comfy.model_management +import node_helpers +from comfy_api.latest import ComfyExtension, io + +from comfy.ldm.hidream_o1.utils import find_closest_resolution + + +class EmptyHiDreamO1LatentImage(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyHiDreamO1LatentImage", + category="latent/hidream_o1", + description=( + "Empty pixel-space latent for HiDream-O1-Image. When " + "snap_to_predefined is on, dimensions are matched (by aspect " + "ratio) to the upstream HiDream-O1 PREDEFINED_RESOLUTIONS list." + ), + inputs=[ + io.Int.Input(id="width", default=2048, min=64, max=4096, step=32), + io.Int.Input(id="height", default=2048, min=64, max=4096, step=32), + io.Int.Input(id="batch_size", default=1, min=1, max=64), + io.Boolean.Input( + id="snap_to_predefined", + default=True, + tooltip=( + "Snap (W, H) to the closest aspect ratio in HiDream-O1's " + "PREDEFINED_RESOLUTIONS table for best parity with the " + "upstream CLI. Disable for arbitrary 32-aligned sizes." + ), + ), + ], + outputs=[io.Latent().Output()], + ) + + @classmethod + def execute(cls, *, width: int, height: int, batch_size: int = 1, + snap_to_predefined: bool = True) -> io.NodeOutput: + if snap_to_predefined: + sw, sh = find_closest_resolution(width, height) + width, height = sw, sh + width = (width // 32) * 32 + height = (height // 32) * 32 + latent = torch.zeros( + (batch_size, 3, height, width), + device=comfy.model_management.intermediate_device(), + ) + return io.NodeOutput({"samples": latent}) + + +class HiDreamO1ReferenceImages(io.ComfyNode): + """Attach reference images to both positive and negative conditioning. + + Refs are model-level inputs, not per-prompt CONDITIONING — they must ride + on both CFG branches, otherwise CFG amplifies "with-refs vs no-refs" + instead of "edit prompt vs empty prompt with same refs" and saturation + blows out. + """ + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="HiDreamO1ReferenceImages", + category="conditioning/hidream_o1", + description=( + "Attach 1-10 reference images to BOTH positive and negative " + "conditioning for HiDream-O1 edit (K=1) or subject-driven " + "personalization (K=2..10). Refs must ride on both CFG " + "branches; this node enforces that." + ), + inputs=[ + io.Conditioning.Input(id="positive"), + io.Conditioning.Input(id="negative"), + io.Autogrow.Input( + "images", + template=io.Autogrow.TemplateNames( + io.Image.Input("image"), + names=[f"image_{i}" for i in range(1, 11)], + min=1, + ), + tooltip=( + "Reference images. K=1 -> instruction edit; " + "K=2..10 -> subject-driven personalization." + ), + ), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, *, positive, negative, images: io.Autogrow.Type) -> io.NodeOutput: + # Numeric-suffix order; alphabetic sort would give image_1, image_10, image_2, ... + refs = [images[f"image_{i}"] for i in range(1, 11) if f"image_{i}" in images] + positive = node_helpers.conditioning_set_values( + positive, {"hidream_o1_ref_images": refs}, + ) + negative = node_helpers.conditioning_set_values( + negative, {"hidream_o1_ref_images": refs}, + ) + return io.NodeOutput(positive, negative) + + +class HiDreamO1Sampling(io.ComfyNode): + """Adjust HiDream-O1's flow-match sigma shift and noise scale together.""" + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="HiDreamO1Sampling", + category="advanced/model/hidream_o1", + description=( + "Patch HiDream-O1's sigma shift and noise scaling factor. " + "Full recipe: shift=3.0, s_noise=8.0. " + "Dev/flash recipe: shift=1.0, s_noise=7.5." + ), + inputs=[ + io.Model.Input(id="model"), + io.Float.Input( + id="shift", default=3.0, min=0.0, max=100.0, step=0.01, + tooltip="Flow-match sigma shift. 3.0 for full, 1.0 for dev.", + ), + io.Float.Input( + id="s_noise", default=8.0, min=0.0, max=64.0, step=0.1, + tooltip=( + "HiDream-O1 noise scale (CONST_SCALED_NOISE._s_noise). " + "8.0 for full, 7.5 for dev/flash." + ), + ), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def execute(cls, *, model, shift: float, s_noise: float) -> io.NodeOutput: + import comfy.model_sampling + m = model.clone() + + class _HiDreamO1SamplingPatched( + comfy.model_sampling.ModelSamplingDiscreteFlow, + comfy.model_sampling.CONST_SCALED_NOISE, + ): + pass + + ms = _HiDreamO1SamplingPatched(m.model.model_config) + ms.set_parameters(shift=float(shift), multiplier=1000) + ms._s_noise = float(s_noise) + m.add_object_patch("model_sampling", ms) + return io.NodeOutput(m) + + +class SamplerEulerFlashFlowmatch(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerEulerFlashFlowmatch", + category="sampling/custom_sampling/samplers", + description=( + "HiDream-O1 dev/flash sampler with tunable per-step noise " + "schedule (start, end, clip_std). Wire into SamplerCustom." + ), + inputs=[ + io.Float.Input( + id="s_noise_start", default=7.5, min=0.0, max=64.0, step=0.1, + tooltip="Per-step noise scale at the first sampling step.", + ), + io.Float.Input( + id="s_noise_end", default=7.5, min=0.0, max=64.0, step=0.1, + tooltip=( + "Per-step noise scale at the last step. Equals " + "s_noise_start for upstream-default behaviour; differ " + "to ramp the noise across the trajectory." + ), + ), + io.Float.Input( + id="noise_clip_std", default=2.5, min=0.0, max=10.0, step=0.1, + tooltip=( + "Clamp per-step noise to +/- N*std. 0 disables. " + "Upstream dev recipe: 2.5." + ), + ), + ], + outputs=[io.Sampler.Output()], + ) + + @classmethod + def execute(cls, *, s_noise_start: float, s_noise_end: float, + noise_clip_std: float) -> io.NodeOutput: + import comfy.samplers + import comfy.k_diffusion.sampling + sampler = comfy.samplers.KSAMPLER( + comfy.k_diffusion.sampling.sample_euler_flash_flowmatch, + extra_options={ + "s_noise": float(s_noise_start), + "s_noise_end": float(s_noise_end), + "noise_clip_std": float(noise_clip_std), + }, + ) + return io.NodeOutput(sampler) + + +class HiDreamO1Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyHiDreamO1LatentImage, + HiDreamO1ReferenceImages, + HiDreamO1Sampling, + SamplerEulerFlashFlowmatch, + ] + + +async def comfy_entrypoint() -> HiDreamO1Extension: + return HiDreamO1Extension() diff --git a/nodes.py b/nodes.py index 5755f0bb8..89a921f05 100644 --- a/nodes.py +++ b/nodes.py @@ -2434,6 +2434,7 @@ async def init_builtin_extra_nodes(): "nodes_frame_interpolation.py", "nodes_sam3.py", "nodes_void.py", + "nodes_hidream_o1.py", ] import_failed = [] From 8982726f6d3837fe33c9cd65eba82192a2a1fd3b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 10 May 2026 11:15:22 +0300 Subject: [PATCH 2/2] Cleanup nodes --- comfy_extras/nodes_hidream_o1.py | 60 +++++++++++--------------------- 1 file changed, 21 insertions(+), 39 deletions(-) diff --git a/comfy_extras/nodes_hidream_o1.py b/comfy_extras/nodes_hidream_o1.py index 389e6d3fc..27d22fb10 100644 --- a/comfy_extras/nodes_hidream_o1.py +++ b/comfy_extras/nodes_hidream_o1.py @@ -14,7 +14,8 @@ class EmptyHiDreamO1LatentImage(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="EmptyHiDreamO1LatentImage", - category="latent/hidream_o1", + display_name="Empty HiDream-O1 Latent Image", + category="latent/image", description=( "Empty pixel-space latent for HiDream-O1-Image. When " "snap_to_predefined is on, dimensions are matched (by aspect " @@ -40,7 +41,7 @@ class EmptyHiDreamO1LatentImage(io.ComfyNode): @classmethod def execute(cls, *, width: int, height: int, batch_size: int = 1, snap_to_predefined: bool = True) -> io.NodeOutput: - if snap_to_predefined: + if snap_to_predefined: #TODO: better way to handle this sw, sh = find_closest_resolution(width, height) width, height = sw, sh width = (width // 32) * 32 @@ -53,24 +54,17 @@ class EmptyHiDreamO1LatentImage(io.ComfyNode): class HiDreamO1ReferenceImages(io.ComfyNode): - """Attach reference images to both positive and negative conditioning. - - Refs are model-level inputs, not per-prompt CONDITIONING — they must ride - on both CFG branches, otherwise CFG amplifies "with-refs vs no-refs" - instead of "edit prompt vs empty prompt with same refs" and saturation - blows out. - """ + """Attach reference images to both positive and negative conditioning.""" @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="HiDreamO1ReferenceImages", - category="conditioning/hidream_o1", + display_name="HiDream-O1 Reference Images", + category="conditioning/image", description=( - "Attach 1-10 reference images to BOTH positive and negative " - "conditioning for HiDream-O1 edit (K=1) or subject-driven " - "personalization (K=2..10). Refs must ride on both CFG " - "branches; this node enforces that." + "Attach 1-10 reference images to conditioning, one for edit instruction" + "or multiple for subject-driven personalization." ), inputs=[ io.Conditioning.Input(id="positive"), @@ -96,14 +90,9 @@ class HiDreamO1ReferenceImages(io.ComfyNode): @classmethod def execute(cls, *, positive, negative, images: io.Autogrow.Type) -> io.NodeOutput: - # Numeric-suffix order; alphabetic sort would give image_1, image_10, image_2, ... refs = [images[f"image_{i}"] for i in range(1, 11) if f"image_{i}" in images] - positive = node_helpers.conditioning_set_values( - positive, {"hidream_o1_ref_images": refs}, - ) - negative = node_helpers.conditioning_set_values( - negative, {"hidream_o1_ref_images": refs}, - ) + positive = node_helpers.conditioning_set_values(positive, {"hidream_o1_ref_images": refs}) + negative = node_helpers.conditioning_set_values(negative, {"hidream_o1_ref_images": refs}) return io.NodeOutput(positive, negative) @@ -114,23 +103,22 @@ class HiDreamO1Sampling(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="HiDreamO1Sampling", - category="advanced/model/hidream_o1", + display_name="HiDream-O1 Sampling", + category="advanced/model", description=( "Patch HiDream-O1's sigma shift and noise scaling factor. " - "Full recipe: shift=3.0, s_noise=8.0. " - "Dev/flash recipe: shift=1.0, s_noise=7.5." + "Base model defaults: shift=3.0, s_noise=8.0. " + "Dev/flash sampler defaults: shift=1.0, s_noise=7.5." ), inputs=[ io.Model.Input(id="model"), io.Float.Input( id="shift", default=3.0, min=0.0, max=100.0, step=0.01, - tooltip="Flow-match sigma shift. 3.0 for full, 1.0 for dev.", + tooltip="Flow-match sigma shift. Defaults: 3.0 for base, 1.0 for dev.", ), io.Float.Input( id="s_noise", default=8.0, min=0.0, max=64.0, step=0.1, - tooltip=( - "HiDream-O1 noise scale (CONST_SCALED_NOISE._s_noise). " - "8.0 for full, 7.5 for dev/flash." + tooltip=("HiDream-O1 noise scale (CONST_SCALED_NOISE). Defaults: 8.0 for base, 7.5 for dev/flash." ), ), ], @@ -160,11 +148,9 @@ class SamplerEulerFlashFlowmatch(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="SamplerEulerFlashFlowmatch", + display_name="Sampler Euler Flash Flowmatch", category="sampling/custom_sampling/samplers", - description=( - "HiDream-O1 dev/flash sampler with tunable per-step noise " - "schedule (start, end, clip_std). Wire into SamplerCustom." - ), + description=("HiDream-O1 dev/flash sampler with tunable per-step noise"), inputs=[ io.Float.Input( id="s_noise_start", default=7.5, min=0.0, max=64.0, step=0.1, @@ -173,17 +159,13 @@ class SamplerEulerFlashFlowmatch(io.ComfyNode): io.Float.Input( id="s_noise_end", default=7.5, min=0.0, max=64.0, step=0.1, tooltip=( - "Per-step noise scale at the last step. Equals " - "s_noise_start for upstream-default behaviour; differ " - "to ramp the noise across the trajectory." + "Per-step noise scale at the last step. Default: 7.5 for dev/flash. " + "Differ from s_noise_start to linearly ramp noise across steps." ), ), io.Float.Input( id="noise_clip_std", default=2.5, min=0.0, max=10.0, step=0.1, - tooltip=( - "Clamp per-step noise to +/- N*std. 0 disables. " - "Upstream dev recipe: 2.5." - ), + tooltip=("Clamp per-step noise to +/- N*std. 0 disables.") ), ], outputs=[io.Sampler.Output()],