mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-06 23:32:30 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
26fe783917
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import math
|
||||||
|
|
||||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||||
@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
|||||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||||
|
|
||||||
|
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
|
||||||
|
def scale_dim(size, scale):
|
||||||
|
scaled = math.ceil(size * scale / patch_size) * patch_size
|
||||||
|
return max(patch_size, int(scaled))
|
||||||
|
|
||||||
|
# Binary search for optimal scale
|
||||||
|
lo, hi = eps / 10, 100.0
|
||||||
|
while hi - lo >= eps:
|
||||||
|
mid = (lo + hi) / 2
|
||||||
|
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
|
||||||
|
if (h // patch_size) * (w // patch_size) <= max_num_patches:
|
||||||
|
lo = mid
|
||||||
|
else:
|
||||||
|
hi = mid
|
||||||
|
|
||||||
|
return scale_dim(oh, lo), scale_dim(ow, lo)
|
||||||
|
|
||||||
|
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
|
||||||
|
if size > 0:
|
||||||
|
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
|
||||||
|
|
||||||
|
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||||
|
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||||
|
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||||
|
image = image.movedim(-1, 1)
|
||||||
|
|
||||||
|
b, c, h, w = image.shape
|
||||||
|
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
|
||||||
|
|
||||||
|
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
|
||||||
|
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||||
|
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
|
||||||
|
|
||||||
class CLIPAttention(torch.nn.Module):
|
class CLIPAttention(torch.nn.Module):
|
||||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -175,6 +209,27 @@ class CLIPTextModel(torch.nn.Module):
|
|||||||
out = self.text_projection(x[2])
|
out = self.text_projection(x[2])
|
||||||
return (x[0], x[1], out, x[2])
|
return (x[0], x[1], out, x[2])
|
||||||
|
|
||||||
|
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
|
||||||
|
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
|
||||||
|
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
|
||||||
|
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
|
||||||
|
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
|
||||||
|
return embeds + embed_weight
|
||||||
|
|
||||||
|
class Siglip2Embeddings(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
|
||||||
|
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
def forward(self, pixel_values):
|
||||||
|
b, c, h, w = pixel_values.shape
|
||||||
|
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
|
||||||
|
img = img.permute(0, 1, 3, 2, 4, 5)
|
||||||
|
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
|
||||||
|
img = self.patch_embedding(img)
|
||||||
|
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
|
||||||
|
|
||||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||||
@ -218,8 +273,11 @@ class CLIPVision(torch.nn.Module):
|
|||||||
intermediate_activation = config_dict["hidden_act"]
|
intermediate_activation = config_dict["hidden_act"]
|
||||||
model_type = config_dict["model_type"]
|
model_type = config_dict["model_type"]
|
||||||
|
|
||||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
if model_type in ["siglip2_vision_model"]:
|
||||||
if model_type == "siglip_vision_model":
|
self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||||
|
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
|
||||||
self.pre_layrnorm = lambda a: a
|
self.pre_layrnorm = lambda a: a
|
||||||
self.output_layernorm = True
|
self.output_layernorm = True
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -21,6 +21,7 @@ clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from br
|
|||||||
IMAGE_ENCODERS = {
|
IMAGE_ENCODERS = {
|
||||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
|
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,9 +33,10 @@ class ClipVisionModel():
|
|||||||
self.image_size = config.get("image_size", 224)
|
self.image_size = config.get("image_size", 224)
|
||||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||||
model_type = config.get("model_type", "clip_vision_model")
|
self.model_type = config.get("model_type", "clip_vision_model")
|
||||||
model_class = IMAGE_ENCODERS.get(model_type)
|
self.config = config.copy()
|
||||||
if model_type == "siglip_vision_model":
|
model_class = IMAGE_ENCODERS.get(self.model_type)
|
||||||
|
if self.model_type == "siglip_vision_model":
|
||||||
self.return_all_hidden_states = True
|
self.return_all_hidden_states = True
|
||||||
else:
|
else:
|
||||||
self.return_all_hidden_states = False
|
self.return_all_hidden_states = False
|
||||||
@ -55,7 +57,10 @@ class ClipVisionModel():
|
|||||||
|
|
||||||
def encode_image(self, image, crop=True):
|
def encode_image(self, image, crop=True):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
if self.model_type == "siglip2_vision_model":
|
||||||
|
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||||
|
else:
|
||||||
|
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||||
|
|
||||||
outputs = Output()
|
outputs = Output()
|
||||||
@ -107,10 +112,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||||
if embed_shape == 729:
|
patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
if len(patch_embedding_shape) == 2:
|
||||||
elif embed_shape == 1024:
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
else:
|
||||||
|
if embed_shape == 729:
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||||
|
elif embed_shape == 1024:
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||||
elif embed_shape == 577:
|
elif embed_shape == 577:
|
||||||
if "multi_modal_projector.linear_1.bias" in sd:
|
if "multi_modal_projector.linear_1.bias" in sd:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||||
|
|||||||
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"num_channels": 3,
|
||||||
|
"hidden_act": "gelu_pytorch_tanh",
|
||||||
|
"hidden_size": 1152,
|
||||||
|
"image_size": -1,
|
||||||
|
"intermediate_size": 4304,
|
||||||
|
"model_type": "siglip2_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 27,
|
||||||
|
"patch_size": 16,
|
||||||
|
"num_patches": 256,
|
||||||
|
"image_mean": [0.5, 0.5, 0.5],
|
||||||
|
"image_std": [0.5, 0.5, 0.5]
|
||||||
|
}
|
||||||
@ -11,6 +11,69 @@ from comfy.ldm.lightricks.model import (
|
|||||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
class CompressedTimestep:
|
||||||
|
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||||
|
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||||
|
|
||||||
|
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||||
|
"""
|
||||||
|
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||||
|
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||||
|
"""
|
||||||
|
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||||
|
|
||||||
|
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||||
|
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||||
|
self.patches_per_frame = patches_per_frame
|
||||||
|
self.num_frames = num_tokens // patches_per_frame
|
||||||
|
|
||||||
|
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||||
|
# All patches in a frame are identical, so we only keep the first one
|
||||||
|
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||||
|
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||||
|
else:
|
||||||
|
# Not divisible or too small - store directly without compression
|
||||||
|
self.patches_per_frame = 1
|
||||||
|
self.num_frames = num_tokens
|
||||||
|
self.data = tensor
|
||||||
|
|
||||||
|
def expand(self):
|
||||||
|
"""Expand back to original tensor."""
|
||||||
|
if self.patches_per_frame == 1:
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
|
||||||
|
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
|
||||||
|
return expanded.reshape(self.batch_size, -1, self.feature_dim)
|
||||||
|
|
||||||
|
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
|
||||||
|
"""Compute ada values on compressed per-frame data, then expand spatially."""
|
||||||
|
num_ada_params = scale_shift_table.shape[0]
|
||||||
|
|
||||||
|
# No compression - compute directly
|
||||||
|
if self.patches_per_frame == 1:
|
||||||
|
num_tokens = self.data.shape[1]
|
||||||
|
dim_per_param = self.feature_dim // num_ada_params
|
||||||
|
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
|
||||||
|
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
|
||||||
|
ada_values = (table_values + reshaped).unbind(dim=2)
|
||||||
|
return ada_values
|
||||||
|
|
||||||
|
# Compressed: compute on per-frame data then expand spatially
|
||||||
|
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
|
||||||
|
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
|
||||||
|
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
|
||||||
|
device=self.data.device, dtype=self.data.dtype
|
||||||
|
)
|
||||||
|
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
|
||||||
|
|
||||||
|
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
|
||||||
|
return tuple(
|
||||||
|
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
|
||||||
|
.reshape(batch_size, -1, frame_val.shape[-1])
|
||||||
|
for frame_val in frame_ada
|
||||||
|
)
|
||||||
|
|
||||||
class BasicAVTransformerBlock(nn.Module):
|
class BasicAVTransformerBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -119,6 +182,9 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
def get_ada_values(
|
def get_ada_values(
|
||||||
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||||
):
|
):
|
||||||
|
if isinstance(timestep, CompressedTimestep):
|
||||||
|
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
|
||||||
|
|
||||||
num_ada_params = scale_shift_table.shape[0]
|
num_ada_params = scale_shift_table.shape[0]
|
||||||
|
|
||||||
ada_values = (
|
ada_values = (
|
||||||
@ -146,10 +212,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
gate_timestep,
|
gate_timestep,
|
||||||
)
|
)
|
||||||
|
|
||||||
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
return (*scale_shift_ada_values, *gate_ada_values)
|
||||||
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
|
||||||
|
|
||||||
return (*scale_shift_chunks, *gate_ada_values)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -543,72 +606,80 @@ class LTXAVModel(LTXVModel):
|
|||||||
if grid_mask is not None:
|
if grid_mask is not None:
|
||||||
timestep = timestep[:, grid_mask]
|
timestep = timestep[:, grid_mask]
|
||||||
|
|
||||||
timestep = timestep * self.timestep_scale_multiplier
|
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||||
|
|
||||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||||
timestep.flatten(),
|
timestep_scaled.flatten(),
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||||
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
|
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||||
v_embedded_timestep = v_embedded_timestep.view(
|
orig_shape = kwargs.get("orig_shape")
|
||||||
batch_size, -1, v_embedded_timestep.shape[-1]
|
v_patches_per_frame = None
|
||||||
)
|
if orig_shape is not None and len(orig_shape) == 5:
|
||||||
|
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||||
|
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||||
|
|
||||||
|
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||||
|
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||||
|
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||||
|
|
||||||
# Prepare audio timestep
|
# Prepare audio timestep
|
||||||
a_timestep = kwargs.get("a_timestep")
|
a_timestep = kwargs.get("a_timestep")
|
||||||
if a_timestep is not None:
|
if a_timestep is not None:
|
||||||
a_timestep = a_timestep * self.timestep_scale_multiplier
|
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||||
|
a_timestep_flat = a_timestep_scaled.flatten()
|
||||||
|
timestep_flat = timestep_scaled.flatten()
|
||||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||||
|
|
||||||
|
# Cross-attention timesteps - compress these too
|
||||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||||
a_timestep.flatten(),
|
a_timestep_flat,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||||
timestep.flatten(),
|
timestep_flat,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||||
timestep.flatten() * av_ca_factor,
|
timestep_flat * av_ca_factor,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||||
a_timestep.flatten() * av_ca_factor,
|
a_timestep_flat * av_ca_factor,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||||
|
cross_av_timestep_ss = [
|
||||||
|
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||||
|
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||||
|
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||||
|
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||||
|
]
|
||||||
|
|
||||||
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||||
a_timestep.flatten(),
|
a_timestep_flat,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
|
# Audio timesteps
|
||||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||||
a_embedded_timestep = a_embedded_timestep.view(
|
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||||
batch_size, -1, a_embedded_timestep.shape[-1]
|
|
||||||
)
|
|
||||||
cross_av_timestep_ss = [
|
|
||||||
av_ca_audio_scale_shift_timestep,
|
|
||||||
av_ca_video_scale_shift_timestep,
|
|
||||||
av_ca_a2v_gate_noise_timestep,
|
|
||||||
av_ca_v2a_gate_noise_timestep,
|
|
||||||
]
|
|
||||||
cross_av_timestep_ss = list(
|
|
||||||
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
a_timestep = timestep
|
a_timestep = timestep_scaled
|
||||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||||
cross_av_timestep_ss = []
|
cross_av_timestep_ss = []
|
||||||
|
|
||||||
@ -767,6 +838,11 @@ class LTXAVModel(LTXVModel):
|
|||||||
ax = x[1]
|
ax = x[1]
|
||||||
v_embedded_timestep = embedded_timestep[0]
|
v_embedded_timestep = embedded_timestep[0]
|
||||||
a_embedded_timestep = embedded_timestep[1]
|
a_embedded_timestep = embedded_timestep[1]
|
||||||
|
|
||||||
|
# Expand compressed video timestep if needed
|
||||||
|
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||||
|
v_embedded_timestep = v_embedded_timestep.expand()
|
||||||
|
|
||||||
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||||
|
|
||||||
# Process audio output
|
# Process audio output
|
||||||
|
|||||||
@ -322,6 +322,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||||
key_map["transformer.{}".format(key_lora)] = to
|
key_map["transformer.{}".format(key_lora)] = to
|
||||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.36.13
|
comfyui-frontend-package==1.36.13
|
||||||
comfyui-workflow-templates==0.7.69
|
comfyui-workflow-templates==0.8.0
|
||||||
comfyui-embedded-docs==0.4.0
|
comfyui-embedded-docs==0.4.0
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user