mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 08:40:50 +08:00
47 lines
1.9 KiB
Python
47 lines
1.9 KiB
Python
import torch
|
|
from einops import rearrange, repeat
|
|
import comfy.rmsnorm
|
|
|
|
|
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
|
padding_mode = "reflect"
|
|
|
|
pad = ()
|
|
for i in range(img.ndim - 2):
|
|
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
|
|
|
|
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
|
|
|
|
|
rms_norm = comfy.rmsnorm.rms_norm
|
|
|
|
def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=(2, 2), transformer_options={}):
|
|
bs, c, h, w = x.shape
|
|
x = pad_to_patch_size(x, (patch_size, patch_size))
|
|
|
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
|
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
|
|
|
steps_h = h_len
|
|
steps_w = w_len
|
|
|
|
rope_options = transformer_options.get("rope_options", None)
|
|
if rope_options is not None:
|
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
|
|
|
index += rope_options.get("shift_t", 0.0)
|
|
h_offset += rope_options.get("shift_y", 0.0)
|
|
w_offset += rope_options.get("shift_x", 0.0)
|
|
|
|
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
|
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
|
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
|
|
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|