mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 18:43:05 +08:00
Refactored FlipFlopTransformer.__call__ to fully separate out actions between flip and flop
This commit is contained in:
parent
84e73f2aa5
commit
f083720eb4
@ -110,6 +110,58 @@ class FlipFlopTransformer:
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, **feed_dict):
|
||||
'''
|
||||
Flip accounts for even blocks (0 is first block), flop accounts for odd blocks.
|
||||
'''
|
||||
# separated flip flop refactor
|
||||
first_flip = True
|
||||
first_flop = True
|
||||
last_flip = False
|
||||
last_flop = False
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
is_flip = i % 2 == 0
|
||||
if is_flip:
|
||||
# flip
|
||||
self.compute_stream.wait_event(self.cpy_end_event)
|
||||
with torch.cuda.stream(self.compute_stream):
|
||||
feed_dict = self.block_wrap_fn(self.flip, **feed_dict)
|
||||
self.event_flip.record(self.compute_stream)
|
||||
# while flip executes, queue flop to copy to its next block
|
||||
next_flop_i = i + 1
|
||||
if next_flop_i >= self.num_blocks:
|
||||
next_flop_i = next_flop_i - self.num_blocks
|
||||
last_flip = True
|
||||
if not first_flip:
|
||||
self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event)
|
||||
if last_flip:
|
||||
self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip)
|
||||
first_flip = False
|
||||
else:
|
||||
# flop
|
||||
if not first_flop:
|
||||
self.compute_stream.wait_event(self.cpy_end_event)
|
||||
with torch.cuda.stream(self.compute_stream):
|
||||
feed_dict = self.block_wrap_fn(self.flop, **feed_dict)
|
||||
self.event_flop.record(self.compute_stream)
|
||||
# while flop executes, queue flip to copy to its next block
|
||||
next_flip_i = i + 1
|
||||
if next_flip_i >= self.num_blocks:
|
||||
next_flip_i = next_flip_i - self.num_blocks
|
||||
last_flop = True
|
||||
self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event)
|
||||
if last_flop:
|
||||
self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop)
|
||||
first_flop = False
|
||||
|
||||
self.compute_stream.record_event(self.cpy_end_event)
|
||||
|
||||
outputs = [feed_dict[name] for name in self.out_names]
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return tuple(outputs)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__old(self, **feed_dict):
|
||||
# contentis' prototype flip flop
|
||||
# Wait for reset
|
||||
self.compute_stream.wait_event(self.cpy_end_event)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user