Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-08-18 15:14:34 +03:00 committed by GitHub
commit 3f09b4dba5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 44 additions and 5 deletions

View File

@ -164,8 +164,11 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0])
self._step = int(indexes[0])
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model

View File

@ -360,6 +360,7 @@ class QwenImageTransformer2DModel(nn.Module):
context,
attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None,
transformer_options={},
**kwargs
):
@ -370,6 +371,31 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)

View File

@ -1331,4 +1331,14 @@ class QwenImage(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out

View File

@ -699,7 +699,7 @@ class WanTrackToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanPhantomSubjectToVideo",
node_id="WanTrackToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.25.8
comfyui-workflow-templates==0.1.59
comfyui-frontend-package==1.25.9
comfyui-workflow-templates==0.1.60
comfyui-embedded-docs==0.2.6
torch
torchsde