mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 15:20:25 +08:00
Compare commits
16 Commits
86ba1392e5
...
8f4bfcdc3a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f4bfcdc3a | ||
|
|
5ac1372533 | ||
|
|
1dcbd9efaf | ||
|
|
db9e6edfa1 | ||
|
|
8af13b439b | ||
|
|
acd0e53653 | ||
|
|
117e7a5853 | ||
|
|
b3c0e4de57 | ||
|
|
ecaeeb990d | ||
|
|
c2b65e2fce | ||
|
|
fd5c0755af | ||
|
|
c881a1d689 | ||
|
|
a3b5d4996a | ||
|
|
c6238047ee | ||
|
|
27d11db345 | ||
|
|
90f23bac28 |
@ -183,7 +183,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
Update your Nvidia drivers if it doesn't start.
|
||||
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
|
||||
@ -92,14 +92,23 @@ def seed_from_paths_batch(
|
||||
session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
winners_by_path: set[str] = set()
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
|
||||
ins_state = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
.returning(AssetCacheState.file_path)
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
winners_by_path.update((session.execute(ins_state, chunk)).scalars().all())
|
||||
session.execute(ins_state, chunk)
|
||||
|
||||
# Query to find which of our paths won (were actually inserted)
|
||||
winners_by_path: set[str] = set()
|
||||
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.file_path.in_(chunk))
|
||||
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
|
||||
)
|
||||
winners_by_path.update(result.scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
@ -112,16 +121,23 @@ def seed_from_paths_batch(
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
ins_info = (
|
||||
sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
.returning(AssetInfo.id)
|
||||
)
|
||||
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
inserted_info_ids.update((session.execute(ins_info, chunk)).scalars().all())
|
||||
session.execute(ins_info, chunk)
|
||||
|
||||
# Query to find which info rows were actually inserted (by matching our generated IDs)
|
||||
all_info_ids = [row["id"] for row in winner_info_rows]
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
|
||||
)
|
||||
inserted_info_ids.update(result.scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
import math
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
|
||||
def scale_dim(size, scale):
|
||||
scaled = math.ceil(size * scale / patch_size) * patch_size
|
||||
return max(patch_size, int(scaled))
|
||||
|
||||
# Binary search for optimal scale
|
||||
lo, hi = eps / 10, 100.0
|
||||
while hi - lo >= eps:
|
||||
mid = (lo + hi) / 2
|
||||
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
|
||||
if (h // patch_size) * (w // patch_size) <= max_num_patches:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid
|
||||
|
||||
return scale_dim(oh, lo), scale_dim(ow, lo)
|
||||
|
||||
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
|
||||
if size > 0:
|
||||
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
|
||||
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
@ -175,6 +209,27 @@ class CLIPTextModel(torch.nn.Module):
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
|
||||
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
|
||||
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
|
||||
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
|
||||
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
|
||||
return embeds + embed_weight
|
||||
|
||||
class Siglip2Embeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
self.patch_size = patch_size
|
||||
|
||||
def forward(self, pixel_values):
|
||||
b, c, h, w = pixel_values.shape
|
||||
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
|
||||
img = img.permute(0, 1, 3, 2, 4, 5)
|
||||
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
|
||||
img = self.patch_embedding(img)
|
||||
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
@ -218,8 +273,11 @@ class CLIPVision(torch.nn.Module):
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
if model_type in ["siglip2_vision_model"]:
|
||||
self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
|
||||
@ -21,6 +21,7 @@ clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from br
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
}
|
||||
|
||||
@ -32,9 +33,10 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_type = config.get("model_type", "clip_vision_model")
|
||||
model_class = IMAGE_ENCODERS.get(model_type)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.model_type = config.get("model_type", "clip_vision_model")
|
||||
self.config = config.copy()
|
||||
model_class = IMAGE_ENCODERS.get(self.model_type)
|
||||
if self.model_type == "siglip_vision_model":
|
||||
self.return_all_hidden_states = True
|
||||
else:
|
||||
self.return_all_hidden_states = False
|
||||
@ -55,7 +57,10 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
if self.model_type == "siglip2_vision_model":
|
||||
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
else:
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
@ -107,10 +112,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
|
||||
if len(patch_embedding_shape) == 2:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
|
||||
else:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
elif embed_shape == 577:
|
||||
if "multi_modal_projector.linear_1.bias" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||
|
||||
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": -1,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip2_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 16,
|
||||
"num_patches": 256,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
118
comfy/float.py
118
comfy/float.py
@ -65,3 +65,121 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
|
||||
# TODO: improve this?
|
||||
def stochastic_float_to_fp4_e2m1(x, generator):
|
||||
orig_shape = x.shape
|
||||
sign = torch.signbit(x).to(torch.uint8)
|
||||
|
||||
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
|
||||
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
|
||||
|
||||
x = x.abs()
|
||||
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
|
||||
|
||||
mantissa = torch.where(
|
||||
exp > 0,
|
||||
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
|
||||
(x * 2.0),
|
||||
out=x
|
||||
).round().to(torch.uint8)
|
||||
del x
|
||||
|
||||
exp = exp.to(torch.uint8)
|
||||
|
||||
fp4 = (sign << 3) | (exp << 1) | mantissa
|
||||
del sign, exp, mantissa
|
||||
|
||||
fp4_flat = fp4.view(-1)
|
||||
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
|
||||
return packed.reshape(list(orig_shape)[:-1] + [-1])
|
||||
|
||||
|
||||
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
||||
See:
|
||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
||||
|
||||
Args:
|
||||
input_matrix: Input tensor of shape (H, W)
|
||||
Returns:
|
||||
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
||||
"""
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
rows, cols = input_matrix.shape
|
||||
n_row_blocks = ceil_div(rows, 128)
|
||||
n_col_blocks = ceil_div(cols, 4)
|
||||
|
||||
# Calculate the padded shape
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
padded = input_matrix
|
||||
if (rows, cols) != (padded_rows, padded_cols):
|
||||
padded = torch.zeros(
|
||||
(padded_rows, padded_cols),
|
||||
device=input_matrix.device,
|
||||
dtype=input_matrix.dtype,
|
||||
)
|
||||
padded[:rows, :cols] = input_matrix
|
||||
|
||||
# Rearrange the blocks
|
||||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
if flatten:
|
||||
return rearranged.flatten()
|
||||
|
||||
return rearranged.reshape(padded_rows, padded_cols)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
F4_E2M1_MAX = 6.0
|
||||
F8_E4M3_MAX = 448.0
|
||||
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
|
||||
orig_shape = x.shape
|
||||
|
||||
# Handle padding
|
||||
if pad_16x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 16)
|
||||
padded_cols = roundup(cols, 16)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
|
||||
# what we want to produce. If we pad here, we want the padded output.
|
||||
orig_shape = x.shape
|
||||
|
||||
block_size = 16
|
||||
|
||||
x = x.reshape(orig_shape[0], -1, block_size)
|
||||
max_abs = torch.amax(torch.abs(x), dim=-1)
|
||||
block_scale = max_abs / F4_E2M1_MAX
|
||||
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
|
||||
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
|
||||
|
||||
# Handle zero blocks (from padding): avoid 0/0 NaN
|
||||
zero_scale_mask = (total_scale == 0)
|
||||
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
|
||||
|
||||
x = x / total_scale_safe.unsqueeze(-1)
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
|
||||
|
||||
x = x.view(orig_shape)
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
|
||||
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
|
||||
return data_lp, blocked_scales
|
||||
|
||||
@ -11,6 +11,69 @@ from comfy.ldm.lightricks.model import (
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||
# All patches in a frame are identical, so we only keep the first one
|
||||
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||
else:
|
||||
# Not divisible or too small - store directly without compression
|
||||
self.patches_per_frame = 1
|
||||
self.num_frames = num_tokens
|
||||
self.data = tensor
|
||||
|
||||
def expand(self):
|
||||
"""Expand back to original tensor."""
|
||||
if self.patches_per_frame == 1:
|
||||
return self.data
|
||||
|
||||
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
|
||||
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
|
||||
return expanded.reshape(self.batch_size, -1, self.feature_dim)
|
||||
|
||||
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
|
||||
"""Compute ada values on compressed per-frame data, then expand spatially."""
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
# No compression - compute directly
|
||||
if self.patches_per_frame == 1:
|
||||
num_tokens = self.data.shape[1]
|
||||
dim_per_param = self.feature_dim // num_ada_params
|
||||
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
|
||||
ada_values = (table_values + reshaped).unbind(dim=2)
|
||||
return ada_values
|
||||
|
||||
# Compressed: compute on per-frame data then expand spatially
|
||||
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
|
||||
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
|
||||
device=self.data.device, dtype=self.data.dtype
|
||||
)
|
||||
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
|
||||
|
||||
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
|
||||
return tuple(
|
||||
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
|
||||
.reshape(batch_size, -1, frame_val.shape[-1])
|
||||
for frame_val in frame_ada
|
||||
)
|
||||
|
||||
class BasicAVTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -119,6 +182,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
def get_ada_values(
|
||||
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||
):
|
||||
if isinstance(timestep, CompressedTimestep):
|
||||
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
|
||||
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
ada_values = (
|
||||
@ -146,10 +212,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
gate_timestep,
|
||||
)
|
||||
|
||||
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
||||
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
||||
|
||||
return (*scale_shift_chunks, *gate_ada_values)
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -543,72 +606,80 @@ class LTXAVModel(LTXVModel):
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
|
||||
v_embedded_timestep = v_embedded_timestep.view(
|
||||
batch_size, -1, v_embedded_timestep.shape[-1]
|
||||
)
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
v_patches_per_frame = None
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
a_timestep = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_flat = a_timestep_scaled.flatten()
|
||||
timestep_flat = timestep_scaled.flatten()
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep.flatten(),
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep.flatten(),
|
||||
timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep.flatten() * av_ca_factor,
|
||||
timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep.flatten() * av_ca_factor,
|
||||
a_timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||
a_timestep.flatten(),
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
# Audio timesteps
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(
|
||||
batch_size, -1, a_embedded_timestep.shape[-1]
|
||||
)
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep,
|
||||
av_ca_video_scale_shift_timestep,
|
||||
av_ca_a2v_gate_noise_timestep,
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
]
|
||||
cross_av_timestep_ss = list(
|
||||
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
|
||||
)
|
||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||
else:
|
||||
a_timestep = timestep
|
||||
a_timestep = timestep_scaled
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
|
||||
@ -767,6 +838,11 @@ class LTXAVModel(LTXVModel):
|
||||
ax = x[1]
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
|
||||
# Expand compressed video timestep if needed
|
||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||
v_embedded_timestep = v_embedded_timestep.expand()
|
||||
|
||||
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||
|
||||
# Process audio output
|
||||
|
||||
@ -322,6 +322,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||
for k in sdk:
|
||||
|
||||
@ -699,7 +699,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
|
||||
@ -7,7 +7,7 @@ try:
|
||||
QuantizedTensor,
|
||||
QuantizedLayout,
|
||||
TensorCoreFP8Layout as _CKFp8Layout,
|
||||
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
|
||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||
register_layout_op,
|
||||
register_layout_class,
|
||||
get_layout_class,
|
||||
@ -34,7 +34,7 @@ except ImportError as e:
|
||||
class _CKFp8Layout:
|
||||
pass
|
||||
|
||||
class TensorCoreNVFP4Layout:
|
||||
class _CKNvfp4Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
@ -84,6 +84,39 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if tensor.dim() != 2:
|
||||
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
|
||||
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
padded_shape = cls.get_padded_shape(orig_shape)
|
||||
needs_padding = padded_shape != orig_shape
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
|
||||
else:
|
||||
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
|
||||
|
||||
params = cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=orig_dtype,
|
||||
orig_shape=orig_shape,
|
||||
block_scale=block_scale,
|
||||
)
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
|
||||
@ -845,7 +845,7 @@ class LTXAV(LTXV):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = 0.061 # TODO
|
||||
self.memory_usage_factor = 0.077 # TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXAV(self, device=device)
|
||||
|
||||
@ -5,6 +5,7 @@ from comfy import sdxl_clip
|
||||
import comfy.model_management
|
||||
import torch
|
||||
import logging
|
||||
import folder_paths
|
||||
|
||||
|
||||
class HiDreamTokenizer:
|
||||
@ -91,6 +92,8 @@ class HiDreamTEModel(torch.nn.Module):
|
||||
token_weight_pairs_llama = token_weight_pairs["llama"]
|
||||
lg_out = None
|
||||
pooled = None
|
||||
t5_out = None
|
||||
ll_out = None
|
||||
extra = {}
|
||||
|
||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||
@ -104,8 +107,9 @@ class HiDreamTEModel(torch.nn.Module):
|
||||
else:
|
||||
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
if self.clip_g is not None and self.clip_l is not None:
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
if self.t5xxl is not None:
|
||||
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||
t5_out, t5_pooled = t5_output[:2]
|
||||
@ -120,13 +124,15 @@ class HiDreamTEModel(torch.nn.Module):
|
||||
ll_out = None
|
||||
|
||||
if t5_out is None:
|
||||
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
|
||||
t5_path = folder_paths.get_full_path_or_raise("hidream_empty_latents", "t5_blank.pt")
|
||||
t5_out = torch.load(t5_path, map_location=comfy.model_management.intermediate_device())
|
||||
|
||||
if ll_out is None:
|
||||
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if pooled is None:
|
||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||
pooled_path = folder_paths.get_full_path_or_raise("hidream_empty_latents", "pooled_blank.pt")
|
||||
pooled = torch.load(pooled_path, map_location=comfy.model_management.intermediate_device())
|
||||
|
||||
extra["conditioning_llama3"] = ll_out
|
||||
return t5_out, pooled, extra
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.8.2"
|
||||
__version__ = "0.9.1"
|
||||
|
||||
@ -47,6 +47,7 @@ folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetwo
|
||||
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||
folder_names_and_paths["hidream_empty_latents"] = ([os.path.join(models_dir, "hidream_empty_latents")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions)
|
||||
|
||||
|
||||
BIN
models/hidream_empty_latents/pooled_blank.pt
Normal file
BIN
models/hidream_empty_latents/pooled_blank.pt
Normal file
Binary file not shown.
BIN
models/hidream_empty_latents/t5_blank.pt
Normal file
BIN
models/hidream_empty_latents/t5_blank.pt
Normal file
Binary file not shown.
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.8.2"
|
||||
version = "0.9.1"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.36.13
|
||||
comfyui-workflow-templates==0.7.69
|
||||
comfyui-frontend-package==1.36.14
|
||||
comfyui-workflow-templates==0.8.4
|
||||
comfyui-embedded-docs==0.4.0
|
||||
torch
|
||||
torchsde
|
||||
@ -21,7 +21,7 @@ psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.5
|
||||
comfy-kitchen>=0.2.6
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user