mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
Do batch_slice in EasyCache's apply_cache_diff (#10376)
This commit is contained in:
parent
b1293d50ef
commit
d8d60b5609
@ -244,6 +244,8 @@ class EasyCacheHolder:
|
|||||||
self.total_steps_skipped += 1
|
self.total_steps_skipped += 1
|
||||||
batch_offset = x.shape[0] // len(uuids)
|
batch_offset = x.shape[0] // len(uuids)
|
||||||
for i, uuid in enumerate(uuids):
|
for i, uuid in enumerate(uuids):
|
||||||
|
# slice out only what is relevant to this cond
|
||||||
|
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
|
||||||
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||||
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
||||||
if not self.allow_mismatch:
|
if not self.allow_mismatch:
|
||||||
@ -261,9 +263,8 @@ class EasyCacheHolder:
|
|||||||
slicing.append(slice(None, dim_u))
|
slicing.append(slice(None, dim_u))
|
||||||
else:
|
else:
|
||||||
slicing.append(slice(None))
|
slicing.append(slice(None))
|
||||||
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
|
batch_slice = batch_slice + slicing
|
||||||
x = x[slicing]
|
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device)
|
||||||
x += self.uuid_cache_diffs[uuid].to(x.device)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user