Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-05-16 14:28:49 -07:00
commit 3d98440fb7
17 changed files with 232 additions and 77 deletions

View File

@ -7,7 +7,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "121"
default: "124"
python_minor:
description: 'python minor version'
@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "2"
default: "3"
# push:
# branches:
# - master
@ -49,7 +49,7 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
python -m pip wheel torch torchvision mpmath==1.3.0 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
python -m pip wheel torch torchvision mpmath==1.3.0 numpy==1.26.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth

View File

@ -76,9 +76,6 @@ def create_parser() -> argparse.ArgumentParser:
help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
parser.add_argument("--dont-upcast-attention", action="store_true",
help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true",
help="Force fp32 (If this makes your GPU work better please report it).")
@ -125,6 +122,9 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
upcast = parser.add_mutually_exclusive_group()
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true",
help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")

View File

@ -37,6 +37,7 @@ class Configuration(dict):
cuda_malloc (bool): Enable cudaMallocAsync. Defaults to True in applicable setups.
disable_cuda_malloc (bool): Disable cudaMallocAsync.
dont_upcast_attention (bool): Disable upcasting of attention.
force_upcast_attention (bool): Force upcasting of attention.
force_fp32 (bool): Force using FP32 precision.
force_fp16 (bool): Force using FP16 precision.
bf16_unet (bool): Use BF16 precision for UNet.
@ -106,6 +107,7 @@ class Configuration(dict):
self.cuda_malloc: bool = True
self.disable_cuda_malloc: bool = False
self.dont_upcast_attention: bool = False
self.force_upcast_attention: bool = False
self.force_fp32: bool = False
self.force_fp16: bool = False
self.bf16_unet: bool = False

View File

@ -23,6 +23,15 @@ class TransformersManagedModel(ModelManageable):
if model.device != self.offload_device:
model.to(device=self.offload_device)
@property
def lowvram_patch_counter(self):
return 0
@lowvram_patch_counter.setter
def lowvram_patch_counter(self, value: int):
warnings.warn("Not supported")
pass
load_device: torch.device
offload_device: torch.device
model: PreTrainedModel
@ -57,7 +66,7 @@ class TransformersManagedModel(ModelManageable):
def model_dtype(self) -> torch.dtype:
return self.model.dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module:
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights=False) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=device_to)

View File

