diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 90c347d3d..77876c2e7 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -588,7 +588,7 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None): + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) @@ -601,10 +601,22 @@ class WanModel(torch.nn.Module): if steps_w is None: steps_w = w_len + h_start = 0 + w_start = 0 + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + t_len = t_len * rope_options.get("scale_t", 1.0) + h_len = h_len * rope_options.get("scale_y", 1.0) + w_len = w_len * rope_options.get("scale_x", 1.0) + + t_start += rope_options.get("shift_t", 0.0) + h_start += rope_options.get("shift_y", 0.0) + w_start += rope_options.get("shift_x", 0.0) + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) - img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) - img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) freqs = self.rope_embedder(img_ids).movedim(1, 2) @@ -630,7 +642,7 @@ class WanModel(torch.nn.Module): if self.ref_conv is not None and "reference_latent" in kwargs: t_len += 1 - freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 674a214ca..3e8799983 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -454,6 +454,19 @@ class ModelPatcher: def set_model_post_input_patch(self, patch): self.set_model_patch(patch, "post_input") + def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): + rope_options = self.model_options["transformer_options"].get("rope_options", {}) + rope_options["scale_x"] = scale_x + rope_options["scale_y"] = scale_y + rope_options["scale_t"] = scale_t + + rope_options["shift_x"] = shift_x + rope_options["shift_y"] = shift_y + rope_options["shift_t"] = shift_t + + self.model_options["transformer_options"]["rope_options"] = rope_options + + def add_object_patch(self, name, obj): self.object_patches[name] = obj diff --git a/comfy_extras/nodes_rope.py b/comfy_extras/nodes_rope.py new file mode 100644 index 000000000..d1feb031e --- /dev/null +++ b/comfy_extras/nodes_rope.py @@ -0,0 +1,47 @@ +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override + + +class ScaleROPE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ScaleROPE", + category="advanced/model_patches", + description="Scale and shift the ROPE of the model.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1), + + io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1), + + io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1), + + + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput: + m = model.clone() + m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) + return io.NodeOutput(m) + + +class RopeExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ScaleROPE + ] + + +async def comfy_entrypoint() -> RopeExtension: + return RopeExtension() diff --git a/nodes.py b/nodes.py index 12e365ca9..5689f6fe1 100644 --- a/nodes.py +++ b/nodes.py @@ -2329,6 +2329,7 @@ async def init_builtin_extra_nodes(): "nodes_model_patch.py", "nodes_easycache.py", "nodes_audio_encoder.py", + "nodes_rope.py", ] import_failed = []