mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
Merge 8982726f6d into 95f6652ef5
This commit is contained in:
commit
5fea7f7fa8
@ -264,6 +264,51 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_euler_flash_flowmatch(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||||
|
s_noise=7.5, s_noise_end=None, noise_clip_std=2.5,
|
||||||
|
noise_sampler=None):
|
||||||
|
"""HiDream-O1-Image-Dev "flash" sampler.
|
||||||
|
|
||||||
|
Step: x_next = sigma_next * noise * s_noise_i + (1 - sigma_next) * denoised,
|
||||||
|
with noise clamped to noise_clip_std stddevs and s_noise_i linearly
|
||||||
|
interpolated from s_noise to s_noise_end across steps. Equivalent to
|
||||||
|
sample_lcm + CONST_SCALED_NOISE when s_noise_end is None and noise_clip_std
|
||||||
|
is 0.
|
||||||
|
"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
in_dtype = x.dtype
|
||||||
|
n_steps = max(1, len(sigmas) - 1)
|
||||||
|
s_start = float(s_noise)
|
||||||
|
s_end = float(s_noise if s_noise_end is None else s_noise_end)
|
||||||
|
for i in trange(n_steps, disable=disable):
|
||||||
|
sigma = sigmas[i]
|
||||||
|
sigma_next = sigmas[i + 1]
|
||||||
|
denoised = model(x, sigma * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised})
|
||||||
|
if sigma_next == 0:
|
||||||
|
x = denoised.to(in_dtype)
|
||||||
|
continue
|
||||||
|
noise = noise_sampler(sigma, sigma_next)
|
||||||
|
if noise_clip_std > 0:
|
||||||
|
clip_val = noise_clip_std * noise.std()
|
||||||
|
noise = noise.clamp(min=-clip_val, max=clip_val)
|
||||||
|
# Linear interpolation start -> end across steps, matching upstream
|
||||||
|
# pipeline.py's noise_scale_schedule construction.
|
||||||
|
t = (i / (n_steps - 1)) if n_steps > 1 else 0.0
|
||||||
|
s_noise_i = s_start + (s_end - s_start) * t
|
||||||
|
# Match upstream FlashFlowMatchEulerDiscreteScheduler.step: do the step
|
||||||
|
# math in fp32 to avoid bf16 accumulation drift across 28 steps.
|
||||||
|
x = (sigma_next * noise.float() * s_noise_i
|
||||||
|
+ (1.0 - sigma_next) * denoised.float()).to(in_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||||
|
|||||||
@ -792,6 +792,13 @@ class ZImagePixelSpace(ChromaRadiance):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamO1Pixel(ChromaRadiance):
|
||||||
|
"""Pixel-space latent format for HiDream-O1.
|
||||||
|
No VAE — model patches/unpatches raw RGB internally with patch_size=32.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
class CogVideoX(LatentFormat):
|
class CogVideoX(LatentFormat):
|
||||||
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
||||||
|
|
||||||
|
|||||||
46
comfy/ldm/hidream_o1/attention.py
Normal file
46
comfy/ldm/hidream_o1/attention.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""HiDream-O1 two-pass attention: tokens [0, ar_len) are causal, [ar_len, T)
|
||||||
|
attend full K/V. Splitting Q at the boundary avoids the (B, 1, T, T) additive
|
||||||
|
mask the general-purpose path would build (~500 MB at T~16K) and lets the
|
||||||
|
gen half hit the user's preferred backend via optimized_attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
|
def make_two_pass_attention(ar_len: int):
|
||||||
|
"""Build a two-pass attention callable. AR pass uses SDPA-causal directly, gen pass routes through optimized_attention.
|
||||||
|
"""
|
||||||
|
def two_pass_attention(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
if skip_reshape:
|
||||||
|
B, H, T, D = q.shape
|
||||||
|
else:
|
||||||
|
B, T, total_dim = q.shape
|
||||||
|
D = total_dim // heads
|
||||||
|
H = heads
|
||||||
|
q = q.view(B, T, H, D).transpose(1, 2)
|
||||||
|
k = k.view(B, T, H, D).transpose(1, 2)
|
||||||
|
v = v.view(B, T, H, D).transpose(1, 2)
|
||||||
|
|
||||||
|
if ar_len >= T:
|
||||||
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
|
||||||
|
elif ar_len <= 0:
|
||||||
|
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
else:
|
||||||
|
out_ar = comfy.ops.scaled_dot_product_attention(
|
||||||
|
q[:, :, :ar_len], k[:, :, :ar_len], v[:, :, :ar_len],
|
||||||
|
attn_mask=None, dropout_p=0.0, is_causal=True,
|
||||||
|
)
|
||||||
|
out_gen = optimized_attention(
|
||||||
|
q[:, :, ar_len:], k, v, heads,
|
||||||
|
mask=None, skip_reshape=True, skip_output_reshape=True,
|
||||||
|
)
|
||||||
|
out = torch.cat([out_ar, out_gen], dim=2)
|
||||||
|
|
||||||
|
if skip_output_reshape:
|
||||||
|
return out
|
||||||
|
return out.transpose(1, 2).reshape(B, T, H * D)
|
||||||
|
|
||||||
|
return two_pass_attention
|
||||||
269
comfy/ldm/hidream_o1/conditioning.py
Normal file
269
comfy/ldm/hidream_o1/conditioning.py
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
"""HiDream-O1 conditioning prep — ref-image dual path + extra_conds assembly.
|
||||||
|
|
||||||
|
Each ref image goes through two paths: a 32x32 patchified stream concatenated
|
||||||
|
to the noised target, and a Qwen3-VL ViT path producing tokens that scatter
|
||||||
|
into input_ids at <|image_pad|> positions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .utils import (PATCH_SIZE, calculate_dimensions, cond_image_size, ref_max_size, resize_pilimage)
|
||||||
|
|
||||||
|
# Qwen3-VL ViT preprocessing constants (preprocessor_config.json).
|
||||||
|
VIT_PATCH = 16
|
||||||
|
VIT_MERGE = 2
|
||||||
|
VIT_TEMPORAL_PATCH = 2
|
||||||
|
VIT_IMAGE_MEAN = [0.5, 0.5, 0.5]
|
||||||
|
VIT_IMAGE_STD = [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
|
||||||
|
def _process_vit_image(pil: Image.Image, device, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Qwen3-VL ViT preprocessing: returns (flatten_patches, image_grid_thw)."""
|
||||||
|
arr = np.asarray(pil, dtype=np.float32) / 255.0
|
||||||
|
img_t = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
|
||||||
|
h, w = img_t.shape[-2:]
|
||||||
|
|
||||||
|
# H/W must be multiples of patch*merge.
|
||||||
|
factor = VIT_PATCH * VIT_MERGE
|
||||||
|
h_bar = max(round(h / factor) * factor, factor)
|
||||||
|
w_bar = max(round(w / factor) * factor, factor)
|
||||||
|
if (h, w) != (h_bar, w_bar):
|
||||||
|
img_t = torch.nn.functional.interpolate(
|
||||||
|
img_t.unsqueeze(0), size=(h_bar, w_bar), mode="bilinear", align_corners=False,
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
mean = torch.tensor(VIT_IMAGE_MEAN).view(3, 1, 1)
|
||||||
|
std = torch.tensor(VIT_IMAGE_STD).view(3, 1, 1)
|
||||||
|
normalized = (img_t - mean) / std
|
||||||
|
|
||||||
|
grid_h = h_bar // VIT_PATCH
|
||||||
|
grid_w = w_bar // VIT_PATCH
|
||||||
|
grid_thw = torch.tensor([1, grid_h, grid_w], dtype=torch.long)
|
||||||
|
|
||||||
|
# Stack 2 copies for the temporal_patch dim, then patchify.
|
||||||
|
pixel_values = normalized.unsqueeze(0).repeat(VIT_TEMPORAL_PATCH, 1, 1, 1)
|
||||||
|
patches = pixel_values.reshape(
|
||||||
|
1, VIT_TEMPORAL_PATCH, 3,
|
||||||
|
grid_h // VIT_MERGE, VIT_MERGE, VIT_PATCH,
|
||||||
|
grid_w // VIT_MERGE, VIT_MERGE, VIT_PATCH,
|
||||||
|
)
|
||||||
|
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||||
|
flatten_patches = patches.reshape(
|
||||||
|
grid_h * grid_w,
|
||||||
|
3 * VIT_TEMPORAL_PATCH * VIT_PATCH * VIT_PATCH,
|
||||||
|
)
|
||||||
|
return flatten_patches.to(device=device, dtype=dtype), grid_thw.to(device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_ref_images(
|
||||||
|
ref_images: List[torch.Tensor],
|
||||||
|
target_h: int,
|
||||||
|
target_w: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""Build the dual-path tensors for K reference images at (target_h, target_w).
|
||||||
|
|
||||||
|
Returns None for K=0, else a dict with ref_patches,
|
||||||
|
ref_pixel_values, ref_image_grid_thw, per_ref_vit_tokens,
|
||||||
|
per_ref_patch_grids.
|
||||||
|
"""
|
||||||
|
K = len(ref_images)
|
||||||
|
if K == 0:
|
||||||
|
return None
|
||||||
|
max_size = ref_max_size(max(target_h, target_w), K)
|
||||||
|
cis = cond_image_size(K)
|
||||||
|
|
||||||
|
pils = []
|
||||||
|
for img in ref_images:
|
||||||
|
arr = np.round(img[0].clamp(0, 1).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||||
|
pils.append(Image.fromarray(arr, "RGB"))
|
||||||
|
pils_resized = [resize_pilimage(p, max_size, PATCH_SIZE) for p in pils]
|
||||||
|
|
||||||
|
# 32-patch path.
|
||||||
|
ref_patches_per = []
|
||||||
|
per_ref_patch_grids = []
|
||||||
|
for pil_r in pils_resized:
|
||||||
|
arr = np.asarray(pil_r, dtype=np.float32) / 255.0
|
||||||
|
t = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
|
||||||
|
t = (t - 0.5) / 0.5 # -> [-1, 1]
|
||||||
|
h_p, w_p = pil_r.height // PATCH_SIZE, pil_r.width // PATCH_SIZE
|
||||||
|
per_ref_patch_grids.append((h_p, w_p))
|
||||||
|
patches = einops.rearrange(
|
||||||
|
t, "C (H p1) (W p2) -> (H W) (C p1 p2)",
|
||||||
|
p1=PATCH_SIZE, p2=PATCH_SIZE,
|
||||||
|
)
|
||||||
|
ref_patches_per.append(patches)
|
||||||
|
ref_patches = torch.cat(ref_patches_per, dim=0).unsqueeze(0).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# ViT path.
|
||||||
|
pils_vlm = []
|
||||||
|
for pil_r in pils_resized:
|
||||||
|
cond_w, cond_h = calculate_dimensions(cis, pil_r.width / pil_r.height)
|
||||||
|
cond_w = max(cond_w, VIT_PATCH * VIT_MERGE)
|
||||||
|
cond_h = max(cond_h, VIT_PATCH * VIT_MERGE)
|
||||||
|
pils_vlm.append(pil_r.resize((cond_w, cond_h), resample=Image.LANCZOS))
|
||||||
|
|
||||||
|
pv_list, grid_list, per_ref_vit_tokens = [], [], []
|
||||||
|
for pil_v in pils_vlm:
|
||||||
|
pv, grid_thw = _process_vit_image(pil_v, device, dtype)
|
||||||
|
pv_list.append(pv)
|
||||||
|
grid_list.append(grid_thw)
|
||||||
|
# Post-merge token count = number of <|image_pad|> tokens this image
|
||||||
|
# expands to in input_ids.
|
||||||
|
gh, gw = int(grid_thw[1].item()), int(grid_thw[2].item())
|
||||||
|
per_ref_vit_tokens.append((gh // VIT_MERGE) * (gw // VIT_MERGE))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ref_patches": ref_patches,
|
||||||
|
"ref_pixel_values": torch.cat(pv_list, dim=0),
|
||||||
|
"ref_image_grid_thw": torch.stack(grid_list, dim=0),
|
||||||
|
"per_ref_vit_tokens": per_ref_vit_tokens,
|
||||||
|
"per_ref_patch_grids": per_ref_patch_grids,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_ref_input_ids(
|
||||||
|
text_input_ids: torch.Tensor,
|
||||||
|
per_ref_vit_tokens: List[int],
|
||||||
|
image_token_id: int,
|
||||||
|
vision_start_id: int,
|
||||||
|
vision_end_id: int,
|
||||||
|
):
|
||||||
|
"""Splice [vision_start, image_pad*N, vision_end] blocks into input_ids
|
||||||
|
after the [im_start, user, \\n] prefix (matches upstream chat template).
|
||||||
|
"""
|
||||||
|
ids = text_input_ids[0].tolist()
|
||||||
|
inserted = []
|
||||||
|
for n_pad in per_ref_vit_tokens:
|
||||||
|
inserted.extend([vision_start_id] + [image_token_id] * n_pad + [vision_end_id])
|
||||||
|
new_ids = ids[:3] + inserted + ids[3:] # 3 = len([im_start, user, \n])
|
||||||
|
return torch.tensor([new_ids], dtype=text_input_ids.dtype, device=text_input_ids.device)
|
||||||
|
|
||||||
|
|
||||||
|
def build_extra_conds(
|
||||||
|
text_input_ids: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
ref_images: List[torch.Tensor] = None,
|
||||||
|
target_patch_size: int = 32,
|
||||||
|
):
|
||||||
|
"""Assemble all conditioning tensors for HiDreamO1Transformer.forward:
|
||||||
|
input_ids (with ref-vision tokens spliced in for the edit/IP path),
|
||||||
|
position_ids (MRoPE), token_types, vinput_mask, plus the ref
|
||||||
|
dual-path tensors when refs are provided.
|
||||||
|
"""
|
||||||
|
from .utils import get_rope_index_fix_point
|
||||||
|
from comfy.text_encoders.hidream_o1 import (
|
||||||
|
IMAGE_TOKEN_ID, VIDEO_TOKEN_ID, VISION_START_ID, VISION_END_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_input_ids.dim() == 1:
|
||||||
|
text_input_ids = text_input_ids.unsqueeze(0)
|
||||||
|
text_input_ids = text_input_ids.long().to(noise.device)
|
||||||
|
B = noise.shape[0]
|
||||||
|
if text_input_ids.shape[0] == 1 and B > 1:
|
||||||
|
text_input_ids = text_input_ids.expand(B, -1)
|
||||||
|
|
||||||
|
H, W = noise.shape[-2], noise.shape[-1]
|
||||||
|
h_p, w_p = H // target_patch_size, W // target_patch_size
|
||||||
|
image_len = h_p * w_p
|
||||||
|
image_grid_thw_tgt = torch.tensor(
|
||||||
|
[[1, h_p, w_p]], dtype=torch.long, device=text_input_ids.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
if ref_images:
|
||||||
|
ref = prepare_ref_images(ref_images, H, W, device=noise.device, dtype=noise.dtype)
|
||||||
|
text_input_ids = build_ref_input_ids(
|
||||||
|
text_input_ids, ref["per_ref_vit_tokens"],
|
||||||
|
IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID,
|
||||||
|
)
|
||||||
|
new_txt_len = text_input_ids.shape[1]
|
||||||
|
|
||||||
|
# Each ref's patchified stream gets a [vision_start, image_pad*N-1]
|
||||||
|
# block in the position-id stream after the noised target.
|
||||||
|
ref_grid_lengths = [hp * wp for (hp, wp) in ref["per_ref_patch_grids"]]
|
||||||
|
tgt_vision = torch.full((1, image_len), IMAGE_TOKEN_ID,
|
||||||
|
dtype=text_input_ids.dtype, device=text_input_ids.device)
|
||||||
|
tgt_vision[:, 0] = VISION_START_ID
|
||||||
|
ref_vision_blocks = []
|
||||||
|
for rl in ref_grid_lengths:
|
||||||
|
blk = torch.full((1, rl), IMAGE_TOKEN_ID,
|
||||||
|
dtype=text_input_ids.dtype, device=text_input_ids.device)
|
||||||
|
blk[:, 0] = VISION_START_ID
|
||||||
|
ref_vision_blocks.append(blk)
|
||||||
|
ref_vision_cat = torch.cat([tgt_vision] + ref_vision_blocks, dim=1)
|
||||||
|
input_ids_pad = torch.cat([text_input_ids, ref_vision_cat], dim=-1)
|
||||||
|
total_ref_patches_len = sum(ref_grid_lengths)
|
||||||
|
total_len = new_txt_len + image_len + total_ref_patches_len
|
||||||
|
|
||||||
|
# K (ViT, post-merge) + 1 (target) + K (ref-patches) image grids.
|
||||||
|
K = len(ref_images)
|
||||||
|
igthw_cond = ref["ref_image_grid_thw"].clone()
|
||||||
|
igthw_cond[:, 1] //= 2
|
||||||
|
igthw_cond[:, 2] //= 2
|
||||||
|
image_grid_thw_ref = torch.tensor(
|
||||||
|
[[1, hp, wp] for (hp, wp) in ref["per_ref_patch_grids"]],
|
||||||
|
dtype=torch.long, device=text_input_ids.device,
|
||||||
|
)
|
||||||
|
igthw_all = torch.cat([
|
||||||
|
igthw_cond.to(text_input_ids.device),
|
||||||
|
image_grid_thw_tgt,
|
||||||
|
image_grid_thw_ref,
|
||||||
|
], dim=0)
|
||||||
|
position_ids, _ = get_rope_index_fix_point(
|
||||||
|
spatial_merge_size=1,
|
||||||
|
image_token_id=IMAGE_TOKEN_ID, video_token_id=VIDEO_TOKEN_ID,
|
||||||
|
vision_start_token_id=VISION_START_ID,
|
||||||
|
input_ids=input_ids_pad, image_grid_thw=igthw_all,
|
||||||
|
video_grid_thw=None, attention_mask=None,
|
||||||
|
skip_vision_start_token=[0] * K + [1] + [1] * K,
|
||||||
|
fix_point=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
# tms + target_image + ref_patches are all gen.
|
||||||
|
tms_pos = new_txt_len - 1
|
||||||
|
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
|
||||||
|
token_types[:, tms_pos:] = 1
|
||||||
|
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
|
||||||
|
vinput_mask[:, new_txt_len:] = True
|
||||||
|
|
||||||
|
# Leading batch dim sidesteps CONDRegular.process_cond's
|
||||||
|
# repeat_to_batch_size truncation (which narrows dim 0 to B).
|
||||||
|
out["ref_pixel_values"] = ref["ref_pixel_values"].unsqueeze(0)
|
||||||
|
out["ref_image_grid_thw"] = ref["ref_image_grid_thw"].unsqueeze(0)
|
||||||
|
out["ref_patches"] = ref["ref_patches"]
|
||||||
|
else:
|
||||||
|
# T2I: text + noised target only. vision_start replaces the first
|
||||||
|
# image token (upstream pipeline.py:51).
|
||||||
|
txt_len = text_input_ids.shape[1]
|
||||||
|
total_len = txt_len + image_len
|
||||||
|
vision_tokens = torch.full((B, image_len), IMAGE_TOKEN_ID,
|
||||||
|
dtype=text_input_ids.dtype, device=text_input_ids.device)
|
||||||
|
vision_tokens[:, 0] = VISION_START_ID
|
||||||
|
input_ids_pad = torch.cat([text_input_ids, vision_tokens], dim=-1)
|
||||||
|
position_ids, _ = get_rope_index_fix_point(
|
||||||
|
spatial_merge_size=1,
|
||||||
|
image_token_id=IMAGE_TOKEN_ID, video_token_id=VIDEO_TOKEN_ID,
|
||||||
|
vision_start_token_id=VISION_START_ID,
|
||||||
|
input_ids=input_ids_pad, image_grid_thw=image_grid_thw_tgt,
|
||||||
|
video_grid_thw=None, attention_mask=None,
|
||||||
|
skip_vision_start_token=[1],
|
||||||
|
)
|
||||||
|
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
|
||||||
|
token_types[:, txt_len - 1:] = 1
|
||||||
|
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
|
||||||
|
vinput_mask[:, txt_len:] = True
|
||||||
|
|
||||||
|
# Collapse position_ids batch and add a leading dim so CONDRegular's
|
||||||
|
# batch-resize doesn't truncate the 3-axis MRoPE dim.
|
||||||
|
out["input_ids"] = text_input_ids
|
||||||
|
out["position_ids"] = position_ids[:, 0].unsqueeze(0)
|
||||||
|
out["token_types"] = token_types
|
||||||
|
out["vinput_mask"] = vinput_mask
|
||||||
|
return out
|
||||||
231
comfy/ldm/hidream_o1/model.py
Normal file
231
comfy/ldm/hidream_o1/model.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
"""HiDream-O1-Image transformer.
|
||||||
|
|
||||||
|
Pixel-space DiT built on Qwen3-VL: the vision tower (Qwen35VisionModel)
|
||||||
|
encodes ref images, the Qwen3-VL-8B decoder (Llama2_ with interleaved MRoPE)
|
||||||
|
processes a unified text+image sequence, and 32x32 patch embed/unembed
|
||||||
|
shims map raw RGB in and out of LLM hidden space. The Qwen3-VL deepstack
|
||||||
|
mergers go unused — their weights are dropped at load.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||||
|
from comfy.text_encoders.llama import Llama2_
|
||||||
|
from comfy.text_encoders.qwen35 import Qwen35VisionModel
|
||||||
|
|
||||||
|
from .attention import make_two_pass_attention
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_TOKEN_ID = 151655 # Qwen3-VL <|image_pad|>
|
||||||
|
TMS_TOKEN_ID = 151673 # HiDream-O1 <|tms_token|>
|
||||||
|
PATCH_SIZE = 32
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HiDreamO1TextConfig:
|
||||||
|
"""Qwen3-VL-8B text-decoder dims (matches public Qwen3-VL-8B-Instruct)."""
|
||||||
|
vocab_size: int = 151936
|
||||||
|
hidden_size: int = 4096
|
||||||
|
intermediate_size: int = 12288
|
||||||
|
num_hidden_layers: int = 36
|
||||||
|
num_attention_heads: int = 32
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
head_dim: int = 128
|
||||||
|
max_position_embeddings: int = 128000
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta: float = 5000000.0
|
||||||
|
rope_scale: Optional[float] = None
|
||||||
|
rope_dims: List[int] = field(default_factory=lambda: [24, 20, 20])
|
||||||
|
interleaved_mrope: bool = True
|
||||||
|
transformer_type: str = "llama"
|
||||||
|
rms_norm_add: bool = False
|
||||||
|
mlp_activation: str = "silu"
|
||||||
|
qkv_bias: bool = False
|
||||||
|
q_norm: str = "gemma3"
|
||||||
|
k_norm: str = "gemma3"
|
||||||
|
final_norm: bool = True
|
||||||
|
lm_head: bool = False
|
||||||
|
stop_tokens: List[int] = field(default_factory=lambda: [151643, 151645])
|
||||||
|
|
||||||
|
|
||||||
|
QWEN3VL_VISION_DEFAULTS = dict(
|
||||||
|
hidden_size=1152,
|
||||||
|
num_heads=16,
|
||||||
|
intermediate_size=4304,
|
||||||
|
depth=27,
|
||||||
|
patch_size=16,
|
||||||
|
temporal_patch_size=2,
|
||||||
|
in_channels=3,
|
||||||
|
spatial_merge_size=2,
|
||||||
|
num_position_embeddings=2304,
|
||||||
|
deepstack_visual_indexes=(8, 16, 24),
|
||||||
|
out_hidden_size=4096, # final merger projects directly into LLM hidden
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckPatchEmbed(nn.Module):
|
||||||
|
# 3072 -> 1024 -> 4096 (raw 32x32 RGB patch -> bottleneck -> LLM hidden).
|
||||||
|
def __init__(self, patch_size=32, in_chans=3, pca_dim=1024, embed_dim=4096, bias=True, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.proj1 = ops.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.proj2 = ops.Linear(pca_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.proj2(self.proj1(x))
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
# 4096 -> 3072 (LLM hidden -> flat pixel patch).
|
||||||
|
def __init__(self, hidden_size, patch_size=32, out_channels=3, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = ops.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True,
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamO1Transformer(nn.Module):
|
||||||
|
"""HiDream-O1 unified pixel-level transformer."""
|
||||||
|
|
||||||
|
def __init__(self, image_model=None, dtype=None, device=None, operations=None,
|
||||||
|
text_config_overrides=None, vision_config_overrides=None, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
text_cfg = HiDreamO1TextConfig(**(text_config_overrides or {}))
|
||||||
|
vision_cfg = dict(QWEN3VL_VISION_DEFAULTS)
|
||||||
|
if vision_config_overrides:
|
||||||
|
vision_cfg.update(vision_config_overrides)
|
||||||
|
vision_cfg["out_hidden_size"] = text_cfg.hidden_size
|
||||||
|
|
||||||
|
self.text_config = text_cfg
|
||||||
|
self.vision_config = vision_cfg
|
||||||
|
self.hidden_size = text_cfg.hidden_size
|
||||||
|
self.patch_size = PATCH_SIZE
|
||||||
|
self.in_channels = 3
|
||||||
|
self.tms_token_id = TMS_TOKEN_ID
|
||||||
|
|
||||||
|
self.visual = Qwen35VisionModel(vision_cfg, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.language_model = Llama2_(text_cfg, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.t_embedder1 = TimestepEmbedder(
|
||||||
|
text_cfg.hidden_size, device=device, dtype=dtype, operations=operations,
|
||||||
|
)
|
||||||
|
self.x_embedder = BottleneckPatchEmbed(
|
||||||
|
patch_size=self.patch_size, in_chans=self.in_channels,
|
||||||
|
pca_dim=text_cfg.hidden_size // 4, embed_dim=text_cfg.hidden_size,
|
||||||
|
bias=True, device=device, dtype=dtype, ops=operations,
|
||||||
|
)
|
||||||
|
self.final_layer2 = FinalLayer(
|
||||||
|
text_cfg.hidden_size, patch_size=self.patch_size,
|
||||||
|
out_channels=self.in_channels, device=device, dtype=dtype, ops=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context=None, transformer_options={},
|
||||||
|
input_ids=None, attention_mask=None, position_ids=None,
|
||||||
|
token_types=None, vinput_mask=None,
|
||||||
|
ref_pixel_values=None, ref_image_grid_thw=None, ref_patches=None,
|
||||||
|
**kwargs):
|
||||||
|
"""Returns flow-match velocity (x - x_pred) / sigma"""
|
||||||
|
if input_ids is None or position_ids is None:
|
||||||
|
raise ValueError("HiDreamO1Transformer requires input_ids and position_ids in conditioning")
|
||||||
|
|
||||||
|
B, _, H, W = x.shape
|
||||||
|
h_p, w_p = H // self.patch_size, W // self.patch_size
|
||||||
|
tgt_image_len = h_p * w_p
|
||||||
|
|
||||||
|
z = einops.rearrange(
|
||||||
|
x, 'B C (H p1) (W p2) -> B (H W) (C p1 p2)',
|
||||||
|
p1=self.patch_size, p2=self.patch_size,
|
||||||
|
)
|
||||||
|
vinputs = torch.cat([z, ref_patches.to(z.dtype)], dim=1) if ref_patches is not None else z
|
||||||
|
|
||||||
|
if input_ids.dim() == 3:
|
||||||
|
input_ids = input_ids.squeeze(-1)
|
||||||
|
input_ids = input_ids.long()
|
||||||
|
inputs_embeds = self.language_model.embed_tokens(input_ids).to(x.dtype)
|
||||||
|
|
||||||
|
if ref_pixel_values is not None and ref_image_grid_thw is not None:
|
||||||
|
ref_pv = ref_pixel_values.to(inputs_embeds.device)
|
||||||
|
ref_grid = ref_image_grid_thw.to(inputs_embeds.device).long()
|
||||||
|
# Refs are model-level (same for cond/uncond), wrapped with a leading batch dim by extra_conds; [0] always recovers them.
|
||||||
|
if ref_pv.dim() == 3:
|
||||||
|
ref_pv = ref_pv[0]
|
||||||
|
if ref_grid.dim() == 3:
|
||||||
|
ref_grid = ref_grid[0]
|
||||||
|
image_embeds = self.visual(ref_pv, ref_grid).to(inputs_embeds.dtype)
|
||||||
|
image_mask = (input_ids == IMAGE_TOKEN_ID)
|
||||||
|
if image_mask[0].sum().item() != image_embeds.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image-token count {image_mask[0].sum().item()} != ViT output count "
|
||||||
|
f"{image_embeds.shape[0]}; check tokenizer/processor alignment."
|
||||||
|
)
|
||||||
|
image_embeds_b = image_embeds.unsqueeze(0).expand(B, -1, -1).reshape(-1, image_embeds.shape[-1])
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(
|
||||||
|
image_mask.unsqueeze(-1).expand_as(inputs_embeds), image_embeds_b,
|
||||||
|
)
|
||||||
|
|
||||||
|
sigma = timesteps.float() / 1000.0
|
||||||
|
t_pixeldit = 1.0 - sigma
|
||||||
|
t_emb = self.t_embedder1(t_pixeldit * 1000, inputs_embeds.dtype)
|
||||||
|
tms_mask_3d = (input_ids == self.tms_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||||
|
inputs_embeds = torch.where(tms_mask_3d, t_emb.unsqueeze(1).expand_as(inputs_embeds), inputs_embeds)
|
||||||
|
|
||||||
|
vinputs_embedded = self.x_embedder(vinputs.to(inputs_embeds.dtype))
|
||||||
|
inputs_embeds = torch.cat([inputs_embeds, vinputs_embedded], dim=1)
|
||||||
|
total_seq_len = inputs_embeds.shape[1]
|
||||||
|
|
||||||
|
# AR (text) tokens are contiguous at the start, so (==0).sum() gives ar_len.
|
||||||
|
if token_types is None:
|
||||||
|
txt_seq_len = input_ids.shape[1]
|
||||||
|
token_types = torch.zeros(B, total_seq_len, dtype=torch.long, device=x.device)
|
||||||
|
token_types[:, txt_seq_len:] = 1
|
||||||
|
else:
|
||||||
|
token_types = token_types.to(x.device)
|
||||||
|
if token_types.dim() == 1:
|
||||||
|
token_types = token_types.unsqueeze(0)
|
||||||
|
if token_types.shape[0] == 1 and B > 1:
|
||||||
|
token_types = token_types.expand(B, -1)
|
||||||
|
ar_len = int((token_types[0] == 0).sum().item())
|
||||||
|
|
||||||
|
# position_ids may arrive as (3, T) or wrapped (1, 3, T) / (3, 1, T) by CONDRegular.
|
||||||
|
position_ids = position_ids.to(x.device).long()
|
||||||
|
if position_ids.dim() == 3:
|
||||||
|
position_ids = position_ids[0] if position_ids.shape[1] == 3 else position_ids[:, 0]
|
||||||
|
freqs_cis = self.language_model.compute_freqs_cis(position_ids, x.device)
|
||||||
|
freqs_cis = tuple(t.to(x.dtype) for t in freqs_cis)
|
||||||
|
|
||||||
|
two_pass_attn = make_two_pass_attention(ar_len)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for layer in self.language_model.layers:
|
||||||
|
hidden_states, _ = layer(
|
||||||
|
x=hidden_states, attention_mask=None,
|
||||||
|
freqs_cis=freqs_cis, optimized_attention=two_pass_attn,
|
||||||
|
past_key_value=None,
|
||||||
|
)
|
||||||
|
if self.language_model.norm is not None:
|
||||||
|
hidden_states = self.language_model.norm(hidden_states)
|
||||||
|
|
||||||
|
x_pred = self.final_layer2(hidden_states)
|
||||||
|
if vinput_mask is not None:
|
||||||
|
vmask = vinput_mask.to(x.device).bool()
|
||||||
|
if vmask.dim() == 1:
|
||||||
|
vmask = vmask.unsqueeze(0)
|
||||||
|
if vmask.shape[0] == 1 and B > 1:
|
||||||
|
vmask = vmask.expand(B, -1)
|
||||||
|
x_pred_tgt = x_pred[vmask].view(B, -1, x_pred.shape[-1])[:, :tgt_image_len]
|
||||||
|
else:
|
||||||
|
txt_seq_len = input_ids.shape[1]
|
||||||
|
x_pred_tgt = x_pred[:, txt_seq_len:txt_seq_len + tgt_image_len]
|
||||||
|
|
||||||
|
# fp32 final subtraction, bf16 here noticeably degrades samples.
|
||||||
|
x_pred_img = einops.rearrange(
|
||||||
|
x_pred_tgt, 'B (H W) (C p1 p2) -> B C (H p1) (W p2)',
|
||||||
|
H=h_p, W=w_p, p1=self.patch_size, p2=self.patch_size,
|
||||||
|
)
|
||||||
|
return (x.float() - x_pred_img.float()) / sigma.view(B, 1, 1, 1).clamp_min(1e-3)
|
||||||
223
comfy/ldm/hidream_o1/utils.py
Normal file
223
comfy/ldm/hidream_o1/utils.py
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
"""HiDream-O1 input-prep helpers: image/resolution math and unified-sequence
|
||||||
|
RoPE position-id assembly. The fix_point offset in get_rope_index_fix_point
|
||||||
|
lets the target image and patchified ref images share spatial RoPE positions
|
||||||
|
despite living at different sequence indices — same 2D image plane.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
PREDEFINED_RESOLUTIONS = [
|
||||||
|
(2048, 2048),
|
||||||
|
(2304, 1728),
|
||||||
|
(1728, 2304),
|
||||||
|
(2560, 1440),
|
||||||
|
(1440, 2560),
|
||||||
|
(2496, 1664),
|
||||||
|
(1664, 2496),
|
||||||
|
(3104, 1312),
|
||||||
|
(1312, 3104),
|
||||||
|
(2304, 1792),
|
||||||
|
(1792, 2304),
|
||||||
|
]
|
||||||
|
|
||||||
|
PATCH_SIZE = 32
|
||||||
|
CONDITION_IMAGE_SIZE = 384 # ViT-side base size for ref images
|
||||||
|
|
||||||
|
|
||||||
|
def find_closest_resolution(width, height):
|
||||||
|
"""Closest (W, H) in PREDEFINED_RESOLUTIONS by aspect ratio."""
|
||||||
|
img_ratio = width / height
|
||||||
|
best = None
|
||||||
|
min_diff = float("inf")
|
||||||
|
for w, h in PREDEFINED_RESOLUTIONS:
|
||||||
|
diff = abs(w / h - img_ratio)
|
||||||
|
if diff < min_diff:
|
||||||
|
min_diff = diff
|
||||||
|
best = (w, h)
|
||||||
|
return best
|
||||||
|
|
||||||
|
|
||||||
|
def resize_pilimage(pil_image, image_size, patch_size=16, resampler=Image.BICUBIC):
|
||||||
|
"""Resize to fit image_size**2 area, patch-aligned, center-cropped. Pre-halves
|
||||||
|
with BOX filter while the image is still very large.
|
||||||
|
"""
|
||||||
|
while min(*pil_image.size) >= 2 * image_size:
|
||||||
|
pil_image = pil_image.resize(
|
||||||
|
tuple(x // 2 for x in pil_image.size), resample=Image.BOX,
|
||||||
|
)
|
||||||
|
|
||||||
|
m = patch_size
|
||||||
|
width, height = pil_image.width, pil_image.height
|
||||||
|
s_max = image_size * image_size
|
||||||
|
scale = math.sqrt(s_max / (width * height))
|
||||||
|
|
||||||
|
candidates = [
|
||||||
|
(round(width * scale) // m * m, round(height * scale) // m * m),
|
||||||
|
(round(width * scale) // m * m, math.floor(height * scale) // m * m),
|
||||||
|
(math.floor(width * scale) // m * m, round(height * scale) // m * m),
|
||||||
|
(math.floor(width * scale) // m * m, math.floor(height * scale) // m * m),
|
||||||
|
]
|
||||||
|
candidates = sorted(candidates, key=lambda x: x[0] * x[1], reverse=True)
|
||||||
|
new_size = candidates[-1]
|
||||||
|
for c in candidates:
|
||||||
|
if c[0] * c[1] <= s_max:
|
||||||
|
new_size = c
|
||||||
|
break
|
||||||
|
|
||||||
|
s1 = width / new_size[0]
|
||||||
|
s2 = height / new_size[1]
|
||||||
|
if s1 < s2:
|
||||||
|
pil_image = pil_image.resize([new_size[0], round(height / s1)], resample=resampler)
|
||||||
|
top = (round(height / s1) - new_size[1]) // 2
|
||||||
|
pil_image = pil_image.crop((0, top, new_size[0], top + new_size[1]))
|
||||||
|
else:
|
||||||
|
pil_image = pil_image.resize([round(width / s2), new_size[1]], resample=resampler)
|
||||||
|
left = (round(width / s2) - new_size[0]) // 2
|
||||||
|
pil_image = pil_image.crop((left, 0, left + new_size[0], new_size[1]))
|
||||||
|
return pil_image
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_dimensions(max_size, ratio):
|
||||||
|
"""(W, H) for an aspect ratio fitting in max_size**2 area, 32-aligned."""
|
||||||
|
width = math.sqrt(max_size * max_size * ratio)
|
||||||
|
height = width / ratio
|
||||||
|
width = int(width / 32) * 32
|
||||||
|
height = int(height / 32) * 32
|
||||||
|
return width, height
|
||||||
|
|
||||||
|
|
||||||
|
def ref_max_size(target_max_dim, k):
|
||||||
|
"""K-dependent ref-image max dim before patchifying."""
|
||||||
|
if k == 1:
|
||||||
|
return target_max_dim
|
||||||
|
if k == 2:
|
||||||
|
return target_max_dim * 48 // 64
|
||||||
|
if k <= 4:
|
||||||
|
return target_max_dim // 2
|
||||||
|
if k <= 8:
|
||||||
|
return target_max_dim * 24 // 64
|
||||||
|
return target_max_dim // 4
|
||||||
|
|
||||||
|
|
||||||
|
def cond_image_size(k):
|
||||||
|
"""K-dependent ViT-side image size."""
|
||||||
|
if k <= 4:
|
||||||
|
return CONDITION_IMAGE_SIZE
|
||||||
|
if k <= 8:
|
||||||
|
return CONDITION_IMAGE_SIZE * 48 // 64
|
||||||
|
return CONDITION_IMAGE_SIZE // 2
|
||||||
|
|
||||||
|
|
||||||
|
def get_rope_index_fix_point(
|
||||||
|
spatial_merge_size: int,
|
||||||
|
image_token_id: int,
|
||||||
|
video_token_id: int,
|
||||||
|
vision_start_token_id: int,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
skip_vision_start_token=None,
|
||||||
|
fix_point: int = 4096,
|
||||||
|
):
|
||||||
|
if video_grid_thw is not None:
|
||||||
|
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
|
||||||
|
video_grid_thw[:, 0] = 1
|
||||||
|
|
||||||
|
mrope_position_deltas = []
|
||||||
|
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
||||||
|
total_input_ids = input_ids
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones_like(total_input_ids)
|
||||||
|
position_ids = torch.ones(
|
||||||
|
3, input_ids.shape[0], input_ids.shape[1],
|
||||||
|
dtype=input_ids.dtype, device=input_ids.device,
|
||||||
|
)
|
||||||
|
image_index, video_index = 0, 0
|
||||||
|
attention_mask = attention_mask.to(total_input_ids.device)
|
||||||
|
for i, input_ids_b in enumerate(total_input_ids):
|
||||||
|
input_ids_b = input_ids_b[attention_mask[i] == 1]
|
||||||
|
vision_start_indices = torch.argwhere(input_ids_b == vision_start_token_id).squeeze(1)
|
||||||
|
vision_tokens = input_ids_b[vision_start_indices + 1]
|
||||||
|
image_nums = (vision_tokens == image_token_id).sum()
|
||||||
|
video_nums = (vision_tokens == video_token_id).sum()
|
||||||
|
input_tokens = input_ids_b.tolist()
|
||||||
|
llm_pos_ids_list = []
|
||||||
|
st = 0
|
||||||
|
remain_images, remain_videos = image_nums, video_nums
|
||||||
|
for _ in range(image_nums + video_nums):
|
||||||
|
if image_token_id in input_tokens and remain_images > 0:
|
||||||
|
ed_image = input_tokens.index(image_token_id, st)
|
||||||
|
else:
|
||||||
|
ed_image = len(input_tokens) + 1
|
||||||
|
if video_token_id in input_tokens and remain_videos > 0:
|
||||||
|
ed_video = input_tokens.index(video_token_id, st)
|
||||||
|
else:
|
||||||
|
ed_video = len(input_tokens) + 1
|
||||||
|
if ed_image < ed_video:
|
||||||
|
t = image_grid_thw[image_index][0]
|
||||||
|
h = image_grid_thw[image_index][1]
|
||||||
|
w = image_grid_thw[image_index][2]
|
||||||
|
image_index += 1
|
||||||
|
remain_images -= 1
|
||||||
|
ed = ed_image
|
||||||
|
else:
|
||||||
|
t = video_grid_thw[video_index][0]
|
||||||
|
h = video_grid_thw[video_index][1]
|
||||||
|
w = video_grid_thw[video_index][2]
|
||||||
|
video_index += 1
|
||||||
|
remain_videos -= 1
|
||||||
|
ed = ed_video
|
||||||
|
llm_grid_t = t.item()
|
||||||
|
llm_grid_h = h.item() // spatial_merge_size
|
||||||
|
llm_grid_w = w.item() // spatial_merge_size
|
||||||
|
text_len = ed - st
|
||||||
|
text_len -= skip_vision_start_token[image_index - 1]
|
||||||
|
text_len = max(0, text_len)
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||||
|
|
||||||
|
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
||||||
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
||||||
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
||||||
|
|
||||||
|
if skip_vision_start_token[image_index - 1]:
|
||||||
|
if fix_point > 0:
|
||||||
|
fix_point = fix_point - st_idx
|
||||||
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fix_point + st_idx)
|
||||||
|
fix_point = 0
|
||||||
|
else:
|
||||||
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||||
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
|
if st < len(input_tokens):
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
text_len = len(input_tokens) - st
|
||||||
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
||||||
|
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
||||||
|
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
||||||
|
return position_ids, mrope_position_deltas
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
||||||
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
||||||
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
||||||
|
else:
|
||||||
|
position_ids = (
|
||||||
|
torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||||
|
.view(1, 1, -1).expand(3, input_ids.shape[0], -1)
|
||||||
|
)
|
||||||
|
mrope_position_deltas = torch.zeros(
|
||||||
|
[input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype,
|
||||||
|
)
|
||||||
|
return position_ids, mrope_position_deltas
|
||||||
@ -58,6 +58,7 @@ import comfy.ldm.cogvideo.model
|
|||||||
import comfy.ldm.rt_detr.rtdetr_v4
|
import comfy.ldm.rt_detr.rtdetr_v4
|
||||||
import comfy.ldm.ernie.model
|
import comfy.ldm.ernie.model
|
||||||
import comfy.ldm.sam3.detector
|
import comfy.ldm.sam3.detector
|
||||||
|
import comfy.ldm.hidream_o1.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -1690,6 +1691,43 @@ class ChromaRadiance(Chroma):
|
|||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamO1(BaseModel):
|
||||||
|
"""HiDream-O1-Image: pixel-space DiT (no VAE). Refs from HiDreamO1ReferenceImages and tokens from the stub TE flow through
|
||||||
|
extra_conds; the heavy preprocessing lives in comfy.ldm.hidream_o1.conditioning."""
|
||||||
|
PATCH_SIZE = 32
|
||||||
|
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device,
|
||||||
|
unet_model=comfy.ldm.hidream_o1.model.HiDreamO1Transformer)
|
||||||
|
# HiDream-O1 trains with x_t = (1-t) x_clean + t * s_noise * noise
|
||||||
|
s_noise = float((model_config.sampling_settings or {}).get("s_noise", 8.0))
|
||||||
|
|
||||||
|
class _HiDreamO1Sampling(
|
||||||
|
comfy.model_sampling.ModelSamplingDiscreteFlow,
|
||||||
|
comfy.model_sampling.CONST_SCALED_NOISE,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
ms = _HiDreamO1Sampling(model_config)
|
||||||
|
ms._s_noise = s_noise
|
||||||
|
self.model_sampling = ms
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
text_input_ids = kwargs.get("text_input_ids", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
if text_input_ids is None or noise is None:
|
||||||
|
return out
|
||||||
|
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
|
||||||
|
conds = build_extra_conds(
|
||||||
|
text_input_ids, noise,
|
||||||
|
ref_images=kwargs.get("hidream_o1_ref_images", None),
|
||||||
|
target_patch_size=self.PATCH_SIZE,
|
||||||
|
)
|
||||||
|
for k, v in conds.items():
|
||||||
|
out[k] = comfy.conds.CONDRegular(v)
|
||||||
|
return out
|
||||||
|
|
||||||
class ACEStep(BaseModel):
|
class ACEStep(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
|
||||||
|
|||||||
@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
|
||||||
|
return {"image_model": "hidream_o1"}
|
||||||
|
|
||||||
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
|
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "hidream"
|
dit_config["image_model"] = "hidream"
|
||||||
|
|||||||
@ -99,6 +99,18 @@ class CONST:
|
|||||||
sigma = reshape_sigma(sigma, latent.ndim)
|
sigma = reshape_sigma(sigma, latent.ndim)
|
||||||
return latent / (1.0 - sigma)
|
return latent / (1.0 - sigma)
|
||||||
|
|
||||||
|
|
||||||
|
class CONST_SCALED_NOISE(CONST):
|
||||||
|
"""CONST variant for flow-match models trained with x_t = (1-t)*x_clean +
|
||||||
|
t*s_noise*noise. Set _s_noise to the recipe value; default 1.0 == plain CONST.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_s_noise = 1.0
|
||||||
|
|
||||||
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||||
|
sigma = reshape_sigma(sigma, noise.ndim)
|
||||||
|
return sigma * (self._s_noise * noise) + (1.0 - sigma) * latent_image
|
||||||
|
|
||||||
class X0(EPS):
|
class X0(EPS):
|
||||||
def calculate_denoised(self, sigma, model_output, model_input):
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
return model_output
|
return model_output
|
||||||
|
|||||||
@ -723,7 +723,8 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece",
|
||||||
|
"euler_flash_flowmatch"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
|||||||
@ -28,6 +28,7 @@ import comfy.text_encoders.ace15
|
|||||||
import comfy.text_encoders.longcat_image
|
import comfy.text_encoders.longcat_image
|
||||||
import comfy.text_encoders.ernie
|
import comfy.text_encoders.ernie
|
||||||
import comfy.text_encoders.cogvideo
|
import comfy.text_encoders.cogvideo
|
||||||
|
import comfy.text_encoders.hidream_o1
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -1480,6 +1481,49 @@ class ChromaRadiance(Chroma):
|
|||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.ChromaRadiance(self, device=device)
|
return model_base.ChromaRadiance(self, device=device)
|
||||||
|
|
||||||
|
class HiDreamO1(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hidream_o1",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 3.0,
|
||||||
|
"s_noise": 8.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.HiDreamO1Pixel
|
||||||
|
memory_usage_factor = 0.6
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.HiDreamO1(self, device=device)
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
# Drop unused Qwen3-VL deepstack merger weights; upstream discards them at inference.
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
if "visual.deepstack_merger_list" in key:
|
||||||
|
del state_dict[key]
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def process_vae_state_dict(self, state_dict):
|
||||||
|
# Pixel-space model: inject sentinel so VAE construction picks PixelspaceConversionVAE.
|
||||||
|
return {"pixel_space_vae": torch.tensor(1.0)}
|
||||||
|
|
||||||
|
def process_clip_state_dict(self, state_dict):
|
||||||
|
# Tokenizer-only TE: inject sentinel so load_state_dict_guess_config triggers CLIP init.
|
||||||
|
return {"_hidream_o1_te_sentinel": torch.zeros(1)}
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return supported_models_base.ClipTarget(
|
||||||
|
comfy.text_encoders.hidream_o1.HiDreamO1Tokenizer,
|
||||||
|
comfy.text_encoders.hidream_o1.HiDreamO1TE,
|
||||||
|
)
|
||||||
|
|
||||||
class ACEStep(supported_models_base.BASE):
|
class ACEStep(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"audio_model": "ace",
|
"audio_model": "ace",
|
||||||
@ -2018,6 +2062,7 @@ models = [
|
|||||||
Hunyuan3Dv2,
|
Hunyuan3Dv2,
|
||||||
Hunyuan3Dv2_1,
|
Hunyuan3Dv2_1,
|
||||||
HiDream,
|
HiDream,
|
||||||
|
HiDreamO1,
|
||||||
Chroma,
|
Chroma,
|
||||||
ChromaRadiance,
|
ChromaRadiance,
|
||||||
ACEStep,
|
ACEStep,
|
||||||
|
|||||||
124
comfy/text_encoders/hidream_o1.py
Normal file
124
comfy/text_encoders/hidream_o1.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
"""HiDream-O1-Image tokenizer-only text encoder.
|
||||||
|
|
||||||
|
The real Qwen3-VL backbone runs inside diffusion_model.* every step, so this
|
||||||
|
module just tokenizes the prompt into text_input_ids and emits them as
|
||||||
|
conditioning. Position ids / token_types / vinput_mask depend on target H/W
|
||||||
|
and are built later in model_base.HiDreamO1.extra_conds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import Qwen2Tokenizer
|
||||||
|
|
||||||
|
from comfy import sd1_clip
|
||||||
|
|
||||||
|
|
||||||
|
# Qwen3-VL special tokens
|
||||||
|
IM_START_ID = 151644
|
||||||
|
IM_END_ID = 151645
|
||||||
|
ASSISTANT_ID = 77091
|
||||||
|
USER_ID = 872
|
||||||
|
NEWLINE_ID = 198
|
||||||
|
VISION_START_ID = 151652
|
||||||
|
VISION_END_ID = 151653
|
||||||
|
IMAGE_TOKEN_ID = 151655
|
||||||
|
VIDEO_TOKEN_ID = 151656
|
||||||
|
# HiDream-O1-specific tokens
|
||||||
|
BOI_TOKEN_ID = 151669
|
||||||
|
BOR_TOKEN_ID = 151670
|
||||||
|
EOR_TOKEN_ID = 151671
|
||||||
|
BOT_TOKEN_ID = 151672
|
||||||
|
TMS_TOKEN_ID = 151673
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamO1QwenTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer"
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
tokenizer_path,
|
||||||
|
pad_with_end=False,
|
||||||
|
embedding_size=4096,
|
||||||
|
embedding_key="hidream_o1",
|
||||||
|
tokenizer_class=Qwen2Tokenizer,
|
||||||
|
has_start_token=False,
|
||||||
|
has_end_token=False,
|
||||||
|
pad_to_max_length=False,
|
||||||
|
max_length=99999999,
|
||||||
|
min_length=1,
|
||||||
|
pad_token=151643,
|
||||||
|
tokenizer_data=tokenizer_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamO1Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
"""Wraps prompt in the upstream chat template ending with boi/tms markers.
|
||||||
|
Image tokens get spliced in at sample time once target H/W is known.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(
|
||||||
|
embedding_directory=embedding_directory,
|
||||||
|
tokenizer_data=tokenizer_data,
|
||||||
|
name="hidream_o1",
|
||||||
|
tokenizer=HiDreamO1QwenTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||||
|
text_tokens_dict = super().tokenize_with_weights(
|
||||||
|
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
|
||||||
|
)
|
||||||
|
text_tuples = text_tokens_dict["hidream_o1"][0]
|
||||||
|
text_tuples = [t for t in text_tuples if int(t[0]) != 151643] # strip pad
|
||||||
|
|
||||||
|
# <|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|boi|><|tms|>
|
||||||
|
def tok(tid):
|
||||||
|
return (tid, 1.0) if not return_word_ids else (tid, 1.0, 0)
|
||||||
|
|
||||||
|
prefix = [tok(IM_START_ID), tok(USER_ID), tok(NEWLINE_ID)]
|
||||||
|
suffix = [
|
||||||
|
tok(IM_END_ID), tok(NEWLINE_ID),
|
||||||
|
tok(IM_START_ID), tok(ASSISTANT_ID), tok(NEWLINE_ID),
|
||||||
|
tok(BOI_TOKEN_ID), tok(TMS_TOKEN_ID),
|
||||||
|
]
|
||||||
|
full = prefix + list(text_tuples) + suffix
|
||||||
|
return {"hidream_o1": [full]}
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamO1TE(torch.nn.Module):
|
||||||
|
"""Passthrough TE: emits int token ids; the Qwen3-VL backbone in
|
||||||
|
diffusion_model.* does the actual encoding.
|
||||||
|
|
||||||
|
dtypes advertises uint8 as a routing hint: supports_cast(cuda, uint8)
|
||||||
|
is False, so CLIP.__init__ downgrades load_device to CPU, which makes
|
||||||
|
CoreModelPatcher skip the VBAR allocator (it would fail on a zero-param TE).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__()
|
||||||
|
self.dtypes = {torch.uint8}
|
||||||
|
self.device = torch.device("cpu") if device is None else torch.device(device)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
tok_pairs = token_weight_pairs["hidream_o1"][0]
|
||||||
|
ids = [int(t[0]) for t in tok_pairs]
|
||||||
|
input_ids = torch.tensor([ids], dtype=torch.long)
|
||||||
|
# Surrogate keeps the cross_attn slot non-empty for CONDITIONING
|
||||||
|
# plumbing; the model reads text_input_ids out of `extra` instead.
|
||||||
|
cross_attn = input_ids.unsqueeze(-1).to(torch.float32)
|
||||||
|
extra = {"text_input_ids": input_ids}
|
||||||
|
return cross_attn, None, extra
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
pass
|
||||||
@ -397,7 +397,7 @@ class RMSNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None, interleaved_mrope=False):
|
||||||
if not isinstance(theta, list):
|
if not isinstance(theta, list):
|
||||||
theta = [theta]
|
theta = [theta]
|
||||||
|
|
||||||
@ -415,6 +415,17 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
|||||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
|
if rope_dims is not None and position_ids.shape[0] > 1 and interleaved_mrope:
|
||||||
|
# Qwen3-VL interleaved MRoPE: T-freqs by default, H/W replace every 3rd dim.
|
||||||
|
freqs_inter = freqs[0].clone()
|
||||||
|
for axis_idx, offset in ((1, 1), (2, 2)):
|
||||||
|
length = rope_dims[axis_idx] * 3
|
||||||
|
idx = slice(offset, length, 3)
|
||||||
|
freqs_inter[..., idx] = freqs[axis_idx, ..., idx]
|
||||||
|
emb = torch.cat((freqs_inter, freqs_inter), dim=-1)
|
||||||
|
cos = emb.cos().unsqueeze(0)
|
||||||
|
sin = emb.sin().unsqueeze(0)
|
||||||
|
else:
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
cos = emb.cos()
|
cos = emb.cos()
|
||||||
sin = emb.sin()
|
sin = emb.sin()
|
||||||
@ -689,6 +700,7 @@ class Llama2_(nn.Module):
|
|||||||
self.config.rope_theta,
|
self.config.rope_theta,
|
||||||
self.config.rope_scale,
|
self.config.rope_scale,
|
||||||
self.config.rope_dims,
|
self.config.rope_dims,
|
||||||
|
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||||
|
|||||||
202
comfy_extras/nodes_hidream_o1.py
Normal file
202
comfy_extras/nodes_hidream_o1.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import node_helpers
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
from comfy.ldm.hidream_o1.utils import find_closest_resolution
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyHiDreamO1LatentImage(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EmptyHiDreamO1LatentImage",
|
||||||
|
display_name="Empty HiDream-O1 Latent Image",
|
||||||
|
category="latent/image",
|
||||||
|
description=(
|
||||||
|
"Empty pixel-space latent for HiDream-O1-Image. When "
|
||||||
|
"snap_to_predefined is on, dimensions are matched (by aspect "
|
||||||
|
"ratio) to the upstream HiDream-O1 PREDEFINED_RESOLUTIONS list."
|
||||||
|
),
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input(id="width", default=2048, min=64, max=4096, step=32),
|
||||||
|
io.Int.Input(id="height", default=2048, min=64, max=4096, step=32),
|
||||||
|
io.Int.Input(id="batch_size", default=1, min=1, max=64),
|
||||||
|
io.Boolean.Input(
|
||||||
|
id="snap_to_predefined",
|
||||||
|
default=True,
|
||||||
|
tooltip=(
|
||||||
|
"Snap (W, H) to the closest aspect ratio in HiDream-O1's "
|
||||||
|
"PREDEFINED_RESOLUTIONS table for best parity with the "
|
||||||
|
"upstream CLI. Disable for arbitrary 32-aligned sizes."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Latent().Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *, width: int, height: int, batch_size: int = 1,
|
||||||
|
snap_to_predefined: bool = True) -> io.NodeOutput:
|
||||||
|
if snap_to_predefined: #TODO: better way to handle this
|
||||||
|
sw, sh = find_closest_resolution(width, height)
|
||||||
|
width, height = sw, sh
|
||||||
|
width = (width // 32) * 32
|
||||||
|
height = (height // 32) * 32
|
||||||
|
latent = torch.zeros(
|
||||||
|
(batch_size, 3, height, width),
|
||||||
|
device=comfy.model_management.intermediate_device(),
|
||||||
|
)
|
||||||
|
return io.NodeOutput({"samples": latent})
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamO1ReferenceImages(io.ComfyNode):
|
||||||
|
"""Attach reference images to both positive and negative conditioning."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="HiDreamO1ReferenceImages",
|
||||||
|
display_name="HiDream-O1 Reference Images",
|
||||||
|
category="conditioning/image",
|
||||||
|
description=(
|
||||||
|
"Attach 1-10 reference images to conditioning, one for edit instruction"
|
||||||
|
"or multiple for subject-driven personalization."
|
||||||
|
),
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input(id="positive"),
|
||||||
|
io.Conditioning.Input(id="negative"),
|
||||||
|
io.Autogrow.Input(
|
||||||
|
"images",
|
||||||
|
template=io.Autogrow.TemplateNames(
|
||||||
|
io.Image.Input("image"),
|
||||||
|
names=[f"image_{i}" for i in range(1, 11)],
|
||||||
|
min=1,
|
||||||
|
),
|
||||||
|
tooltip=(
|
||||||
|
"Reference images. K=1 -> instruction edit; "
|
||||||
|
"K=2..10 -> subject-driven personalization."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *, positive, negative, images: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
refs = [images[f"image_{i}"] for i in range(1, 11) if f"image_{i}" in images]
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"hidream_o1_ref_images": refs})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"hidream_o1_ref_images": refs})
|
||||||
|
return io.NodeOutput(positive, negative)
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamO1Sampling(io.ComfyNode):
|
||||||
|
"""Adjust HiDream-O1's flow-match sigma shift and noise scale together."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="HiDreamO1Sampling",
|
||||||
|
display_name="HiDream-O1 Sampling",
|
||||||
|
category="advanced/model",
|
||||||
|
description=(
|
||||||
|
"Patch HiDream-O1's sigma shift and noise scaling factor. "
|
||||||
|
"Base model defaults: shift=3.0, s_noise=8.0. "
|
||||||
|
"Dev/flash sampler defaults: shift=1.0, s_noise=7.5."
|
||||||
|
),
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input(id="model"),
|
||||||
|
io.Float.Input(
|
||||||
|
id="shift", default=3.0, min=0.0, max=100.0, step=0.01,
|
||||||
|
tooltip="Flow-match sigma shift. Defaults: 3.0 for base, 1.0 for dev.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
id="s_noise", default=8.0, min=0.0, max=64.0, step=0.1,
|
||||||
|
tooltip=("HiDream-O1 noise scale (CONST_SCALED_NOISE). Defaults: 8.0 for base, 7.5 for dev/flash."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *, model, shift: float, s_noise: float) -> io.NodeOutput:
|
||||||
|
import comfy.model_sampling
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
class _HiDreamO1SamplingPatched(
|
||||||
|
comfy.model_sampling.ModelSamplingDiscreteFlow,
|
||||||
|
comfy.model_sampling.CONST_SCALED_NOISE,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
ms = _HiDreamO1SamplingPatched(m.model.model_config)
|
||||||
|
ms.set_parameters(shift=float(shift), multiplier=1000)
|
||||||
|
ms._s_noise = float(s_noise)
|
||||||
|
m.add_object_patch("model_sampling", ms)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerEulerFlashFlowmatch(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SamplerEulerFlashFlowmatch",
|
||||||
|
display_name="Sampler Euler Flash Flowmatch",
|
||||||
|
category="sampling/custom_sampling/samplers",
|
||||||
|
description=("HiDream-O1 dev/flash sampler with tunable per-step noise"),
|
||||||
|
inputs=[
|
||||||
|
io.Float.Input(
|
||||||
|
id="s_noise_start", default=7.5, min=0.0, max=64.0, step=0.1,
|
||||||
|
tooltip="Per-step noise scale at the first sampling step.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
id="s_noise_end", default=7.5, min=0.0, max=64.0, step=0.1,
|
||||||
|
tooltip=(
|
||||||
|
"Per-step noise scale at the last step. Default: 7.5 for dev/flash. "
|
||||||
|
"Differ from s_noise_start to linearly ramp noise across steps."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
id="noise_clip_std", default=2.5, min=0.0, max=10.0, step=0.1,
|
||||||
|
tooltip=("Clamp per-step noise to +/- N*std. 0 disables.")
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Sampler.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *, s_noise_start: float, s_noise_end: float,
|
||||||
|
noise_clip_std: float) -> io.NodeOutput:
|
||||||
|
import comfy.samplers
|
||||||
|
import comfy.k_diffusion.sampling
|
||||||
|
sampler = comfy.samplers.KSAMPLER(
|
||||||
|
comfy.k_diffusion.sampling.sample_euler_flash_flowmatch,
|
||||||
|
extra_options={
|
||||||
|
"s_noise": float(s_noise_start),
|
||||||
|
"s_noise_end": float(s_noise_end),
|
||||||
|
"noise_clip_std": float(noise_clip_std),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return io.NodeOutput(sampler)
|
||||||
|
|
||||||
|
|
||||||
|
class HiDreamO1Extension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
EmptyHiDreamO1LatentImage,
|
||||||
|
HiDreamO1ReferenceImages,
|
||||||
|
HiDreamO1Sampling,
|
||||||
|
SamplerEulerFlashFlowmatch,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> HiDreamO1Extension:
|
||||||
|
return HiDreamO1Extension()
|
||||||
Loading…
Reference in New Issue
Block a user