@ -18,13 +18,13 @@ from ...cli_args import args
from ... import ops
ops = ops.disable_weight_init
# CrossAttn precision handling
if args.dont_upcast_attention:
logging.info("disabling upcasting of attention")
_ATTN_PRECISION = "fp16"
else:
_ATTN_PRECISION = "fp32"
def get_attn_precision(attn_precision):
if args.dont_upcast_attention:
return None
if attn_precision is None and args.force_upcast_attention:
return torch.float32
return attn_precision
def exists(val):
return val is not None
@ -84,7 +84,9 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None):
def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
@ -100,7 +102,7 @@ def attention_basic(q, k, v, heads, mask=None):
)
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
@ -134,7 +136,9 @@ def attention_basic(q, k, v, heads, mask=None):
return out
def attention_sub_quad(query, key, value, heads, mask=None):
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = query.shape
dim_head //= heads
@ -145,7 +149,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8
else:
@ -194,7 +198,9 @@ def attention_sub_quad(query, key, value, heads, mask=None):
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None):
def attention_split(q, k, v, heads, mask=None, attn_precision=None):
attn_precision = get_attn_precision(attn_precision)
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
@ -213,10 +219,12 @@ def attention_split(q, k, v, heads, mask=None):
mem_free_total = model_management.get_free_memory(q.device)
if _ATTN_PRECISION =="fp32":
if attn_precision == torch.float32:
element_size = 4
upcast = True
else:
element_size = q.element_size()
upcast = False
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
@ -250,7 +258,7 @@ def attention_split(q, k, v, heads, mask=None):
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
if _ATTN_PRECISION =="fp32":
if upcast:
with torch.autocast(enabled=False, device_type = 'cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else:
@ -301,7 +309,7 @@ try:
except:
pass
def attention_xformers(q, k, v, heads, mask=None):
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
if BROKEN_XFORMERS:
@ -333,7 +341,7 @@ def attention_xformers(q, k, v, heads, mask=None):
)
return out
def attention_pytorch(q, k, v, heads, mask=None):
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
@ -383,10 +391,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False):
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.attn_precision = attn_precision
self.heads = heads
self.dim_head = dim_head
@ -408,15 +417,15 @@ class CrossAttention(nn.Module):
v = self.to_v(context)
if mask is None:
out = optimized_attention(q, k, v, self.heads)
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
out = optimized_attention_masked(q, k, v, self.heads, mask)
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
self.ff_in = ff_in or inner_dim is not None
@ -424,6 +433,7 @@ class BasicTransformerBlock(nn.Module):
inner_dim = dim
self.is_res = inner_dim == dim
self.attn_precision = attn_precision
if self.ff_in:
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
@ -431,7 +441,7 @@ class BasicTransformerBlock(nn.Module):
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention:
@ -445,7 +455,7 @@ class BasicTransformerBlock(nn.Module):
context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
@ -475,6 +485,7 @@ class BasicTransformerBlock(nn.Module):
extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head
extra_options["attn_precision"] = self.attn_precision
if self.ff_in:
x_skip = x
@ -585,7 +596,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None, operations=ops):
use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
@ -603,7 +614,7 @@ class SpatialTransformer(nn.Module):
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
for d in range(depth)]
)
if not use_linear:
@ -659,6 +670,7 @@ class SpatialVideoTransformer(SpatialTransformer):
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
attn_precision=None,
dtype=None, device=None, operations=ops
):
super().__init__(
@ -671,6 +683,7 @@ class SpatialVideoTransformer(SpatialTransformer):
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
self.time_depth = time_depth
@ -700,6 +713,7 @@ class SpatialVideoTransformer(SpatialTransformer):
inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
for _ in range(self.depth)

View File

@ -431,6 +431,7 @@ class UNetModel(nn.Module):
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
attn_precision=None,
device=None,
operations=ops,
):
@ -550,13 +551,14 @@ class UNetModel(nn.Module):
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
attn_precision=attn_precision,
dtype=self.dtype, device=device, operations=operations
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
)
def get_resblock(

View File

@ -119,8 +119,8 @@ def get_total_memory(dev=None, torch_total_too=False):
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@ -308,7 +308,7 @@ class LoadedModel:
else:
return self.model_memory()
def model_load(self, lowvram_model_memory=0):
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
patch_model_to = self.device
self.model.model_patches_to(self.device)
@ -318,7 +318,7 @@ class LoadedModel:
try:
if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e:
@ -332,6 +332,11 @@ class LoadedModel:
self.weights_loaded = True
return self.real_model
def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter > 0:
return True
return False
def model_unload(self, unpatch_weights=True):
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
@ -408,7 +413,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache()
def load_models_gpu(models, memory_required=0):
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
global vram_state
with model_management_lock:
@ -420,12 +425,21 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
loaded = None
if loaded_model in current_loaded_models:
index = current_loaded_models.index(loaded_model)
current_loaded_models.insert(0, current_loaded_models.pop(index))
models_already_loaded.append(loaded_model)
else:
try:
loaded_model_index = current_loaded_models.index(loaded_model)
except ValueError:
loaded_model_index = None
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
models_already_loaded.append(loaded)
if loaded is None:
if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
@ -473,7 +487,7 @@ def load_models_gpu(models, memory_required=0):
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
return
@ -738,10 +752,10 @@ def get_free_memory(dev=None, torch_free_too=False):
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_allocated = stats['allocated_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']

View File

@ -38,7 +38,7 @@ class ModelManageable(Protocol):
def model_dtype(self) -> torch.dtype:
...
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module:
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
...
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
@ -46,3 +46,7 @@ class ModelManageable(Protocol):
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
...
@property
def lowvram_patch_counter(self) -> int:
...

View File

@ -19,7 +19,7 @@ def apply_weight_decompose(dora_scale, weight):
.transpose(0, 1)
)
return weight * (dora_scale / weight_norm)
return weight * (dora_scale / weight_norm).type(weight.dtype)
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
@ -65,6 +65,15 @@ class ModelPatcher(ModelManageable):
self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.patches_uuid = uuid.uuid4()
self._lowvram_patch_counter = 0
@property
def lowvram_patch_counter(self):
return self._lowvram_patch_counter
@lowvram_patch_counter.setter
def lowvram_patch_counter(self, value: int):
self._lowvram_patch_counter = value
def model_size(self):
if self.size > 0:
@ -278,7 +287,7 @@ class ModelPatcher(ModelManageable):
return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
@ -292,6 +301,7 @@ class ModelPatcher(ModelManageable):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0
patch_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
@ -304,9 +314,17 @@ class ModelPatcher(ModelManageable):
if lowvram_weight:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
@ -319,6 +337,7 @@ class ModelPatcher(ModelManageable):
logging.debug("lowvram: loaded module regularly {}".format(m))
self.model_lowvram = True
self.lowvram_patch_counter = patch_counter
return self.model
def calculate_weight(self, patches, weight, key):
@ -470,6 +489,7 @@ class ModelPatcher(ModelManageable):
m.bias_function = None
self.model_lowvram = False
self.lowvram_patch_counter = 0
keys = list(self.backup.keys())

