mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
ruff
This commit is contained in:
parent
0da072e098
commit
f588e6c821
@ -98,7 +98,7 @@ def var_attn_arg(kwargs):
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
|
||||
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
@ -449,9 +449,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
k = k.view(1, total_tokens, heads, dim_head)
|
||||
v = v.view(1, total_tokens, heads, dim_head)
|
||||
else:
|
||||
if q.ndim == 3: q = q.unsqueeze(0)
|
||||
if k.ndim == 3: k = k.unsqueeze(0)
|
||||
if v.ndim == 3: v = v.unsqueeze(0)
|
||||
if q.ndim == 3:
|
||||
q = q.unsqueeze(0)
|
||||
if k.ndim == 3:
|
||||
k = k.unsqueeze(0)
|
||||
if v.ndim == 3:
|
||||
v = v.unsqueeze(0)
|
||||
dim_head = q.shape[-1]
|
||||
|
||||
target_output_shape = (q.shape[1], -1)
|
||||
@ -526,7 +529,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
b = q.size(0); dim_head = q.shape[-1]
|
||||
b = q.size(0)
|
||||
dim_head = q.shape[-1]
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
|
||||
@ -78,8 +78,10 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
else:
|
||||
out = vae_model.decode_(t_chunk)
|
||||
|
||||
if isinstance(out, (tuple, list)): out = out[0]
|
||||
if out.ndim == 4: out = out.unsqueeze(2)
|
||||
if isinstance(out, (tuple, list)):
|
||||
out = out[0]
|
||||
if out.ndim == 4:
|
||||
out = out.unsqueeze(2)
|
||||
|
||||
if pad_amount > 0:
|
||||
if encode:
|
||||
@ -136,13 +138,17 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
|
||||
if cur_ov_h > 0:
|
||||
r = get_ramp(cur_ov_h)
|
||||
if y_idx > 0: w_h[:cur_ov_h] = r
|
||||
if y_end < h: w_h[-cur_ov_h:] = 1.0 - r
|
||||
if y_idx > 0:
|
||||
w_h[:cur_ov_h] = r
|
||||
if y_end < h:
|
||||
w_h[-cur_ov_h:] = 1.0 - r
|
||||
|
||||
if cur_ov_w > 0:
|
||||
r = get_ramp(cur_ov_w)
|
||||
if x_idx > 0: w_w[:cur_ov_w] = r
|
||||
if x_end < w: w_w[-cur_ov_w:] = 1.0 - r
|
||||
if x_idx > 0:
|
||||
w_w[:cur_ov_w] = r
|
||||
if x_end < w:
|
||||
w_w[-cur_ov_w:] = 1.0 - r
|
||||
|
||||
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
||||
|
||||
@ -335,7 +341,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
|
||||
comfy.model_management.load_models_gpu([vae.patcher])
|
||||
vae_model = vae.first_stage_model
|
||||
scale = 0.9152; shift = 0
|
||||
scale = 0.9152
|
||||
shift = 0
|
||||
if images.dim() != 5: # add the t dim
|
||||
images = images.unsqueeze(0)
|
||||
images = images.permute(0, 1, 4, 2, 3)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user