mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-04 08:47:35 +08:00
feat: Support SCAIL WanVideo model (#12614)
This commit is contained in:
parent
95e1059661
commit
1f6744162f
@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel):
|
|||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class SCAILWanModel(WanModel):
|
||||||
|
def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||||
|
super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||||
|
|
||||||
|
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||||
|
|
||||||
|
if reference_latent is not None:
|
||||||
|
x = torch.cat((reference_latent, x), dim=2)
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
transformer_options["grid_sizes"] = grid_sizes
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
scail_pose_seq_len = 0
|
||||||
|
if pose_latents is not None:
|
||||||
|
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||||
|
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||||
|
scail_pose_seq_len = scail_x.shape[1]
|
||||||
|
x = torch.cat([x, scail_x], dim=1)
|
||||||
|
del scail_x
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||||
|
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||||
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
context_img_len = None
|
||||||
|
if clip_fea is not None:
|
||||||
|
if self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.cat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
if scail_pose_seq_len > 0:
|
||||||
|
x = x[:, :-scail_pose_seq_len]
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
|
||||||
|
if reference_latent is not None:
|
||||||
|
x = x[:, :, reference_latent.shape[2]:]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
||||||
|
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
if pose_latents is None:
|
||||||
|
return main_freqs
|
||||||
|
|
||||||
|
ref_t_patches = 0
|
||||||
|
if reference_latent is not None:
|
||||||
|
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||||
|
|
||||||
|
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||||
|
|
||||||
|
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
|
||||||
|
h_scale = h / H_pose
|
||||||
|
w_scale = w / W_pose
|
||||||
|
|
||||||
|
# 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code
|
||||||
|
h_shift = (h_scale - 1) / 2
|
||||||
|
w_shift = (w_scale - 1) / 2
|
||||||
|
pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||||
|
pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options)
|
||||||
|
|
||||||
|
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||||
|
bs, c, t, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
|
if pose_latents is not None:
|
||||||
|
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||||
|
|
||||||
|
t_len = t
|
||||||
|
if time_dim_concat is not None:
|
||||||
|
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||||
|
x = torch.cat([x, time_dim_concat], dim=2)
|
||||||
|
t_len = x.shape[2]
|
||||||
|
|
||||||
|
reference_latent = None
|
||||||
|
if "reference_latent" in kwargs:
|
||||||
|
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||||
|
t_len += reference_latent.shape[2]
|
||||||
|
|
||||||
|
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||||
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|||||||
@ -1502,6 +1502,44 @@ class WAN21_FlowRVS(WAN21):
|
|||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
self.image_to_video = image_to_video
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
class WAN21_SCAIL(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
|
||||||
|
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
|
||||||
|
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
|
||||||
|
reference_latents = kwargs.get("reference_latents", None)
|
||||||
|
if reference_latents is not None:
|
||||||
|
ref_latent = self.process_latent_in(reference_latents[-1])
|
||||||
|
ref_mask = torch.ones_like(ref_latent[:, :4])
|
||||||
|
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
|
||||||
|
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
|
||||||
|
|
||||||
|
pose_latents = kwargs.get("pose_video_latent", None)
|
||||||
|
if pose_latents is not None:
|
||||||
|
pose_latents = self.process_latent_in(pose_latents)
|
||||||
|
pose_mask = torch.ones_like(pose_latents[:, :4])
|
||||||
|
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
|
||||||
|
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
|
|
||||||
|
pose_latents = kwargs.get("pose_video_latent", None)
|
||||||
|
if pose_latents is not None:
|
||||||
|
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
|||||||
@ -498,6 +498,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["model_type"] = "humo"
|
dit_config["model_type"] = "humo"
|
||||||
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "animate"
|
dit_config["model_type"] = "animate"
|
||||||
|
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "scail"
|
||||||
else:
|
else:
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
|
|||||||
@ -1268,6 +1268,16 @@ class WAN21_FlowRVS(WAN21_T2V):
|
|||||||
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
|
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN21_SCAIL(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "scail",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@ -1710,6 +1720,6 @@ class LongCatImage(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -1456,6 +1456,63 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
|||||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||||
|
|
||||||
|
|
||||||
|
class WanSCAILToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="WanSCAILToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
|
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||||
|
io.Image.Input("reference_image", optional=True),
|
||||||
|
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
|
||||||
|
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
|
||||||
|
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
|
||||||
|
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
ref_latent = None
|
||||||
|
if reference_image is not None:
|
||||||
|
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
ref_latent = vae.encode(reference_image[:, :, :, :3])
|
||||||
|
|
||||||
|
if ref_latent is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
|
||||||
|
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||||
|
|
||||||
|
if pose_video is not None:
|
||||||
|
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||||
|
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
|
||||||
|
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||||
|
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class WanExtension(ComfyExtension):
|
class WanExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -1476,6 +1533,7 @@ class WanExtension(ComfyExtension):
|
|||||||
WanAnimateToVideo,
|
WanAnimateToVideo,
|
||||||
Wan22ImageToVideoLatent,
|
Wan22ImageToVideoLatent,
|
||||||
WanInfiniteTalkToVideo,
|
WanInfiniteTalkToVideo,
|
||||||
|
WanSCAILToVideo,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> WanExtension:
|
async def comfy_entrypoint() -> WanExtension:
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False):
|
|||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0):
|
||||||
|
"""
|
||||||
|
Apply values to conditioning only during [start_percent, end_percent], keeping the
|
||||||
|
original conditioning active outside that range. Respects existing per-entry ranges.
|
||||||
|
"""
|
||||||
|
if start_percent > end_percent:
|
||||||
|
logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})")
|
||||||
|
return conditioning
|
||||||
|
|
||||||
|
EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
cond_start = t[1].get("start_percent", 0.0)
|
||||||
|
cond_end = t[1].get("end_percent", 1.0)
|
||||||
|
intersect_start = max(start_percent, cond_start)
|
||||||
|
intersect_end = min(end_percent, cond_end)
|
||||||
|
|
||||||
|
if intersect_start >= intersect_end: # no overlap: emit unchanged
|
||||||
|
c.append(t)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if intersect_start > cond_start: # part before the requested range
|
||||||
|
c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS}))
|
||||||
|
|
||||||
|
c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end}))
|
||||||
|
|
||||||
|
if intersect_end < cond_end: # part after the requested range
|
||||||
|
c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end}))
|
||||||
|
return c
|
||||||
|
|
||||||
def pillow(fn, arg):
|
def pillow(fn, arg):
|
||||||
prev_value = None
|
prev_value = None
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user