Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski 2025-08-17 16:02:44 -07:00
commit 383f9b34cb
5 changed files with 83 additions and 20 deletions

View File

@ -164,8 +164,11 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]): def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0]) mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
self._step = int(indexes[0]) matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model full_length = x_in.size(self.dim) # TODO: choose dim based on model

View File

@ -333,21 +333,25 @@ class QwenImageTransformer2DModel(nn.Module):
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def pos_embeds(self, x, context): def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
patch_size = self.patch_size patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
h_len = ((h + (patch_size // 2)) // patch_size) h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) h_offset = ((h_offset + (patch_size // 2)) // patch_size)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_start = round(max(h_len, w_len)) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3) img_ids[:, :, 0] = img_ids[:, :, 1] + index
ids = torch.cat((txt_ids, img_ids), dim=1) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def forward( def forward(
self, self,
@ -356,6 +360,7 @@ class QwenImageTransformer2DModel(nn.Module):
context, context,
attention_mask=None, attention_mask=None,
guidance: torch.Tensor = None, guidance: torch.Tensor = None,
ref_latents=None,
transformer_options={}, transformer_options={},
**kwargs **kwargs
): ):
@ -363,13 +368,38 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = context encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask encoder_hidden_states_mask = attention_mask
image_rotary_emb = self.pos_embeds(x, context) hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) if ref_latents is not None:
orig_shape = hidden_states.shape h = 0
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) w = 0
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) index = 0
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
hidden_states = self.img_in(hidden_states) hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@ -408,6 +438,6 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -1331,4 +1331,14 @@ class QwenImage(BaseModel):
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out return out

View File

@ -346,6 +346,24 @@ class LoadAudio:
return "Invalid audio file: {}".format(audio) return "Invalid audio file: {}".format(audio)
return True return True
class RecordAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"audio": ("AUDIO_RECORD", {})}}
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio, "EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio, "VAEEncodeAudio": VAEEncodeAudio,
@ -356,6 +374,7 @@ NODE_CLASS_MAPPINGS = {
"LoadAudio": LoadAudio, "LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio, "PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio, "ConditioningStableAudio": ConditioningStableAudio,
"RecordAudio": RecordAudio,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@ -367,4 +386,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveAudio": "Save Audio (FLAC)", "SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)", "SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)", "SaveAudioOpus": "Save Audio (Opus)",
"RecordAudio": "Record Audio",
} }

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.24.4 comfyui-frontend-package==1.25.8
comfyui-workflow-templates==0.1.59 comfyui-workflow-templates==0.1.60
comfyui-embedded-docs==0.2.6 comfyui-embedded-docs==0.2.6
torch torch
torchsde torchsde