diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index 4018307db..7bbb08208 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -56,8 +56,8 @@ class FlipFlopModule(torch.nn.Module): out = func(i, block, *out, *args, **kwargs) if isinstance(out, torch.Tensor): out = (out,) - if "transformer_blocks" in self.flipflop: - holder = self.flipflop["transformer_blocks"] + if block_type in self.flipflop: + holder = self.flipflop[block_type] with holder.context() as ctx: for i, block in enumerate(holder.blocks): out = ctx(func, i, block, *out, *args, **kwargs)