mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 05:52:33 +08:00
Merge remote-tracking branch 'origin/master' into group-nodes
This commit is contained in:
commit
73cc92af77
@ -416,7 +416,7 @@ class T2IAdapter(ControlBase):
|
|||||||
if control_prev is not None:
|
if control_prev is not None:
|
||||||
return control_prev
|
return control_prev
|
||||||
else:
|
else:
|
||||||
return {}
|
return None
|
||||||
|
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
|
|||||||
@ -95,9 +95,19 @@ def Normalize(in_channels, dtype=None, device=None):
|
|||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def attention_basic(q, k, v, heads, mask=None):
|
def attention_basic(q, k, v, heads, mask=None):
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
scale = dim_head ** -0.5
|
||||||
|
|
||||||
h = heads
|
h = heads
|
||||||
scale = (q.shape[-1] // heads) ** -0.5
|
q, k, v = map(
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(b, -1, heads, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b * heads, -1, dim_head)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
# force cast to fp32 to avoid overflowing
|
# force cast to fp32 to avoid overflowing
|
||||||
if _ATTN_PRECISION =="fp32":
|
if _ATTN_PRECISION =="fp32":
|
||||||
@ -119,16 +129,24 @@ def attention_basic(q, k, v, heads, mask=None):
|
|||||||
sim = sim.softmax(dim=-1)
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None):
|
def attention_sub_quad(query, key, value, heads, mask=None):
|
||||||
scale = (query.shape[-1] // heads) ** -0.5
|
b, _, dim_head = query.shape
|
||||||
query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
|
dim_head //= heads
|
||||||
key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1)
|
|
||||||
del key
|
scale = dim_head ** -0.5
|
||||||
value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
|
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||||
|
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||||
|
|
||||||
|
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||||
|
|
||||||
dtype = query.dtype
|
dtype = query.dtype
|
||||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
||||||
@ -137,7 +155,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
|||||||
else:
|
else:
|
||||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||||
batch_x_heads, q_tokens, _ = query.shape
|
batch_x_heads, q_tokens, _ = query.shape
|
||||||
_, _, k_tokens = key_t.shape
|
_, _, k_tokens = key.shape
|
||||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
|
|
||||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||||
@ -171,7 +189,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
|||||||
|
|
||||||
hidden_states = efficient_dot_product_attention(
|
hidden_states = efficient_dot_product_attention(
|
||||||
query,
|
query,
|
||||||
key_t,
|
key,
|
||||||
value,
|
value,
|
||||||
query_chunk_size=query_chunk_size,
|
query_chunk_size=query_chunk_size,
|
||||||
kv_chunk_size=kv_chunk_size,
|
kv_chunk_size=kv_chunk_size,
|
||||||
@ -186,9 +204,19 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_split(q, k, v, heads, mask=None):
|
def attention_split(q, k, v, heads, mask=None):
|
||||||
scale = (q.shape[-1] // heads) ** -0.5
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
scale = dim_head ** -0.5
|
||||||
|
|
||||||
h = heads
|
h = heads
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(
|
||||||
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(b, -1, heads, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b * heads, -1, dim_head)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
@ -248,17 +276,23 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
r1 = (
|
||||||
del r1
|
r1.unsqueeze(0)
|
||||||
return r2
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
return r1
|
||||||
|
|
||||||
def attention_xformers(q, k, v, heads, mask=None):
|
def attention_xformers(q, k, v, heads, mask=None):
|
||||||
b, _, _ = q.shape
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.unsqueeze(3)
|
lambda t: t.unsqueeze(3)
|
||||||
.reshape(b, t.shape[1], heads, -1)
|
.reshape(b, -1, heads, dim_head)
|
||||||
.permute(0, 2, 1, 3)
|
.permute(0, 2, 1, 3)
|
||||||
.reshape(b * heads, t.shape[1], -1)
|
.reshape(b * heads, -1, dim_head)
|
||||||
.contiguous(),
|
.contiguous(),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
@ -270,9 +304,9 @@ def attention_xformers(q, k, v, heads, mask=None):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
out = (
|
out = (
|
||||||
out.unsqueeze(0)
|
out.unsqueeze(0)
|
||||||
.reshape(b, heads, out.shape[1], -1)
|
.reshape(b, heads, -1, dim_head)
|
||||||
.permute(0, 2, 1, 3)
|
.permute(0, 2, 1, 3)
|
||||||
.reshape(b, out.shape[1], -1)
|
.reshape(b, -1, heads * dim_head)
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -463,7 +463,11 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ((widget.type === "number" && !inputData?.[1]?.control_after_generate) || widget.type === "combo") {
|
if ((widget.type === "number" && !inputData?.[1]?.control_after_generate) || widget.type === "combo") {
|
||||||
addValueControlWidget(this, widget, "fixed");
|
let control_value = this.widgets_values?.[1];
|
||||||
|
if (!control_value) {
|
||||||
|
control_value = "fixed";
|
||||||
|
}
|
||||||
|
addValueControlWidget(this, widget, control_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// When our value changes, update other widgets to reflect our changes
|
// When our value changes, update other widgets to reflect our changes
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user