mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-12 02:02:36 +08:00
Merge branch 'master' into feature/generic-feature-flag-cli
This commit is contained in:
commit
0992141135
13
README.md
13
README.md
@ -1,7 +1,7 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
# ComfyUI
|
# ComfyUI
|
||||||
**The most powerful and modular visual AI engine and application.**
|
**The most powerful and modular AI engine for content creation.**
|
||||||
|
|
||||||
|
|
||||||
[![Website][website-shield]][website-url]
|
[![Website][website-shield]][website-url]
|
||||||
@ -31,10 +31,16 @@
|
|||||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||||
|
|
||||||

|
<img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/36e065e0-bfae-4456-8c7f-8369d5ea48a2" />
|
||||||
|
<br>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
|
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
|
||||||
|
- ComfyUI natively supports the latest open-source state of the art models.
|
||||||
|
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
|
||||||
|
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
|
||||||
|
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
|
||||||
|
- It integrates seamlessly into production pipelines with our API endpoints.
|
||||||
|
|
||||||
## Get Started
|
## Get Started
|
||||||
|
|
||||||
@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
|
|||||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||||
|
- Ernie Image
|
||||||
- Image Editing Models
|
- Image Editing Models
|
||||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||||
|
|||||||
@ -91,6 +91,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
|
|||||||
|
|
||||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||||
|
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
|
||||||
|
|
||||||
class LatentPreviewMethod(enum.Enum):
|
class LatentPreviewMethod(enum.Enum):
|
||||||
NoPreviews = "none"
|
NoPreviews = "none"
|
||||||
|
|||||||
@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
|||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
|
|
||||||
|
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
|
|
||||||
scale = dim_head ** -0.5
|
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
|
||||||
|
n_rep = q.shape[-3] // k.shape[-3]
|
||||||
|
k = k.repeat_interleave(n_rep, dim=-3)
|
||||||
|
v = v.repeat_interleave(n_rep, dim=-3)
|
||||||
|
|
||||||
|
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||||
|
|
||||||
h = heads
|
h = heads
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
b, _, dim_head = query.shape
|
b, _, dim_head = query.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
|
|
||||||
|
if "scale" in kwargs:
|
||||||
|
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
|
||||||
|
query = query * (kwargs["scale"] * dim_head ** 0.5)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
query = query.reshape(b * heads, -1, dim_head)
|
query = query.reshape(b * heads, -1, dim_head)
|
||||||
value = value.reshape(b * heads, -1, dim_head)
|
value = value.reshape(b * heads, -1, dim_head)
|
||||||
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
|
|
||||||
scale = dim_head ** -0.5
|
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
if mask.ndim == 3:
|
if mask.ndim == 3:
|
||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
|
||||||
|
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
|
||||||
|
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
|
||||||
|
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
|
||||||
|
|
||||||
if SDP_BATCH_LIMIT >= b:
|
if SDP_BATCH_LIMIT >= b:
|
||||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
k[i : i + SDP_BATCH_LIMIT],
|
k[i : i + SDP_BATCH_LIMIT],
|
||||||
v[i : i + SDP_BATCH_LIMIT],
|
v[i : i + SDP_BATCH_LIMIT],
|
||||||
attn_mask=m,
|
attn_mask=m,
|
||||||
dropout_p=0.0, is_causal=False
|
dropout_p=0.0, is_causal=False, **sdpa_extra
|
||||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
87
comfy/ops.py
87
comfy/ops.py
@ -1246,6 +1246,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
self._buffers[key] = fn(buf)
|
self._buffers[key] = fn(buf)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
class Embedding(manual_cast.Embedding):
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
weight_key = f"{prefix}weight"
|
||||||
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||||
|
if layer_conf is not None:
|
||||||
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
|
|
||||||
|
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
||||||
|
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
||||||
|
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
|
||||||
|
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
||||||
|
self.quant_format = quant_format
|
||||||
|
qconfig = QUANT_ALGOS[quant_format]
|
||||||
|
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
|
||||||
|
weight = state_dict.pop(weight_key)
|
||||||
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
|
scale_key = f"{prefix}weight_scale"
|
||||||
|
scale = state_dict.pop(scale_key, None)
|
||||||
|
if scale is not None:
|
||||||
|
scale = scale.float()
|
||||||
|
manually_loaded_keys.append(scale_key)
|
||||||
|
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.num_embeddings, self.embedding_dim),
|
||||||
|
)
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
for k in manually_loaded_keys:
|
||||||
|
if k in missing_keys:
|
||||||
|
missing_keys.remove(k)
|
||||||
|
else:
|
||||||
|
if layer_conf is not None:
|
||||||
|
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
|
if destination is not None:
|
||||||
|
sd = destination
|
||||||
|
else:
|
||||||
|
sd = {}
|
||||||
|
|
||||||
|
if not hasattr(self, 'weight') or self.weight is None:
|
||||||
|
return sd
|
||||||
|
|
||||||
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
|
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||||
|
for k in sd_out:
|
||||||
|
sd[k] = sd_out[k]
|
||||||
|
|
||||||
|
quant_conf = {"format": self.quant_format}
|
||||||
|
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
else:
|
||||||
|
sd["{}weight".format(prefix)] = self.weight
|
||||||
|
return sd
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||||
|
weight = self.weight
|
||||||
|
|
||||||
|
# Optimized path: lookup in fp8, dequantize only the selected rows.
|
||||||
|
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
|
||||||
|
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
|
||||||
|
if isinstance(qdata, QuantizedTensor):
|
||||||
|
scale = qdata._params.scale
|
||||||
|
qdata = qdata._qdata
|
||||||
|
else:
|
||||||
|
scale = None
|
||||||
|
|
||||||
|
x = torch.nn.functional.embedding(
|
||||||
|
input, qdata, self.padding_idx, self.max_norm,
|
||||||
|
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||||
|
uncast_bias_weight(self, qdata, None, offload_stream)
|
||||||
|
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
|
||||||
|
x = x.to(dtype=target_dtype)
|
||||||
|
if scale is not None and scale != 1.0:
|
||||||
|
x = x * scale.to(dtype=target_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Fallback for non-quantized or weight_function (LoRA) case
|
||||||
|
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
|
||||||
|
|
||||||
return MixedPrecisionOps
|
return MixedPrecisionOps
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import comfy_kitchen as ck
|
import comfy_kitchen as ck
|
||||||
from comfy_kitchen.tensor import (
|
from comfy_kitchen.tensor import (
|
||||||
@ -21,7 +23,15 @@ try:
|
|||||||
ck.registry.disable("cuda")
|
ck.registry.disable("cuda")
|
||||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||||
|
|
||||||
ck.registry.disable("triton")
|
if args.enable_triton_backend:
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
|
||||||
|
except ImportError as e:
|
||||||
|
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
|
||||||
|
ck.registry.disable("triton")
|
||||||
|
else:
|
||||||
|
ck.registry.disable("triton")
|
||||||
for k, v in ck.list_backends().items():
|
for k, v in ck.list_backends().items():
|
||||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import comfy.model_management
|
|||||||
|
|
||||||
RMSNorm = torch.nn.RMSNorm
|
RMSNorm = torch.nn.RMSNorm
|
||||||
|
|
||||||
|
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
if weight is None:
|
if weight is None:
|
||||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||||
|
|||||||
17
comfy/sd.py
17
comfy/sd.py
@ -65,6 +65,7 @@ import comfy.text_encoders.ace15
|
|||||||
import comfy.text_encoders.longcat_image
|
import comfy.text_encoders.longcat_image
|
||||||
import comfy.text_encoders.qwen35
|
import comfy.text_encoders.qwen35
|
||||||
import comfy.text_encoders.ernie
|
import comfy.text_encoders.ernie
|
||||||
|
import comfy.text_encoders.gemma4
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -1271,6 +1272,9 @@ class TEModel(Enum):
|
|||||||
QWEN35_9B = 26
|
QWEN35_9B = 26
|
||||||
QWEN35_27B = 27
|
QWEN35_27B = 27
|
||||||
MINISTRAL_3_3B = 28
|
MINISTRAL_3_3B = 28
|
||||||
|
GEMMA_4_E4B = 29
|
||||||
|
GEMMA_4_E2B = 30
|
||||||
|
GEMMA_4_31B = 31
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1296,6 +1300,12 @@ def detect_te_model(sd):
|
|||||||
return TEModel.BYT5_SMALL_GLYPH
|
return TEModel.BYT5_SMALL_GLYPH
|
||||||
return TEModel.T5_BASE
|
return TEModel.T5_BASE
|
||||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||||
|
if 'model.layers.59.self_attn.q_norm.weight' in sd:
|
||||||
|
return TEModel.GEMMA_4_31B
|
||||||
|
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
|
||||||
|
return TEModel.GEMMA_4_E4B
|
||||||
|
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
|
||||||
|
return TEModel.GEMMA_4_E2B
|
||||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||||
return TEModel.GEMMA_3_12B
|
return TEModel.GEMMA_3_12B
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
@ -1435,6 +1445,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
else:
|
else:
|
||||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||||
|
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
|
||||||
|
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
|
||||||
|
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
|
||||||
|
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
|
||||||
|
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
|
||||||
|
clip_target.tokenizer = variant.tokenizer
|
||||||
|
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||||
elif te_model == TEModel.GEMMA_2_2B:
|
elif te_model == TEModel.GEMMA_2_2B:
|
||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||||
|
|||||||
1298
comfy/text_encoders/gemma4.py
Normal file
1298
comfy/text_encoders/gemma4.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -521,7 +521,7 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
present_key_value = (xk, xv, index + num_tokens)
|
present_key_value = (xk, xv, index + num_tokens)
|
||||||
|
|
||||||
if sliding_window is not None and xk.shape[2] > sliding_window:
|
if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
|
||||||
xk = xk[:, :, -sliding_window:]
|
xk = xk[:, :, -sliding_window:]
|
||||||
xv = xv[:, :, -sliding_window:]
|
xv = xv[:, :, -sliding_window:]
|
||||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||||
@ -533,12 +533,12 @@ class Attention(nn.Module):
|
|||||||
return self.o_proj(output), present_key_value
|
return self.o_proj(output), present_key_value
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
ops = ops or nn
|
intermediate_size = intermediate_size or config.intermediate_size
|
||||||
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||||
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||||
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
if config.mlp_activation == "silu":
|
if config.mlp_activation == "silu":
|
||||||
self.activation = torch.nn.functional.silu
|
self.activation = torch.nn.functional.silu
|
||||||
elif config.mlp_activation == "gelu_pytorch_tanh":
|
elif config.mlp_activation == "gelu_pytorch_tanh":
|
||||||
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
|
|
||||||
return x, present_key_value
|
return x, present_key_value
|
||||||
|
|
||||||
|
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
|
||||||
|
class ScaledEmbedding(ops.Embedding):
|
||||||
|
def forward(self, input_ids, out_dtype=None):
|
||||||
|
return super().forward(input_ids, out_dtype=out_dtype) * scale
|
||||||
|
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class Llama2_(nn.Module):
|
class Llama2_(nn.Module):
|
||||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = ops.Embedding(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype
|
|
||||||
)
|
|
||||||
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
|
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
|
||||||
transformer = TransformerBlockGemma2
|
transformer = TransformerBlockGemma2
|
||||||
self.normalize_in = True
|
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
|
||||||
else:
|
else:
|
||||||
transformer = TransformerBlock
|
transformer = TransformerBlock
|
||||||
self.normalize_in = False
|
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||||
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
|
|||||||
self.config.rope_dims,
|
self.config.rope_dims,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
x = embeds
|
x = embeds
|
||||||
else:
|
else:
|
||||||
x = self.embed_tokens(x, out_dtype=dtype)
|
x = self.embed_tokens(x, out_dtype=dtype)
|
||||||
|
|
||||||
if self.normalize_in:
|
|
||||||
x *= self.config.hidden_size ** 0.5
|
|
||||||
|
|
||||||
seq_len = x.shape[1]
|
seq_len = x.shape[1]
|
||||||
past_len = 0
|
past_len = 0
|
||||||
if past_key_values is not None and len(past_key_values) > 0:
|
if past_key_values is not None and len(past_key_values) > 0:
|
||||||
@ -850,7 +848,7 @@ class BaseGenerate:
|
|||||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
||||||
device = embeds.device
|
device = embeds.device
|
||||||
|
|
||||||
if stop_tokens is None:
|
if stop_tokens is None:
|
||||||
@ -875,14 +873,16 @@ class BaseGenerate:
|
|||||||
pbar = comfy.utils.ProgressBar(max_length)
|
pbar = comfy.utils.ProgressBar(max_length)
|
||||||
|
|
||||||
# Generation loop
|
# Generation loop
|
||||||
|
current_input_ids = initial_input_ids
|
||||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
||||||
logits = self.logits(x)[:, -1]
|
logits = self.logits(x)[:, -1]
|
||||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||||
token_id = next_token[0].item()
|
token_id = next_token[0].item()
|
||||||
generated_token_ids.append(token_id)
|
generated_token_ids.append(token_id)
|
||||||
|
|
||||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||||
|
current_input_ids = next_token if initial_input_ids is not None else None
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
if token_id in stop_tokens:
|
if token_id in stop_tokens:
|
||||||
|
|||||||
@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
|||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
|
||||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
|
||||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||||
|
|
||||||
class DualLinearProjection(torch.nn.Module):
|
class DualLinearProjection(torch.nn.Module):
|
||||||
|
|||||||
@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
|
|||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
def process_tokens(self, tokens, device):
|
def process_tokens(self, tokens, device):
|
||||||
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
|
embeds, _, _, _ = super().process_tokens(tokens, device)
|
||||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||||
|
|||||||
@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
|
|||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.normalize_in = False
|
|
||||||
|
|
||||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||||
|
|||||||
@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
|
|||||||
memo[obj_id] = res
|
memo[obj_id] = res
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
|
|
||||||
"""Normalize image embeddings to match text embedding scale"""
|
|
||||||
for info in embeds_info:
|
|
||||||
if info.get("type") == "image":
|
|
||||||
start_idx = info["index"]
|
|
||||||
end_idx = start_idx + info["size"]
|
|
||||||
embeds[:, start_idx:end_idx, :] /= scale_factor
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel):
|
|||||||
grain: Optional[float] = Field(None, description="Grain after AI model processing")
|
grain: Optional[float] = Field(None, description="Grain after AI model processing")
|
||||||
grainSize: Optional[float] = Field(None, description="Size of generated grain")
|
grainSize: Optional[float] = Field(None, description="Size of generated grain")
|
||||||
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
|
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
|
||||||
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
|
creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.")
|
||||||
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
|
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
|
||||||
|
prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)")
|
||||||
|
sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness")
|
||||||
|
realism: float | None = Field(None, description="ast-2 realism control")
|
||||||
|
|
||||||
|
|
||||||
class OutputInformationVideo(BaseModel):
|
class OutputInformationVideo(BaseModel):
|
||||||
@ -90,7 +93,7 @@ class Overrides(BaseModel):
|
|||||||
|
|
||||||
class CreateVideoRequest(BaseModel):
|
class CreateVideoRequest(BaseModel):
|
||||||
source: CreateVideoRequestSource = Field(...)
|
source: CreateVideoRequestSource = Field(...)
|
||||||
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
|
filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...)
|
||||||
output: OutputInformationVideo = Field(...)
|
output: OutputInformationVideo = Field(...)
|
||||||
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
|
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
|
||||||
|
|
||||||
|
|||||||
@ -36,11 +36,15 @@ from comfy_api_nodes.util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
UPSCALER_MODELS_MAP = {
|
UPSCALER_MODELS_MAP = {
|
||||||
|
"Astra 2": "ast-2",
|
||||||
"Starlight (Astra) Fast": "slf-1",
|
"Starlight (Astra) Fast": "slf-1",
|
||||||
"Starlight (Astra) Creative": "slc-1",
|
"Starlight (Astra) Creative": "slc-1",
|
||||||
"Starlight Precise 2.5": "slp-2.5",
|
"Starlight Precise 2.5": "slp-2.5",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AST2_MAX_FRAMES = 9000
|
||||||
|
AST2_MAX_FRAMES_WITH_PROMPT = 450
|
||||||
|
|
||||||
|
|
||||||
class TopazImageEnhance(IO.ComfyNode):
|
class TopazImageEnhance(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="TopazVideoEnhance",
|
node_id="TopazVideoEnhance",
|
||||||
display_name="Topaz Video Enhance",
|
display_name="Topaz Video Enhance (Legacy)",
|
||||||
category="api node/video/Topaz",
|
category="api node/video/Topaz",
|
||||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Video.Input("video"),
|
IO.Video.Input("video"),
|
||||||
IO.Boolean.Input("upscaler_enabled", default=True),
|
IO.Boolean.Input("upscaler_enabled", default=True),
|
||||||
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
IO.Combo.Input(
|
||||||
|
"upscaler_model",
|
||||||
|
options=[
|
||||||
|
"Starlight (Astra) Fast",
|
||||||
|
"Starlight (Astra) Creative",
|
||||||
|
"Starlight Precise 2.5",
|
||||||
|
],
|
||||||
|
),
|
||||||
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"upscaler_creativity",
|
"upscaler_creativity",
|
||||||
@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -457,12 +469,357 @@ class TopazVideoEnhance(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
||||||
|
|
||||||
|
|
||||||
|
class TopazVideoEnhanceV2(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="TopazVideoEnhanceV2",
|
||||||
|
display_name="Topaz Video Enhance",
|
||||||
|
category="api node/video/Topaz",
|
||||||
|
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||||
|
inputs=[
|
||||||
|
IO.Video.Input("video"),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"upscaler_model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"Astra 2",
|
||||||
|
[
|
||||||
|
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||||
|
IO.Float.Input(
|
||||||
|
"creativity",
|
||||||
|
default=0.5,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Creative strength of the upscale.",
|
||||||
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Optional descriptive (not instructive) scene prompt."
|
||||||
|
f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"sharp",
|
||||||
|
default=0.5,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Pre-enhance sharpness: "
|
||||||
|
"0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"realism",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Pulls output toward photographic realism."
|
||||||
|
"Leave at 0 for the model default.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"Starlight (Astra) Fast",
|
||||||
|
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"Starlight (Astra) Creative",
|
||||||
|
[
|
||||||
|
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"creativity",
|
||||||
|
options=["low", "middle", "high"],
|
||||||
|
default="low",
|
||||||
|
tooltip="Creative strength of the upscale.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"Starlight Precise 2.5",
|
||||||
|
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option("Disabled", []),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"interpolation_model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option("Disabled", []),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"apo-8",
|
||||||
|
[
|
||||||
|
IO.Int.Input(
|
||||||
|
"interpolation_frame_rate",
|
||||||
|
default=60,
|
||||||
|
min=15,
|
||||||
|
max=240,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Output frame rate.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"interpolation_slowmo",
|
||||||
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=16,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Slow-motion factor applied to the input video. "
|
||||||
|
"For example, 2 makes the output twice as slow and doubles the duration.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"interpolation_duplicate",
|
||||||
|
default=False,
|
||||||
|
tooltip="Analyze the input for duplicate frames and remove them.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"interpolation_duplicate_threshold",
|
||||||
|
default=0.01,
|
||||||
|
min=0.001,
|
||||||
|
max=0.1,
|
||||||
|
step=0.001,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Detection sensitivity for duplicate frames.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"dynamic_compression_level",
|
||||||
|
options=["Low", "Mid", "High"],
|
||||||
|
default="Low",
|
||||||
|
tooltip="CQP level.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=[
|
||||||
|
"upscaler_model",
|
||||||
|
"upscaler_model.upscaler_resolution",
|
||||||
|
"interpolation_model",
|
||||||
|
]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$model := $lookup(widgets, "upscaler_model");
|
||||||
|
$res := $lookup(widgets, "upscaler_model.upscaler_resolution");
|
||||||
|
$interp := $lookup(widgets, "interpolation_model");
|
||||||
|
$is4k := $contains($res, "4k");
|
||||||
|
$hasInterp := $interp != "disabled";
|
||||||
|
$rates := {
|
||||||
|
"starlight (astra) fast": {"hd": 0.43, "uhd": 0.85},
|
||||||
|
"starlight precise 2.5": {"hd": 0.70, "uhd": 1.54},
|
||||||
|
"astra 2": {"hd": 1.72, "uhd": 2.85},
|
||||||
|
"starlight (astra) creative": {"hd": 2.25, "uhd": 3.99}
|
||||||
|
};
|
||||||
|
$surcharge := $is4k ? 0.28 : 0.14;
|
||||||
|
$entry := $lookup($rates, $model);
|
||||||
|
$base := $is4k ? $entry.uhd : $entry.hd;
|
||||||
|
$hi := $base + ($hasInterp ? $surcharge : 0);
|
||||||
|
$model = "disabled"
|
||||||
|
? {"type":"text","text":"Interpolation only"}
|
||||||
|
: ($hasInterp
|
||||||
|
? {"type":"text","text":"~" & $string($base) & "–" & $string($hi) & " credits/src frame"}
|
||||||
|
: {"type":"text","text":"~" & $string($base) & " credits/src frame"})
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
video: Input.Video,
|
||||||
|
upscaler_model: dict,
|
||||||
|
interpolation_model: dict,
|
||||||
|
dynamic_compression_level: str = "Low",
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
upscaler_choice = upscaler_model["upscaler_model"]
|
||||||
|
interpolation_choice = interpolation_model["interpolation_model"]
|
||||||
|
if upscaler_choice == "Disabled" and interpolation_choice == "Disabled":
|
||||||
|
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
|
||||||
|
validate_container_format_is_mp4(video)
|
||||||
|
src_width, src_height = video.get_dimensions()
|
||||||
|
src_frame_rate = int(video.get_frame_rate())
|
||||||
|
duration_sec = video.get_duration()
|
||||||
|
src_video_stream = video.get_stream_source()
|
||||||
|
target_width = src_width
|
||||||
|
target_height = src_height
|
||||||
|
target_frame_rate = src_frame_rate
|
||||||
|
filters = []
|
||||||
|
if upscaler_choice != "Disabled":
|
||||||
|
if "1080p" in upscaler_model["upscaler_resolution"]:
|
||||||
|
target_pixel_p = 1080
|
||||||
|
max_long_side = 1920
|
||||||
|
else:
|
||||||
|
target_pixel_p = 2160
|
||||||
|
max_long_side = 3840
|
||||||
|
ar = src_width / src_height
|
||||||
|
if src_width >= src_height:
|
||||||
|
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
|
||||||
|
target_height = target_pixel_p
|
||||||
|
target_width = int(target_height * ar)
|
||||||
|
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
|
||||||
|
if target_width > max_long_side:
|
||||||
|
target_width = max_long_side
|
||||||
|
target_height = int(target_width / ar)
|
||||||
|
else:
|
||||||
|
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
|
||||||
|
target_width = target_pixel_p
|
||||||
|
target_height = int(target_width / ar)
|
||||||
|
# Check if height exceeds standard bounds
|
||||||
|
if target_height > max_long_side:
|
||||||
|
target_height = max_long_side
|
||||||
|
target_width = int(target_height * ar)
|
||||||
|
if target_width % 2 != 0:
|
||||||
|
target_width += 1
|
||||||
|
if target_height % 2 != 0:
|
||||||
|
target_height += 1
|
||||||
|
model_id = UPSCALER_MODELS_MAP[upscaler_choice]
|
||||||
|
if model_id == "slc-1":
|
||||||
|
filters.append(
|
||||||
|
VideoEnhancementFilter(
|
||||||
|
model=model_id,
|
||||||
|
creativity=upscaler_model["creativity"],
|
||||||
|
isOptimizedMode=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif model_id == "ast-2":
|
||||||
|
n_frames = video.get_frame_count()
|
||||||
|
ast2_prompt = (upscaler_model["prompt"] or "").strip()
|
||||||
|
if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT:
|
||||||
|
raise ValueError(
|
||||||
|
f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames "
|
||||||
|
f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip."
|
||||||
|
)
|
||||||
|
if n_frames > AST2_MAX_FRAMES:
|
||||||
|
raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.")
|
||||||
|
realism = upscaler_model["realism"]
|
||||||
|
filters.append(
|
||||||
|
VideoEnhancementFilter(
|
||||||
|
model=model_id,
|
||||||
|
creativity=upscaler_model["creativity"],
|
||||||
|
prompt=(ast2_prompt or None),
|
||||||
|
sharp=upscaler_model["sharp"],
|
||||||
|
realism=(realism if realism > 0 else None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filters.append(VideoEnhancementFilter(model=model_id))
|
||||||
|
if interpolation_choice != "Disabled":
|
||||||
|
target_frame_rate = interpolation_model["interpolation_frame_rate"]
|
||||||
|
filters.append(
|
||||||
|
VideoFrameInterpolationFilter(
|
||||||
|
model=interpolation_choice,
|
||||||
|
slowmo=interpolation_model["interpolation_slowmo"],
|
||||||
|
fps=interpolation_model["interpolation_frame_rate"],
|
||||||
|
duplicate=interpolation_model["interpolation_duplicate"],
|
||||||
|
duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
initial_res = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
|
||||||
|
response_model=CreateVideoResponse,
|
||||||
|
data=CreateVideoRequest(
|
||||||
|
source=CreateVideoRequestSource(
|
||||||
|
container="mp4",
|
||||||
|
size=get_fs_object_size(src_video_stream),
|
||||||
|
duration=int(duration_sec),
|
||||||
|
frameCount=video.get_frame_count(),
|
||||||
|
frameRate=src_frame_rate,
|
||||||
|
resolution=Resolution(width=src_width, height=src_height),
|
||||||
|
),
|
||||||
|
filters=filters,
|
||||||
|
output=OutputInformationVideo(
|
||||||
|
resolution=Resolution(width=target_width, height=target_height),
|
||||||
|
frameRate=target_frame_rate,
|
||||||
|
audioCodec="AAC",
|
||||||
|
audioTransfer="Copy",
|
||||||
|
dynamicCompressionLevel=dynamic_compression_level,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
wait_label="Creating task",
|
||||||
|
final_label_on_success="Task created",
|
||||||
|
)
|
||||||
|
upload_res = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(
|
||||||
|
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
|
||||||
|
method="PATCH",
|
||||||
|
),
|
||||||
|
response_model=VideoAcceptResponse,
|
||||||
|
wait_label="Preparing upload",
|
||||||
|
final_label_on_success="Upload started",
|
||||||
|
)
|
||||||
|
if len(upload_res.urls) > 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
|
||||||
|
)
|
||||||
|
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
|
||||||
|
if isinstance(src_video_stream, BytesIO):
|
||||||
|
src_video_stream.seek(0)
|
||||||
|
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
|
||||||
|
upload_etag = res.headers["Etag"]
|
||||||
|
else:
|
||||||
|
with builtins.open(src_video_stream, "rb") as video_file:
|
||||||
|
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
|
||||||
|
upload_etag = res.headers["Etag"]
|
||||||
|
await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(
|
||||||
|
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
|
||||||
|
method="PATCH",
|
||||||
|
),
|
||||||
|
response_model=VideoCompleteUploadResponse,
|
||||||
|
data=VideoCompleteUploadRequest(
|
||||||
|
uploadResults=[
|
||||||
|
VideoCompleteUploadRequestPart(
|
||||||
|
partNum=1,
|
||||||
|
eTag=upload_etag,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
wait_label="Finalizing upload",
|
||||||
|
final_label_on_success="Upload completed",
|
||||||
|
)
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
status_extractor=lambda x: x.status,
|
||||||
|
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||||
|
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
|
||||||
|
poll_interval=10.0,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
||||||
|
|
||||||
|
|
||||||
class TopazExtension(ComfyExtension):
|
class TopazExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
TopazImageEnhance,
|
TopazImageEnhance,
|
||||||
TopazVideoEnhance,
|
TopazVideoEnhance,
|
||||||
|
TopazVideoEnhanceV2,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||||
batch_size = min(len(image), len(alpha))
|
batch_size = max(len(image), len(alpha))
|
||||||
out_images = []
|
|
||||||
|
|
||||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||||
for i in range(batch_size):
|
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
|
||||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
image = comfy.utils.repeat_to_batch_size(image, batch_size)
|
||||||
|
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
|
||||||
return io.NodeOutput(torch.stack(out_images))
|
|
||||||
|
|
||||||
|
|
||||||
class CompositingExtension(ComfyExtension):
|
class CompositingExtension(ComfyExtension):
|
||||||
|
|||||||
@ -666,12 +666,13 @@ class ColorTransfer(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ColorTransfer",
|
node_id="ColorTransfer",
|
||||||
|
display_name="Color Transfer",
|
||||||
category="image/postprocessing",
|
category="image/postprocessing",
|
||||||
description="Match the colors of one image to another using various algorithms.",
|
description="Match the colors of one image to another using various algorithms.",
|
||||||
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
|
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
|
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
|
||||||
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
|
io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."),
|
||||||
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
|
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
|
||||||
io.DynamicCombo.Input("source_stats",
|
io.DynamicCombo.Input("source_stats",
|
||||||
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
|
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class Int(io.ComfyNode):
|
|||||||
display_name="Int",
|
display_name="Int",
|
||||||
category="utils/primitive",
|
category="utils/primitive",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
|
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
|
||||||
],
|
],
|
||||||
outputs=[io.Int.Output()],
|
outputs=[io.Int.Output()],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode):
|
|||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
||||||
io.Image.Input("image", optional=True),
|
io.Image.Input("image", optional=True),
|
||||||
|
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
|
||||||
|
io.Audio.Input("audio", optional=True),
|
||||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||||
@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
|
||||||
|
|
||||||
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
|
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio)
|
||||||
|
|
||||||
# Get sampling parameters from dynamic combo
|
# Get sampling parameters from dynamic combo
|
||||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||||
@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode):
|
|||||||
seed=seed
|
seed=seed
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
|
generated_text = clip.decode(generated_ids)
|
||||||
|
|
||||||
return io.NodeOutput(generated_text)
|
return io.NodeOutput(generated_text)
|
||||||
|
|
||||||
|
|
||||||
@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
|
||||||
if image is None:
|
if image is None:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
else:
|
else:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
|
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio)
|
||||||
|
|
||||||
|
|
||||||
class TextgenExtension(ComfyExtension):
|
class TextgenExtension(ComfyExtension):
|
||||||
|
|||||||
@ -28,7 +28,7 @@
|
|||||||
#config for a1111 ui
|
#config for a1111 ui
|
||||||
#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed
|
#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed
|
||||||
|
|
||||||
#a111:
|
#a1111:
|
||||||
# base_path: path/to/stable-diffusion-webui/
|
# base_path: path/to/stable-diffusion-webui/
|
||||||
# checkpoints: models/Stable-diffusion
|
# checkpoints: models/Stable-diffusion
|
||||||
# configs: models/Stable-diffusion
|
# configs: models/Stable-diffusion
|
||||||
|
|||||||
@ -86,6 +86,6 @@ def image_alpha_fix(destination, source):
|
|||||||
if destination.shape[-1] < source.shape[-1]:
|
if destination.shape[-1] < source.shape[-1]:
|
||||||
source = source[...,:destination.shape[-1]]
|
source = source[...,:destination.shape[-1]]
|
||||||
elif destination.shape[-1] > source.shape[-1]:
|
elif destination.shape[-1] > source.shape[-1]:
|
||||||
destination = torch.nn.functional.pad(destination, (0, 1))
|
source = torch.nn.functional.pad(source, (0, 1))
|
||||||
destination[..., -1] = 1.0
|
source[..., -1] = 1.0
|
||||||
return destination, source
|
return destination, source
|
||||||
|
|||||||
66
nodes.py
66
nodes.py
@ -1754,57 +1754,49 @@ class LoadImage:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class LoadImageMask:
|
|
||||||
|
class LoadImageMask(LoadImage):
|
||||||
ESSENTIALS_CATEGORY = "Image Tools"
|
ESSENTIALS_CATEGORY = "Image Tools"
|
||||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||||
|
|
||||||
_color_channels = ["alpha", "red", "green", "blue"]
|
_color_channels = ["alpha", "red", "green", "blue"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
types = super().INPUT_TYPES()
|
||||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
return {
|
||||||
return {"required":
|
"required": {
|
||||||
{"image": (sorted(files), {"image_upload": True}),
|
**types["required"],
|
||||||
"channel": (s._color_channels, ), }
|
"channel": (s._color_channels, )
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
CATEGORY = "mask"
|
CATEGORY = "mask"
|
||||||
|
|
||||||
RETURN_TYPES = ("MASK",)
|
RETURN_TYPES = ("MASK",)
|
||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image_mask"
|
||||||
def load_image(self, image, channel):
|
|
||||||
image_path = folder_paths.get_annotated_filepath(image)
|
def load_image_mask(self, image, channel):
|
||||||
i = node_helpers.pillow(Image.open, image_path)
|
image_tensor, mask_tensor = super().load_image(image)
|
||||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
|
||||||
if i.getbands() != ("R", "G", "B", "A"):
|
|
||||||
if i.mode == 'I':
|
|
||||||
i = i.point(lambda i: i * (1 / 255))
|
|
||||||
i = i.convert("RGBA")
|
|
||||||
mask = None
|
|
||||||
c = channel[0].upper()
|
c = channel[0].upper()
|
||||||
if c in i.getbands():
|
|
||||||
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
|
if c == 'A':
|
||||||
mask = torch.from_numpy(mask)
|
return (mask_tensor,)
|
||||||
if c == 'A':
|
|
||||||
mask = 1. - mask
|
channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0)
|
||||||
|
|
||||||
|
if channel_idx < image_tensor.shape[-1]:
|
||||||
|
return (image_tensor[..., channel_idx].clone(),)
|
||||||
else:
|
else:
|
||||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
empty_mask = torch.zeros(
|
||||||
return (mask.unsqueeze(0),)
|
image_tensor.shape[:-1],
|
||||||
|
dtype=image_tensor.dtype,
|
||||||
|
device=image_tensor.device
|
||||||
|
)
|
||||||
|
return (empty_mask,)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, image, channel):
|
def IS_CHANGED(s, image, channel):
|
||||||
image_path = folder_paths.get_annotated_filepath(image)
|
return super().IS_CHANGED(image)
|
||||||
m = hashlib.sha256()
|
|
||||||
with open(image_path, 'rb') as f:
|
|
||||||
m.update(f.read())
|
|
||||||
return m.digest().hex()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def VALIDATE_INPUTS(s, image):
|
|
||||||
if not folder_paths.exists_annotated_filepath(image):
|
|
||||||
return "Invalid image file: {}".format(image)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class LoadImageOutput(LoadImage):
|
class LoadImageOutput(LoadImage):
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import errno
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -1245,7 +1246,13 @@ class PromptServer():
|
|||||||
address = addr[0]
|
address = addr[0]
|
||||||
port = addr[1]
|
port = addr[1]
|
||||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||||
await site.start()
|
try:
|
||||||
|
await site.start()
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno == errno.EADDRINUSE:
|
||||||
|
logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.")
|
||||||
|
raise SystemExit(1)
|
||||||
|
raise
|
||||||
|
|
||||||
if not hasattr(self, 'address'):
|
if not hasattr(self, 'address'):
|
||||||
self.address = address #TODO: remove this
|
self.address = address #TODO: remove this
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user