Add a way to set the default ref method in the qwen image code. (#11349)

This commit is contained in:
comfyanonymous 2025-12-15 22:26:55 -08:00 committed by GitHub
parent 645ee1881e
commit bc606d7d64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -322,6 +322,7 @@ class QwenImageTransformer2DModel(nn.Module):
pooled_projection_dim: int = 768, pooled_projection_dim: int = 768,
guidance_embeds: bool = False, guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index",
image_model=None, image_model=None,
final_layer=True, final_layer=True,
dtype=None, dtype=None,
@ -334,6 +335,7 @@ class QwenImageTransformer2DModel(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels or in_channels self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim self.inner_dim = num_attention_heads * attention_head_dim
self.default_ref_method = default_ref_method
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
@ -416,7 +418,7 @@ class QwenImageTransformer2DModel(nn.Module):
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
ref_method = kwargs.get("ref_latents_method", "index") ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
timestep_zero = ref_method == "index_timestep_zero" timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents: for ref in ref_latents: