mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Do batch_slice in EasyCache's apply_cache_diff (#10376)
This commit is contained in:
parent
b1293d50ef
commit
d8d60b5609
@ -244,6 +244,8 @@ class EasyCacheHolder:
|
||||
self.total_steps_skipped += 1
|
||||
batch_offset = x.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
# slice out only what is relevant to this cond
|
||||
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
|
||||
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
@ -261,9 +263,8 @@ class EasyCacheHolder:
|
||||
slicing.append(slice(None, dim_u))
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
|
||||
x = x[slicing]
|
||||
x += self.uuid_cache_diffs[uuid].to(x.device)
|
||||
batch_slice = batch_slice + slicing
|
||||
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device)
|
||||
return x
|
||||
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user