mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 20:00:17 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
3f09b4dba5
@ -164,8 +164,11 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return resized_cond
|
return resized_cond
|
||||||
|
|
||||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0])
|
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||||
self._step = int(indexes[0])
|
matches = torch.nonzero(mask)
|
||||||
|
if torch.numel(matches) == 0:
|
||||||
|
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||||
|
self._step = int(matches[0].item())
|
||||||
|
|
||||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||||
|
|||||||
@ -360,6 +360,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
context,
|
context,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
guidance: torch.Tensor = None,
|
guidance: torch.Tensor = None,
|
||||||
|
ref_latents=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@ -370,6 +371,31 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
num_embeds = hidden_states.shape[1]
|
num_embeds = hidden_states.shape[1]
|
||||||
|
|
||||||
|
if ref_latents is not None:
|
||||||
|
h = 0
|
||||||
|
w = 0
|
||||||
|
index = 0
|
||||||
|
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||||
|
for ref in ref_latents:
|
||||||
|
if index_ref_method:
|
||||||
|
index += 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
else:
|
||||||
|
index = 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||||
|
w_offset = w
|
||||||
|
else:
|
||||||
|
h_offset = h
|
||||||
|
h = max(h, ref.shape[-2] + h_offset)
|
||||||
|
w = max(w, ref.shape[-1] + w_offset)
|
||||||
|
|
||||||
|
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
|
||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
|
||||||
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
|||||||
@ -1331,4 +1331,14 @@ class QwenImage(BaseModel):
|
|||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
latents = []
|
||||||
|
for lat in ref_latents:
|
||||||
|
latents.append(self.process_latent_in(lat))
|
||||||
|
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||||
|
|
||||||
|
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||||
|
if ref_latents_method is not None:
|
||||||
|
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -699,7 +699,7 @@ class WanTrackToVideo(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="WanPhantomSubjectToVideo",
|
node_id="WanTrackToVideo",
|
||||||
category="conditioning/video_models",
|
category="conditioning/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input("positive"),
|
io.Conditioning.Input("positive"),
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.25.8
|
comfyui-frontend-package==1.25.9
|
||||||
comfyui-workflow-templates==0.1.59
|
comfyui-workflow-templates==0.1.60
|
||||||
comfyui-embedded-docs==0.2.6
|
comfyui-embedded-docs==0.2.6
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user