"""HiDream-O1-Image transformer. Pixel-space DiT built on Qwen3-VL: the vision tower (Qwen35VisionModel) encodes ref images, the Qwen3-VL-8B decoder (Llama2_ with interleaved MRoPE) processes a unified text+image sequence, and 32x32 patch embed/unembed shims map raw RGB in and out of LLM hidden space. The Qwen3-VL deepstack mergers go unused — their weights are dropped at load. """ from dataclasses import dataclass, field from typing import List, Optional import einops import torch import torch.nn as nn from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.text_encoders.llama import Llama2_ from comfy.text_encoders.qwen35 import Qwen35VisionModel from .attention import make_two_pass_attention IMAGE_TOKEN_ID = 151655 # Qwen3-VL <|image_pad|> TMS_TOKEN_ID = 151673 # HiDream-O1 <|tms_token|> PATCH_SIZE = 32 @dataclass class HiDreamO1TextConfig: """Qwen3-VL-8B text-decoder dims (matches public Qwen3-VL-8B-Instruct).""" vocab_size: int = 151936 hidden_size: int = 4096 intermediate_size: int = 12288 num_hidden_layers: int = 36 num_attention_heads: int = 32 num_key_value_heads: int = 8 head_dim: int = 128 max_position_embeddings: int = 128000 rms_norm_eps: float = 1e-6 rope_theta: float = 5000000.0 rope_scale: Optional[float] = None rope_dims: List[int] = field(default_factory=lambda: [24, 20, 20]) interleaved_mrope: bool = True transformer_type: str = "llama" rms_norm_add: bool = False mlp_activation: str = "silu" qkv_bias: bool = False q_norm: str = "gemma3" k_norm: str = "gemma3" final_norm: bool = True lm_head: bool = False stop_tokens: List[int] = field(default_factory=lambda: [151643, 151645]) QWEN3VL_VISION_DEFAULTS = dict( hidden_size=1152, num_heads=16, intermediate_size=4304, depth=27, patch_size=16, temporal_patch_size=2, in_channels=3, spatial_merge_size=2, num_position_embeddings=2304, deepstack_visual_indexes=(8, 16, 24), out_hidden_size=4096, # final merger projects directly into LLM hidden ) class BottleneckPatchEmbed(nn.Module): # 3072 -> 1024 -> 4096 (raw 32x32 RGB patch -> bottleneck -> LLM hidden). def __init__(self, patch_size=32, in_chans=3, pca_dim=1024, embed_dim=4096, bias=True, device=None, dtype=None, ops=None): super().__init__() self.proj1 = ops.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False, device=device, dtype=dtype) self.proj2 = ops.Linear(pca_dim, embed_dim, bias=bias, device=device, dtype=dtype) def forward(self, x): return self.proj2(self.proj1(x)) class FinalLayer(nn.Module): # 4096 -> 3072 (LLM hidden -> flat pixel patch). def __init__(self, hidden_size, patch_size=32, out_channels=3, device=None, dtype=None, ops=None): super().__init__() self.linear = ops.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, device=device, dtype=dtype) def forward(self, x): return self.linear(x) class HiDreamO1Transformer(nn.Module): """HiDream-O1 unified pixel-level transformer.""" def __init__(self, image_model=None, dtype=None, device=None, operations=None, text_config_overrides=None, vision_config_overrides=None, **kwargs): super().__init__() self.dtype = dtype text_cfg = HiDreamO1TextConfig(**(text_config_overrides or {})) vision_cfg = dict(QWEN3VL_VISION_DEFAULTS) if vision_config_overrides: vision_cfg.update(vision_config_overrides) vision_cfg["out_hidden_size"] = text_cfg.hidden_size self.text_config = text_cfg self.vision_config = vision_cfg self.hidden_size = text_cfg.hidden_size self.patch_size = PATCH_SIZE self.in_channels = 3 self.tms_token_id = TMS_TOKEN_ID self.visual = Qwen35VisionModel(vision_cfg, device=device, dtype=dtype, ops=operations) self.language_model = Llama2_(text_cfg, device=device, dtype=dtype, ops=operations) self.t_embedder1 = TimestepEmbedder( text_cfg.hidden_size, device=device, dtype=dtype, operations=operations, ) self.x_embedder = BottleneckPatchEmbed( patch_size=self.patch_size, in_chans=self.in_channels, pca_dim=text_cfg.hidden_size // 4, embed_dim=text_cfg.hidden_size, bias=True, device=device, dtype=dtype, ops=operations, ) self.final_layer2 = FinalLayer( text_cfg.hidden_size, patch_size=self.patch_size, out_channels=self.in_channels, device=device, dtype=dtype, ops=operations, ) def forward(self, x, timesteps, context=None, transformer_options={}, input_ids=None, attention_mask=None, position_ids=None, token_types=None, vinput_mask=None, ref_pixel_values=None, ref_image_grid_thw=None, ref_patches=None, **kwargs): """Returns flow-match velocity (x - x_pred) / sigma""" if input_ids is None or position_ids is None: raise ValueError("HiDreamO1Transformer requires input_ids and position_ids in conditioning") B, _, H, W = x.shape h_p, w_p = H // self.patch_size, W // self.patch_size tgt_image_len = h_p * w_p z = einops.rearrange( x, 'B C (H p1) (W p2) -> B (H W) (C p1 p2)', p1=self.patch_size, p2=self.patch_size, ) vinputs = torch.cat([z, ref_patches.to(z.dtype)], dim=1) if ref_patches is not None else z if input_ids.dim() == 3: input_ids = input_ids.squeeze(-1) input_ids = input_ids.long() inputs_embeds = self.language_model.embed_tokens(input_ids).to(x.dtype) if ref_pixel_values is not None and ref_image_grid_thw is not None: ref_pv = ref_pixel_values.to(inputs_embeds.device) ref_grid = ref_image_grid_thw.to(inputs_embeds.device).long() # Refs are model-level (same for cond/uncond), wrapped with a leading batch dim by extra_conds; [0] always recovers them. if ref_pv.dim() == 3: ref_pv = ref_pv[0] if ref_grid.dim() == 3: ref_grid = ref_grid[0] image_embeds = self.visual(ref_pv, ref_grid).to(inputs_embeds.dtype) image_mask = (input_ids == IMAGE_TOKEN_ID) if image_mask[0].sum().item() != image_embeds.shape[0]: raise ValueError( f"Image-token count {image_mask[0].sum().item()} != ViT output count " f"{image_embeds.shape[0]}; check tokenizer/processor alignment." ) image_embeds_b = image_embeds.unsqueeze(0).expand(B, -1, -1).reshape(-1, image_embeds.shape[-1]) inputs_embeds = inputs_embeds.masked_scatter( image_mask.unsqueeze(-1).expand_as(inputs_embeds), image_embeds_b, ) sigma = timesteps.float() / 1000.0 t_pixeldit = 1.0 - sigma t_emb = self.t_embedder1(t_pixeldit * 1000, inputs_embeds.dtype) tms_mask_3d = (input_ids == self.tms_token_id).unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds = torch.where(tms_mask_3d, t_emb.unsqueeze(1).expand_as(inputs_embeds), inputs_embeds) vinputs_embedded = self.x_embedder(vinputs.to(inputs_embeds.dtype)) inputs_embeds = torch.cat([inputs_embeds, vinputs_embedded], dim=1) total_seq_len = inputs_embeds.shape[1] # AR (text) tokens are contiguous at the start, so (==0).sum() gives ar_len. if token_types is None: txt_seq_len = input_ids.shape[1] token_types = torch.zeros(B, total_seq_len, dtype=torch.long, device=x.device) token_types[:, txt_seq_len:] = 1 else: token_types = token_types.to(x.device) if token_types.dim() == 1: token_types = token_types.unsqueeze(0) if token_types.shape[0] == 1 and B > 1: token_types = token_types.expand(B, -1) ar_len = int((token_types[0] == 0).sum().item()) # position_ids may arrive as (3, T) or wrapped (1, 3, T) / (3, 1, T) by CONDRegular. position_ids = position_ids.to(x.device).long() if position_ids.dim() == 3: position_ids = position_ids[0] if position_ids.shape[1] == 3 else position_ids[:, 0] freqs_cis = self.language_model.compute_freqs_cis(position_ids, x.device) freqs_cis = tuple(t.to(x.dtype) for t in freqs_cis) two_pass_attn = make_two_pass_attention(ar_len) hidden_states = inputs_embeds for layer in self.language_model.layers: hidden_states, _ = layer( x=hidden_states, attention_mask=None, freqs_cis=freqs_cis, optimized_attention=two_pass_attn, past_key_value=None, ) if self.language_model.norm is not None: hidden_states = self.language_model.norm(hidden_states) x_pred = self.final_layer2(hidden_states) if vinput_mask is not None: vmask = vinput_mask.to(x.device).bool() if vmask.dim() == 1: vmask = vmask.unsqueeze(0) if vmask.shape[0] == 1 and B > 1: vmask = vmask.expand(B, -1) x_pred_tgt = x_pred[vmask].view(B, -1, x_pred.shape[-1])[:, :tgt_image_len] else: txt_seq_len = input_ids.shape[1] x_pred_tgt = x_pred[:, txt_seq_len:txt_seq_len + tgt_image_len] # fp32 final subtraction, bf16 here noticeably degrades samples. x_pred_img = einops.rearrange( x_pred_tgt, 'B (H W) (C p1 p2) -> B C (H p1) (W p2)', H=h_p, W=w_p, p1=self.patch_size, p2=self.patch_size, ) return (x.float() - x_pred_img.float()) / sigma.view(B, 1, 1, 1).clamp_min(1e-3)