mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 20:10:48 +08:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
commit
383f9b34cb
@ -164,8 +164,11 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return resized_cond
|
return resized_cond
|
||||||
|
|
||||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0])
|
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||||
self._step = int(indexes[0])
|
matches = torch.nonzero(mask)
|
||||||
|
if torch.numel(matches) == 0:
|
||||||
|
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||||
|
self._step = int(matches[0].item())
|
||||||
|
|
||||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||||
|
|||||||
@ -333,21 +333,25 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def pos_embeds(self, x, context):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||||
|
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
txt_start = round(max(h_len, w_len))
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
|
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -356,6 +360,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
context,
|
context,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
guidance: torch.Tensor = None,
|
guidance: torch.Tensor = None,
|
||||||
|
ref_latents=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@ -363,13 +368,38 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
encoder_hidden_states = context
|
encoder_hidden_states = context
|
||||||
encoder_hidden_states_mask = attention_mask
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
|
||||||
image_rotary_emb = self.pos_embeds(x, context)
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
|
num_embeds = hidden_states.shape[1]
|
||||||
|
|
||||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
if ref_latents is not None:
|
||||||
orig_shape = hidden_states.shape
|
h = 0
|
||||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
w = 0
|
||||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
index = 0
|
||||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||||
|
for ref in ref_latents:
|
||||||
|
if index_ref_method:
|
||||||
|
index += 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
else:
|
||||||
|
index = 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||||
|
w_offset = w
|
||||||
|
else:
|
||||||
|
h_offset = h
|
||||||
|
h = max(h, ref.shape[-2] + h_offset)
|
||||||
|
w = max(w, ref.shape[-1] + w_offset)
|
||||||
|
|
||||||
|
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
|
||||||
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
|
||||||
|
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
@ -408,6 +438,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||||
|
|||||||
@ -1331,4 +1331,14 @@ class QwenImage(BaseModel):
|
|||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
latents = []
|
||||||
|
for lat in ref_latents:
|
||||||
|
latents.append(self.process_latent_in(lat))
|
||||||
|
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||||
|
|
||||||
|
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||||
|
if ref_latents_method is not None:
|
||||||
|
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -346,6 +346,24 @@ class LoadAudio:
|
|||||||
return "Invalid audio file: {}".format(audio)
|
return "Invalid audio file: {}".format(audio)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
class RecordAudio:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"audio": ("AUDIO_RECORD", {})}}
|
||||||
|
|
||||||
|
CATEGORY = "audio"
|
||||||
|
|
||||||
|
RETURN_TYPES = ("AUDIO", )
|
||||||
|
FUNCTION = "load"
|
||||||
|
|
||||||
|
def load(self, audio):
|
||||||
|
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||||
|
|
||||||
|
waveform, sample_rate = torchaudio.load(audio_path)
|
||||||
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||||
|
return (audio, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"EmptyLatentAudio": EmptyLatentAudio,
|
"EmptyLatentAudio": EmptyLatentAudio,
|
||||||
"VAEEncodeAudio": VAEEncodeAudio,
|
"VAEEncodeAudio": VAEEncodeAudio,
|
||||||
@ -356,6 +374,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LoadAudio": LoadAudio,
|
"LoadAudio": LoadAudio,
|
||||||
"PreviewAudio": PreviewAudio,
|
"PreviewAudio": PreviewAudio,
|
||||||
"ConditioningStableAudio": ConditioningStableAudio,
|
"ConditioningStableAudio": ConditioningStableAudio,
|
||||||
|
"RecordAudio": RecordAudio,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -367,4 +386,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"SaveAudio": "Save Audio (FLAC)",
|
"SaveAudio": "Save Audio (FLAC)",
|
||||||
"SaveAudioMP3": "Save Audio (MP3)",
|
"SaveAudioMP3": "Save Audio (MP3)",
|
||||||
"SaveAudioOpus": "Save Audio (Opus)",
|
"SaveAudioOpus": "Save Audio (Opus)",
|
||||||
|
"RecordAudio": "Record Audio",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.24.4
|
comfyui-frontend-package==1.25.8
|
||||||
comfyui-workflow-templates==0.1.59
|
comfyui-workflow-templates==0.1.60
|
||||||
comfyui-embedded-docs==0.2.6
|
comfyui-embedded-docs==0.2.6
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user