diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 8d8f6565d..d8059eafe 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -110,6 +110,58 @@ class FlipFlopTransformer: @torch.no_grad() def __call__(self, **feed_dict): + ''' + Flip accounts for even blocks (0 is first block), flop accounts for odd blocks. + ''' + # separated flip flop refactor + first_flip = True + first_flop = True + last_flip = False + last_flop = False + for i, block in enumerate(self.transformer_blocks): + is_flip = i % 2 == 0 + if is_flip: + # flip + self.compute_stream.wait_event(self.cpy_end_event) + with torch.cuda.stream(self.compute_stream): + feed_dict = self.block_wrap_fn(self.flip, **feed_dict) + self.event_flip.record(self.compute_stream) + # while flip executes, queue flop to copy to its next block + next_flop_i = i + 1 + if next_flop_i >= self.num_blocks: + next_flop_i = next_flop_i - self.num_blocks + last_flip = True + if not first_flip: + self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[next_flop_i].state_dict(), self.event_flop, self.cpy_end_event) + if last_flip: + self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[0].state_dict(), cpy_start_event=self.event_flip) + first_flip = False + else: + # flop + if not first_flop: + self.compute_stream.wait_event(self.cpy_end_event) + with torch.cuda.stream(self.compute_stream): + feed_dict = self.block_wrap_fn(self.flop, **feed_dict) + self.event_flop.record(self.compute_stream) + # while flop executes, queue flip to copy to its next block + next_flip_i = i + 1 + if next_flip_i >= self.num_blocks: + next_flip_i = next_flip_i - self.num_blocks + last_flop = True + self._copy_state_dict(self.flip.state_dict(), self.transformer_blocks[next_flip_i].state_dict(), self.event_flip, self.cpy_end_event) + if last_flop: + self._copy_state_dict(self.flop.state_dict(), self.transformer_blocks[1].state_dict(), cpy_start_event=self.event_flop) + first_flop = False + + self.compute_stream.record_event(self.cpy_end_event) + + outputs = [feed_dict[name] for name in self.out_names] + if len(outputs) == 1: + return outputs[0] + return tuple(outputs) + + @torch.no_grad() + def __call__old(self, **feed_dict): # contentis' prototype flip flop # Wait for reset self.compute_stream.wait_event(self.cpy_end_event)