View File

@ -1464,6 +1464,9 @@ class LoadImage:
output_images = []
output_masks = []
w, h = None, None
excluded_formats = ['MPO']
# maintain the legacy path
# this will ultimately return a tensor, so we'd rather have the tensors directly
@ -1478,6 +1481,14 @@ class LoadImage:
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if len(output_images) == 0:
w = image.size[0]
h = image.size[1]
if image.size[0] != w or image.size[1] != h:
continue
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
@ -1488,14 +1499,14 @@ class LoadImage:
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1:
if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return output_image, output_mask
return (output_image, output_mask)
@classmethod
def IS_CHANGED(s, image):

View File

@ -582,7 +582,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
model_management.load_models_gpu(load_models)
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
for k in extra_keys:

View File

@ -65,6 +65,12 @@ class SD20(supported_models_base.BASE):
"use_temporal_attention": False,
}
unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
"attn_precision": torch.float32,
}
latent_format = latent_formats.SD15
def model_type(self, state_dict, prefix=""):
@ -276,6 +282,12 @@ class SVD_img2vid(supported_models_base.BASE):
"use_temporal_resblock": True
}
unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
"attn_precision": torch.float32,
}
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15

View File

@ -262,6 +262,36 @@ export class ComfyApp {
})
);
}
#addRestoreWorkflowView() {
const serialize = LGraph.prototype.serialize;
const self = this;
LGraph.prototype.serialize = function() {
const workflow = serialize.apply(this, arguments);
// Store the drag & scale info in the serialized workflow if the setting is enabled
if (self.enableWorkflowViewRestore.value) {
if (!workflow.extra) {
workflow.extra = {};
}
workflow.extra.ds = {
scale: self.canvas.ds.scale,
offset: self.canvas.ds.offset,
};
} else if (workflow.extra?.ds) {
// Clear any old view data
delete workflow.extra.ds;
}
return workflow;
}
this.enableWorkflowViewRestore = this.ui.settings.addSetting({
id: "Comfy.EnableWorkflowViewRestore",
name: "Save and restore canvas position and zoom level in workflows",
type: "boolean",
defaultValue: true
});
}
/**
* Adds special context menu handling for nodes
@ -1505,6 +1535,7 @@ export class ComfyApp {
this.#addProcessKeyHandler();
this.#addConfigureHandler();
this.#addApiUpdateHandlers();
this.#addRestoreWorkflowView();
this.graph = new LGraph();
@ -1805,6 +1836,10 @@ export class ComfyApp {
try {
this.graph.configure(graphData);
if (this.enableWorkflowViewRestore.value && graphData.extra?.ds) {
this.canvas.ds.offset = graphData.extra.ds.offset;
this.canvas.ds.scale = graphData.extra.ds.scale;
}
} catch (error) {
let errorHint = [];
// Try extracting filename to see if it was caused by an extension script
@ -2122,6 +2157,14 @@ export class ComfyApp {
api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } }));
}
showErrorOnFileLoad(file) {
this.ui.dialog.show(
$el("div", [
$el("p", {textContent: `Unable to find workflow in ${file.name}`})
]).outerHTML
);
}
/**
* Loads workflow data from the specified file
* @param {File} file
@ -2129,27 +2172,27 @@ export class ComfyApp {
async handleFile(file) {
if (file.type === "image/png") {
const pngInfo = await getPngMetadata(file);
if (pngInfo) {
if (pngInfo.workflow) {
await this.loadGraphData(JSON.parse(pngInfo.workflow));
} else if (pngInfo.prompt) {
this.loadApiJson(JSON.parse(pngInfo.prompt));
} else if (pngInfo.parameters) {
importA1111(this.graph, pngInfo.parameters);
}
if (pngInfo?.workflow) {
await this.loadGraphData(JSON.parse(pngInfo.workflow));
} else if (pngInfo?.prompt) {
this.loadApiJson(JSON.parse(pngInfo.prompt));
} else if (pngInfo?.parameters) {
importA1111(this.graph, pngInfo.parameters);
} else {
this.showErrorOnFileLoad(file);
}
} else if (file.type === "image/webp") {
const pngInfo = await getWebpMetadata(file);
if (pngInfo) {
if (pngInfo.workflow) {
this.loadGraphData(JSON.parse(pngInfo.workflow));
} else if (pngInfo.Workflow) {
this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node.
} else if (pngInfo.prompt) {
this.loadApiJson(JSON.parse(pngInfo.prompt));
} else if (pngInfo.Prompt) {
this.loadApiJson(JSON.parse(pngInfo.Prompt)); // Support loading prompts from that webp custom node.
}
// Support loading workflows from that webp custom node.
const workflow = pngInfo?.workflow || pngInfo?.Workflow;
const prompt = pngInfo?.prompt || pngInfo?.Prompt;
if (workflow) {
this.loadGraphData(JSON.parse(workflow));
} else if (prompt) {
this.loadApiJson(JSON.parse(prompt));
} else {
this.showErrorOnFileLoad(file);
}
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader();
@ -2170,7 +2213,11 @@ export class ComfyApp {
await this.loadGraphData(JSON.parse(info.workflow));
} else if (info.prompt) {
this.loadApiJson(JSON.parse(info.prompt));
} else {
this.showErrorOnFileLoad(file);
}
} else {
this.showErrorOnFileLoad(file);
}
}
@ -2278,6 +2325,12 @@ export class ComfyApp {
await this.#invokeExtensionsAsync("refreshComboInNodes", defs);
}
resetView() {
app.canvas.ds.scale = 1;
app.canvas.ds.offset = [0, 0]
app.graph.setDirtyCanvas(true, true);
}
/**
* Clean current state
*/

