Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-05-20 12:06:35 -07:00
commit f69b6225c0
7 changed files with 38 additions and 26 deletions

View File

@ -52,6 +52,7 @@ class ControlNet(nn.Module):
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
attn_precision=None,
device=None,
operations=ops.disable_weight_init,
**kwargs,
@ -202,7 +203,7 @@ class ControlNet(nn.Module):
SpatialTransformer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -262,7 +263,7 @@ class ControlNet(nn.Module):
mid_block += [SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
),
ResBlock(
ch,

View File

@ -736,8 +736,27 @@ ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
outputs = set()
for x in prompt:
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True:
if 'class_type' not in prompt[x]:
error = {
"type": "invalid_prompt",
"message": f"Cannot execute because a node is missing the class_type property.",
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
if class_ is None:
error = {
"type": "invalid_prompt",
"message": f"Cannot execute because node {class_type} does not exist.",
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x)
if len(outputs) == 0:

View File

@ -40,12 +40,13 @@ class Latent2RGBPreviewer(LatentPreviewer):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
def decode_latent_to_preview(self, x0):
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()
).to(device="cpu", dtype=torch.uint8, non_blocking=True)
return Image.fromarray(latents_ubyte.numpy())
@ -66,8 +67,6 @@ def get_previewer(device, latent_format):
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
if taesd_decoder_path:
method = LatentPreviewMethod.TAESD
if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:

View File

@ -6,7 +6,7 @@ from einops import rearrange, repeat
from typing import Optional, Any
import logging
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
from ... import model_management
@ -317,11 +317,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
@ -334,10 +330,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
out.reshape(b, -1, heads * dim_head)
)
return out
@ -460,15 +453,11 @@ class BasicTransformerBlock(nn.Module):
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
extra_options = {}
block = transformer_options.get("block", None)
block_index = transformer_options.get("block_index", 0)
@ -635,7 +624,7 @@ class SpatialTransformer(nn.Module):
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
x = x.movedim(1, 3).flatten(1, 2).contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
@ -643,7 +632,7 @@ class SpatialTransformer(nn.Module):
x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in

View File

@ -3,7 +3,6 @@ import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from typing import Optional, Any
import logging

View File

@ -258,7 +258,7 @@ class ResBlock(TimestepBlock):
else:
if emb_out is not None:
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
emb_out = emb_out.movedim(1, 2)
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h

View File

@ -143,6 +143,11 @@ total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try:
logging.info("pytorch version: {}".format(torch.version.__version__))
except:
pass
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except: