This commit is contained in:
Yousef Rafat 2026-01-04 20:30:24 +02:00
parent 0da072e098
commit f588e6c821
2 changed files with 23 additions and 12 deletions

View File

@ -98,7 +98,7 @@ def var_attn_arg(kwargs):
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
# feedforward
class GEGLU(nn.Module):
@ -449,9 +449,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
k = k.view(1, total_tokens, heads, dim_head)
v = v.view(1, total_tokens, heads, dim_head)
else:
if q.ndim == 3: q = q.unsqueeze(0)
if k.ndim == 3: k = k.unsqueeze(0)
if v.ndim == 3: v = v.unsqueeze(0)
if q.ndim == 3:
q = q.unsqueeze(0)
if k.ndim == 3:
k = k.unsqueeze(0)
if v.ndim == 3:
v = v.unsqueeze(0)
dim_head = q.shape[-1]
target_output_shape = (q.shape[1], -1)
@ -526,7 +529,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
b = q.size(0); dim_head = q.shape[-1]
b = q.size(0)
dim_head = q.shape[-1]
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())

View File

@ -78,8 +78,10 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
else:
out = vae_model.decode_(t_chunk)
if isinstance(out, (tuple, list)): out = out[0]
if out.ndim == 4: out = out.unsqueeze(2)
if isinstance(out, (tuple, list)):
out = out[0]
if out.ndim == 4:
out = out.unsqueeze(2)
if pad_amount > 0:
if encode:
@ -136,13 +138,17 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
if cur_ov_h > 0:
r = get_ramp(cur_ov_h)
if y_idx > 0: w_h[:cur_ov_h] = r
if y_end < h: w_h[-cur_ov_h:] = 1.0 - r
if y_idx > 0:
w_h[:cur_ov_h] = r
if y_end < h:
w_h[-cur_ov_h:] = 1.0 - r
if cur_ov_w > 0:
r = get_ramp(cur_ov_w)
if x_idx > 0: w_w[:cur_ov_w] = r
if x_end < w: w_w[-cur_ov_w:] = 1.0 - r
if x_idx > 0:
w_w[:cur_ov_w] = r
if x_end < w:
w_w[-cur_ov_w:] = 1.0 - r
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
@ -335,7 +341,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
comfy.model_management.load_models_gpu([vae.patcher])
vae_model = vae.first_stage_model
scale = 0.9152; shift = 0
scale = 0.9152
shift = 0
if images.dim() != 5: # add the t dim
images = images.unsqueeze(0)
images = images.permute(0, 1, 4, 2, 3)