diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 90d730df6..51d1e5b9c 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -9,6 +9,14 @@ if TYPE_CHECKING: from uuid import UUID +def _extract_tensor(data, output_channels): + """Extract tensor from data, handling both single tensors and lists.""" + if isinstance(data, list): + # LTX2 AV tensors: [video, audio] + return data[0][:, :output_channels], data[1][:, :output_channels] + return data[:, :output_channels], None + + def easycache_forward_wrapper(executor, *args, **kwargs): # get values from args transformer_options: dict[str] = args[-1] @@ -17,7 +25,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if not transformer_options: transformer_options = args[-2] easycache: EasyCacheHolder = transformer_options["easycache"] - x: torch.Tensor = args[0][:, :easycache.output_channels] + x, ax = _extract_tensor(args[0], easycache.output_channels) sigmas = transformer_options["sigmas"] uuids = transformer_options["uuids"] if sigmas is not None and easycache.is_past_end_timestep(sigmas): @@ -35,7 +43,11 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if easycache.skip_current_step and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") - return easycache.apply_cache_diff(x, uuids) + result = easycache.apply_cache_diff(x, uuids) + if ax is not None: + result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True) + return [result, result_audio] + return result if easycache.initial_step: easycache.first_cond_uuid = uuids[0] has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) @@ -51,13 +63,18 @@ def easycache_forward_wrapper(executor, *args, **kwargs): logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") # other conds should also skip this step, and instead use their cached values easycache.skip_current_step = True - return easycache.apply_cache_diff(x, uuids) + result = easycache.apply_cache_diff(x, uuids) + if ax is not None: + result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True) + return [result, result_audio] + return result else: if easycache.verbose: logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") easycache.cumulative_change_rate = 0.0 - output: torch.Tensor = executor(*args, **kwargs) + full_output: torch.Tensor = executor(*args, **kwargs) + output, audio_output = _extract_tensor(full_output, easycache.output_channels) if has_first_cond_uuid and easycache.has_output_prev_norm(): output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() if easycache.verbose: @@ -74,13 +91,15 @@ def easycache_forward_wrapper(executor, *args, **kwargs): logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}") # TODO: allow cache_diff to be offloaded easycache.update_cache_diff(output, next_x_prev, uuids) + if audio_output is not None: + easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True) if has_first_cond_uuid: easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids) easycache.output_prev_subsampled = easycache.subsample(output, uuids) easycache.output_prev_norm = output.flatten().abs().mean() if easycache.verbose: logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") - return output + return full_output def lazycache_predict_noise_wrapper(executor, *args, **kwargs): # get values from args @@ -89,8 +108,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs): easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] if easycache.is_past_end_timestep(timestep): return executor(*args, **kwargs) + x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels) # prepare next x_prev - x: torch.Tensor = args[0][:, :easycache.output_channels] next_x_prev = x input_change = None do_easycache = easycache.should_do_easycache(timestep) @@ -197,6 +216,7 @@ class EasyCacheHolder: self.output_prev_subsampled: torch.Tensor = None self.output_prev_norm: torch.Tensor = None self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {} + self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {} self.output_change_rates = [] self.approx_output_change_rates = [] self.total_steps_skipped = 0 @@ -245,20 +265,21 @@ class EasyCacheHolder: def can_apply_cache_diff(self, uuids: list[UUID]) -> bool: return all(uuid in self.uuid_cache_diffs for uuid in uuids) - def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): - if self.first_cond_uuid in uuids: + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False): + if self.first_cond_uuid in uuids and not is_audio: self.total_steps_skipped += 1 + cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs batch_offset = x.shape[0] // len(uuids) for i, uuid in enumerate(uuids): # slice out only what is relevant to this cond batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)] # if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video) - if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]: + if x.shape[1:] != cache_diffs[uuid].shape[1:]: if not self.allow_mismatch: raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good") slicing = [] skip_this_dim = True - for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape): + for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape): if skip_this_dim: skip_this_dim = False continue @@ -270,10 +291,11 @@ class EasyCacheHolder: else: slicing.append(slice(None)) batch_slice = batch_slice + slicing - x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device) + x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device) return x - def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False): + cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs # if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video) if output.shape[1:] != x.shape[1:]: if not self.allow_mismatch: @@ -293,7 +315,7 @@ class EasyCacheHolder: diff = output - x batch_offset = diff.shape[0] // len(uuids) for i, uuid in enumerate(uuids): - self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] + cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] def has_first_cond_uuid(self, uuids: list[UUID]) -> bool: return self.first_cond_uuid in uuids @@ -324,6 +346,8 @@ class EasyCacheHolder: self.output_prev_norm = None del self.uuid_cache_diffs self.uuid_cache_diffs = {} + del self.uuid_cache_diffs_audio + self.uuid_cache_diffs_audio = {} self.total_steps_skipped = 0 self.state_metadata = None return self