diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 4e85dfd9a..50d3ec5cd 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -307,24 +307,23 @@ class Kandinsky5(nn.Module): h_start += rope_options.get("shift_y", 0.0) w_start += rope_options.get("shift_x", 0.0) else: - # this is experimental as the original code only had two fixed scales for 512p and 1024p - if t == 1: # image model - self.rope_scale_factor = (1.0, 1.0, 1.0) - else: + if self.model_dim == 4096: # pro video model,this is experimental as the original code only had two fixed scales for 512p and 1024p spatial_size = h * w scale_16384 = (1.0, 3.16, 3.16) scale_9216 = (1.0, 2.0, 2.0) if spatial_size <= 6144: - self.rope_scale_factor = scale_9216 + rope_scale_factor = scale_9216 elif spatial_size >= 14080: - self.rope_scale_factor = scale_16384 + rope_scale_factor = scale_16384 else: t = (spatial_size - 14080) / (6144 - 14080) - self.rope_scale_factor = tuple(a + (b - a) * t for a, b in zip(scale_16384, scale_9216)) + rope_scale_factor = tuple(a + (b - a) * t for a, b in zip(scale_16384, scale_9216)) + else: + rope_scale_factor = self.rope_scale_factor - t_len = (t_len - 1.0) // self.rope_scale_factor[0] + 1.0 - h_len = (h_len - 1.0) // self.rope_scale_factor[1] + 1.0 - w_len = (w_len - 1.0) // self.rope_scale_factor[2] + 1.0 + t_len = (t_len - 1.0) // rope_scale_factor[0] + 1.0 + h_len = (h_len - 1.0) // rope_scale_factor[1] + 1.0 + w_len = (w_len - 1.0) // rope_scale_factor[2] + 1.0 img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) @@ -345,8 +344,8 @@ class Kandinsky5(nn.Module): visual_embed = self.visual_embeddings(x) visual_shape = visual_embed.shape[:-1] - visual_embed = visual_embed.flatten(1, -2) + blocks_replace = patches_replace.get("dit", {}) transformer_options["total_blocks"] = len(self.visual_transformer_blocks) transformer_options["block_type"] = "double"