diff --git a/comfy/ldm/helios/model.py b/comfy/ldm/helios/model.py index c1ea5f595..2faeea897 100644 --- a/comfy/ldm/helios/model.py +++ b/comfy/ldm/helios/model.py @@ -711,19 +711,9 @@ class HeliosModel(torch.nn.Module): ) f_long = self._rope_downsample_3d(f_long, (long_t, hs, ws), (4, 4, 4)) hidden_states = torch.cat([x_long, hidden_states], dim=1) - freqs = torch.cat([f_long, freqs], dim=1) + freqs = torch.cat([f_long, freqs], dim=1) history_context_length = hidden_states.shape[1] - original_context_length - mismatch = hidden_states.shape[1] != freqs.shape[1] - summary_key = ( - int(post_t), - int(post_h), - int(post_w), - int(original_context_length), - int(hidden_states.shape[1]), - int(freqs.shape[1]), - int(history_context_length), - ) if timestep.ndim == 0: timestep = timestep.unsqueeze(0) @@ -770,28 +760,28 @@ class HeliosModel(torch.nn.Module): def unpatchify(self, x, grid_sizes): """ Unpatchify the output from proj_out back to video format. - + Args: x: [batch, num_patches, out_dim * prod(patch_size)] grid_sizes: (num_frames, height, width) in patch space - + Returns: [batch, out_dim, num_frames, height, width] in pixel space """ b = x.shape[0] post_t, post_h, post_w = grid_sizes p_t, p_h, p_w = self.patch_size - + # Reshape: [B, T*H*W, out_dim*p_t*p_h*p_w] -> [B, T, H, W, p_t, p_h, p_w, out_dim] # Use -1 to let PyTorch infer the channel dimension (out_dim) hidden_states = x.reshape(b, post_t, post_h, post_w, p_t, p_h, p_w, -1) - + # Permute: [B, T, H, W, p_t, p_h, p_w, C] -> [B, C, T, p_t, H, p_h, W, p_w] hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) - + # Flatten patches: [B, C, T, p_t, H, p_h, W, p_w] -> [B, C, T*p_t, H*p_h, W*p_w] output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - + return output def _rope_downsample_3d(self, freqs, grid_sizes, kernel_size): b, _, one, d, i2, j2 = freqs.shape diff --git a/comfy_extras/nodes_helios.py b/comfy_extras/nodes_helios.py index 89a63c3fd..2fcf0a64f 100644 --- a/comfy_extras/nodes_helios.py +++ b/comfy_extras/nodes_helios.py @@ -412,7 +412,6 @@ def _helios_dmd_sample( for i in range(len(sigmas) - 1): sigma = sigmas[i] - sigma_next = sigmas[i + 1] timestep = all_timesteps[i] if i < len(all_timesteps) else i denoised = model(x, sigma * s_in, **extra_args)