EasyCache: Fix for mismatch in input/output channels with some models (#10788)

Slices model input with output channels so the caching tracks only the noise channels, resolves channel mismatch with models like WanVideo I2V

Also fix for slicing deprecation in pytorch 2.9
This commit is contained in:
Jukka Seppänen 2025-11-18 17:00:21 +02:00 committed by GitHub
parent 048f49adbd
commit e1ab6bb394
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,13 +11,13 @@ if TYPE_CHECKING:
def easycache_forward_wrapper(executor, *args, **kwargs): def easycache_forward_wrapper(executor, *args, **kwargs):
# get values from args # get values from args
x: torch.Tensor = args[0]
transformer_options: dict[str] = args[-1] transformer_options: dict[str] = args[-1]
if not isinstance(transformer_options, dict): if not isinstance(transformer_options, dict):
transformer_options = kwargs.get("transformer_options") transformer_options = kwargs.get("transformer_options")
if not transformer_options: if not transformer_options:
transformer_options = args[-2] transformer_options = args[-2]
easycache: EasyCacheHolder = transformer_options["easycache"] easycache: EasyCacheHolder = transformer_options["easycache"]
x: torch.Tensor = args[0][:, :easycache.output_channels]
sigmas = transformer_options["sigmas"] sigmas = transformer_options["sigmas"]
uuids = transformer_options["uuids"] uuids = transformer_options["uuids"]
if sigmas is not None and easycache.is_past_end_timestep(sigmas): if sigmas is not None and easycache.is_past_end_timestep(sigmas):
@ -82,13 +82,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
def lazycache_predict_noise_wrapper(executor, *args, **kwargs): def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
# get values from args # get values from args
x: torch.Tensor = args[0]
timestep: float = args[1] timestep: float = args[1]
model_options: dict[str] = args[2] model_options: dict[str] = args[2]
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
if easycache.is_past_end_timestep(timestep): if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs) return executor(*args, **kwargs)
# prepare next x_prev # prepare next x_prev
x: torch.Tensor = args[0][:, :easycache.output_channels]
next_x_prev = x next_x_prev = x
input_change = None input_change = None
do_easycache = easycache.should_do_easycache(timestep) do_easycache = easycache.should_do_easycache(timestep)
@ -173,7 +173,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
class EasyCacheHolder: class EasyCacheHolder:
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
self.name = "EasyCache" self.name = "EasyCache"
self.reuse_threshold = reuse_threshold self.reuse_threshold = reuse_threshold
self.start_percent = start_percent self.start_percent = start_percent
@ -202,6 +202,7 @@ class EasyCacheHolder:
self.allow_mismatch = True self.allow_mismatch = True
self.cut_from_start = True self.cut_from_start = True
self.state_metadata = None self.state_metadata = None
self.output_channels = output_channels
def is_past_end_timestep(self, timestep: float) -> bool: def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item() return not (timestep[0] > self.end_t).item()
@ -264,7 +265,7 @@ class EasyCacheHolder:
else: else:
slicing.append(slice(None)) slicing.append(slice(None))
batch_slice = batch_slice + slicing batch_slice = batch_slice + slicing
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device) x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
return x return x
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
@ -283,7 +284,7 @@ class EasyCacheHolder:
else: else:
slicing.append(slice(None)) slicing.append(slice(None))
skip_dim = False skip_dim = False
x = x[slicing] x = x[tuple(slicing)]
diff = output - x diff = output - x
batch_offset = diff.shape[0] // len(uuids) batch_offset = diff.shape[0] // len(uuids)
for i, uuid in enumerate(uuids): for i, uuid in enumerate(uuids):
@ -323,7 +324,7 @@ class EasyCacheHolder:
return self return self
def clone(self): def clone(self):
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
class EasyCacheNode(io.ComfyNode): class EasyCacheNode(io.ComfyNode):
@ -350,7 +351,7 @@ class EasyCacheNode(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
model = model.clone() model = model.clone()
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
@ -358,7 +359,7 @@ class EasyCacheNode(io.ComfyNode):
class LazyCacheHolder: class LazyCacheHolder:
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
self.name = "LazyCache" self.name = "LazyCache"
self.reuse_threshold = reuse_threshold self.reuse_threshold = reuse_threshold
self.start_percent = start_percent self.start_percent = start_percent
@ -382,6 +383,7 @@ class LazyCacheHolder:
self.approx_output_change_rates = [] self.approx_output_change_rates = []
self.total_steps_skipped = 0 self.total_steps_skipped = 0
self.state_metadata = None self.state_metadata = None
self.output_channels = output_channels
def has_cache_diff(self) -> bool: def has_cache_diff(self) -> bool:
return self.cache_diff is not None return self.cache_diff is not None
@ -456,7 +458,7 @@ class LazyCacheHolder:
return self return self
def clone(self): def clone(self):
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
class LazyCacheNode(io.ComfyNode): class LazyCacheNode(io.ComfyNode):
@classmethod @classmethod
@ -482,7 +484,7 @@ class LazyCacheNode(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
model = model.clone() model = model.clone()
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
return io.NodeOutput(model) return io.NodeOutput(model)