mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Fix for edge case of EasyCache when conditionings change during a sampling run (like with timestep scheduling) (#12020)
This commit is contained in:
parent
abe2ec26a6
commit
f09904720d
@ -29,8 +29,10 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
do_easycache = easycache.should_do_easycache(sigmas)
|
do_easycache = easycache.should_do_easycache(sigmas)
|
||||||
if do_easycache:
|
if do_easycache:
|
||||||
easycache.check_metadata(x)
|
easycache.check_metadata(x)
|
||||||
|
# if there isn't a cache diff for current conds, we cannot skip this step
|
||||||
|
can_apply_cache_diff = easycache.can_apply_cache_diff(uuids)
|
||||||
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
||||||
if easycache.skip_current_step:
|
if easycache.skip_current_step and can_apply_cache_diff:
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||||
return easycache.apply_cache_diff(x, uuids)
|
return easycache.apply_cache_diff(x, uuids)
|
||||||
@ -44,7 +46,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||||
easycache.cumulative_change_rate += approx_output_change_rate
|
easycache.cumulative_change_rate += approx_output_change_rate
|
||||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff:
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||||
# other conds should also skip this step, and instead use their cached values
|
# other conds should also skip this step, and instead use their cached values
|
||||||
@ -240,6 +242,9 @@ class EasyCacheHolder:
|
|||||||
return to_return.clone()
|
return to_return.clone()
|
||||||
return to_return
|
return to_return
|
||||||
|
|
||||||
|
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
|
||||||
|
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
|
||||||
|
|
||||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
||||||
if self.first_cond_uuid in uuids:
|
if self.first_cond_uuid in uuids:
|
||||||
self.total_steps_skipped += 1
|
self.total_steps_skipped += 1
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user