Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-05-16 14:28:49 -07:00
commit 3d98440fb7
17 changed files with 232 additions and 77 deletions

View File

@ -7,7 +7,7 @@ on:
description: 'cuda version' description: 'cuda version'
required: true required: true
type: string type: string
default: "121" default: "124"
python_minor: python_minor:
description: 'python minor version' description: 'python minor version'
@ -19,7 +19,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "2" default: "3"
# push: # push:
# branches: # branches:
# - master # - master
@ -49,7 +49,7 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py ./python.exe get-pip.py
python -m pip wheel torch torchvision mpmath==1.3.0 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir python -m pip wheel torch torchvision mpmath==1.3.0 numpy==1.26.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/* ./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth

View File

@ -76,9 +76,6 @@ def create_parser() -> argparse.ArgumentParser:
help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
parser.add_argument("--dont-upcast-attention", action="store_true",
help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
fp_group = parser.add_mutually_exclusive_group() fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", fp_group.add_argument("--force-fp32", action="store_true",
help="Force fp32 (If this makes your GPU work better please report it).") help="Force fp32 (If this makes your GPU work better please report it).")
@ -125,6 +122,9 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
upcast = parser.add_mutually_exclusive_group()
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
vram_group = parser.add_mutually_exclusive_group() vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", vram_group.add_argument("--gpu-only", action="store_true",
help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")

View File

@ -37,6 +37,7 @@ class Configuration(dict):
cuda_malloc (bool): Enable cudaMallocAsync. Defaults to True in applicable setups. cuda_malloc (bool): Enable cudaMallocAsync. Defaults to True in applicable setups.
disable_cuda_malloc (bool): Disable cudaMallocAsync. disable_cuda_malloc (bool): Disable cudaMallocAsync.
dont_upcast_attention (bool): Disable upcasting of attention. dont_upcast_attention (bool): Disable upcasting of attention.
force_upcast_attention (bool): Force upcasting of attention.
force_fp32 (bool): Force using FP32 precision. force_fp32 (bool): Force using FP32 precision.
force_fp16 (bool): Force using FP16 precision. force_fp16 (bool): Force using FP16 precision.
bf16_unet (bool): Use BF16 precision for UNet. bf16_unet (bool): Use BF16 precision for UNet.
@ -106,6 +107,7 @@ class Configuration(dict):
self.cuda_malloc: bool = True self.cuda_malloc: bool = True
self.disable_cuda_malloc: bool = False self.disable_cuda_malloc: bool = False
self.dont_upcast_attention: bool = False self.dont_upcast_attention: bool = False
self.force_upcast_attention: bool = False
self.force_fp32: bool = False self.force_fp32: bool = False
self.force_fp16: bool = False self.force_fp16: bool = False
self.bf16_unet: bool = False self.bf16_unet: bool = False

View File

@ -23,6 +23,15 @@ class TransformersManagedModel(ModelManageable):
if model.device != self.offload_device: if model.device != self.offload_device:
model.to(device=self.offload_device) model.to(device=self.offload_device)
@property
def lowvram_patch_counter(self):
return 0
@lowvram_patch_counter.setter
def lowvram_patch_counter(self, value: int):
warnings.warn("Not supported")
pass
load_device: torch.device load_device: torch.device
offload_device: torch.device offload_device: torch.device
model: PreTrainedModel model: PreTrainedModel
@ -57,7 +66,7 @@ class TransformersManagedModel(ModelManageable):
def model_dtype(self) -> torch.dtype: def model_dtype(self) -> torch.dtype:
return self.model.dtype return self.model.dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module: def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights=False) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs") warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=device_to) return self.model.to(device=device_to)

View File

@ -18,13 +18,13 @@ from ...cli_args import args
from ... import ops from ... import ops
ops = ops.disable_weight_init ops = ops.disable_weight_init
# CrossAttn precision handling
if args.dont_upcast_attention:
logging.info("disabling upcasting of attention")
_ATTN_PRECISION = "fp16"
else:
_ATTN_PRECISION = "fp32"
def get_attn_precision(attn_precision):
if args.dont_upcast_attention:
return None
if attn_precision is None and args.force_upcast_attention:
return torch.float32
return attn_precision
def exists(val): def exists(val):
return val is not None return val is not None
@ -84,7 +84,9 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None): def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 scale = dim_head ** -0.5
@ -100,7 +102,7 @@ def attention_basic(q, k, v, heads, mask=None):
) )
# force cast to fp32 to avoid overflowing # force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32": if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else: else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale sim = einsum('b i d, b j d -> b i j', q, k) * scale
@ -134,7 +136,9 @@ def attention_basic(q, k, v, heads, mask=None):
return out return out
def attention_sub_quad(query, key, value, heads, mask=None): def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = query.shape b, _, dim_head = query.shape
dim_head //= heads dim_head //= heads
@ -145,7 +149,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype dtype = query.dtype
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
if upcast_attention: if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8 bytes_per_token = torch.finfo(torch.float32).bits//8
else: else:
@ -194,7 +198,9 @@ def attention_sub_quad(query, key, value, heads, mask=None):
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states return hidden_states
def attention_split(q, k, v, heads, mask=None): def attention_split(q, k, v, heads, mask=None, attn_precision=None):
attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 scale = dim_head ** -0.5
@ -213,10 +219,12 @@ def attention_split(q, k, v, heads, mask=None):
mem_free_total = model_management.get_free_memory(q.device) mem_free_total = model_management.get_free_memory(q.device)
if _ATTN_PRECISION =="fp32": if attn_precision == torch.float32:
element_size = 4 element_size = 4
upcast = True
else: else:
element_size = q.element_size() element_size = q.element_size()
upcast = False
gb = 1024 ** 3 gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
@ -250,7 +258,7 @@ def attention_split(q, k, v, heads, mask=None):
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
if _ATTN_PRECISION =="fp32": if upcast:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else: else:
@ -301,7 +309,7 @@ try:
except: except:
pass pass
def attention_xformers(q, k, v, heads, mask=None): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
if BROKEN_XFORMERS: if BROKEN_XFORMERS:
@ -333,7 +341,7 @@ def attention_xformers(q, k, v, heads, mask=None):
) )
return out return out
def attention_pytorch(q, k, v, heads, mask=None): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
q, k, v = map( q, k, v = map(
@ -383,10 +391,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False):
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
self.attn_precision = attn_precision
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
@ -408,15 +417,15 @@ class CrossAttention(nn.Module):
v = self.to_v(context) v = self.to_v(context)
if mask is None: if mask is None:
out = optimized_attention(q, k, v, self.heads) out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else: else:
out = optimized_attention_masked(q, k, v, self.heads, mask) out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out) return self.to_out(out)
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops): disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
self.ff_in = ff_in or inner_dim is not None self.ff_in = ff_in or inner_dim is not None
@ -424,6 +433,7 @@ class BasicTransformerBlock(nn.Module):
inner_dim = dim inner_dim = dim
self.is_res = inner_dim == dim self.is_res = inner_dim == dim
self.attn_precision = attn_precision
if self.ff_in: if self.ff_in:
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device) self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
@ -431,7 +441,7 @@ class BasicTransformerBlock(nn.Module):
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention: if disable_temporal_crossattention:
@ -445,7 +455,7 @@ class BasicTransformerBlock(nn.Module):
context_dim_attn2 = context_dim context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
@ -475,6 +485,7 @@ class BasicTransformerBlock(nn.Module):
extra_options["n_heads"] = self.n_heads extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head extra_options["dim_head"] = self.d_head
extra_options["attn_precision"] = self.attn_precision
if self.ff_in: if self.ff_in:
x_skip = x x_skip = x
@ -585,7 +596,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False, disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None, operations=ops): use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
if exists(context_dim) and not isinstance(context_dim, list): if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth context_dim = [context_dim] * depth
@ -603,7 +614,7 @@ class SpatialTransformer(nn.Module):
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations) disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
for d in range(depth)] for d in range(depth)]
) )
if not use_linear: if not use_linear:
@ -659,6 +670,7 @@ class SpatialVideoTransformer(SpatialTransformer):
disable_self_attn=False, disable_self_attn=False,
disable_temporal_crossattention=False, disable_temporal_crossattention=False,
max_time_embed_period: int = 10000, max_time_embed_period: int = 10000,
attn_precision=None,
dtype=None, device=None, operations=ops dtype=None, device=None, operations=ops
): ):
super().__init__( super().__init__(
@ -671,6 +683,7 @@ class SpatialVideoTransformer(SpatialTransformer):
context_dim=context_dim, context_dim=context_dim,
use_linear=use_linear, use_linear=use_linear,
disable_self_attn=disable_self_attn, disable_self_attn=disable_self_attn,
attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
self.time_depth = time_depth self.time_depth = time_depth
@ -700,6 +713,7 @@ class SpatialVideoTransformer(SpatialTransformer):
inner_dim=time_mix_inner_dim, inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn, disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention, disable_temporal_crossattention=disable_temporal_crossattention,
attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
for _ in range(self.depth) for _ in range(self.depth)

View File

@ -431,6 +431,7 @@ class UNetModel(nn.Module):
video_kernel_size=None, video_kernel_size=None,
disable_temporal_crossattention=False, disable_temporal_crossattention=False,
max_ddpm_temb_period=10000, max_ddpm_temb_period=10000,
attn_precision=None,
device=None, device=None,
operations=ops, operations=ops,
): ):
@ -550,13 +551,14 @@ class UNetModel(nn.Module):
disable_self_attn=disable_self_attn, disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention, disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period, max_time_embed_period=max_ddpm_temb_period,
attn_precision=attn_precision,
dtype=self.dtype, device=device, operations=operations dtype=self.dtype, device=device, operations=operations
) )
else: else:
return SpatialTransformer( return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim, ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
) )
def get_resblock( def get_resblock(

View File

@ -119,8 +119,8 @@ def get_total_memory(dev=None, torch_total_too=False):
elif is_intel_xpu(): elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev) stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
@ -308,7 +308,7 @@ class LoadedModel:
else: else:
return self.model_memory() return self.model_memory()
def model_load(self, lowvram_model_memory=0): def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
patch_model_to = self.device patch_model_to = self.device
self.model.model_patches_to(self.device) self.model.model_patches_to(self.device)
@ -318,7 +318,7 @@ class LoadedModel:
try: try:
if lowvram_model_memory > 0 and load_weights: if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory) self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else: else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e: except Exception as e:
@ -332,6 +332,11 @@ class LoadedModel:
self.weights_loaded = True self.weights_loaded = True
return self.real_model return self.real_model
def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter > 0:
return True
return False
def model_unload(self, unpatch_weights=True): def model_unload(self, unpatch_weights=True):
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
@ -408,7 +413,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache() soft_empty_cache()
def load_models_gpu(models, memory_required=0): def load_models_gpu(models, memory_required=0, force_patch_weights=False):
global vram_state global vram_state
with model_management_lock: with model_management_lock:
@ -420,12 +425,21 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded = [] models_already_loaded = []
for x in models: for x in models:
loaded_model = LoadedModel(x) loaded_model = LoadedModel(x)
loaded = None
if loaded_model in current_loaded_models: try:
index = current_loaded_models.index(loaded_model) loaded_model_index = current_loaded_models.index(loaded_model)
current_loaded_models.insert(0, current_loaded_models.pop(index)) except ValueError:
models_already_loaded.append(loaded_model) loaded_model_index = None
else:
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
models_already_loaded.append(loaded)
if loaded is None:
if hasattr(x, "model"): if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}") logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model) models_to_load.append(loaded_model)
@ -473,7 +487,7 @@ def load_models_gpu(models, memory_required=0):
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024 lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory) loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)
return return
@ -738,10 +752,10 @@ def get_free_memory(dev=None, torch_free_too=False):
elif is_intel_xpu(): elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev) stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
mem_allocated = stats['allocated_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']

View File

@ -38,7 +38,7 @@ class ModelManageable(Protocol):
def model_dtype(self) -> torch.dtype: def model_dtype(self) -> torch.dtype:
... ...
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module: def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
... ...
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module: def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
@ -46,3 +46,7 @@ class ModelManageable(Protocol):
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
... ...
@property
def lowvram_patch_counter(self) -> int:
...

View File

@ -19,7 +19,7 @@ def apply_weight_decompose(dora_scale, weight):
.transpose(0, 1) .transpose(0, 1)
) )
return weight * (dora_scale / weight_norm) return weight * (dora_scale / weight_norm).type(weight.dtype)
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
@ -65,6 +65,15 @@ class ModelPatcher(ModelManageable):
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False self.model_lowvram = False
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
self._lowvram_patch_counter = 0
@property
def lowvram_patch_counter(self):
return self._lowvram_patch_counter
@lowvram_patch_counter.setter
def lowvram_patch_counter(self, value: int):
self._lowvram_patch_counter = value
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
@ -278,7 +287,7 @@ class ModelPatcher(ModelManageable):
return self.model return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0): def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False) self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024))) logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
@ -292,6 +301,7 @@ class ModelPatcher(ModelManageable):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0 mem_counter = 0
patch_counter = 0
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
lowvram_weight = False lowvram_weight = False
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
@ -304,9 +314,17 @@ class ModelPatcher(ModelManageable):
if lowvram_weight: if lowvram_weight:
if weight_key in self.patches: if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self) if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self) if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
@ -319,6 +337,7 @@ class ModelPatcher(ModelManageable):
logging.debug("lowvram: loaded module regularly {}".format(m)) logging.debug("lowvram: loaded module regularly {}".format(m))
self.model_lowvram = True self.model_lowvram = True
self.lowvram_patch_counter = patch_counter
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
@ -470,6 +489,7 @@ class ModelPatcher(ModelManageable):
m.bias_function = None m.bias_function = None
self.model_lowvram = False self.model_lowvram = False
self.lowvram_patch_counter = 0
keys = list(self.backup.keys()) keys = list(self.backup.keys())

View File

@ -1464,6 +1464,9 @@ class LoadImage:
output_images = [] output_images = []
output_masks = [] output_masks = []
w, h = None, None
excluded_formats = ['MPO']
# maintain the legacy path # maintain the legacy path
# this will ultimately return a tensor, so we'd rather have the tensors directly # this will ultimately return a tensor, so we'd rather have the tensors directly
@ -1478,6 +1481,14 @@ class LoadImage:
if i.mode == 'I': if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255)) i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB") image = i.convert("RGB")
if len(output_images) == 0:
w = image.size[0]
h = image.size[1]
if image.size[0] != w or image.size[1] != h:
continue
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,] image = torch.from_numpy(image)[None,]
if 'A' in i.getbands(): if 'A' in i.getbands():
@ -1488,14 +1499,14 @@ class LoadImage:
output_images.append(image) output_images.append(image)
output_masks.append(mask.unsqueeze(0)) output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1: if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0) output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0) output_mask = torch.cat(output_masks, dim=0)
else: else:
output_image = output_images[0] output_image = output_images[0]
output_mask = output_masks[0] output_mask = output_masks[0]
return output_image, output_mask return (output_image, output_mask)
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):

View File

@ -582,7 +582,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
load_models.append(clip.load_model()) load_models.append(clip.load_model())
clip_sd = clip.get_sd() clip_sd = clip.get_sd()
model_management.load_models_gpu(load_models) model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
for k in extra_keys: for k in extra_keys:

View File

@ -65,6 +65,12 @@ class SD20(supported_models_base.BASE):
"use_temporal_attention": False, "use_temporal_attention": False,
} }
unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
"attn_precision": torch.float32,
}
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
@ -276,6 +282,12 @@ class SVD_img2vid(supported_models_base.BASE):
"use_temporal_resblock": True "use_temporal_resblock": True
} }
unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
"attn_precision": torch.float32,
}
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual." clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15

View File

@ -263,6 +263,36 @@ export class ComfyApp {
); );
} }
#addRestoreWorkflowView() {
const serialize = LGraph.prototype.serialize;
const self = this;
LGraph.prototype.serialize = function() {
const workflow = serialize.apply(this, arguments);
// Store the drag & scale info in the serialized workflow if the setting is enabled
if (self.enableWorkflowViewRestore.value) {
if (!workflow.extra) {
workflow.extra = {};
}
workflow.extra.ds = {
scale: self.canvas.ds.scale,
offset: self.canvas.ds.offset,
};
} else if (workflow.extra?.ds) {
// Clear any old view data
delete workflow.extra.ds;
}
return workflow;
}
this.enableWorkflowViewRestore = this.ui.settings.addSetting({
id: "Comfy.EnableWorkflowViewRestore",
name: "Save and restore canvas position and zoom level in workflows",
type: "boolean",
defaultValue: true
});
}
/** /**
* Adds special context menu handling for nodes * Adds special context menu handling for nodes
* e.g. this adds Open Image functionality for nodes that show images * e.g. this adds Open Image functionality for nodes that show images
@ -1505,6 +1535,7 @@ export class ComfyApp {
this.#addProcessKeyHandler(); this.#addProcessKeyHandler();
this.#addConfigureHandler(); this.#addConfigureHandler();
this.#addApiUpdateHandlers(); this.#addApiUpdateHandlers();
this.#addRestoreWorkflowView();
this.graph = new LGraph(); this.graph = new LGraph();
@ -1805,6 +1836,10 @@ export class ComfyApp {
try { try {
this.graph.configure(graphData); this.graph.configure(graphData);
if (this.enableWorkflowViewRestore.value && graphData.extra?.ds) {
this.canvas.ds.offset = graphData.extra.ds.offset;
this.canvas.ds.scale = graphData.extra.ds.scale;
}
} catch (error) { } catch (error) {
let errorHint = []; let errorHint = [];
// Try extracting filename to see if it was caused by an extension script // Try extracting filename to see if it was caused by an extension script
@ -2122,6 +2157,14 @@ export class ComfyApp {
api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } })); api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } }));
} }
showErrorOnFileLoad(file) {
this.ui.dialog.show(
$el("div", [
$el("p", {textContent: `Unable to find workflow in ${file.name}`})
]).outerHTML
);
}
/** /**
* Loads workflow data from the specified file * Loads workflow data from the specified file
* @param {File} file * @param {File} file
@ -2129,27 +2172,27 @@ export class ComfyApp {
async handleFile(file) { async handleFile(file) {
if (file.type === "image/png") { if (file.type === "image/png") {
const pngInfo = await getPngMetadata(file); const pngInfo = await getPngMetadata(file);
if (pngInfo) { if (pngInfo?.workflow) {
if (pngInfo.workflow) { await this.loadGraphData(JSON.parse(pngInfo.workflow));
await this.loadGraphData(JSON.parse(pngInfo.workflow)); } else if (pngInfo?.prompt) {
} else if (pngInfo.prompt) { this.loadApiJson(JSON.parse(pngInfo.prompt));
this.loadApiJson(JSON.parse(pngInfo.prompt)); } else if (pngInfo?.parameters) {
} else if (pngInfo.parameters) { importA1111(this.graph, pngInfo.parameters);
importA1111(this.graph, pngInfo.parameters); } else {
} this.showErrorOnFileLoad(file);
} }
} else if (file.type === "image/webp") { } else if (file.type === "image/webp") {
const pngInfo = await getWebpMetadata(file); const pngInfo = await getWebpMetadata(file);
if (pngInfo) { // Support loading workflows from that webp custom node.
if (pngInfo.workflow) { const workflow = pngInfo?.workflow || pngInfo?.Workflow;
this.loadGraphData(JSON.parse(pngInfo.workflow)); const prompt = pngInfo?.prompt || pngInfo?.Prompt;
} else if (pngInfo.Workflow) {
this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node. if (workflow) {
} else if (pngInfo.prompt) { this.loadGraphData(JSON.parse(workflow));
this.loadApiJson(JSON.parse(pngInfo.prompt)); } else if (prompt) {
} else if (pngInfo.Prompt) { this.loadApiJson(JSON.parse(prompt));
this.loadApiJson(JSON.parse(pngInfo.Prompt)); // Support loading prompts from that webp custom node. } else {
} this.showErrorOnFileLoad(file);
} }
} else if (file.type === "application/json" || file.name?.endsWith(".json")) { } else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader(); const reader = new FileReader();
@ -2170,7 +2213,11 @@ export class ComfyApp {
await this.loadGraphData(JSON.parse(info.workflow)); await this.loadGraphData(JSON.parse(info.workflow));
} else if (info.prompt) { } else if (info.prompt) {
this.loadApiJson(JSON.parse(info.prompt)); this.loadApiJson(JSON.parse(info.prompt));
} else {
this.showErrorOnFileLoad(file);
} }
} else {
this.showErrorOnFileLoad(file);
} }
} }
@ -2278,6 +2325,12 @@ export class ComfyApp {
await this.#invokeExtensionsAsync("refreshComboInNodes", defs); await this.#invokeExtensionsAsync("refreshComboInNodes", defs);
} }
resetView() {
app.canvas.ds.scale = 1;
app.canvas.ds.offset = [0, 0]
app.graph.setDirtyCanvas(true, true);
}
/** /**
* Clean current state * Clean current state
*/ */

