Refactored FlipFlopTransformer.__call__ to fully separate out actions between flip and flop

This commit is contained in:
Jedrzej Kosinski 2025-09-25 16:16:51 -07:00
parent 84e73f2aa5
commit f083720eb4

View File

@ -110,6 +110,58 @@ class FlipFlopTransformer:
@torch.no_grad()
def __call__(self, **feed_dict):
'''
Flip accounts for even blocks (0 is first block), flop accounts for odd blocks.
'''
# separated flip flop refactor
first_flip = True
first_flop = True
last_flip = False
last_flop = False
for i, block in enumerate(self.transformer_blocks):
is_flip = i % 2 == 0
if is_flip:
# flip
self.compute_stream.wait_event(self.cpy_end_event)
with torch.cuda.stream(self.compute_stream):
feed_dict = self.block_wrap_fn(self.flip, **feed_dict)
self.event_flip.record(self.compute_stream)
# while flip executes, queue flop to copy to its next block
next_flop_i = i + 1
if next_flop_i >= self.num_blocks:
next_flop_i = next_flop_i - self.num_blocks
last_flip = True
if not first_flip:
self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event)
if last_flip:
self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip)
first_flip = False
else:
# flop
if not first_flop:
self.compute_stream.wait_event(self.cpy_end_event)
with torch.cuda.stream(self.compute_stream):
feed_dict = self.block_wrap_fn(self.flop, **feed_dict)
self.event_flop.record(self.compute_stream)
# while flop executes, queue flip to copy to its next block
next_flip_i = i + 1
if next_flip_i >= self.num_blocks:
next_flip_i = next_flip_i - self.num_blocks
last_flop = True
self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event)
if last_flop:
self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop)
first_flop = False
self.compute_stream.record_event(self.cpy_end_event)
outputs = [feed_dict[name] for name in self.out_names]
if len(outputs) == 1:
return outputs[0]
return tuple(outputs)
@torch.no_grad()
def __call__old(self, **feed_dict):
# contentis' prototype flip flop
# Wait for reset
self.compute_stream.wait_event(self.cpy_end_event)