Add helper function to set timestep ranges for the conditioning

This commit is contained in:
kijai 2026-02-27 17:12:06 +02:00
parent 48c7b1464f
commit b4734f44d3
3 changed files with 33 additions and 17 deletions

View File

@ -1507,23 +1507,8 @@ class WAN21_SCAIL(WAN21):
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
pose_start = kwargs.get("pose_start", 0.0)
pose_end = kwargs.get("pose_end", 1.0)
out['pose_start'] = comfy.conds.CONDConstant(pose_start)
out['pose_end'] = comfy.conds.CONDConstant(pose_end)
return out
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, pose_start=0.0, pose_end=1.0, **kwargs):
if t[0] >= self.model_sampling.percent_to_sigma(pose_start) or t[0] <= self.model_sampling.percent_to_sigma(pose_end):
kwargs.pop("pose_latents", None)
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._apply_model,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options)
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)

View File

@ -1505,8 +1505,8 @@ class WanSCAILToVideo(io.ComfyNode):
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent, "pose_start": pose_start, "pose_end": pose_end})
negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent, "pose_start": pose_start, "pose_end": pose_end})
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
out_latent = {}
out_latent["samples"] = latent

View File

@ -1,5 +1,6 @@
import hashlib
import torch
import logging
from comfy.cli_args import args
@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False):
return c
def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0):
"""
Apply values to conditioning only during [start_percent, end_percent], keeping the
original conditioning active outside that range. Respects existing per-entry ranges.
"""
if start_percent > end_percent:
logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})")
return conditioning
EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep
c = []
for t in conditioning:
cond_start = t[1].get("start_percent", 0.0)
cond_end = t[1].get("end_percent", 1.0)
intersect_start = max(start_percent, cond_start)
intersect_end = min(end_percent, cond_end)
if intersect_start >= intersect_end: # no overlap: emit unchanged
c.append(t)
continue
if intersect_start > cond_start: # part before the requested range
c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS}))
c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end}))
if intersect_end < cond_end: # part after the requested range
c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end}))
return c
def pillow(fn, arg):
prev_value = None
try: