mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
Support block replace patches (SLG mostly)
This commit is contained in:
parent
ccfc3ab7cf
commit
dd318ada2f
@ -336,6 +336,7 @@ class Kandinsky5(nn.Module):
|
||||
return freqs
|
||||
|
||||
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
context = self.text_embeddings(context)
|
||||
time_embed = self.time_embeddings(timestep).to(x.dtype) + self.pooled_text_embeddings(y)
|
||||
|
||||
@ -345,8 +346,18 @@ class Kandinsky5(nn.Module):
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
|
||||
for block in self.visual_transformer_blocks:
|
||||
visual_embed = block(visual_embed.flatten(1, -2), context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||
visual_embed = visual_embed.flatten(1, -2)
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.visual_transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
||||
else:
|
||||
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
return self.out_layer(visual_embed, time_embed)
|
||||
|
||||
@ -1691,5 +1691,5 @@ class Kandinsky5_image(Kandinsky5):
|
||||
def concat_cond(self, **kwargs):
|
||||
return None
|
||||
|
||||
def process_latent_out(self, latent):
|
||||
def process_latent_out(self, latent): # input is still 5D, return single frame to decode with Flux VAE
|
||||
return self.latent_format.process_out(latent)[:, :, 0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user