View File

@ -597,16 +597,23 @@ export class ComfyUI {
if (!confirmClear.value || confirm("Clear workflow?")) {
app.clean();
app.graph.clear();
app.resetView();
}
}
}),
$el("button", {
id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => {
if (!confirmClear.value || confirm("Load default workflow?")) {
app.resetView();
await app.loadGraphData()
}
}
}),
$el("button", {
id: "comfy-reset-view-button", textContent: "Reset View", onclick: async () => {
app.resetView();
}
}),
]);
const devMode = this.settings.addSetting({

View File

@ -174,9 +174,14 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
enable_modelspec = True
if isinstance(model.model, model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
if isinstance(model.model, model_base.SDXL_instructpix2pix):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-edit"
else:
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
elif isinstance(model.model, model_base.SVD_img2vid):
metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1"
else:
enable_modelspec = False
@ -261,7 +266,7 @@ class CLIPSave:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
model_management.load_models_gpu([clip.load_model()])
model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", ""]:

View File

@ -5,12 +5,12 @@ import math
from einops import rearrange, repeat
import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
from comfy.ldm.modules.attention import optimized_attention
from comfy import samplers
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
def attention_basic_with_sim(q, k, v, heads, mask=None):
def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
@ -26,7 +26,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
)
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
@ -121,13 +121,13 @@ class SelfAttentionGuidance:
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out
else:
return optimized_attention(q, k, v, heads=heads)
return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
def post_cfg_function(args):
nonlocal attn_scores

View File

@ -1,6 +1,7 @@
import os
import pathlib
import re
import sys
import uuid
from datetime import datetime
@ -225,6 +226,7 @@ def test_image_exif_merge():
@freeze_time("2024-01-14 03:21:34", tz_offset=-4)
@pytest.mark.skipif(sys.platform == 'win32')
def test_image_exif_creation_date_and_batch_number():
assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None
n = ImageExifCreationDateAndBatchNumber()
@ -264,7 +266,7 @@ def test_file_request_parameter(use_temporary_input_directory):
image.save(image_path)
n = ImageRequestParameter()
loaded_image, = n.execute(uri=image_path)
loaded_image, = n.execute(value=image_path)
assert loaded_image.shape == (1, 1, 1, 3)
from comfy.nodes.base_nodes import LoadImage