diff --git a/comfy/model_base.py b/comfy/model_base.py index cc278a770..970c56e37 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1507,23 +1507,8 @@ class WAN21_SCAIL(WAN21): pose_latents = torch.cat([pose_latents, pose_mask], dim=1) out['pose_latents'] = comfy.conds.CONDRegular(pose_latents) - pose_start = kwargs.get("pose_start", 0.0) - pose_end = kwargs.get("pose_end", 1.0) - out['pose_start'] = comfy.conds.CONDConstant(pose_start) - out['pose_end'] = comfy.conds.CONDConstant(pose_end) - return out - def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, pose_start=0.0, pose_end=1.0, **kwargs): - if t[0] >= self.model_sampling.percent_to_sigma(pose_start) or t[0] <= self.model_sampling.percent_to_sigma(pose_end): - kwargs.pop("pose_latents", None) - - return comfy.patcher_extension.WrapperExecutor.new_class_executor( - self._apply_model, - self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options) - ).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) - def extra_conds_shapes(self, **kwargs): out = {} ref_latents = kwargs.get("reference_latents", None) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 9e7aab3e3..e50bfcd2c 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1505,8 +1505,8 @@ class WanSCAILToVideo(io.ComfyNode): if pose_video is not None: pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength - positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent, "pose_start": pose_start, "pose_end": pose_end}) - negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent, "pose_start": pose_start, "pose_end": pose_end}) + positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) out_latent = {} out_latent["samples"] = latent diff --git a/node_helpers.py b/node_helpers.py index 4ff960ef8..d3d834516 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -1,5 +1,6 @@ import hashlib import torch +import logging from comfy.cli_args import args @@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False): return c +def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0): + """ + Apply values to conditioning only during [start_percent, end_percent], keeping the + original conditioning active outside that range. Respects existing per-entry ranges. + """ + if start_percent > end_percent: + logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})") + return conditioning + + EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep + c = [] + for t in conditioning: + cond_start = t[1].get("start_percent", 0.0) + cond_end = t[1].get("end_percent", 1.0) + intersect_start = max(start_percent, cond_start) + intersect_end = min(end_percent, cond_end) + + if intersect_start >= intersect_end: # no overlap: emit unchanged + c.append(t) + continue + + if intersect_start > cond_start: # part before the requested range + c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS})) + + c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end})) + + if intersect_end < cond_end: # part after the requested range + c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end})) + return c + def pillow(fn, arg): prev_value = None try: