mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 02:53:05 +08:00
Refactored FlipFlopTransformer.__call__ to fully separate out actions between flip and flop
This commit is contained in:
parent
84e73f2aa5
commit
f083720eb4
@ -110,6 +110,58 @@ class FlipFlopTransformer:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, **feed_dict):
|
def __call__(self, **feed_dict):
|
||||||
|
'''
|
||||||
|
Flip accounts for even blocks (0 is first block), flop accounts for odd blocks.
|
||||||
|
'''
|
||||||
|
# separated flip flop refactor
|
||||||
|
first_flip = True
|
||||||
|
first_flop = True
|
||||||
|
last_flip = False
|
||||||
|
last_flop = False
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
is_flip = i % 2 == 0
|
||||||
|
if is_flip:
|
||||||
|
# flip
|
||||||
|
self.compute_stream.wait_event(self.cpy_end_event)
|
||||||
|
with torch.cuda.stream(self.compute_stream):
|
||||||
|
feed_dict = self.block_wrap_fn(self.flip, **feed_dict)
|
||||||
|
self.event_flip.record(self.compute_stream)
|
||||||
|
# while flip executes, queue flop to copy to its next block
|
||||||
|
next_flop_i = i + 1
|
||||||
|
if next_flop_i >= self.num_blocks:
|
||||||
|
next_flop_i = next_flop_i - self.num_blocks
|
||||||
|
last_flip = True
|
||||||
|
if not first_flip:
|
||||||
|
self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event)
|
||||||
|
if last_flip:
|
||||||
|
self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip)
|
||||||
|
first_flip = False
|
||||||
|
else:
|
||||||
|
# flop
|
||||||
|
if not first_flop:
|
||||||
|
self.compute_stream.wait_event(self.cpy_end_event)
|
||||||
|
with torch.cuda.stream(self.compute_stream):
|
||||||
|
feed_dict = self.block_wrap_fn(self.flop, **feed_dict)
|
||||||
|
self.event_flop.record(self.compute_stream)
|
||||||
|
# while flop executes, queue flip to copy to its next block
|
||||||
|
next_flip_i = i + 1
|
||||||
|
if next_flip_i >= self.num_blocks:
|
||||||
|
next_flip_i = next_flip_i - self.num_blocks
|
||||||
|
last_flop = True
|
||||||
|
self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event)
|
||||||
|
if last_flop:
|
||||||
|
self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop)
|
||||||
|
first_flop = False
|
||||||
|
|
||||||
|
self.compute_stream.record_event(self.cpy_end_event)
|
||||||
|
|
||||||
|
outputs = [feed_dict[name] for name in self.out_names]
|
||||||
|
if len(outputs) == 1:
|
||||||
|
return outputs[0]
|
||||||
|
return tuple(outputs)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__old(self, **feed_dict):
|
||||||
# contentis' prototype flip flop
|
# contentis' prototype flip flop
|
||||||
# Wait for reset
|
# Wait for reset
|
||||||
self.compute_stream.wait_event(self.cpy_end_event)
|
self.compute_stream.wait_event(self.cpy_end_event)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user