ComfyUI/comfy_extras/nodes_hidream_o1.py
Jukka Seppänen 8e53f001a4
feat: Support HiDream-O1-Image (CORE-187) (#13817)
* Initial HiDream01-image support

* Cleanup nodes

* Cleaner handling of empty placeholder models

* Remove snap_to_predefined, prefer tooltip for the trained resolutions

* Add model and block wrappers

* Fix shift tooltip

* Add node to work around the patch tile issue

Experimental, runs multiple passes with the patch grid offset and blends with various different methods.

* Qwen35 vision rotary_pos_emb cast fix

* Fix embedding layout type

* Some small optimizations

* Cleanup, don't need this fallback

* Prefix KV cache, cleanup

Bit of speed, reduce redundant code

* Get rid of redundant custom sampler, refactor noise scaling

Our existing lcm sampler is mathematically same, just added the missing options to it instead and a node to control them. Refactored the noise scaling and fix it for the stochastic samplers, add a generic node to control the initial noise scale.

* Update nodes_hidream_o1.py

* Fix some cache validation cases

* Keep existing sampling params

* Remove redundant video vision path

* Replace some numpy ops with torch

* Fx RoPE index for batch size > 1

* Prefer torch preprocessing

* Rename block_type to be compatible with existing patch nodes

* Fixes and tweaks
2026-05-11 20:35:53 -07:00

257 lines
11 KiB
Python

from typing_extensions import override
import torch
import comfy.model_management
import comfy.patcher_extension
import node_helpers
from comfy_api.latest import ComfyExtension, io
class EmptyHiDreamO1LatentImage(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EmptyHiDreamO1LatentImage",
display_name="Empty HiDream-O1 Latent Image",
category="latent/image",
description=(
"Empty pixel-space latent for HiDream-O1-Image. The model was "
"trained at ~4 megapixels; lower resolutions go off-distribution "
"and quality regresses noticeably. Trained resolutions: "
"2048x2048, 2304x1728, 1728x2304, 2560x1440, 1440x2560, "
"2496x1664, 1664x2496, 3104x1312, 1312x3104, 2304x1792, 1792x2304."
),
inputs=[
io.Int.Input(id="width", default=2048, min=64, max=4096, step=32),
io.Int.Input(id="height", default=2048, min=64, max=4096, step=32),
io.Int.Input(id="batch_size", default=1, min=1, max=64),
],
outputs=[io.Latent().Output()],
)
@classmethod
def execute(cls, *, width: int, height: int, batch_size: int = 1) -> io.NodeOutput:
latent = torch.zeros(
(batch_size, 3, height, width),
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput({"samples": latent})
class HiDreamO1ReferenceImages(io.ComfyNode):
"""Attach reference images to both positive and negative conditioning."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="HiDreamO1ReferenceImages",
display_name="HiDream-O1 Reference Images",
category="conditioning/image",
description=(
"Attach 1-10 reference images to conditioning, one for edit instruction"
"or multiple for subject-driven personalization."
),
inputs=[
io.Conditioning.Input(id="positive"),
io.Conditioning.Input(id="negative"),
io.Autogrow.Input(
"images",
template=io.Autogrow.TemplateNames(
io.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 11)],
min=1,
),
tooltip=("Reference images. 1 image = instruction edit; 2-10 images = multi reference."
),
),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
@classmethod
def execute(cls, *, positive, negative, images: io.Autogrow.Type) -> io.NodeOutput:
refs = [images[f"image_{i}"] for i in range(1, 11) if f"image_{i}" in images]
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": refs}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": refs}, append=True)
return io.NodeOutput(positive, negative)
class HiDreamO1PatchSeamSmoothing(io.ComfyNode):
PATCH_SIZE = 32
EDGE_FEATHER = 4
# Shift presets per (pattern, N). 8-pass = 4-quadrant + 4 quarter-patch offsets.
SHIFTS_BY_PATTERN = {
("single_shift", 2): [(0, 0), (16, 16)],
("single_shift", 4): [(0, 0), (16, 0), (0, 16), (16, 16)],
("single_shift", 8): [(0, 0), (16, 0), (0, 16), (16, 16),
(8, 8), (24, 8), (8, 24), (24, 24)],
("symmetric", 2): [(-8, -8), (8, 8)],
("symmetric", 4): [(-8, -8), (8, -8), (-8, 8), (8, 8)],
("symmetric", 8): [(-12, -12), (4, -12), (-12, 4), (4, 4),
(-4, -4), (12, -4), (-4, 12), (12, 12)],
}
RAMP_LEVELS = {
"2": [2],
"4": [4],
"ramp_2_4": [2, 4],
"ramp_2_4_8": [2, 4, 8],
}
@staticmethod
def _hann_tile(cy: int, cx: int, size: int = 32) -> torch.Tensor:
"""size x size Hann tile peaking at (cy, cx) within a patch."""
half = size // 2
yy = torch.arange(size).view(size, 1)
xx = torch.arange(size).view(1, size)
dy = ((yy - cy + half) % size) - half
dx = ((xx - cx + half) % size) - half
return 0.25 * (1 + torch.cos(torch.pi * dy / half)) * (1 + torch.cos(torch.pi * dx / half))
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="HiDreamO1PatchSeamSmoothing",
display_name="HiDream-O1 Patch Seam Smoothing",
category="advanced/model",
is_experimental=True,
description=(
"Average the model output across multiple shifted patch-grid "
"positions during the late portion of sampling. Cancels seams."
),
inputs=[
io.Model.Input(id="model"),
io.Float.Input(id="start_percent", default=0.8, min=0.0, max=1.0, step=0.01,
tooltip="Sampling progress (0=start, 1=end) at which the blend turns ON.",
),
io.Float.Input(id="end_percent", default=1.0, min=0.0, max=1.0, step=0.01,
tooltip="Sampling progress at which the blend turns OFF.",
),
io.Combo.Input(
id="pattern",
options=["single_shift", "symmetric"],
default="single_shift",
tooltip="Shift layout. single_shift: one pass at the natural patch grid + others offset. symmetric: all passes off-grid, shifts split around origin.",
),
io.Combo.Input(
id="passes",
options=["2", "4", "ramp_2_4", "ramp_2_4_8"],
default="2",
tooltip="Number of passes per gated step. 2/4 = fixed. ramp_*: pass count increases as sampling approaches end (more smoothing where seams are most visible).",
),
io.Combo.Input(
id="blend",
options=["average", "window", "median"],
default="average",
tooltip="average: equal-weight mean. window: Hann-windowed weighting favoring each pass away from its patch boundaries. median: per-pixel median, rejects wraparound-outlier passes.",
),
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01,
tooltip="Interpolation between the natural-grid pred (0) and the averaged result (1).",
),
],
outputs=[io.Model.Output()],
)
@classmethod
def execute(cls, *, model, start_percent: float, end_percent: float, pattern: str, passes: str, blend: str, strength: float) -> io.NodeOutput:
if strength <= 0.0 or end_percent <= start_percent:
return io.NodeOutput(model)
P = cls.PATCH_SIZE
half = P // 2
shift_levels = [cls.SHIFTS_BY_PATTERN[(pattern, n)] for n in cls.RAMP_LEVELS[passes]]
if blend == "window":
window_tile_levels = [
torch.stack([cls._hann_tile((half - sy) % P, (half - sx) % P, P) for sy, sx in lst], dim=0)
for lst in shift_levels
]
else:
window_tile_levels = [None] * len(shift_levels)
m = model.clone()
model_sampling = m.get_model_object("model_sampling")
multiplier = float(model_sampling.multiplier)
start_t = float(model_sampling.percent_to_sigma(start_percent)) * multiplier
end_t = float(model_sampling.percent_to_sigma(end_percent)) * multiplier
edge_ramp_cache: dict = {}
def get_edge_ramp(H: int, W: int, device, dtype) -> torch.Tensor:
key = (H, W, device, dtype)
cached = edge_ramp_cache.get(key)
if cached is not None:
return cached
feather = cls.EDGE_FEATHER
ys = torch.minimum(torch.arange(H, device=device, dtype=torch.float32),
(H - 1) - torch.arange(H, device=device, dtype=torch.float32))
xs = torch.minimum(torch.arange(W, device=device, dtype=torch.float32),
(W - 1) - torch.arange(W, device=device, dtype=torch.float32))
y_mask = ((ys - P) / feather).clamp(0, 1)
x_mask = ((xs - P) / feather).clamp(0, 1)
ramp = (y_mask[:, None] * x_mask[None, :]).to(dtype)
edge_ramp_cache[key] = ramp
return ramp
def smoothing_wrapper(executor, *args, **kwargs):
x = args[0]
t = float(args[1][0])
pred = executor(*args, **kwargs)
if not (end_t <= t <= start_t):
return pred
# Pick shift-level by sigma phase across the gated range.
if len(shift_levels) == 1:
level_idx = 0
else:
phase = (start_t - t) / max(start_t - end_t, 1e-8)
level_idx = min(int(phase * len(shift_levels)), len(shift_levels) - 1)
shifts = shift_levels[level_idx]
window_tiles = window_tile_levels[level_idx]
preds = []
for sy, sx in shifts:
if sy == 0 and sx == 0:
preds.append(pred)
continue
x_rolled = torch.roll(x, shifts=(sy, sx), dims=(-2, -1))
pred_rolled = executor(x_rolled, *args[1:], **kwargs)
preds.append(torch.roll(pred_rolled, shifts=(-sy, -sx), dims=(-2, -1)))
stacked = torch.stack(preds, dim=0) # (N, B, C, H, W)
_, _, _, H, W = stacked.shape
if blend == "window":
N = stacked.shape[0]
tiles = window_tiles.to(device=stacked.device, dtype=stacked.dtype)
w = tiles.repeat(1, H // P, W // P)[:, :H, :W]
sum_w = w.sum(dim=0, keepdim=True)
w = torch.where(sum_w < 1e-3, torch.full_like(w, 1.0 / N), w / sum_w.clamp(min=1e-8))
avg = (stacked * w[:, None, None, :, :]).sum(dim=0)
elif blend == "median":
avg = torch.median(stacked, dim=0).values
else:
avg = stacked.mean(dim=0)
# Mask out the P-px wraparound contamination strip at each edge.
mask = get_edge_ramp(H, W, pred.device, pred.dtype)
return pred * (1.0 - mask * strength) + avg * (mask * strength)
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "hidream_o1_patch_seam_smoothing", smoothing_wrapper)
return io.NodeOutput(m)
class HiDreamO1Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyHiDreamO1LatentImage,
HiDreamO1ReferenceImages,
HiDreamO1PatchSeamSmoothing,
]
async def comfy_entrypoint() -> HiDreamO1Extension:
return HiDreamO1Extension()