mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 07:10:15 +08:00
Compare commits
15 Commits
d14cb83594
...
50a5ec4cbd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50a5ec4cbd | ||
|
|
5ac1372533 | ||
|
|
1dcbd9efaf | ||
|
|
db9e6edfa1 | ||
|
|
8af13b439b | ||
|
|
acd0e53653 | ||
|
|
117e7a5853 | ||
|
|
b3c0e4de57 | ||
|
|
ecaeeb990d | ||
|
|
c2b65e2fce | ||
|
|
fd5c0755af | ||
|
|
c881a1d689 | ||
|
|
a3b5d4996a | ||
|
|
a0f6703f05 | ||
|
|
5cd7530637 |
@ -92,14 +92,23 @@ def seed_from_paths_batch(
|
||||
session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
winners_by_path: set[str] = set()
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
|
||||
ins_state = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
.returning(AssetCacheState.file_path)
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
winners_by_path.update((session.execute(ins_state, chunk)).scalars().all())
|
||||
session.execute(ins_state, chunk)
|
||||
|
||||
# Query to find which of our paths won (were actually inserted)
|
||||
winners_by_path: set[str] = set()
|
||||
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.file_path.in_(chunk))
|
||||
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
|
||||
)
|
||||
winners_by_path.update(result.scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
@ -112,16 +121,23 @@ def seed_from_paths_batch(
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
ins_info = (
|
||||
sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
.returning(AssetInfo.id)
|
||||
)
|
||||
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
inserted_info_ids.update((session.execute(ins_info, chunk)).scalars().all())
|
||||
session.execute(ins_info, chunk)
|
||||
|
||||
# Query to find which info rows were actually inserted (by matching our generated IDs)
|
||||
all_info_ids = [row["id"] for row in winner_info_rows]
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
|
||||
)
|
||||
inserted_info_ids.update(result.scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
import math
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
|
||||
def scale_dim(size, scale):
|
||||
scaled = math.ceil(size * scale / patch_size) * patch_size
|
||||
return max(patch_size, int(scaled))
|
||||
|
||||
# Binary search for optimal scale
|
||||
lo, hi = eps / 10, 100.0
|
||||
while hi - lo >= eps:
|
||||
mid = (lo + hi) / 2
|
||||
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
|
||||
if (h // patch_size) * (w // patch_size) <= max_num_patches:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid
|
||||
|
||||
return scale_dim(oh, lo), scale_dim(ow, lo)
|
||||
|
||||
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
|
||||
if size > 0:
|
||||
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
|
||||
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
@ -175,6 +209,27 @@ class CLIPTextModel(torch.nn.Module):
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
|
||||
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
|
||||
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
|
||||
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
|
||||
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
|
||||
return embeds + embed_weight
|
||||
|
||||
class Siglip2Embeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
self.patch_size = patch_size
|
||||
|
||||
def forward(self, pixel_values):
|
||||
b, c, h, w = pixel_values.shape
|
||||
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
|
||||
img = img.permute(0, 1, 3, 2, 4, 5)
|
||||
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
|
||||
img = self.patch_embedding(img)
|
||||
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
@ -218,8 +273,11 @@ class CLIPVision(torch.nn.Module):
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
if model_type in ["siglip2_vision_model"]:
|
||||
self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
|
||||
@ -21,6 +21,7 @@ clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from br
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
}
|
||||
|
||||
@ -32,9 +33,10 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_type = config.get("model_type", "clip_vision_model")
|
||||
model_class = IMAGE_ENCODERS.get(model_type)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.model_type = config.get("model_type", "clip_vision_model")
|
||||
self.config = config.copy()
|
||||
model_class = IMAGE_ENCODERS.get(self.model_type)
|
||||
if self.model_type == "siglip_vision_model":
|
||||
self.return_all_hidden_states = True
|
||||
else:
|
||||
self.return_all_hidden_states = False
|
||||
@ -55,7 +57,10 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
if self.model_type == "siglip2_vision_model":
|
||||
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
else:
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
@ -107,10 +112,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
|
||||
if len(patch_embedding_shape) == 2:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
|
||||
else:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
elif embed_shape == 577:
|
||||
if "multi_modal_projector.linear_1.bias" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||
|
||||
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": -1,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip2_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 16,
|
||||
"num_patches": 256,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
118
comfy/float.py
118
comfy/float.py
@ -65,3 +65,121 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
|
||||
# TODO: improve this?
|
||||
def stochastic_float_to_fp4_e2m1(x, generator):
|
||||
orig_shape = x.shape
|
||||
sign = torch.signbit(x).to(torch.uint8)
|
||||
|
||||
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
|
||||
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
|
||||
|
||||
x = x.abs()
|
||||
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
|
||||
|
||||
mantissa = torch.where(
|
||||
exp > 0,
|
||||
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
|
||||
(x * 2.0),
|
||||
out=x
|
||||
).round().to(torch.uint8)
|
||||
del x
|
||||
|
||||
exp = exp.to(torch.uint8)
|
||||
|
||||
fp4 = (sign << 3) | (exp << 1) | mantissa
|
||||
del sign, exp, mantissa
|
||||
|
||||
fp4_flat = fp4.view(-1)
|
||||
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
|
||||
return packed.reshape(list(orig_shape)[:-1] + [-1])
|
||||
|
||||
|
||||
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
||||
See:
|
||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
||||
|
||||
Args:
|
||||
input_matrix: Input tensor of shape (H, W)
|
||||
Returns:
|
||||
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
||||
"""
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
rows, cols = input_matrix.shape
|
||||
n_row_blocks = ceil_div(rows, 128)
|
||||
n_col_blocks = ceil_div(cols, 4)
|
||||
|
||||
# Calculate the padded shape
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
padded = input_matrix
|
||||
if (rows, cols) != (padded_rows, padded_cols):
|
||||
padded = torch.zeros(
|
||||
(padded_rows, padded_cols),
|
||||
device=input_matrix.device,
|
||||
dtype=input_matrix.dtype,
|
||||
)
|
||||
padded[:rows, :cols] = input_matrix
|
||||
|
||||
# Rearrange the blocks
|
||||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
if flatten:
|
||||
return rearranged.flatten()
|
||||
|
||||
return rearranged.reshape(padded_rows, padded_cols)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
F4_E2M1_MAX = 6.0
|
||||
F8_E4M3_MAX = 448.0
|
||||
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
|
||||
orig_shape = x.shape
|
||||
|
||||
# Handle padding
|
||||
if pad_16x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 16)
|
||||
padded_cols = roundup(cols, 16)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
|
||||
# what we want to produce. If we pad here, we want the padded output.
|
||||
orig_shape = x.shape
|
||||
|
||||
block_size = 16
|
||||
|
||||
x = x.reshape(orig_shape[0], -1, block_size)
|
||||
max_abs = torch.amax(torch.abs(x), dim=-1)
|
||||
block_scale = max_abs / F4_E2M1_MAX
|
||||
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
|
||||
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
|
||||
|
||||
# Handle zero blocks (from padding): avoid 0/0 NaN
|
||||
zero_scale_mask = (total_scale == 0)
|
||||
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
|
||||
|
||||
x = x / total_scale_safe.unsqueeze(-1)
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
|
||||
|
||||
x = x.view(orig_shape)
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
|
||||
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
|
||||
return data_lp, blocked_scales
|
||||
|
||||
@ -11,6 +11,69 @@ from comfy.ldm.lightricks.model import (
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||
# All patches in a frame are identical, so we only keep the first one
|
||||
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||
else:
|
||||
# Not divisible or too small - store directly without compression
|
||||
self.patches_per_frame = 1
|
||||
self.num_frames = num_tokens
|
||||
self.data = tensor
|
||||
|
||||
def expand(self):
|
||||
"""Expand back to original tensor."""
|
||||
if self.patches_per_frame == 1:
|
||||
return self.data
|
||||
|
||||
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
|
||||
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
|
||||
return expanded.reshape(self.batch_size, -1, self.feature_dim)
|
||||
|
||||
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
|
||||
"""Compute ada values on compressed per-frame data, then expand spatially."""
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
# No compression - compute directly
|
||||
if self.patches_per_frame == 1:
|
||||
num_tokens = self.data.shape[1]
|
||||
dim_per_param = self.feature_dim // num_ada_params
|
||||
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
|
||||
ada_values = (table_values + reshaped).unbind(dim=2)
|
||||
return ada_values
|
||||
|
||||
# Compressed: compute on per-frame data then expand spatially
|
||||
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
|
||||
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
|
||||
device=self.data.device, dtype=self.data.dtype
|
||||
)
|
||||
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
|
||||
|
||||
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
|
||||
return tuple(
|
||||
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
|
||||
.reshape(batch_size, -1, frame_val.shape[-1])
|
||||
for frame_val in frame_ada
|
||||
)
|
||||
|
||||
class BasicAVTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -119,6 +182,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
def get_ada_values(
|
||||
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||
):
|
||||
if isinstance(timestep, CompressedTimestep):
|
||||
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
|
||||
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
ada_values = (
|
||||
@ -146,10 +212,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
gate_timestep,
|
||||
)
|
||||
|
||||
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
||||
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
||||
|
||||
return (*scale_shift_chunks, *gate_ada_values)
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -543,72 +606,80 @@ class LTXAVModel(LTXVModel):
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
|
||||
v_embedded_timestep = v_embedded_timestep.view(
|
||||
batch_size, -1, v_embedded_timestep.shape[-1]
|
||||
)
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
v_patches_per_frame = None
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
a_timestep = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_flat = a_timestep_scaled.flatten()
|
||||
timestep_flat = timestep_scaled.flatten()
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep.flatten(),
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep.flatten(),
|
||||
timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep.flatten() * av_ca_factor,
|
||||
timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep.flatten() * av_ca_factor,
|
||||
a_timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||
a_timestep.flatten(),
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
# Audio timesteps
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(
|
||||
batch_size, -1, a_embedded_timestep.shape[-1]
|
||||
)
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep,
|
||||
av_ca_video_scale_shift_timestep,
|
||||
av_ca_a2v_gate_noise_timestep,
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
]
|
||||
cross_av_timestep_ss = list(
|
||||
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
|
||||
)
|
||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||
else:
|
||||
a_timestep = timestep
|
||||
a_timestep = timestep_scaled
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
|
||||
@ -767,6 +838,11 @@ class LTXAVModel(LTXVModel):
|
||||
ax = x[1]
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
|
||||
# Expand compressed video timestep if needed
|
||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||
v_embedded_timestep = v_embedded_timestep.expand()
|
||||
|
||||
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||
|
||||
# Process audio output
|
||||
|
||||
@ -322,6 +322,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||
for k in sdk:
|
||||
|
||||
@ -699,7 +699,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
|
||||
@ -7,7 +7,7 @@ try:
|
||||
QuantizedTensor,
|
||||
QuantizedLayout,
|
||||
TensorCoreFP8Layout as _CKFp8Layout,
|
||||
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
|
||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||
register_layout_op,
|
||||
register_layout_class,
|
||||
get_layout_class,
|
||||
@ -34,7 +34,7 @@ except ImportError as e:
|
||||
class _CKFp8Layout:
|
||||
pass
|
||||
|
||||
class TensorCoreNVFP4Layout:
|
||||
class _CKNvfp4Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
@ -84,6 +84,39 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if tensor.dim() != 2:
|
||||
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
|
||||
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
padded_shape = cls.get_padded_shape(orig_shape)
|
||||
needs_padding = padded_shape != orig_shape
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
|
||||
else:
|
||||
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
|
||||
|
||||
params = cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=orig_dtype,
|
||||
orig_shape=orig_shape,
|
||||
block_scale=block_scale,
|
||||
)
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
|
||||
@ -845,7 +845,7 @@ class LTXAV(LTXV):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = 0.061 # TODO
|
||||
self.memory_usage_factor = 0.077 # TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXAV(self, device=device)
|
||||
|
||||
@ -78,20 +78,19 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
overlap = 32
|
||||
|
||||
oom = True
|
||||
try:
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
raise e
|
||||
finally:
|
||||
upscale_model.to("cpu")
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
# KEY CHANGE: Pass device as output_device instead of default "cpu"
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, output_device=device, pbar=pbar)
|
||||
oom = False
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
raise e
|
||||
|
||||
# upscale_model.to("cpu") # Commented out to keep model on GPU because when processing batch images then model unnecessarily moves to CPU
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||
return io.NodeOutput(s)
|
||||
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.8.2"
|
||||
__version__ = "0.9.1"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.8.2"
|
||||
version = "0.9.1"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.36.13
|
||||
comfyui-workflow-templates==0.7.69
|
||||
comfyui-frontend-package==1.36.14
|
||||
comfyui-workflow-templates==0.8.4
|
||||
comfyui-embedded-docs==0.4.0
|
||||
torch
|
||||
torchsde
|
||||
@ -21,7 +21,7 @@ psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.5
|
||||
comfy-kitchen>=0.2.6
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user