diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 506ab541b..d06d0c6a8 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -7,7 +7,7 @@ on: description: 'cuda version' required: true type: string - default: "121" + default: "124" python_minor: description: 'python minor version' @@ -19,7 +19,7 @@ on: description: 'python patch version' required: true type: string - default: "2" + default: "3" # push: # branches: # - master @@ -49,7 +49,7 @@ jobs: echo 'import site' >> ./python3${{ inputs.python_minor }}._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision mpmath==1.3.0 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision mpmath==1.3.0 numpy==1.26.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 8014e3349..ef57a50c7 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -76,9 +76,6 @@ def create_parser() -> argparse.ArgumentParser: help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") - parser.add_argument("--dont-upcast-attention", action="store_true", - help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") - fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") @@ -125,6 +122,9 @@ def create_parser() -> argparse.ArgumentParser: parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") + upcast = parser.add_mutually_exclusive_group() + upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.") + upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.") vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 7187934c2..974663f06 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -37,6 +37,7 @@ class Configuration(dict): cuda_malloc (bool): Enable cudaMallocAsync. Defaults to True in applicable setups. disable_cuda_malloc (bool): Disable cudaMallocAsync. dont_upcast_attention (bool): Disable upcasting of attention. + force_upcast_attention (bool): Force upcasting of attention. force_fp32 (bool): Force using FP32 precision. force_fp16 (bool): Force using FP16 precision. bf16_unet (bool): Use BF16 precision for UNet. @@ -106,6 +107,7 @@ class Configuration(dict): self.cuda_malloc: bool = True self.disable_cuda_malloc: bool = False self.dont_upcast_attention: bool = False + self.force_upcast_attention: bool = False self.force_fp32: bool = False self.force_fp16: bool = False self.bf16_unet: bool = False diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index b1df17741..55b47d81b 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -23,6 +23,15 @@ class TransformersManagedModel(ModelManageable): if model.device != self.offload_device: model.to(device=self.offload_device) + @property + def lowvram_patch_counter(self): + return 0 + + @lowvram_patch_counter.setter + def lowvram_patch_counter(self, value: int): + warnings.warn("Not supported") + pass + load_device: torch.device offload_device: torch.device model: PreTrainedModel @@ -57,7 +66,7 @@ class TransformersManagedModel(ModelManageable): def model_dtype(self) -> torch.dtype: return self.model.dtype - def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module: + def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights=False) -> torch.nn.Module: warnings.warn("Transformers models do not currently support adapters like LoRAs") return self.model.to(device=device_to) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 1795a2500..efeb1c16f 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -18,13 +18,13 @@ from ...cli_args import args from ... import ops ops = ops.disable_weight_init -# CrossAttn precision handling -if args.dont_upcast_attention: - logging.info("disabling upcasting of attention") - _ATTN_PRECISION = "fp16" -else: - _ATTN_PRECISION = "fp32" +def get_attn_precision(attn_precision): + if args.dont_upcast_attention: + return None + if attn_precision is None and args.force_upcast_attention: + return torch.float32 + return attn_precision def exists(val): return val is not None @@ -84,7 +84,9 @@ class FeedForward(nn.Module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) -def attention_basic(q, k, v, heads, mask=None): +def attention_basic(q, k, v, heads, mask=None, attn_precision=None): + attn_precision = get_attn_precision(attn_precision) + b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 @@ -100,7 +102,7 @@ def attention_basic(q, k, v, heads, mask=None): ) # force cast to fp32 to avoid overflowing - if _ATTN_PRECISION =="fp32": + if attn_precision == torch.float32: sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale @@ -134,7 +136,9 @@ def attention_basic(q, k, v, heads, mask=None): return out -def attention_sub_quad(query, key, value, heads, mask=None): +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None): + attn_precision = get_attn_precision(attn_precision) + b, _, dim_head = query.shape dim_head //= heads @@ -145,7 +149,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) dtype = query.dtype - upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 + upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32 if upcast_attention: bytes_per_token = torch.finfo(torch.float32).bits//8 else: @@ -194,7 +198,9 @@ def attention_sub_quad(query, key, value, heads, mask=None): hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states -def attention_split(q, k, v, heads, mask=None): +def attention_split(q, k, v, heads, mask=None, attn_precision=None): + attn_precision = get_attn_precision(attn_precision) + b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 @@ -213,10 +219,12 @@ def attention_split(q, k, v, heads, mask=None): mem_free_total = model_management.get_free_memory(q.device) - if _ATTN_PRECISION =="fp32": + if attn_precision == torch.float32: element_size = 4 + upcast = True else: element_size = q.element_size() + upcast = False gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size @@ -250,7 +258,7 @@ def attention_split(q, k, v, heads, mask=None): slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size - if _ATTN_PRECISION =="fp32": + if upcast: with torch.autocast(enabled=False, device_type = 'cuda'): s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale else: @@ -301,7 +309,7 @@ try: except: pass -def attention_xformers(q, k, v, heads, mask=None): +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads if BROKEN_XFORMERS: @@ -333,7 +341,7 @@ def attention_xformers(q, k, v, heads, mask=None): ) return out -def attention_pytorch(q, k, v, heads, mask=None): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads q, k, v = map( @@ -383,10 +391,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False): class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) + self.attn_precision = attn_precision self.heads = heads self.dim_head = dim_head @@ -408,15 +417,15 @@ class CrossAttention(nn.Module): v = self.to_v(context) if mask is None: - out = optimized_attention(q, k, v, self.heads) + out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) else: - out = optimized_attention_masked(q, k, v, self.heads, mask) + out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) return self.to_out(out) class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, - disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops): + disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops): super().__init__() self.ff_in = ff_in or inner_dim is not None @@ -424,6 +433,7 @@ class BasicTransformerBlock(nn.Module): inner_dim = dim self.is_res = inner_dim == dim + self.attn_precision = attn_precision if self.ff_in: self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device) @@ -431,7 +441,7 @@ class BasicTransformerBlock(nn.Module): self.disable_self_attn = disable_self_attn self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn + context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) if disable_temporal_crossattention: @@ -445,7 +455,7 @@ class BasicTransformerBlock(nn.Module): context_dim_attn2 = context_dim self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, - heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none + heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) @@ -475,6 +485,7 @@ class BasicTransformerBlock(nn.Module): extra_options["n_heads"] = self.n_heads extra_options["dim_head"] = self.d_head + extra_options["attn_precision"] = self.attn_precision if self.ff_in: x_skip = x @@ -585,7 +596,7 @@ class SpatialTransformer(nn.Module): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, - use_checkpoint=True, dtype=None, device=None, operations=ops): + use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] * depth @@ -603,7 +614,7 @@ class SpatialTransformer(nn.Module): self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations) + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) for d in range(depth)] ) if not use_linear: @@ -659,6 +670,7 @@ class SpatialVideoTransformer(SpatialTransformer): disable_self_attn=False, disable_temporal_crossattention=False, max_time_embed_period: int = 10000, + attn_precision=None, dtype=None, device=None, operations=ops ): super().__init__( @@ -671,6 +683,7 @@ class SpatialVideoTransformer(SpatialTransformer): context_dim=context_dim, use_linear=use_linear, disable_self_attn=disable_self_attn, + attn_precision=attn_precision, dtype=dtype, device=device, operations=operations ) self.time_depth = time_depth @@ -700,6 +713,7 @@ class SpatialVideoTransformer(SpatialTransformer): inner_dim=time_mix_inner_dim, disable_self_attn=disable_self_attn, disable_temporal_crossattention=disable_temporal_crossattention, + attn_precision=attn_precision, dtype=dtype, device=device, operations=operations ) for _ in range(self.depth) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index f7925478e..e79114769 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -431,6 +431,7 @@ class UNetModel(nn.Module): video_kernel_size=None, disable_temporal_crossattention=False, max_ddpm_temb_period=10000, + attn_precision=None, device=None, operations=ops, ): @@ -550,13 +551,14 @@ class UNetModel(nn.Module): disable_self_attn=disable_self_attn, disable_temporal_crossattention=disable_temporal_crossattention, max_time_embed_period=max_ddpm_temb_period, + attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations ) else: return SpatialTransformer( ch, num_heads, dim_head, depth=depth, context_dim=context_dim, disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations ) def get_resblock( diff --git a/comfy/model_management.py b/comfy/model_management.py index 5de03e571..803aa8a2a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -119,8 +119,8 @@ def get_total_memory(dev=None, torch_total_too=False): elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] - mem_total = torch.xpu.get_device_properties(dev).total_memory mem_total_torch = mem_reserved + mem_total = torch.xpu.get_device_properties(dev).total_memory else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -308,7 +308,7 @@ class LoadedModel: else: return self.model_memory() - def model_load(self, lowvram_model_memory=0): + def model_load(self, lowvram_model_memory=0, force_patch_weights=False): patch_model_to = self.device self.model.model_patches_to(self.device) @@ -318,7 +318,7 @@ class LoadedModel: try: if lowvram_model_memory > 0 and load_weights: - self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory) + self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) else: self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) except Exception as e: @@ -332,6 +332,11 @@ class LoadedModel: self.weights_loaded = True return self.real_model + def should_reload_model(self, force_patch_weights=False): + if force_patch_weights and self.model.lowvram_patch_counter > 0: + return True + return False + def model_unload(self, unpatch_weights=True): self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) @@ -408,7 +413,7 @@ def free_memory(memory_required, device, keep_loaded=[]): soft_empty_cache() -def load_models_gpu(models, memory_required=0): +def load_models_gpu(models, memory_required=0, force_patch_weights=False): global vram_state with model_management_lock: @@ -420,12 +425,21 @@ def load_models_gpu(models, memory_required=0): models_already_loaded = [] for x in models: loaded_model = LoadedModel(x) + loaded = None - if loaded_model in current_loaded_models: - index = current_loaded_models.index(loaded_model) - current_loaded_models.insert(0, current_loaded_models.pop(index)) - models_already_loaded.append(loaded_model) - else: + try: + loaded_model_index = current_loaded_models.index(loaded_model) + except ValueError: + loaded_model_index = None + + if loaded_model_index is not None: + loaded = current_loaded_models[loaded_model_index] + if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic + current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True) + loaded = None + else: + models_already_loaded.append(loaded) + if loaded is None: if hasattr(x, "model"): logging.info(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) @@ -473,7 +487,7 @@ def load_models_gpu(models, memory_required=0): if vram_set_state == VRAMState.NO_VRAM: lowvram_model_memory = 64 * 1024 * 1024 - cur_loaded_model = loaded_model.model_load(lowvram_model_memory) + loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model) return @@ -738,10 +752,10 @@ def get_free_memory(dev=None, torch_free_too=False): elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) mem_active = stats['active_bytes.all.current'] - mem_allocated = stats['allocated_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] mem_free_torch = mem_reserved - mem_active - mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated + mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved + mem_free_total = mem_free_xpu + mem_free_torch else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index 88c306b63..119eeafcd 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -38,7 +38,7 @@ class ModelManageable(Protocol): def model_dtype(self) -> torch.dtype: ... - def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module: + def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module: ... def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module: @@ -46,3 +46,7 @@ class ModelManageable(Protocol): def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: ... + + @property + def lowvram_patch_counter(self) -> int: + ... \ No newline at end of file diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a7c4a4fb3..d135ba91c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -19,7 +19,7 @@ def apply_weight_decompose(dora_scale, weight): .transpose(0, 1) ) - return weight * (dora_scale / weight_norm) + return weight * (dora_scale / weight_norm).type(weight.dtype) def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): @@ -65,6 +65,15 @@ class ModelPatcher(ModelManageable): self.weight_inplace_update = weight_inplace_update self.model_lowvram = False self.patches_uuid = uuid.uuid4() + self._lowvram_patch_counter = 0 + + @property + def lowvram_patch_counter(self): + return self._lowvram_patch_counter + + @lowvram_patch_counter.setter + def lowvram_patch_counter(self, value: int): + self._lowvram_patch_counter = value def model_size(self): if self.size > 0: @@ -278,7 +287,7 @@ class ModelPatcher(ModelManageable): return self.model - def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0): + def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): self.patch_model(device_to, patch_weights=False) logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024))) @@ -292,6 +301,7 @@ class ModelPatcher(ModelManageable): return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) mem_counter = 0 + patch_counter = 0 for n, m in self.model.named_modules(): lowvram_weight = False if hasattr(m, "comfy_cast_weights"): @@ -304,9 +314,17 @@ class ModelPatcher(ModelManageable): if lowvram_weight: if weight_key in self.patches: - m.weight_function = LowVramPatch(weight_key, self) + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + m.weight_function = LowVramPatch(weight_key, self) + patch_counter += 1 if bias_key in self.patches: - m.bias_function = LowVramPatch(bias_key, self) + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + m.bias_function = LowVramPatch(bias_key, self) + patch_counter += 1 m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True @@ -319,6 +337,7 @@ class ModelPatcher(ModelManageable): logging.debug("lowvram: loaded module regularly {}".format(m)) self.model_lowvram = True + self.lowvram_patch_counter = patch_counter return self.model def calculate_weight(self, patches, weight, key): @@ -470,6 +489,7 @@ class ModelPatcher(ModelManageable): m.bias_function = None self.model_lowvram = False + self.lowvram_patch_counter = 0 keys = list(self.backup.keys()) diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 891f17731..5bcae2868 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -1464,6 +1464,9 @@ class LoadImage: output_images = [] output_masks = [] + w, h = None, None + + excluded_formats = ['MPO'] # maintain the legacy path # this will ultimately return a tensor, so we'd rather have the tensors directly @@ -1478,6 +1481,14 @@ class LoadImage: if i.mode == 'I': i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") + + if len(output_images) == 0: + w = image.size[0] + h = image.size[1] + + if image.size[0] != w or image.size[1] != h: + continue + image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] if 'A' in i.getbands(): @@ -1488,14 +1499,14 @@ class LoadImage: output_images.append(image) output_masks.append(mask.unsqueeze(0)) - if len(output_images) > 1: + if len(output_images) > 1 and img.format not in excluded_formats: output_image = torch.cat(output_images, dim=0) output_mask = torch.cat(output_masks, dim=0) else: output_image = output_images[0] output_mask = output_masks[0] - return output_image, output_mask + return (output_image, output_mask) @classmethod def IS_CHANGED(s, image): diff --git a/comfy/sd.py b/comfy/sd.py index 46dedff2c..5a6f49882 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -582,7 +582,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m load_models.append(clip.load_model()) clip_sd = clip.get_sd() - model_management.load_models_gpu(load_models) + model_management.load_models_gpu(load_models, force_patch_weights=True) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) for k in extra_keys: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b3b69e05b..6ca32e8ee 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -65,6 +65,12 @@ class SD20(supported_models_base.BASE): "use_temporal_attention": False, } + unet_extra_config = { + "num_heads": -1, + "num_head_channels": 64, + "attn_precision": torch.float32, + } + latent_format = latent_formats.SD15 def model_type(self, state_dict, prefix=""): @@ -276,6 +282,12 @@ class SVD_img2vid(supported_models_base.BASE): "use_temporal_resblock": True } + unet_extra_config = { + "num_heads": -1, + "num_head_channels": 64, + "attn_precision": torch.float32, + } + clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual." latent_format = latent_formats.SD15 diff --git a/comfy/web/scripts/app.js b/comfy/web/scripts/app.js index a3105c275..a516be704 100644 --- a/comfy/web/scripts/app.js +++ b/comfy/web/scripts/app.js @@ -262,6 +262,36 @@ export class ComfyApp { }) ); } + + #addRestoreWorkflowView() { + const serialize = LGraph.prototype.serialize; + const self = this; + LGraph.prototype.serialize = function() { + const workflow = serialize.apply(this, arguments); + + // Store the drag & scale info in the serialized workflow if the setting is enabled + if (self.enableWorkflowViewRestore.value) { + if (!workflow.extra) { + workflow.extra = {}; + } + workflow.extra.ds = { + scale: self.canvas.ds.scale, + offset: self.canvas.ds.offset, + }; + } else if (workflow.extra?.ds) { + // Clear any old view data + delete workflow.extra.ds; + } + + return workflow; + } + this.enableWorkflowViewRestore = this.ui.settings.addSetting({ + id: "Comfy.EnableWorkflowViewRestore", + name: "Save and restore canvas position and zoom level in workflows", + type: "boolean", + defaultValue: true + }); + } /** * Adds special context menu handling for nodes @@ -1505,6 +1535,7 @@ export class ComfyApp { this.#addProcessKeyHandler(); this.#addConfigureHandler(); this.#addApiUpdateHandlers(); + this.#addRestoreWorkflowView(); this.graph = new LGraph(); @@ -1805,6 +1836,10 @@ export class ComfyApp { try { this.graph.configure(graphData); + if (this.enableWorkflowViewRestore.value && graphData.extra?.ds) { + this.canvas.ds.offset = graphData.extra.ds.offset; + this.canvas.ds.scale = graphData.extra.ds.scale; + } } catch (error) { let errorHint = []; // Try extracting filename to see if it was caused by an extension script @@ -2122,6 +2157,14 @@ export class ComfyApp { api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } })); } + showErrorOnFileLoad(file) { + this.ui.dialog.show( + $el("div", [ + $el("p", {textContent: `Unable to find workflow in ${file.name}`}) + ]).outerHTML + ); + } + /** * Loads workflow data from the specified file * @param {File} file @@ -2129,27 +2172,27 @@ export class ComfyApp { async handleFile(file) { if (file.type === "image/png") { const pngInfo = await getPngMetadata(file); - if (pngInfo) { - if (pngInfo.workflow) { - await this.loadGraphData(JSON.parse(pngInfo.workflow)); - } else if (pngInfo.prompt) { - this.loadApiJson(JSON.parse(pngInfo.prompt)); - } else if (pngInfo.parameters) { - importA1111(this.graph, pngInfo.parameters); - } + if (pngInfo?.workflow) { + await this.loadGraphData(JSON.parse(pngInfo.workflow)); + } else if (pngInfo?.prompt) { + this.loadApiJson(JSON.parse(pngInfo.prompt)); + } else if (pngInfo?.parameters) { + importA1111(this.graph, pngInfo.parameters); + } else { + this.showErrorOnFileLoad(file); } } else if (file.type === "image/webp") { const pngInfo = await getWebpMetadata(file); - if (pngInfo) { - if (pngInfo.workflow) { - this.loadGraphData(JSON.parse(pngInfo.workflow)); - } else if (pngInfo.Workflow) { - this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node. - } else if (pngInfo.prompt) { - this.loadApiJson(JSON.parse(pngInfo.prompt)); - } else if (pngInfo.Prompt) { - this.loadApiJson(JSON.parse(pngInfo.Prompt)); // Support loading prompts from that webp custom node. - } + // Support loading workflows from that webp custom node. + const workflow = pngInfo?.workflow || pngInfo?.Workflow; + const prompt = pngInfo?.prompt || pngInfo?.Prompt; + + if (workflow) { + this.loadGraphData(JSON.parse(workflow)); + } else if (prompt) { + this.loadApiJson(JSON.parse(prompt)); + } else { + this.showErrorOnFileLoad(file); } } else if (file.type === "application/json" || file.name?.endsWith(".json")) { const reader = new FileReader(); @@ -2170,7 +2213,11 @@ export class ComfyApp { await this.loadGraphData(JSON.parse(info.workflow)); } else if (info.prompt) { this.loadApiJson(JSON.parse(info.prompt)); + } else { + this.showErrorOnFileLoad(file); } + } else { + this.showErrorOnFileLoad(file); } } @@ -2278,6 +2325,12 @@ export class ComfyApp { await this.#invokeExtensionsAsync("refreshComboInNodes", defs); } + resetView() { + app.canvas.ds.scale = 1; + app.canvas.ds.offset = [0, 0] + app.graph.setDirtyCanvas(true, true); + } + /** * Clean current state */ diff --git a/comfy/web/scripts/ui.js b/comfy/web/scripts/ui.js index d0fa46efb..36fed3238 100644 --- a/comfy/web/scripts/ui.js +++ b/comfy/web/scripts/ui.js @@ -597,16 +597,23 @@ export class ComfyUI { if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); app.graph.clear(); + app.resetView(); } } }), $el("button", { id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => { if (!confirmClear.value || confirm("Load default workflow?")) { + app.resetView(); await app.loadGraphData() } } }), + $el("button", { + id: "comfy-reset-view-button", textContent: "Reset View", onclick: async () => { + app.resetView(); + } + }), ]); const devMode = this.settings.addSetting({ diff --git a/comfy_extras/nodes/nodes_model_merging.py b/comfy_extras/nodes/nodes_model_merging.py index 7a5532993..8562a09f3 100644 --- a/comfy_extras/nodes/nodes_model_merging.py +++ b/comfy_extras/nodes/nodes_model_merging.py @@ -174,9 +174,14 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi enable_modelspec = True if isinstance(model.model, model_base.SDXL): - metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" + if isinstance(model.model, model_base.SDXL_instructpix2pix): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-edit" + else: + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" elif isinstance(model.model, model_base.SDXLRefiner): metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" + elif isinstance(model.model, model_base.SVD_img2vid): + metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1" else: enable_modelspec = False @@ -261,7 +266,7 @@ class CLIPSave: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) - model_management.load_models_gpu([clip.load_model()]) + model_management.load_models_gpu([clip.load_model()], force_patch_weights=True) clip_sd = clip.get_sd() for prefix in ["clip_l.", "clip_g.", ""]: diff --git a/comfy_extras/nodes/nodes_sag.py b/comfy_extras/nodes/nodes_sag.py index 1d6dd40cd..998d7e287 100644 --- a/comfy_extras/nodes/nodes_sag.py +++ b/comfy_extras/nodes/nodes_sag.py @@ -5,12 +5,12 @@ import math from einops import rearrange, repeat import os -from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION +from comfy.ldm.modules.attention import optimized_attention from comfy import samplers # from comfy/ldm/modules/attention.py # but modified to return attention scores as well as output -def attention_basic_with_sim(q, k, v, heads, mask=None): +def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 @@ -26,7 +26,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None): ) # force cast to fp32 to avoid overflowing - if _ATTN_PRECISION =="fp32": + if attn_precision == torch.float32: sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale @@ -121,13 +121,13 @@ class SelfAttentionGuidance: if 1 in cond_or_uncond: uncond_index = cond_or_uncond.index(1) # do the entire attention operation, but save the attention scores to attn_scores - (out, sim) = attention_basic_with_sim(q, k, v, heads=heads) + (out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] n_slices = heads * b attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] return out else: - return optimized_attention(q, k, v, heads=heads) + return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) def post_cfg_function(args): nonlocal attn_scores diff --git a/tests/unit/test_openapi_nodes.py b/tests/unit/test_openapi_nodes.py index f141b69b8..09bab96ef 100644 --- a/tests/unit/test_openapi_nodes.py +++ b/tests/unit/test_openapi_nodes.py @@ -1,6 +1,7 @@ import os import pathlib import re +import sys import uuid from datetime import datetime @@ -225,6 +226,7 @@ def test_image_exif_merge(): @freeze_time("2024-01-14 03:21:34", tz_offset=-4) +@pytest.mark.skipif(sys.platform == 'win32') def test_image_exif_creation_date_and_batch_number(): assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None n = ImageExifCreationDateAndBatchNumber() @@ -264,7 +266,7 @@ def test_file_request_parameter(use_temporary_input_directory): image.save(image_path) n = ImageRequestParameter() - loaded_image, = n.execute(uri=image_path) + loaded_image, = n.execute(value=image_path) assert loaded_image.shape == (1, 1, 1, 3) from comfy.nodes.base_nodes import LoadImage