mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
f69b6225c0
@ -52,6 +52,7 @@ class ControlNet(nn.Module):
|
|||||||
adm_in_channels=None,
|
adm_in_channels=None,
|
||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
transformer_depth_output=None,
|
transformer_depth_output=None,
|
||||||
|
attn_precision=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=ops.disable_weight_init,
|
operations=ops.disable_weight_init,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -202,7 +203,7 @@ class ControlNet(nn.Module):
|
|||||||
SpatialTransformer(
|
SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
@ -262,7 +263,7 @@ class ControlNet(nn.Module):
|
|||||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
||||||
),
|
),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
|
|||||||
@ -736,8 +736,27 @@ ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[
|
|||||||
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||||
outputs = set()
|
outputs = set()
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
if 'class_type' not in prompt[x]:
|
||||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True:
|
error = {
|
||||||
|
"type": "invalid_prompt",
|
||||||
|
"message": f"Cannot execute because a node is missing the class_type property.",
|
||||||
|
"details": f"Node ID '#{x}'",
|
||||||
|
"extra_info": {}
|
||||||
|
}
|
||||||
|
return (False, error, [], [])
|
||||||
|
|
||||||
|
class_type = prompt[x]['class_type']
|
||||||
|
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
|
||||||
|
if class_ is None:
|
||||||
|
error = {
|
||||||
|
"type": "invalid_prompt",
|
||||||
|
"message": f"Cannot execute because node {class_type} does not exist.",
|
||||||
|
"details": f"Node ID '#{x}'",
|
||||||
|
"extra_info": {}
|
||||||
|
}
|
||||||
|
return (False, error, [], [])
|
||||||
|
|
||||||
|
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||||
outputs.add(x)
|
outputs.add(x)
|
||||||
|
|
||||||
if len(outputs) == 0:
|
if len(outputs) == 0:
|
||||||
|
|||||||
@ -40,12 +40,13 @@ class Latent2RGBPreviewer(LatentPreviewer):
|
|||||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
|
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||||
|
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||||
|
|
||||||
latents_ubyte = (((latent_image + 1) / 2)
|
latents_ubyte = (((latent_image + 1) / 2)
|
||||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
.mul(0xFF) # to 0..255
|
.mul(0xFF) # to 0..255
|
||||||
.byte()).cpu()
|
).to(device="cpu", dtype=torch.uint8, non_blocking=True)
|
||||||
|
|
||||||
return Image.fromarray(latents_ubyte.numpy())
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
|
|
||||||
@ -66,8 +67,6 @@ def get_previewer(device, latent_format):
|
|||||||
|
|
||||||
if method == LatentPreviewMethod.Auto:
|
if method == LatentPreviewMethod.Auto:
|
||||||
method = LatentPreviewMethod.Latent2RGB
|
method = LatentPreviewMethod.Latent2RGB
|
||||||
if taesd_decoder_path:
|
|
||||||
method = LatentPreviewMethod.TAESD
|
|
||||||
|
|
||||||
if method == LatentPreviewMethod.TAESD:
|
if method == LatentPreviewMethod.TAESD:
|
||||||
if taesd_decoder_path:
|
if taesd_decoder_path:
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from einops import rearrange, repeat
|
|||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
|
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
||||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||||
from ... import model_management
|
from ... import model_management
|
||||||
|
|
||||||
@ -317,11 +317,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
|||||||
return attention_pytorch(q, k, v, heads, mask)
|
return attention_pytorch(q, k, v, heads, mask)
|
||||||
|
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.unsqueeze(3)
|
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||||
.reshape(b, -1, heads, dim_head)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(b * heads, -1, dim_head)
|
|
||||||
.contiguous(),
|
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -334,10 +330,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
|||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||||
|
|
||||||
out = (
|
out = (
|
||||||
out.unsqueeze(0)
|
out.reshape(b, -1, heads * dim_head)
|
||||||
.reshape(b, heads, -1, dim_head)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(b, -1, heads * dim_head)
|
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -460,15 +453,11 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.checkpoint = checkpoint
|
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
|
||||||
|
|
||||||
def _forward(self, x, context=None, transformer_options={}):
|
|
||||||
extra_options = {}
|
extra_options = {}
|
||||||
block = transformer_options.get("block", None)
|
block = transformer_options.get("block", None)
|
||||||
block_index = transformer_options.get("block_index", 0)
|
block_index = transformer_options.get("block_index", 0)
|
||||||
@ -635,7 +624,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
if not self.use_linear:
|
if not self.use_linear:
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
x = x.movedim(1, 3).flatten(1, 2).contiguous()
|
||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
@ -643,7 +632,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
x = block(x, context=context[i], transformer_options=transformer_options)
|
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous()
|
||||||
if not self.use_linear:
|
if not self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
return x + x_in
|
return x + x_in
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import rearrange
|
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|||||||
@ -258,7 +258,7 @@ class ResBlock(TimestepBlock):
|
|||||||
else:
|
else:
|
||||||
if emb_out is not None:
|
if emb_out is not None:
|
||||||
if self.exchange_temb_dims:
|
if self.exchange_temb_dims:
|
||||||
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
emb_out = emb_out.movedim(1, 2)
|
||||||
h = h + emb_out
|
h = h + emb_out
|
||||||
h = self.out_layers(h)
|
h = self.out_layers(h)
|
||||||
return self.skip_connection(x) + h
|
return self.skip_connection(x) + h
|
||||||
|
|||||||
@ -143,6 +143,11 @@ total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
|||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.info("pytorch version: {}".format(torch.version.__version__))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||||
except:
|
except:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user