mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 15:20:51 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
39c3ef9d66
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
def rms_norm(x, weight, eps=1e-6):
|
||||
if rms_norm_torch is not None:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
@ -6,6 +6,7 @@ from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
@ -63,8 +64,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device)
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
|
||||
@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module):
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
|
||||
class SwiGLUFeedForward(nn.Module):
|
||||
|
||||
@ -41,15 +41,14 @@ def get_images(ws, prompt):
|
||||
continue #previews are binary data
|
||||
|
||||
history = get_history(prompt_id)[prompt_id]
|
||||
for o in history['outputs']:
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
if 'images' in node_output:
|
||||
images_output = []
|
||||
for image in node_output['images']:
|
||||
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
images_output = []
|
||||
if 'images' in node_output:
|
||||
for image in node_output['images']:
|
||||
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
|
||||
return output_images
|
||||
|
||||
|
||||
@ -95,17 +95,16 @@ class ComfyClient:
|
||||
pass # Probably want to store this off for testing
|
||||
|
||||
history = self.get_history(prompt_id)[prompt_id]
|
||||
for o in history['outputs']:
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
result.outputs[node_id] = node_output
|
||||
if 'images' in node_output:
|
||||
images_output = []
|
||||
for image in node_output['images']:
|
||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||
image_obj = Image.open(BytesIO(image_data))
|
||||
images_output.append(image_obj)
|
||||
node_output['image_objects'] = images_output
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
result.outputs[node_id] = node_output
|
||||
images_output = []
|
||||
if 'images' in node_output:
|
||||
for image in node_output['images']:
|
||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||
image_obj = Image.open(BytesIO(image_data))
|
||||
images_output.append(image_obj)
|
||||
node_output['image_objects'] = images_output
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -109,15 +109,14 @@ class ComfyClient:
|
||||
continue #previews are binary data
|
||||
|
||||
history = self.get_history(prompt_id)[prompt_id]
|
||||
for o in history['outputs']:
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
if 'images' in node_output:
|
||||
images_output = []
|
||||
for image in node_output['images']:
|
||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
images_output = []
|
||||
if 'images' in node_output:
|
||||
for image in node_output['images']:
|
||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
|
||||
return output_images
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user