From b79fd7d92c7355b5e9cd5c1ea746d7dd06c27351 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 28 Aug 2024 16:12:24 -0400 Subject: [PATCH 1/3] ComfyUI supports more than just stable diffusion. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c72f27857..370f64100 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@
# ComfyUI -**The most powerful and modular stable diffusion GUI and backend.** +**The most powerful and modular diffusion model GUI and backend.** [![Website][website-shield]][website-url] From d31e226650ad01daefff66ec202992b8c3bf8384 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 28 Aug 2024 16:18:39 -0400 Subject: [PATCH 2/3] Unify RMSNorm code. --- comfy/ldm/common_dit.py | 13 +++++++++++ comfy/ldm/flux/layers.py | 4 ++-- comfy/ldm/modules/diffusionmodules/mmdit.py | 24 ++------------------- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index 990025521..9016abc44 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,4 +1,5 @@ import torch +import comfy.ops def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): @@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0] pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1] return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) + +try: + rms_norm_torch = torch.nn.functional.rms_norm +except: + rms_norm_torch = None + +def rms_norm(x, weight, eps=1e-6): + if rms_norm_torch is not None: + return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) + else: + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) + return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 20bd28505..dabab3e33 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -6,6 +6,7 @@ from torch import Tensor, nn from .math import attention, rope import comfy.ops +import comfy.ldm.common_dit class EmbedND(nn.Module): @@ -63,8 +64,7 @@ class RMSNorm(torch.nn.Module): self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) def forward(self, x: Tensor): - rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) - return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device) + return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6) class QKNorm(torch.nn.Module): diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 491a58a20..759788a97 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module): else: self.register_parameter("weight", None) - def _norm(self, x): - """ - Apply the RMSNorm normalization to the input tensor. - Args: - x (torch.Tensor): The input tensor. - Returns: - torch.Tensor: The normalized tensor. - """ - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - def forward(self, x): - """ - Forward pass through the RMSNorm layer. - Args: - x (torch.Tensor): The input tensor. - Returns: - torch.Tensor: The output tensor after applying RMSNorm. - """ - x = self._norm(x) - if self.learnable_scale: - return x * self.weight.to(device=x.device, dtype=x.dtype) - else: - return x + return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps) + class SwiGLUFeedForward(nn.Module): From 34eda0f853daffdfcab04e1b3187de40f1d30bbf Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Thu, 29 Aug 2024 06:46:30 +0900 Subject: [PATCH 3/3] fix: remove redundant useless loop (#4656) fix: potential error of undefined variable https://github.com/comfyanonymous/ComfyUI/discussions/4650 --- script_examples/websockets_api_example.py | 17 ++++++++--------- tests/inference/test_execution.py | 21 ++++++++++----------- tests/inference/test_inference.py | 17 ++++++++--------- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/script_examples/websockets_api_example.py b/script_examples/websockets_api_example.py index 57a6cbd9b..04c9fa21b 100644 --- a/script_examples/websockets_api_example.py +++ b/script_examples/websockets_api_example.py @@ -41,15 +41,14 @@ def get_images(ws, prompt): continue #previews are binary data history = get_history(prompt_id)[prompt_id] - for o in history['outputs']: - for node_id in history['outputs']: - node_output = history['outputs'][node_id] - if 'images' in node_output: - images_output = [] - for image in node_output['images']: - image_data = get_image(image['filename'], image['subfolder'], image['type']) - images_output.append(image_data) - output_images[node_id] = images_output + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + images_output = [] + if 'images' in node_output: + for image in node_output['images']: + image_data = get_image(image['filename'], image['subfolder'], image['type']) + images_output.append(image_data) + output_images[node_id] = images_output return output_images diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index ffc0c482a..c7daddeb6 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -95,17 +95,16 @@ class ComfyClient: pass # Probably want to store this off for testing history = self.get_history(prompt_id)[prompt_id] - for o in history['outputs']: - for node_id in history['outputs']: - node_output = history['outputs'][node_id] - result.outputs[node_id] = node_output - if 'images' in node_output: - images_output = [] - for image in node_output['images']: - image_data = self.get_image(image['filename'], image['subfolder'], image['type']) - image_obj = Image.open(BytesIO(image_data)) - images_output.append(image_obj) - node_output['image_objects'] = images_output + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + result.outputs[node_id] = node_output + images_output = [] + if 'images' in node_output: + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + image_obj = Image.open(BytesIO(image_data)) + images_output.append(image_obj) + node_output['image_objects'] = images_output return result diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 141cc5c7e..2e11778f2 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -109,15 +109,14 @@ class ComfyClient: continue #previews are binary data history = self.get_history(prompt_id)[prompt_id] - for o in history['outputs']: - for node_id in history['outputs']: - node_output = history['outputs'][node_id] - if 'images' in node_output: - images_output = [] - for image in node_output['images']: - image_data = self.get_image(image['filename'], image['subfolder'], image['type']) - images_output.append(image_data) - output_images[node_id] = images_output + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + images_output = [] + if 'images' in node_output: + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + images_output.append(image_data) + output_images[node_id] = images_output return output_images