mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 12:50:18 +08:00
Rever RoPE scaling to simpler one
This commit is contained in:
parent
e74db2404f
commit
f1a5f6f5b3
@ -307,23 +307,14 @@ class Kandinsky5(nn.Module):
|
|||||||
h_start += rope_options.get("shift_y", 0.0)
|
h_start += rope_options.get("shift_y", 0.0)
|
||||||
w_start += rope_options.get("shift_x", 0.0)
|
w_start += rope_options.get("shift_x", 0.0)
|
||||||
else:
|
else:
|
||||||
if self.model_dim == 4096: # pro video model,this is experimental as the original code only had two fixed scales for 512p and 1024p
|
rope_scale_factor = self.rope_scale_factor
|
||||||
spatial_size = h * w
|
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
|
||||||
scale_16384 = (1.0, 3.16, 3.16)
|
if h * w >= 14080:
|
||||||
scale_9216 = (1.0, 2.0, 2.0)
|
rope_scale_factor = (1.0, 3.16, 3.16)
|
||||||
if spatial_size <= 6144:
|
|
||||||
rope_scale_factor = scale_9216
|
|
||||||
elif spatial_size >= 14080:
|
|
||||||
rope_scale_factor = scale_16384
|
|
||||||
else:
|
|
||||||
t = (spatial_size - 14080) / (6144 - 14080)
|
|
||||||
rope_scale_factor = tuple(a + (b - a) * t for a, b in zip(scale_16384, scale_9216))
|
|
||||||
else:
|
|
||||||
rope_scale_factor = self.rope_scale_factor
|
|
||||||
|
|
||||||
t_len = (t_len - 1.0) // rope_scale_factor[0] + 1.0
|
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
|
||||||
h_len = (h_len - 1.0) // rope_scale_factor[1] + 1.0
|
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
|
||||||
w_len = (w_len - 1.0) // rope_scale_factor[2] + 1.0
|
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
|
||||||
|
|
||||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||||
|
|||||||
@ -1690,6 +1690,6 @@ class Kandinsky5_image(Kandinsky5):
|
|||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def process_latent_out(self, latent): # input is still 5D, return single frame to decode with Flux VAE
|
def process_latent_out(self, latent): # input is still 5D, return single frame to decode with Flux VAE
|
||||||
return self.latent_format.process_out(latent)[:, :, 0]
|
return self.latent_format.process_out(latent)[:, :, 0]
|
||||||
|
|||||||
@ -67,7 +67,7 @@ def adaptive_mean_std_normalization(source, reference):
|
|||||||
# normalization
|
# normalization
|
||||||
normalized = (source - source_mean) / (source_std + 1e-8)
|
normalized = (source - source_mean) / (source_std + 1e-8)
|
||||||
normalized = normalized * reference_std + reference_mean
|
normalized = normalized * reference_std + reference_mean
|
||||||
|
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
@ -97,9 +97,9 @@ class NormalizeVideoLatentFrames(io.ComfyNode):
|
|||||||
|
|
||||||
first_frames = samples[:, :, :frames_to_normalize]
|
first_frames = samples[:, :, :frames_to_normalize]
|
||||||
reference_frames_data = samples[:, :, frames_to_normalize:frames_to_normalize+min(reference_frames, samples.shape[2]-frames_to_normalize)]
|
reference_frames_data = samples[:, :, frames_to_normalize:frames_to_normalize+min(reference_frames, samples.shape[2]-frames_to_normalize)]
|
||||||
|
|
||||||
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
|
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
|
||||||
|
|
||||||
samples[:, :, :frames_to_normalize] = normalized_first_frames
|
samples[:, :, :frames_to_normalize] = normalized_first_frames
|
||||||
s["samples"] = samples
|
s["samples"] = samples
|
||||||
return io.NodeOutput(s)
|
return io.NodeOutput(s)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user