mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 14:20:27 +08:00
Compare commits
14 Commits
73459299a3
...
cc34bd35e9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc34bd35e9 | ||
|
|
fcd9a236b0 | ||
|
|
21e8425087 | ||
|
|
b6c79a648a | ||
|
|
25bc1b5b57 | ||
|
|
3cd19e99c1 | ||
|
|
007b87e7ac | ||
|
|
34751fe9f9 | ||
|
|
1c705f7bfb | ||
|
|
48e5ea1dfd | ||
|
|
296b7c7b6d | ||
|
|
a3f78be5c2 | ||
|
|
0c84b7650f | ||
|
|
2bff3c520f |
@ -6,6 +6,12 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.kandinsky5.utils_nabla import (
|
||||
fractal_flatten,
|
||||
fractal_unflatten,
|
||||
fast_sta_nabla,
|
||||
nabla,
|
||||
)
|
||||
|
||||
def attention(q, k, v, heads, transformer_options={}):
|
||||
return optimized_attention(
|
||||
@ -116,14 +122,17 @@ class SelfAttention(nn.Module):
|
||||
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(norm_fn(result), freqs)
|
||||
|
||||
def _forward(self, x, freqs, transformer_options={}):
|
||||
def _forward(self, x, freqs, sparse_params=None, transformer_options={}):
|
||||
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
if sparse_params is None:
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
else:
|
||||
out = nabla(q, k, v, sparse_params)
|
||||
return self.out_layer(out)
|
||||
|
||||
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||
def _forward_chunked(self, x, freqs, sparse_params=None, transformer_options={}):
|
||||
def process_chunks(proj_fn, norm_fn):
|
||||
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||
@ -135,14 +144,17 @@ class SelfAttention(nn.Module):
|
||||
q = process_chunks(self.to_query, self.query_norm)
|
||||
k = process_chunks(self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
if sparse_params is None:
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
else:
|
||||
out = nabla(q, k, v, sparse_params)
|
||||
return self.out_layer(out)
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
def forward(self, x, freqs, sparse_params=None, transformer_options={}):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
||||
return self._forward_chunked(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
|
||||
else:
|
||||
return self._forward(x, freqs, transformer_options=transformer_options)
|
||||
return self._forward(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class CrossAttention(SelfAttention):
|
||||
@ -251,12 +263,12 @@ class TransformerDecoderBlock(nn.Module):
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
||||
def forward(self, visual_embed, text_embed, time_embed, freqs, sparse_params=None, transformer_options={}):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||
# self attention
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
||||
visual_out = self.self_attention(visual_out, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# cross attention
|
||||
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||
@ -369,21 +381,82 @@ class Kandinsky5(nn.Module):
|
||||
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_embed = visual_embed.flatten(1, -2)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
|
||||
B, _, T, H, W = x.shape
|
||||
NABLA_THR = 31 # long (10 sec) generation
|
||||
if T > NABLA_THR:
|
||||
assert self.patch_size[0] == 1
|
||||
|
||||
# pro video model uses lower P at higher resolutions
|
||||
P = 0.7 if self.model_dim == 4096 and H * W >= 14080 else 0.9
|
||||
|
||||
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])
|
||||
visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:])
|
||||
pt, ph, pw = self.patch_size
|
||||
T, H, W = T // pt, H // ph, W // pw
|
||||
|
||||
wT, wW, wH = 11, 3, 3
|
||||
sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device)
|
||||
|
||||
sparse_params = dict(
|
||||
sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0),
|
||||
attention_type="nabla",
|
||||
to_fractal=True,
|
||||
P=P,
|
||||
wT=wT, wW=wW, wH=wH,
|
||||
add_sta=True,
|
||||
visual_shape=(T, H, W),
|
||||
method="topcdf",
|
||||
)
|
||||
else:
|
||||
sparse_params = None
|
||||
visual_embed = visual_embed.flatten(1, -2)
|
||||
|
||||
for i, block in enumerate(self.visual_transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
||||
return block(
|
||||
x=args["x"],
|
||||
context=args["context"],
|
||||
time_embed=args["time_embed"],
|
||||
freqs=args["freqs"],
|
||||
sparse_params=args.get("sparse_params"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
)
|
||||
visual_embed = blocks_replace[("double_block", i)](
|
||||
{
|
||||
"x": visual_embed,
|
||||
"context": context,
|
||||
"time_embed": time_embed,
|
||||
"freqs": freqs,
|
||||
"sparse_params": sparse_params,
|
||||
"transformer_options": transformer_options,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)["x"]
|
||||
else:
|
||||
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||
visual_embed = block(
|
||||
visual_embed,
|
||||
context,
|
||||
time_embed,
|
||||
freqs=freqs,
|
||||
sparse_params=sparse_params,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if T > NABLA_THR:
|
||||
visual_embed = fractal_unflatten(
|
||||
visual_embed,
|
||||
visual_shape[1:],
|
||||
)
|
||||
else:
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
return self.out_layer(visual_embed, time_embed)
|
||||
|
||||
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
|
||||
146
comfy/ldm/kandinsky5/utils_nabla.py
Normal file
146
comfy/ldm/kandinsky5/utils_nabla.py
Normal file
@ -0,0 +1,146 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.attention.flex_attention import BlockMask, flex_attention
|
||||
|
||||
|
||||
def fractal_flatten(x, rope, shape):
|
||||
pixel_size = 8
|
||||
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
x = x.flatten(1, 2)
|
||||
rope = rope.flatten(1, 2)
|
||||
return x, rope
|
||||
|
||||
|
||||
def fractal_unflatten(x, shape):
|
||||
pixel_size = 8
|
||||
x = x.reshape(x.shape[0], -1, pixel_size**2, x.shape[-1])
|
||||
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
return x
|
||||
|
||||
def local_patching(x, shape, group_size, dim=0):
|
||||
duration, height, width = shape
|
||||
g1, g2, g3 = group_size
|
||||
x = x.reshape(
|
||||
*x.shape[:dim],
|
||||
duration // g1,
|
||||
g1,
|
||||
height // g2,
|
||||
g2,
|
||||
width // g3,
|
||||
g3,
|
||||
*x.shape[dim + 3 :]
|
||||
)
|
||||
x = x.permute(
|
||||
*range(len(x.shape[:dim])),
|
||||
dim,
|
||||
dim + 2,
|
||||
dim + 4,
|
||||
dim + 1,
|
||||
dim + 3,
|
||||
dim + 5,
|
||||
*range(dim + 6, len(x.shape))
|
||||
)
|
||||
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
|
||||
return x
|
||||
|
||||
|
||||
def local_merge(x, shape, group_size, dim=0):
|
||||
duration, height, width = shape
|
||||
g1, g2, g3 = group_size
|
||||
x = x.reshape(
|
||||
*x.shape[:dim],
|
||||
duration // g1,
|
||||
height // g2,
|
||||
width // g3,
|
||||
g1,
|
||||
g2,
|
||||
g3,
|
||||
*x.shape[dim + 2 :]
|
||||
)
|
||||
x = x.permute(
|
||||
*range(len(x.shape[:dim])),
|
||||
dim,
|
||||
dim + 3,
|
||||
dim + 1,
|
||||
dim + 4,
|
||||
dim + 2,
|
||||
dim + 5,
|
||||
*range(dim + 6, len(x.shape))
|
||||
)
|
||||
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
|
||||
return x
|
||||
|
||||
def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> Tensor:
|
||||
l = torch.Tensor([T, H, W]).amax()
|
||||
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
|
||||
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
|
||||
sta_t, sta_h, sta_w = (
|
||||
mat[:T, :T].flatten(),
|
||||
mat[:H, :H].flatten(),
|
||||
mat[:W, :W].flatten(),
|
||||
)
|
||||
sta_t = sta_t <= wT // 2
|
||||
sta_h = sta_h <= wH // 2
|
||||
sta_w = sta_w <= wW // 2
|
||||
sta_hw = (
|
||||
(sta_h.unsqueeze(1) * sta_w.unsqueeze(0))
|
||||
.reshape(H, H, W, W)
|
||||
.transpose(1, 2)
|
||||
.flatten()
|
||||
)
|
||||
sta = (
|
||||
(sta_t.unsqueeze(1) * sta_hw.unsqueeze(0))
|
||||
.reshape(T, T, H * W, H * W)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
return sta.reshape(T * H * W, T * H * W)
|
||||
|
||||
def nablaT_v2(q: Tensor, k: Tensor, sta: Tensor, thr: float = 0.9) -> BlockMask:
|
||||
# Map estimation
|
||||
B, h, S, D = q.shape
|
||||
s1 = S // 64
|
||||
qa = q.reshape(B, h, s1, 64, D).mean(-2)
|
||||
ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
|
||||
map = qa @ ka
|
||||
|
||||
map = torch.softmax(map / math.sqrt(D), dim=-1)
|
||||
# Map binarization
|
||||
vals, inds = map.sort(-1)
|
||||
cvals = vals.cumsum_(-1)
|
||||
mask = (cvals >= 1 - thr).int()
|
||||
mask = mask.gather(-1, inds.argsort(-1))
|
||||
mask = torch.logical_or(mask, sta)
|
||||
|
||||
# BlockMask creation
|
||||
kv_nb = mask.sum(-1).to(torch.int32)
|
||||
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
|
||||
return BlockMask.from_kv_blocks(
|
||||
torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None
|
||||
)
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
||||
def nabla(query, key, value, sparse_params=None):
|
||||
query = query.transpose(1, 2).contiguous()
|
||||
key = key.transpose(1, 2).contiguous()
|
||||
value = value.transpose(1, 2).contiguous()
|
||||
block_mask = nablaT_v2(
|
||||
query,
|
||||
key,
|
||||
sparse_params["sta_mask"],
|
||||
thr=sparse_params["P"],
|
||||
)
|
||||
out = (
|
||||
flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
block_mask=block_mask
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
out = out.flatten(-2, -1)
|
||||
return out
|
||||
@ -718,6 +718,7 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
cast_weight = self.force_cast_weights
|
||||
m.comfy_force_cast_weights = self.force_cast_weights
|
||||
if lowvram_weight:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.weight_function = []
|
||||
@ -790,11 +791,12 @@ class ModelPatcher:
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
|
||||
30
comfy/ops.py
30
comfy/ops.py
@ -654,29 +654,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
run_every_op()
|
||||
|
||||
input_shape = input.shape
|
||||
tensor_3d = input.ndim == 3
|
||||
|
||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
reshaped_3d = False
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
if tensor_3d:
|
||||
input = input.reshape(-1, input_shape[2])
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
|
||||
if input.ndim != 2:
|
||||
# Fall back to comfy_cast_weights for non-2D tensors
|
||||
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
|
||||
# Fall back to non-quantized for non-2D tensors
|
||||
if input_reshaped.ndim == 2:
|
||||
reshaped_3d = input.ndim == 3
|
||||
# dtype is now implicit in the layout class
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
# dtype is now implicit in the layout class
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
|
||||
|
||||
output = self._forward(input, self.weight, self.bias)
|
||||
output = self.forward_comfy_cast_weights(input)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if tensor_3d:
|
||||
if reshaped_3d:
|
||||
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
|
||||
|
||||
return output
|
||||
|
||||
@ -19,6 +19,7 @@ try:
|
||||
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
|
||||
if cuda_version < (13,):
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
|
||||
11
comfy/sd.py
11
comfy/sd.py
@ -218,7 +218,7 @@ class CLIP:
|
||||
if unprojected:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
@ -266,7 +266,7 @@ class CLIP:
|
||||
if return_pooled == "unprojected":
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
@ -299,8 +299,11 @@ class CLIP:
|
||||
sd_clip[k] = sd_tokenizer[k]
|
||||
return sd_clip
|
||||
|
||||
def load_model(self):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
def load_model(self, tokens={}):
|
||||
memory_used = 0
|
||||
if hasattr(self.cond_stage_model, "memory_estimation_function"):
|
||||
memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
return self.patcher
|
||||
|
||||
def get_key_patches(self):
|
||||
|
||||
@ -845,7 +845,7 @@ class LTXAV(LTXV):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = 0.055 # TODO
|
||||
self.memory_usage_factor = 0.061 # TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXAV(self, device=device)
|
||||
|
||||
@ -98,10 +98,13 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||
out_device = out.device
|
||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||
out = out.movedim(1, -1).to(self.execution_device)
|
||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||
out = self.text_embedding_projection(out)
|
||||
out = out.float()
|
||||
out_vid = self.video_embeddings_connector(out)[0]
|
||||
out_audio = self.audio_embeddings_connector(out)[0]
|
||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||
@ -118,6 +121,14 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
return self.load_state_dict(sdo, strict=False)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
constant = 6.0
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
constant /= 2.0
|
||||
|
||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class LTXAVTEModel_(LTXAVTEModel):
|
||||
|
||||
@ -34,6 +34,9 @@ class Kandinsky5ImageToVideo(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||
if length > 121: # 10 sec generation, for nabla
|
||||
height = 128 * round(height / 128)
|
||||
width = 128 * round(width / 128)
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
cond_latent_out = {}
|
||||
if start_image is not None:
|
||||
|
||||
@ -185,6 +185,10 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
options=folder_paths.get_filename_list("checkpoints"),
|
||||
),
|
||||
io.Combo.Input(
|
||||
"device",
|
||||
options=["default", "cpu"],
|
||||
)
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
@ -197,7 +201,11 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
model_options = {}
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
return io.NodeOutput(clip)
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.35.9
|
||||
comfyui-workflow-templates==0.7.67
|
||||
comfyui-workflow-templates==0.7.69
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
@ -21,7 +21,7 @@ psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.3
|
||||
comfy-kitchen>=0.2.5
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user