ComfyUI/comfy/ldm/hidream_o1/attention.py
Jukka Seppänen 8e53f001a4
feat: Support HiDream-O1-Image (CORE-187) (#13817)
* Initial HiDream01-image support

* Cleanup nodes

* Cleaner handling of empty placeholder models

* Remove snap_to_predefined, prefer tooltip for the trained resolutions

* Add model and block wrappers

* Fix shift tooltip

* Add node to work around the patch tile issue

Experimental, runs multiple passes with the patch grid offset and blends with various different methods.

* Qwen35 vision rotary_pos_emb cast fix

* Fix embedding layout type

* Some small optimizations

* Cleanup, don't need this fallback

* Prefix KV cache, cleanup

Bit of speed, reduce redundant code

* Get rid of redundant custom sampler, refactor noise scaling

Our existing lcm sampler is mathematically same, just added the missing options to it instead and a node to control them. Refactored the noise scaling and fix it for the stochastic samplers, add a generic node to control the initial noise scale.

* Update nodes_hidream_o1.py

* Fix some cache validation cases

* Keep existing sampling params

* Remove redundant video vision path

* Replace some numpy ops with torch

* Fx RoPE index for batch size > 1

* Prefer torch preprocessing

* Rename block_type to be compatible with existing patch nodes

* Fixes and tweaks
2026-05-11 20:35:53 -07:00

42 lines
2.0 KiB
Python

"""HiDream-O1 two-pass attention: tokens [0, ar_len) are causal, [ar_len, T)
attend full K/V. Splitting Q at the boundary avoids the (B, 1, T, T) additive
mask the general-purpose path would build (~500 MB at T~16K) and lets the
gen half hit the user's preferred backend via optimized_attention.
"""
import torch
import comfy.ops
from comfy.ldm.modules.attention import optimized_attention
def make_two_pass_attention(ar_len: int, transformer_options=None):
"""Build a two-pass attention callable. AR pass uses SDPA-causal directly, gen pass routes through optimized_attention.
The AR pass goes through SDPA directand bypasses wrappers, it is only ~1% of T at typical edit sizes.
"""
def two_pass_attention(q, k, v, heads, **kwargs):
B, H, T, D = q.shape
if T < k.shape[2]: # KV-cache hot path: Q is shorter than K/V (cached AR prefix is in K/V only), all fresh Q positions are in the gen region, single full-attention call
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
elif ar_len >= T:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
elif ar_len <= 0:
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
else:
out_ar = comfy.ops.scaled_dot_product_attention(
q[:, :, :ar_len], k[:, :, :ar_len], v[:, :, :ar_len],
attn_mask=None, dropout_p=0.0, is_causal=True,
)
out_gen = optimized_attention(
q[:, :, ar_len:], k, v, heads,
mask=None, skip_reshape=True, skip_output_reshape=True,
transformer_options=transformer_options,
)
out = torch.cat([out_ar, out_gen], dim=2)
return out.transpose(1, 2).reshape(B, T, H * D)
return two_pass_attention