View File

@ -597,16 +597,23 @@ export class ComfyUI {
if (!confirmClear.value || confirm("Clear workflow?")) { if (!confirmClear.value || confirm("Clear workflow?")) {
app.clean(); app.clean();
app.graph.clear(); app.graph.clear();
app.resetView();
} }
} }
}), }),
$el("button", { $el("button", {
id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => { id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => {
if (!confirmClear.value || confirm("Load default workflow?")) { if (!confirmClear.value || confirm("Load default workflow?")) {
app.resetView();
await app.loadGraphData() await app.loadGraphData()
} }
} }
}), }),
$el("button", {
id: "comfy-reset-view-button", textContent: "Reset View", onclick: async () => {
app.resetView();
}
}),
]); ]);
const devMode = this.settings.addSetting({ const devMode = this.settings.addSetting({

View File

@ -174,9 +174,14 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
enable_modelspec = True enable_modelspec = True
if isinstance(model.model, model_base.SDXL): if isinstance(model.model, model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" if isinstance(model.model, model_base.SDXL_instructpix2pix):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-edit"
else:
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, model_base.SDXLRefiner): elif isinstance(model.model, model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
elif isinstance(model.model, model_base.SVD_img2vid):
metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1"
else: else:
enable_modelspec = False enable_modelspec = False
@ -261,7 +266,7 @@ class CLIPSave:
for x in extra_pnginfo: for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x]) metadata[x] = json.dumps(extra_pnginfo[x])
model_management.load_models_gpu([clip.load_model()]) model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd() clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", ""]: for prefix in ["clip_l.", "clip_g.", ""]:

View File

@ -5,12 +5,12 @@ import math
from einops import rearrange, repeat from einops import rearrange, repeat
import os import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION from comfy.ldm.modules.attention import optimized_attention
from comfy import samplers from comfy import samplers
# from comfy/ldm/modules/attention.py # from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output # but modified to return attention scores as well as output
def attention_basic_with_sim(q, k, v, heads, mask=None): def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 scale = dim_head ** -0.5
@ -26,7 +26,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
) )
# force cast to fp32 to avoid overflowing # force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32": if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else: else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale sim = einsum('b i d, b j d -> b i j', q, k) * scale
@ -121,13 +121,13 @@ class SelfAttentionGuidance:
if 1 in cond_or_uncond: if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1) uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores # do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads) (out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out return out
else: else:
return optimized_attention(q, k, v, heads=heads) return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
def post_cfg_function(args): def post_cfg_function(args):
nonlocal attn_scores nonlocal attn_scores

View File

@ -1,6 +1,7 @@
import os import os
import pathlib import pathlib
import re import re
import sys
import uuid import uuid
from datetime import datetime from datetime import datetime
@ -225,6 +226,7 @@ def test_image_exif_merge():
@freeze_time("2024-01-14 03:21:34", tz_offset=-4) @freeze_time("2024-01-14 03:21:34", tz_offset=-4)
@pytest.mark.skipif(sys.platform == 'win32')
def test_image_exif_creation_date_and_batch_number(): def test_image_exif_creation_date_and_batch_number():
assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None
n = ImageExifCreationDateAndBatchNumber() n = ImageExifCreationDateAndBatchNumber()
@ -264,7 +266,7 @@ def test_file_request_parameter(use_temporary_input_directory):
image.save(image_path) image.save(image_path)
n = ImageRequestParameter() n = ImageRequestParameter()
loaded_image, = n.execute(uri=image_path) loaded_image, = n.execute(value=image_path)
assert loaded_image.shape == (1, 1, 1, 3) assert loaded_image.shape == (1, 1, 1, 3)
from comfy.nodes.base_nodes import LoadImage from comfy.nodes.base_nodes import LoadImage