Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-05-20 12:06:35 -07:00
commit f69b6225c0
7 changed files with 38 additions and 26 deletions

View File

@ -52,6 +52,7 @@ class ControlNet(nn.Module):
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None, transformer_depth_middle=None,
transformer_depth_output=None, transformer_depth_output=None,
attn_precision=None,
device=None, device=None,
operations=ops.disable_weight_init, operations=ops.disable_weight_init,
**kwargs, **kwargs,
@ -202,7 +203,7 @@ class ControlNet(nn.Module):
SpatialTransformer( SpatialTransformer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -262,7 +263,7 @@ class ControlNet(nn.Module):
mid_block += [SpatialTransformer( # always uses a self-attn mid_block += [SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
), ),
ResBlock( ResBlock(
ch, ch,

View File

@ -736,8 +736,27 @@ ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
outputs = set() outputs = set()
for x in prompt: for x in prompt:
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] if 'class_type' not in prompt[x]:
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True: error = {
"type": "invalid_prompt",
"message": f"Cannot execute because a node is missing the class_type property.",
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
if class_ is None:
error = {
"type": "invalid_prompt",
"message": f"Cannot execute because node {class_type} does not exist.",
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x) outputs.add(x)
if len(outputs) == 0: if len(outputs) == 0:

View File

@ -40,12 +40,13 @@ class Latent2RGBPreviewer(LatentPreviewer):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu") self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
latents_ubyte = (((latent_image + 1) / 2) latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1 .clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255 .mul(0xFF) # to 0..255
.byte()).cpu() ).to(device="cpu", dtype=torch.uint8, non_blocking=True)
return Image.fromarray(latents_ubyte.numpy()) return Image.fromarray(latents_ubyte.numpy())
@ -66,8 +67,6 @@ def get_previewer(device, latent_format):
if method == LatentPreviewMethod.Auto: if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB method = LatentPreviewMethod.Latent2RGB
if taesd_decoder_path:
method = LatentPreviewMethod.TAESD
if method == LatentPreviewMethod.TAESD: if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path: if taesd_decoder_path:

View File

@ -6,7 +6,7 @@ from einops import rearrange, repeat
from typing import Optional, Any from typing import Optional, Any
import logging import logging
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
from ... import model_management from ... import model_management
@ -317,11 +317,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
return attention_pytorch(q, k, v, heads, mask) return attention_pytorch(q, k, v, heads, mask)
q, k, v = map( q, k, v = map(
lambda t: t.unsqueeze(3) lambda t: t.reshape(b, -1, heads, dim_head),
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v), (q, k, v),
) )
@ -334,10 +330,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = ( out = (
out.unsqueeze(0) out.reshape(b, -1, heads * dim_head)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
) )
return out return out
@ -460,15 +453,11 @@ class BasicTransformerBlock(nn.Module):
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads self.n_heads = n_heads
self.d_head = d_head self.d_head = d_head
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
def forward(self, x, context=None, transformer_options={}): def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
extra_options = {} extra_options = {}
block = transformer_options.get("block", None) block = transformer_options.get("block", None)
block_index = transformer_options.get("block_index", 0) block_index = transformer_options.get("block_index", 0)
@ -635,7 +624,7 @@ class SpatialTransformer(nn.Module):
x = self.norm(x) x = self.norm(x)
if not self.use_linear: if not self.use_linear:
x = self.proj_in(x) x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous() x = x.movedim(1, 3).flatten(1, 2).contiguous()
if self.use_linear: if self.use_linear:
x = self.proj_in(x) x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
@ -643,7 +632,7 @@ class SpatialTransformer(nn.Module):
x = block(x, context=context[i], transformer_options=transformer_options) x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear: if self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous()
if not self.use_linear: if not self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in

View File

@ -3,7 +3,6 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
import logging import logging

View File

@ -258,7 +258,7 @@ class ResBlock(TimestepBlock):
else: else:
if emb_out is not None: if emb_out is not None:
if self.exchange_temb_dims: if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...") emb_out = emb_out.movedim(1, 2)
h = h + emb_out h = h + emb_out
h = self.out_layers(h) h = self.out_layers(h)
return self.skip_connection(x) + h return self.skip_connection(x) + h

View File

@ -143,6 +143,11 @@ total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try:
logging.info("pytorch version: {}".format(torch.version.__version__))
except:
pass
try: try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except: except: