mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
EasyCache: Fix for mismatch in input/output channels with some models (#10788)
Slices model input with output channels so the caching tracks only the noise channels, resolves channel mismatch with models like WanVideo I2V Also fix for slicing deprecation in pytorch 2.9
This commit is contained in:
parent
048f49adbd
commit
e1ab6bb394
@ -11,13 +11,13 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def easycache_forward_wrapper(executor, *args, **kwargs):
|
def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||||
# get values from args
|
# get values from args
|
||||||
x: torch.Tensor = args[0]
|
|
||||||
transformer_options: dict[str] = args[-1]
|
transformer_options: dict[str] = args[-1]
|
||||||
if not isinstance(transformer_options, dict):
|
if not isinstance(transformer_options, dict):
|
||||||
transformer_options = kwargs.get("transformer_options")
|
transformer_options = kwargs.get("transformer_options")
|
||||||
if not transformer_options:
|
if not transformer_options:
|
||||||
transformer_options = args[-2]
|
transformer_options = args[-2]
|
||||||
easycache: EasyCacheHolder = transformer_options["easycache"]
|
easycache: EasyCacheHolder = transformer_options["easycache"]
|
||||||
|
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||||
sigmas = transformer_options["sigmas"]
|
sigmas = transformer_options["sigmas"]
|
||||||
uuids = transformer_options["uuids"]
|
uuids = transformer_options["uuids"]
|
||||||
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
||||||
@ -82,13 +82,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
|
|
||||||
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||||
# get values from args
|
# get values from args
|
||||||
x: torch.Tensor = args[0]
|
|
||||||
timestep: float = args[1]
|
timestep: float = args[1]
|
||||||
model_options: dict[str] = args[2]
|
model_options: dict[str] = args[2]
|
||||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||||
if easycache.is_past_end_timestep(timestep):
|
if easycache.is_past_end_timestep(timestep):
|
||||||
return executor(*args, **kwargs)
|
return executor(*args, **kwargs)
|
||||||
# prepare next x_prev
|
# prepare next x_prev
|
||||||
|
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||||
next_x_prev = x
|
next_x_prev = x
|
||||||
input_change = None
|
input_change = None
|
||||||
do_easycache = easycache.should_do_easycache(timestep)
|
do_easycache = easycache.should_do_easycache(timestep)
|
||||||
@ -173,7 +173,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
class EasyCacheHolder:
|
class EasyCacheHolder:
|
||||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
||||||
self.name = "EasyCache"
|
self.name = "EasyCache"
|
||||||
self.reuse_threshold = reuse_threshold
|
self.reuse_threshold = reuse_threshold
|
||||||
self.start_percent = start_percent
|
self.start_percent = start_percent
|
||||||
@ -202,6 +202,7 @@ class EasyCacheHolder:
|
|||||||
self.allow_mismatch = True
|
self.allow_mismatch = True
|
||||||
self.cut_from_start = True
|
self.cut_from_start = True
|
||||||
self.state_metadata = None
|
self.state_metadata = None
|
||||||
|
self.output_channels = output_channels
|
||||||
|
|
||||||
def is_past_end_timestep(self, timestep: float) -> bool:
|
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||||
return not (timestep[0] > self.end_t).item()
|
return not (timestep[0] > self.end_t).item()
|
||||||
@ -264,7 +265,7 @@ class EasyCacheHolder:
|
|||||||
else:
|
else:
|
||||||
slicing.append(slice(None))
|
slicing.append(slice(None))
|
||||||
batch_slice = batch_slice + slicing
|
batch_slice = batch_slice + slicing
|
||||||
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device)
|
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||||
@ -283,7 +284,7 @@ class EasyCacheHolder:
|
|||||||
else:
|
else:
|
||||||
slicing.append(slice(None))
|
slicing.append(slice(None))
|
||||||
skip_dim = False
|
skip_dim = False
|
||||||
x = x[slicing]
|
x = x[tuple(slicing)]
|
||||||
diff = output - x
|
diff = output - x
|
||||||
batch_offset = diff.shape[0] // len(uuids)
|
batch_offset = diff.shape[0] // len(uuids)
|
||||||
for i, uuid in enumerate(uuids):
|
for i, uuid in enumerate(uuids):
|
||||||
@ -323,7 +324,7 @@ class EasyCacheHolder:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
||||||
|
|
||||||
|
|
||||||
class EasyCacheNode(io.ComfyNode):
|
class EasyCacheNode(io.ComfyNode):
|
||||||
@ -350,7 +351,7 @@ class EasyCacheNode(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
||||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
|
||||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
|
||||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
|
||||||
@ -358,7 +359,7 @@ class EasyCacheNode(io.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class LazyCacheHolder:
|
class LazyCacheHolder:
|
||||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
||||||
self.name = "LazyCache"
|
self.name = "LazyCache"
|
||||||
self.reuse_threshold = reuse_threshold
|
self.reuse_threshold = reuse_threshold
|
||||||
self.start_percent = start_percent
|
self.start_percent = start_percent
|
||||||
@ -382,6 +383,7 @@ class LazyCacheHolder:
|
|||||||
self.approx_output_change_rates = []
|
self.approx_output_change_rates = []
|
||||||
self.total_steps_skipped = 0
|
self.total_steps_skipped = 0
|
||||||
self.state_metadata = None
|
self.state_metadata = None
|
||||||
|
self.output_channels = output_channels
|
||||||
|
|
||||||
def has_cache_diff(self) -> bool:
|
def has_cache_diff(self) -> bool:
|
||||||
return self.cache_diff is not None
|
return self.cache_diff is not None
|
||||||
@ -456,7 +458,7 @@ class LazyCacheHolder:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
||||||
|
|
||||||
class LazyCacheNode(io.ComfyNode):
|
class LazyCacheNode(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -482,7 +484,7 @@ class LazyCacheNode(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
||||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
|
||||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user