mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 09:57:24 +08:00
Compare commits
10 Commits
eafe5425d1
...
7b112dbac5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b112dbac5 | ||
|
|
f3ea976cba | ||
|
|
5538f62b0b | ||
|
|
2806163f6e | ||
|
|
cea8d0925f | ||
|
|
b138133ffa | ||
|
|
7d600fe114 | ||
|
|
a616140d6b | ||
|
|
79641ceff1 | ||
|
|
e41e0060b9 |
@ -91,6 +91,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
|
||||
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
NoPreviews = "none"
|
||||
|
||||
@ -9,12 +9,13 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import repeat
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
|
||||
from comfy.ldm.flux.layers import EmbedND, timestep_embedding, DoubleStreamBlock, SingleStreamBlock
|
||||
|
||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||
from comfy.ldm.chroma.layers import (
|
||||
Approximator,
|
||||
ChromaModulationOut,
|
||||
)
|
||||
from .layers import (
|
||||
NerfEmbedder,
|
||||
@ -25,7 +26,26 @@ from .layers import (
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaRadianceParams(ChromaParams):
|
||||
class ChromaRadianceParams:
|
||||
# Fields from ChromaParams (now independent)
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
in_dim: int
|
||||
out_dim: int
|
||||
hidden_dim: int
|
||||
n_layers: int
|
||||
txt_ids_dims: list
|
||||
vec_in_dim: int
|
||||
# ChromaRadiance-specific fields
|
||||
patch_size: int
|
||||
nerf_hidden_size: int
|
||||
nerf_mlp_ratio: int
|
||||
@ -39,7 +59,7 @@ class ChromaRadianceParams(ChromaParams):
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
use_x0: bool
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
class ChromaRadiance(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
@ -47,7 +67,7 @@ class ChromaRadiance(Chroma):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
if operations is None:
|
||||
raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
|
||||
nn.Module.__init__(self)
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = ChromaRadianceParams(**kwargs)
|
||||
self.params = params
|
||||
@ -176,6 +196,155 @@ class ChromaRadiance(Chroma):
|
||||
# flatten into a sequence for the transformer.
|
||||
return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
|
||||
|
||||
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
|
||||
# This function slices up the modulations tensor which has the following layout:
|
||||
# single : num_single_blocks * 3 elements
|
||||
# double_img : num_double_blocks * 6 elements
|
||||
# double_txt : num_double_blocks * 6 elements
|
||||
# final : 2 elements
|
||||
if block_type == "final":
|
||||
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
|
||||
single_block_count = self.params.depth_single_blocks
|
||||
double_block_count = self.params.depth
|
||||
offset = 3 * idx
|
||||
if block_type == "single":
|
||||
return ChromaModulationOut.from_offset(tensor, offset)
|
||||
# Double block modulations are 6 elements so we double 3 * idx.
|
||||
offset *= 2
|
||||
if block_type in {"double_img", "double_txt"}:
|
||||
# Advance past the single block modulations.
|
||||
offset += 3 * single_block_count
|
||||
if block_type == "double_txt":
|
||||
# Advance past the double block img modulations.
|
||||
offset += 6 * double_block_count
|
||||
return (
|
||||
ChromaModulationOut.from_offset(tensor, offset),
|
||||
ChromaModulationOut.from_offset(tensor, offset + 3),
|
||||
)
|
||||
raise ValueError("Bad block_type")
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control = None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
|
||||
# distilled vector guidance
|
||||
mod_index_length = 344
|
||||
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
|
||||
# guidance = guidance *
|
||||
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
|
||||
|
||||
# get all modulation index
|
||||
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
|
||||
# we need to broadcast the modulation index here so each batch has all of the index
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
|
||||
# and we need to broadcast timestep and guidance along too
|
||||
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
|
||||
# then and only then we could concatenate it together
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
|
||||
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_mmdit:
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
self.get_modulations(mod_vectors, "double_txt", idx=i),
|
||||
)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": double_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=double_mod,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": single_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
return img
|
||||
|
||||
def forward_nerf(
|
||||
self,
|
||||
img_orig: Tensor,
|
||||
@ -285,6 +454,13 @@ class ChromaRadiance(Chroma):
|
||||
eps = 0.0
|
||||
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
@ -332,4 +508,3 @@ class ChromaRadiance(Chroma):
|
||||
if hasattr(self, "__x0__"):
|
||||
out = self._apply_x0_residual(out, img, timestep)
|
||||
return out
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
from comfy_kitchen.tensor import (
|
||||
@ -21,7 +23,15 @@ try:
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
if args.enable_triton_backend:
|
||||
try:
|
||||
import triton
|
||||
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
|
||||
except ImportError as e:
|
||||
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
|
||||
ck.registry.disable("triton")
|
||||
else:
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
except ImportError as e:
|
||||
|
||||
@ -666,12 +666,13 @@ class ColorTransfer(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ColorTransfer",
|
||||
display_name="Color Transfer",
|
||||
category="image/postprocessing",
|
||||
description="Match the colors of one image to another using various algorithms.",
|
||||
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
|
||||
inputs=[
|
||||
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
|
||||
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
|
||||
io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."),
|
||||
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
|
||||
io.DynamicCombo.Input("source_stats",
|
||||
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
|
||||
|
||||
@ -49,7 +49,7 @@ class Int(io.ComfyNode):
|
||||
display_name="Int",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
|
||||
],
|
||||
outputs=[io.Int.Output()],
|
||||
)
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
#config for a1111 ui
|
||||
#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed
|
||||
|
||||
#a111:
|
||||
#a1111:
|
||||
# base_path: path/to/stable-diffusion-webui/
|
||||
# checkpoints: models/Stable-diffusion
|
||||
# configs: models/Stable-diffusion
|
||||
|
||||
66
nodes.py
66
nodes.py
@ -1754,57 +1754,49 @@ class LoadImage:
|
||||
|
||||
return True
|
||||
|
||||
class LoadImageMask:
|
||||
|
||||
class LoadImageMask(LoadImage):
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||
|
||||
_color_channels = ["alpha", "red", "green", "blue"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(files), {"image_upload": True}),
|
||||
"channel": (s._color_channels, ), }
|
||||
}
|
||||
types = super().INPUT_TYPES()
|
||||
return {
|
||||
"required": {
|
||||
**types["required"],
|
||||
"channel": (s._color_channels, )
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = node_helpers.pillow(Image.open, image_path)
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
if i.getbands() != ("R", "G", "B", "A"):
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
i = i.convert("RGBA")
|
||||
mask = None
|
||||
FUNCTION = "load_image_mask"
|
||||
|
||||
def load_image_mask(self, image, channel):
|
||||
image_tensor, mask_tensor = super().load_image(image)
|
||||
c = channel[0].upper()
|
||||
if c in i.getbands():
|
||||
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
|
||||
mask = torch.from_numpy(mask)
|
||||
if c == 'A':
|
||||
mask = 1. - mask
|
||||
|
||||
if c == 'A':
|
||||
return (mask_tensor,)
|
||||
|
||||
channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0)
|
||||
|
||||
if channel_idx < image_tensor.shape[-1]:
|
||||
return (image_tensor[..., channel_idx].clone(),)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
return (mask.unsqueeze(0),)
|
||||
empty_mask = torch.zeros(
|
||||
image_tensor.shape[:-1],
|
||||
dtype=image_tensor.dtype,
|
||||
device=image_tensor.device
|
||||
)
|
||||
return (empty_mask,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
|
||||
return True
|
||||
return super().IS_CHANGED(image)
|
||||
|
||||
|
||||
class LoadImageOutput(LoadImage):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user