mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 02:57:24 +08:00
* Initial HiDream01-image support * Cleanup nodes * Cleaner handling of empty placeholder models * Remove snap_to_predefined, prefer tooltip for the trained resolutions * Add model and block wrappers * Fix shift tooltip * Add node to work around the patch tile issue Experimental, runs multiple passes with the patch grid offset and blends with various different methods. * Qwen35 vision rotary_pos_emb cast fix * Fix embedding layout type * Some small optimizations * Cleanup, don't need this fallback * Prefix KV cache, cleanup Bit of speed, reduce redundant code * Get rid of redundant custom sampler, refactor noise scaling Our existing lcm sampler is mathematically same, just added the missing options to it instead and a node to control them. Refactored the noise scaling and fix it for the stochastic samplers, add a generic node to control the initial noise scale. * Update nodes_hidream_o1.py * Fix some cache validation cases * Keep existing sampling params * Remove redundant video vision path * Replace some numpy ops with torch * Fx RoPE index for batch size > 1 * Prefer torch preprocessing * Rename block_type to be compatible with existing patch nodes * Fixes and tweaks
257 lines
11 KiB
Python
257 lines
11 KiB
Python
from typing_extensions import override
|
|
|
|
import torch
|
|
|
|
import comfy.model_management
|
|
import comfy.patcher_extension
|
|
import node_helpers
|
|
from comfy_api.latest import ComfyExtension, io
|
|
|
|
|
|
class EmptyHiDreamO1LatentImage(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="EmptyHiDreamO1LatentImage",
|
|
display_name="Empty HiDream-O1 Latent Image",
|
|
category="latent/image",
|
|
description=(
|
|
"Empty pixel-space latent for HiDream-O1-Image. The model was "
|
|
"trained at ~4 megapixels; lower resolutions go off-distribution "
|
|
"and quality regresses noticeably. Trained resolutions: "
|
|
"2048x2048, 2304x1728, 1728x2304, 2560x1440, 1440x2560, "
|
|
"2496x1664, 1664x2496, 3104x1312, 1312x3104, 2304x1792, 1792x2304."
|
|
),
|
|
inputs=[
|
|
io.Int.Input(id="width", default=2048, min=64, max=4096, step=32),
|
|
io.Int.Input(id="height", default=2048, min=64, max=4096, step=32),
|
|
io.Int.Input(id="batch_size", default=1, min=1, max=64),
|
|
],
|
|
outputs=[io.Latent().Output()],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, *, width: int, height: int, batch_size: int = 1) -> io.NodeOutput:
|
|
latent = torch.zeros(
|
|
(batch_size, 3, height, width),
|
|
device=comfy.model_management.intermediate_device(),
|
|
)
|
|
return io.NodeOutput({"samples": latent})
|
|
|
|
|
|
class HiDreamO1ReferenceImages(io.ComfyNode):
|
|
"""Attach reference images to both positive and negative conditioning."""
|
|
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="HiDreamO1ReferenceImages",
|
|
display_name="HiDream-O1 Reference Images",
|
|
category="conditioning/image",
|
|
description=(
|
|
"Attach 1-10 reference images to conditioning, one for edit instruction"
|
|
"or multiple for subject-driven personalization."
|
|
),
|
|
inputs=[
|
|
io.Conditioning.Input(id="positive"),
|
|
io.Conditioning.Input(id="negative"),
|
|
io.Autogrow.Input(
|
|
"images",
|
|
template=io.Autogrow.TemplateNames(
|
|
io.Image.Input("image"),
|
|
names=[f"image_{i}" for i in range(1, 11)],
|
|
min=1,
|
|
),
|
|
tooltip=("Reference images. 1 image = instruction edit; 2-10 images = multi reference."
|
|
),
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Conditioning.Output(display_name="positive"),
|
|
io.Conditioning.Output(display_name="negative"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, *, positive, negative, images: io.Autogrow.Type) -> io.NodeOutput:
|
|
refs = [images[f"image_{i}"] for i in range(1, 11) if f"image_{i}" in images]
|
|
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": refs}, append=True)
|
|
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": refs}, append=True)
|
|
return io.NodeOutput(positive, negative)
|
|
|
|
|
|
class HiDreamO1PatchSeamSmoothing(io.ComfyNode):
|
|
PATCH_SIZE = 32
|
|
EDGE_FEATHER = 4
|
|
|
|
# Shift presets per (pattern, N). 8-pass = 4-quadrant + 4 quarter-patch offsets.
|
|
SHIFTS_BY_PATTERN = {
|
|
("single_shift", 2): [(0, 0), (16, 16)],
|
|
("single_shift", 4): [(0, 0), (16, 0), (0, 16), (16, 16)],
|
|
("single_shift", 8): [(0, 0), (16, 0), (0, 16), (16, 16),
|
|
(8, 8), (24, 8), (8, 24), (24, 24)],
|
|
("symmetric", 2): [(-8, -8), (8, 8)],
|
|
("symmetric", 4): [(-8, -8), (8, -8), (-8, 8), (8, 8)],
|
|
("symmetric", 8): [(-12, -12), (4, -12), (-12, 4), (4, 4),
|
|
(-4, -4), (12, -4), (-4, 12), (12, 12)],
|
|
}
|
|
RAMP_LEVELS = {
|
|
"2": [2],
|
|
"4": [4],
|
|
"ramp_2_4": [2, 4],
|
|
"ramp_2_4_8": [2, 4, 8],
|
|
}
|
|
|
|
@staticmethod
|
|
def _hann_tile(cy: int, cx: int, size: int = 32) -> torch.Tensor:
|
|
"""size x size Hann tile peaking at (cy, cx) within a patch."""
|
|
half = size // 2
|
|
yy = torch.arange(size).view(size, 1)
|
|
xx = torch.arange(size).view(1, size)
|
|
dy = ((yy - cy + half) % size) - half
|
|
dx = ((xx - cx + half) % size) - half
|
|
return 0.25 * (1 + torch.cos(torch.pi * dy / half)) * (1 + torch.cos(torch.pi * dx / half))
|
|
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="HiDreamO1PatchSeamSmoothing",
|
|
display_name="HiDream-O1 Patch Seam Smoothing",
|
|
category="advanced/model",
|
|
is_experimental=True,
|
|
description=(
|
|
"Average the model output across multiple shifted patch-grid "
|
|
"positions during the late portion of sampling. Cancels seams."
|
|
),
|
|
inputs=[
|
|
io.Model.Input(id="model"),
|
|
io.Float.Input(id="start_percent", default=0.8, min=0.0, max=1.0, step=0.01,
|
|
tooltip="Sampling progress (0=start, 1=end) at which the blend turns ON.",
|
|
),
|
|
io.Float.Input(id="end_percent", default=1.0, min=0.0, max=1.0, step=0.01,
|
|
tooltip="Sampling progress at which the blend turns OFF.",
|
|
),
|
|
io.Combo.Input(
|
|
id="pattern",
|
|
options=["single_shift", "symmetric"],
|
|
default="single_shift",
|
|
tooltip="Shift layout. single_shift: one pass at the natural patch grid + others offset. symmetric: all passes off-grid, shifts split around origin.",
|
|
),
|
|
io.Combo.Input(
|
|
id="passes",
|
|
options=["2", "4", "ramp_2_4", "ramp_2_4_8"],
|
|
default="2",
|
|
tooltip="Number of passes per gated step. 2/4 = fixed. ramp_*: pass count increases as sampling approaches end (more smoothing where seams are most visible).",
|
|
),
|
|
io.Combo.Input(
|
|
id="blend",
|
|
options=["average", "window", "median"],
|
|
default="average",
|
|
tooltip="average: equal-weight mean. window: Hann-windowed weighting favoring each pass away from its patch boundaries. median: per-pixel median, rejects wraparound-outlier passes.",
|
|
),
|
|
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01,
|
|
tooltip="Interpolation between the natural-grid pred (0) and the averaged result (1).",
|
|
),
|
|
],
|
|
outputs=[io.Model.Output()],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, *, model, start_percent: float, end_percent: float, pattern: str, passes: str, blend: str, strength: float) -> io.NodeOutput:
|
|
if strength <= 0.0 or end_percent <= start_percent:
|
|
return io.NodeOutput(model)
|
|
|
|
P = cls.PATCH_SIZE
|
|
half = P // 2
|
|
shift_levels = [cls.SHIFTS_BY_PATTERN[(pattern, n)] for n in cls.RAMP_LEVELS[passes]]
|
|
|
|
if blend == "window":
|
|
window_tile_levels = [
|
|
torch.stack([cls._hann_tile((half - sy) % P, (half - sx) % P, P) for sy, sx in lst], dim=0)
|
|
for lst in shift_levels
|
|
]
|
|
else:
|
|
window_tile_levels = [None] * len(shift_levels)
|
|
|
|
m = model.clone()
|
|
model_sampling = m.get_model_object("model_sampling")
|
|
multiplier = float(model_sampling.multiplier)
|
|
start_t = float(model_sampling.percent_to_sigma(start_percent)) * multiplier
|
|
end_t = float(model_sampling.percent_to_sigma(end_percent)) * multiplier
|
|
|
|
edge_ramp_cache: dict = {}
|
|
|
|
def get_edge_ramp(H: int, W: int, device, dtype) -> torch.Tensor:
|
|
key = (H, W, device, dtype)
|
|
cached = edge_ramp_cache.get(key)
|
|
if cached is not None:
|
|
return cached
|
|
feather = cls.EDGE_FEATHER
|
|
ys = torch.minimum(torch.arange(H, device=device, dtype=torch.float32),
|
|
(H - 1) - torch.arange(H, device=device, dtype=torch.float32))
|
|
xs = torch.minimum(torch.arange(W, device=device, dtype=torch.float32),
|
|
(W - 1) - torch.arange(W, device=device, dtype=torch.float32))
|
|
y_mask = ((ys - P) / feather).clamp(0, 1)
|
|
x_mask = ((xs - P) / feather).clamp(0, 1)
|
|
ramp = (y_mask[:, None] * x_mask[None, :]).to(dtype)
|
|
edge_ramp_cache[key] = ramp
|
|
return ramp
|
|
|
|
def smoothing_wrapper(executor, *args, **kwargs):
|
|
x = args[0]
|
|
t = float(args[1][0])
|
|
pred = executor(*args, **kwargs)
|
|
if not (end_t <= t <= start_t):
|
|
return pred
|
|
# Pick shift-level by sigma phase across the gated range.
|
|
if len(shift_levels) == 1:
|
|
level_idx = 0
|
|
else:
|
|
phase = (start_t - t) / max(start_t - end_t, 1e-8)
|
|
level_idx = min(int(phase * len(shift_levels)), len(shift_levels) - 1)
|
|
shifts = shift_levels[level_idx]
|
|
window_tiles = window_tile_levels[level_idx]
|
|
|
|
preds = []
|
|
for sy, sx in shifts:
|
|
if sy == 0 and sx == 0:
|
|
preds.append(pred)
|
|
continue
|
|
x_rolled = torch.roll(x, shifts=(sy, sx), dims=(-2, -1))
|
|
pred_rolled = executor(x_rolled, *args[1:], **kwargs)
|
|
preds.append(torch.roll(pred_rolled, shifts=(-sy, -sx), dims=(-2, -1)))
|
|
stacked = torch.stack(preds, dim=0) # (N, B, C, H, W)
|
|
_, _, _, H, W = stacked.shape
|
|
if blend == "window":
|
|
N = stacked.shape[0]
|
|
tiles = window_tiles.to(device=stacked.device, dtype=stacked.dtype)
|
|
w = tiles.repeat(1, H // P, W // P)[:, :H, :W]
|
|
sum_w = w.sum(dim=0, keepdim=True)
|
|
w = torch.where(sum_w < 1e-3, torch.full_like(w, 1.0 / N), w / sum_w.clamp(min=1e-8))
|
|
avg = (stacked * w[:, None, None, :, :]).sum(dim=0)
|
|
elif blend == "median":
|
|
avg = torch.median(stacked, dim=0).values
|
|
else:
|
|
avg = stacked.mean(dim=0)
|
|
|
|
# Mask out the P-px wraparound contamination strip at each edge.
|
|
mask = get_edge_ramp(H, W, pred.device, pred.dtype)
|
|
return pred * (1.0 - mask * strength) + avg * (mask * strength)
|
|
|
|
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "hidream_o1_patch_seam_smoothing", smoothing_wrapper)
|
|
return io.NodeOutput(m)
|
|
|
|
|
|
class HiDreamO1Extension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
EmptyHiDreamO1LatentImage,
|
|
HiDreamO1ReferenceImages,
|
|
HiDreamO1PatchSeamSmoothing,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> HiDreamO1Extension:
|
|
return HiDreamO1Extension()
|