mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
Merge ebe2e774e7 into 85fc35e8fa
This commit is contained in:
commit
fea35e29be
@ -9,6 +9,14 @@ if TYPE_CHECKING:
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tensor(data, output_channels):
|
||||||
|
"""Extract tensor from data, handling both single tensors and lists."""
|
||||||
|
if isinstance(data, list):
|
||||||
|
# LTX2 AV tensors: [video, audio]
|
||||||
|
return data[0][:, :output_channels], data[1][:, :output_channels]
|
||||||
|
return data[:, :output_channels], None
|
||||||
|
|
||||||
|
|
||||||
def easycache_forward_wrapper(executor, *args, **kwargs):
|
def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||||
# get values from args
|
# get values from args
|
||||||
transformer_options: dict[str] = args[-1]
|
transformer_options: dict[str] = args[-1]
|
||||||
@ -17,7 +25,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
if not transformer_options:
|
if not transformer_options:
|
||||||
transformer_options = args[-2]
|
transformer_options = args[-2]
|
||||||
easycache: EasyCacheHolder = transformer_options["easycache"]
|
easycache: EasyCacheHolder = transformer_options["easycache"]
|
||||||
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
x, ax = _extract_tensor(args[0], easycache.output_channels)
|
||||||
sigmas = transformer_options["sigmas"]
|
sigmas = transformer_options["sigmas"]
|
||||||
uuids = transformer_options["uuids"]
|
uuids = transformer_options["uuids"]
|
||||||
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
||||||
@ -35,7 +43,11 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
if easycache.skip_current_step and can_apply_cache_diff:
|
if easycache.skip_current_step and can_apply_cache_diff:
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||||
return easycache.apply_cache_diff(x, uuids)
|
result = easycache.apply_cache_diff(x, uuids)
|
||||||
|
if ax is not None:
|
||||||
|
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
|
||||||
|
return [result, result_audio]
|
||||||
|
return result
|
||||||
if easycache.initial_step:
|
if easycache.initial_step:
|
||||||
easycache.first_cond_uuid = uuids[0]
|
easycache.first_cond_uuid = uuids[0]
|
||||||
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
||||||
@ -51,13 +63,18 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||||
# other conds should also skip this step, and instead use their cached values
|
# other conds should also skip this step, and instead use their cached values
|
||||||
easycache.skip_current_step = True
|
easycache.skip_current_step = True
|
||||||
return easycache.apply_cache_diff(x, uuids)
|
result = easycache.apply_cache_diff(x, uuids)
|
||||||
|
if ax is not None:
|
||||||
|
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
|
||||||
|
return [result, result_audio]
|
||||||
|
return result
|
||||||
else:
|
else:
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||||
easycache.cumulative_change_rate = 0.0
|
easycache.cumulative_change_rate = 0.0
|
||||||
|
|
||||||
output: torch.Tensor = executor(*args, **kwargs)
|
full_output: torch.Tensor = executor(*args, **kwargs)
|
||||||
|
output, audio_output = _extract_tensor(full_output, easycache.output_channels)
|
||||||
if has_first_cond_uuid and easycache.has_output_prev_norm():
|
if has_first_cond_uuid and easycache.has_output_prev_norm():
|
||||||
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
@ -74,13 +91,15 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||||
# TODO: allow cache_diff to be offloaded
|
# TODO: allow cache_diff to be offloaded
|
||||||
easycache.update_cache_diff(output, next_x_prev, uuids)
|
easycache.update_cache_diff(output, next_x_prev, uuids)
|
||||||
|
if audio_output is not None:
|
||||||
|
easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True)
|
||||||
if has_first_cond_uuid:
|
if has_first_cond_uuid:
|
||||||
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
|
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
|
||||||
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
||||||
easycache.output_prev_norm = output.flatten().abs().mean()
|
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||||
return output
|
return full_output
|
||||||
|
|
||||||
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||||
# get values from args
|
# get values from args
|
||||||
@ -89,8 +108,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
|||||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||||
if easycache.is_past_end_timestep(timestep):
|
if easycache.is_past_end_timestep(timestep):
|
||||||
return executor(*args, **kwargs)
|
return executor(*args, **kwargs)
|
||||||
|
x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels)
|
||||||
# prepare next x_prev
|
# prepare next x_prev
|
||||||
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
|
||||||
next_x_prev = x
|
next_x_prev = x
|
||||||
input_change = None
|
input_change = None
|
||||||
do_easycache = easycache.should_do_easycache(timestep)
|
do_easycache = easycache.should_do_easycache(timestep)
|
||||||
@ -197,6 +216,7 @@ class EasyCacheHolder:
|
|||||||
self.output_prev_subsampled: torch.Tensor = None
|
self.output_prev_subsampled: torch.Tensor = None
|
||||||
self.output_prev_norm: torch.Tensor = None
|
self.output_prev_norm: torch.Tensor = None
|
||||||
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
|
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
|
||||||
|
self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {}
|
||||||
self.output_change_rates = []
|
self.output_change_rates = []
|
||||||
self.approx_output_change_rates = []
|
self.approx_output_change_rates = []
|
||||||
self.total_steps_skipped = 0
|
self.total_steps_skipped = 0
|
||||||
@ -245,20 +265,21 @@ class EasyCacheHolder:
|
|||||||
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
|
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
|
||||||
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
|
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
|
||||||
|
|
||||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
|
||||||
if self.first_cond_uuid in uuids:
|
if self.first_cond_uuid in uuids and not is_audio:
|
||||||
self.total_steps_skipped += 1
|
self.total_steps_skipped += 1
|
||||||
|
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
|
||||||
batch_offset = x.shape[0] // len(uuids)
|
batch_offset = x.shape[0] // len(uuids)
|
||||||
for i, uuid in enumerate(uuids):
|
for i, uuid in enumerate(uuids):
|
||||||
# slice out only what is relevant to this cond
|
# slice out only what is relevant to this cond
|
||||||
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
|
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
|
||||||
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||||
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
if x.shape[1:] != cache_diffs[uuid].shape[1:]:
|
||||||
if not self.allow_mismatch:
|
if not self.allow_mismatch:
|
||||||
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
|
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
|
||||||
slicing = []
|
slicing = []
|
||||||
skip_this_dim = True
|
skip_this_dim = True
|
||||||
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
|
for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape):
|
||||||
if skip_this_dim:
|
if skip_this_dim:
|
||||||
skip_this_dim = False
|
skip_this_dim = False
|
||||||
continue
|
continue
|
||||||
@ -270,10 +291,11 @@ class EasyCacheHolder:
|
|||||||
else:
|
else:
|
||||||
slicing.append(slice(None))
|
slicing.append(slice(None))
|
||||||
batch_slice = batch_slice + slicing
|
batch_slice = batch_slice + slicing
|
||||||
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
|
x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
|
||||||
|
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
|
||||||
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||||
if output.shape[1:] != x.shape[1:]:
|
if output.shape[1:] != x.shape[1:]:
|
||||||
if not self.allow_mismatch:
|
if not self.allow_mismatch:
|
||||||
@ -293,7 +315,7 @@ class EasyCacheHolder:
|
|||||||
diff = output - x
|
diff = output - x
|
||||||
batch_offset = diff.shape[0] // len(uuids)
|
batch_offset = diff.shape[0] // len(uuids)
|
||||||
for i, uuid in enumerate(uuids):
|
for i, uuid in enumerate(uuids):
|
||||||
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
||||||
|
|
||||||
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
||||||
return self.first_cond_uuid in uuids
|
return self.first_cond_uuid in uuids
|
||||||
@ -324,6 +346,8 @@ class EasyCacheHolder:
|
|||||||
self.output_prev_norm = None
|
self.output_prev_norm = None
|
||||||
del self.uuid_cache_diffs
|
del self.uuid_cache_diffs
|
||||||
self.uuid_cache_diffs = {}
|
self.uuid_cache_diffs = {}
|
||||||
|
del self.uuid_cache_diffs_audio
|
||||||
|
self.uuid_cache_diffs_audio = {}
|
||||||
self.total_steps_skipped = 0
|
self.total_steps_skipped = 0
|
||||||
self.state_metadata = None
|
self.state_metadata = None
|
||||||
return self
|
return self
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user