mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-25 08:27:25 +08:00
Compare commits
13 Commits
417aa3948d
...
4afb3cbd0f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4afb3cbd0f | ||
|
|
025e6792ee | ||
|
|
867b8d2408 | ||
|
|
d0f0b15cf5 | ||
|
|
b5bb83c964 | ||
|
|
f6d5068ac0 | ||
|
|
be95871adc | ||
|
|
f756d801a1 | ||
|
|
1d23a875ed | ||
|
|
ef6722f6be | ||
|
|
783782d5d7 | ||
|
|
bf7257448e | ||
|
|
6f2e815adf |
13
README.md
13
README.md
@ -1,7 +1,7 @@
|
||||
<div align="center">
|
||||
|
||||
# ComfyUI
|
||||
**The most powerful and modular visual AI engine and application.**
|
||||
**The most powerful and modular AI engine for content creation.**
|
||||
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
@ -31,10 +31,16 @@
|
||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
|
||||

|
||||
<img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/36e065e0-bfae-4456-8c7f-8369d5ea48a2" />
|
||||
<br>
|
||||
</div>
|
||||
|
||||
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
|
||||
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
|
||||
- ComfyUI natively supports the latest open-source state of the art models.
|
||||
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
|
||||
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
|
||||
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
|
||||
- It integrates seamlessly into production pipelines with our API endpoints.
|
||||
|
||||
## Get Started
|
||||
|
||||
@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
|
||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||
- Ernie Image
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
|
||||
@ -647,22 +647,29 @@ def upsert_reference(
|
||||
if created:
|
||||
return True, False
|
||||
|
||||
update_conditions = [
|
||||
AssetReference.asset_id != asset_id,
|
||||
AssetReference.mtime_ns.is_(None),
|
||||
AssetReference.mtime_ns != int(mtime_ns),
|
||||
AssetReference.is_missing == True, # noqa: E712
|
||||
AssetReference.deleted_at.isnot(None),
|
||||
]
|
||||
update_values = {
|
||||
"asset_id": asset_id,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
"is_missing": False,
|
||||
"deleted_at": None,
|
||||
"updated_at": now,
|
||||
}
|
||||
if owner_id:
|
||||
update_conditions.append(AssetReference.owner_id != owner_id)
|
||||
update_values["owner_id"] = owner_id
|
||||
|
||||
upd = (
|
||||
sa.update(AssetReference)
|
||||
.where(AssetReference.file_path == file_path)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetReference.asset_id != asset_id,
|
||||
AssetReference.mtime_ns.is_(None),
|
||||
AssetReference.mtime_ns != int(mtime_ns),
|
||||
AssetReference.is_missing == True, # noqa: E712
|
||||
AssetReference.deleted_at.isnot(None),
|
||||
)
|
||||
)
|
||||
.values(
|
||||
asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False,
|
||||
deleted_at=None, updated_at=now,
|
||||
)
|
||||
.where(sa.or_(*update_conditions))
|
||||
.values(**update_values)
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
updated = int(res2.rowcount or 0) > 0
|
||||
|
||||
@ -3,6 +3,7 @@ from app.assets.services.asset_management import (
|
||||
delete_asset_reference,
|
||||
get_asset_by_hash,
|
||||
get_asset_detail,
|
||||
is_file_visible_to_owner,
|
||||
list_assets_page,
|
||||
resolve_asset_for_download,
|
||||
set_asset_preview,
|
||||
@ -23,6 +24,7 @@ from app.assets.services.ingest import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
create_from_hash,
|
||||
collect_output_absolute_paths,
|
||||
ingest_existing_file,
|
||||
register_output_files,
|
||||
upload_from_temp_path,
|
||||
@ -71,10 +73,12 @@ __all__ = [
|
||||
"asset_exists",
|
||||
"batch_insert_seed_assets",
|
||||
"create_from_hash",
|
||||
"collect_output_absolute_paths",
|
||||
"delete_asset_reference",
|
||||
"get_asset_by_hash",
|
||||
"get_asset_detail",
|
||||
"ingest_existing_file",
|
||||
"is_file_visible_to_owner",
|
||||
"register_output_files",
|
||||
"get_mtime_ns",
|
||||
"get_size_and_mtime_ns",
|
||||
|
||||
@ -13,6 +13,7 @@ from app.assets.database.queries import (
|
||||
soft_delete_reference_by_id,
|
||||
fetch_reference_asset_and_tags,
|
||||
get_asset_by_hash as queries_get_asset_by_hash,
|
||||
get_reference_by_file_path,
|
||||
get_reference_by_id,
|
||||
get_reference_with_owner_check,
|
||||
list_references_page,
|
||||
@ -321,6 +322,22 @@ def resolve_hash_to_path(
|
||||
)
|
||||
|
||||
|
||||
def is_file_visible_to_owner(
|
||||
abs_path: str,
|
||||
owner_id: str = "",
|
||||
) -> bool:
|
||||
"""Return whether a file-backed asset reference is visible to owner_id."""
|
||||
locator = os.path.abspath(abs_path)
|
||||
owner_id = (owner_id or "").strip()
|
||||
with create_session() as session:
|
||||
ref = get_reference_by_file_path(session, locator)
|
||||
if not ref:
|
||||
return os.path.isfile(locator)
|
||||
if ref.deleted_at is not None:
|
||||
return False
|
||||
return ref.owner_id == "" or ref.owner_id == owner_id
|
||||
|
||||
|
||||
def resolve_asset_for_download(
|
||||
reference_id: str,
|
||||
owner_id: str = "",
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any, Sequence
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import folder_paths
|
||||
import app.assets.services.hashing as hashing
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
@ -138,6 +139,7 @@ def register_output_files(
|
||||
file_paths: Sequence[str],
|
||||
user_metadata: UserMetadata = None,
|
||||
job_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> int:
|
||||
"""Register a batch of output file paths as assets.
|
||||
|
||||
@ -149,7 +151,7 @@ def register_output_files(
|
||||
continue
|
||||
try:
|
||||
if ingest_existing_file(
|
||||
abs_path, user_metadata=user_metadata, job_id=job_id
|
||||
abs_path, user_metadata=user_metadata, job_id=job_id, owner_id=owner_id
|
||||
):
|
||||
registered += 1
|
||||
except Exception:
|
||||
@ -157,6 +159,51 @@ def register_output_files(
|
||||
return registered
|
||||
|
||||
|
||||
def collect_output_absolute_paths(output_data: dict) -> list[str]:
|
||||
"""Extract absolute output/temp paths from a node UI output or history result."""
|
||||
if not isinstance(output_data, dict):
|
||||
return []
|
||||
|
||||
if isinstance(output_data.get("outputs"), dict):
|
||||
node_outputs = output_data["outputs"].values()
|
||||
else:
|
||||
node_outputs = [output_data]
|
||||
|
||||
paths: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for node_output in node_outputs:
|
||||
if not isinstance(node_output, dict):
|
||||
continue
|
||||
for items in node_output.values():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type not in ("output", "temp"):
|
||||
continue
|
||||
base_dir = folder_paths.get_directory_by_type(item_type)
|
||||
if base_dir is None:
|
||||
continue
|
||||
base_dir = os.path.abspath(base_dir)
|
||||
filename = item.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
abs_path = os.path.abspath(
|
||||
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||
)
|
||||
try:
|
||||
if os.path.commonpath((base_dir, abs_path)) != base_dir:
|
||||
continue
|
||||
except ValueError:
|
||||
continue
|
||||
if abs_path not in seen:
|
||||
seen.add(abs_path)
|
||||
paths.append(abs_path)
|
||||
return paths
|
||||
|
||||
|
||||
def ingest_existing_file(
|
||||
abs_path: str,
|
||||
user_metadata: UserMetadata = None,
|
||||
@ -184,6 +231,8 @@ def ingest_existing_file(
|
||||
existing_ref = get_reference_by_file_path(session, locator)
|
||||
if existing_ref is not None:
|
||||
now = get_utc_now()
|
||||
if owner_id and existing_ref.owner_id != owner_id:
|
||||
existing_ref.owner_id = owner_id
|
||||
existing_ref.mtime_ns = mtime_ns
|
||||
existing_ref.job_id = job_id
|
||||
existing_ref.is_missing = False
|
||||
|
||||
@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_prefetch
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
|
||||
|
||||
# Process transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_prompt_timestep=a_prompt_timestep,
|
||||
)
|
||||
|
||||
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
|
||||
@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
|
||||
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
|
||||
n_rep = q.shape[-3] // k.shape[-3]
|
||||
k = k.repeat_interleave(n_rep, dim=-3)
|
||||
v = v.repeat_interleave(n_rep, dim=-3)
|
||||
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
|
||||
if "scale" in kwargs:
|
||||
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
|
||||
query = query * (kwargs["scale"] * dim_head ** 0.5)
|
||||
|
||||
if skip_reshape:
|
||||
query = query.reshape(b * heads, -1, dim_head)
|
||||
value = value.reshape(b * heads, -1, dim_head)
|
||||
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
|
||||
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
|
||||
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
|
||||
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
k[i : i + SDP_BATCH_LIMIT],
|
||||
v[i : i + SDP_BATCH_LIMIT],
|
||||
attn_mask=m,
|
||||
dropout_p=0.0, is_causal=False
|
||||
dropout_p=0.0, is_causal=False, **sdpa_extra
|
||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_base
|
||||
@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
def prefetch_prepared_value(value, allocate_buffer, stream):
|
||||
if isinstance(value, torch.Tensor):
|
||||
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
||||
elif isinstance(value, list):
|
||||
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
@ -214,6 +214,11 @@ class BaseModel(torch.nn.Module):
|
||||
if "latent_shapes" in extra_conds:
|
||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["prefetch_dynamic_vbars"] = (
|
||||
self.current_patcher is not None and self.current_patcher.is_dynamic()
|
||||
)
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||
model_output, _ = utils.pack_latents(model_output)
|
||||
|
||||
@ -31,6 +31,7 @@ from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
import comfy_aimdo.vram_buffer
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -1175,6 +1176,10 @@ stream_counters = {}
|
||||
|
||||
STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
|
||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||
|
||||
def get_cast_buffer(offload_stream, device, size, ref):
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
@ -1208,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def get_aimdo_cast_buffer(offload_stream, device):
|
||||
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
|
||||
if cast_buffer is None:
|
||||
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
return cast_buffer
|
||||
def reset_cast_buffers():
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||
if offload_stream is not None:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
|
||||
@ -121,9 +121,20 @@ class LowVramPatch:
|
||||
self.patches = patches
|
||||
self.convert_func = convert_func # TODO: remove
|
||||
self.set_func = set_func
|
||||
self.prepared_patches = None
|
||||
|
||||
def prepare(self, allocate_buffer, stream):
|
||||
self.prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
||||
for patch in self.patches[self.key]
|
||||
]
|
||||
|
||||
def clear_prepared(self):
|
||||
self.prepared_patches = None
|
||||
|
||||
def __call__(self, weight):
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
|
||||
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
|
||||
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||
|
||||
|
||||
65
comfy/model_prefetch.py
Normal file
65
comfy/model_prefetch.py
Normal file
@ -0,0 +1,65 @@
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
|
||||
PREFETCH_QUEUES = []
|
||||
|
||||
def cleanup_prefetched_modules(comfy_modules):
|
||||
for s in comfy_modules:
|
||||
prefetch = getattr(s, "_prefetch", None)
|
||||
if prefetch is None:
|
||||
continue
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_fn.clear_prepared()
|
||||
if prefetch["signature"] is not None:
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
delattr(s, "_prefetch")
|
||||
|
||||
def cleanup_prefetch_queues():
|
||||
global PREFETCH_QUEUES
|
||||
|
||||
for queue in PREFETCH_QUEUES:
|
||||
for entry in queue:
|
||||
if entry is None or not isinstance(entry, tuple):
|
||||
continue
|
||||
_, prefetch_state = entry
|
||||
comfy_modules = prefetch_state[1]
|
||||
if comfy_modules is not None:
|
||||
cleanup_prefetched_modules(comfy_modules)
|
||||
PREFETCH_QUEUES = []
|
||||
|
||||
def prefetch_queue_pop(queue, device, module):
|
||||
if queue is None:
|
||||
return
|
||||
|
||||
consumed = queue.pop(0)
|
||||
if consumed is not None:
|
||||
offload_stream, prefetch_state = consumed
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
_, comfy_modules = prefetch_state
|
||||
if comfy_modules is not None:
|
||||
cleanup_prefetched_modules(comfy_modules)
|
||||
|
||||
prefetch = queue[0]
|
||||
if prefetch is not None:
|
||||
comfy_modules = []
|
||||
for s in prefetch.modules():
|
||||
if hasattr(s, "_v"):
|
||||
comfy_modules.append(s)
|
||||
|
||||
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
||||
|
||||
def make_prefetch_queue(queue, device, transformer_options):
|
||||
if (not transformer_options.get("prefetch_dynamic_vbars", False)
|
||||
or comfy.model_management.NUM_STREAMS == 0
|
||||
or comfy.model_management.is_device_cpu(device)
|
||||
or not comfy.model_management.device_supports_non_blocking(device)):
|
||||
return None
|
||||
|
||||
queue = [None] + queue + [None]
|
||||
PREFETCH_QUEUES.append(queue)
|
||||
return queue
|
||||
268
comfy/ops.py
268
comfy/ops.py
@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
|
||||
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
||||
return weight, bias, (None, None, None)
|
||||
|
||||
# FIXME: add n=1 cache hit fast path
|
||||
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
cast_buffer = None
|
||||
cast_buffer_offset = 0
|
||||
|
||||
def ensure_offload_stream(module, required_size, check_largest):
|
||||
nonlocal offload_stream
|
||||
nonlocal cast_buffer
|
||||
|
||||
if offload_stream is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
|
||||
return
|
||||
|
||||
current_size = 0 if cast_buffer is None else cast_buffer.size()
|
||||
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
cast_buffer = None
|
||||
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
|
||||
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
|
||||
|
||||
def get_cast_buffer(buffer_size):
|
||||
nonlocal offload_stream
|
||||
nonlocal cast_buffer
|
||||
nonlocal cast_buffer_offset
|
||||
|
||||
if buffer_size == 0:
|
||||
return None
|
||||
|
||||
if offload_stream is None:
|
||||
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
|
||||
|
||||
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
||||
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
|
||||
cast_buffer_offset += buffer_size
|
||||
return buffer
|
||||
|
||||
for s in comfy_modules:
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
prefetch = {
|
||||
"signature": signature,
|
||||
"resident": resident,
|
||||
}
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
if signature is not None:
|
||||
if resident:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
s._prefetch = prefetch
|
||||
continue
|
||||
|
||||
if not resident:
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
needs_cast = False
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
|
||||
@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
if data is None:
|
||||
continue
|
||||
if data.dtype != geometry.dtype:
|
||||
needs_cast = True
|
||||
cast_dest = xfer_dest
|
||||
if cast_dest is None:
|
||||
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
||||
xfer_dest = None
|
||||
break
|
||||
|
||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if xfer_dest is None and offload_stream is not None:
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
|
||||
if xfer_dest is None:
|
||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
||||
offload_stream = None
|
||||
xfer_dest = get_cast_buffer(dest_size)
|
||||
|
||||
if signature is None and pin is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
if cast_dest is not None:
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
prefetch["cast_geometry"] = cast_geometry
|
||||
prefetch["needs_cast"] = needs_cast
|
||||
s._prefetch = prefetch
|
||||
|
||||
return offload_stream
|
||||
|
||||
|
||||
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
|
||||
|
||||
prefetch = getattr(s, "_prefetch", None)
|
||||
|
||||
if prefetch["resident"]:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = prefetch["xfer_dest"]
|
||||
if prefetch["needs_cast"]:
|
||||
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
|
||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
|
||||
if post_cast is not None:
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
if signature is not None:
|
||||
if prefetch["signature"] is not None:
|
||||
s._v_weight = weight
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
s._v_signature = prefetch["signature"]
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
fns = getattr(s, param_key + "_function", [])
|
||||
|
||||
if x is None:
|
||||
return None
|
||||
|
||||
orig = x
|
||||
|
||||
def to_dequant(tensor, dtype):
|
||||
@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
update_weight = signature is not None
|
||||
update_weight = prefetch["signature"] is not None
|
||||
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
|
||||
if bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
|
||||
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
return weight, bias
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||
@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
def format_return(result, offloadable):
|
||||
weight, bias, offload_stream = result
|
||||
return (weight, bias, offload_stream) if offloadable else (weight, bias)
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||
|
||||
prefetched = hasattr(s, "_prefetch")
|
||||
offload_stream = None
|
||||
offload_device = None
|
||||
if not prefetched:
|
||||
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
|
||||
|
||||
if not prefetched:
|
||||
if getattr(s, "_prefetch")["signature"] is not None:
|
||||
offload_device = device
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_fn.clear_prepared()
|
||||
delattr(s, "_prefetch")
|
||||
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
|
||||
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
if offloadable:
|
||||
return weight, bias, (offload_stream, weight_a, bias_a)
|
||||
else:
|
||||
#Legacy function signature
|
||||
return weight, bias
|
||||
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
|
||||
|
||||
|
||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
@ -1173,6 +1246,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
|
||||
class Embedding(manual_cast.Embedding):
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
weight_key = f"{prefix}weight"
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
||||
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
||||
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
|
||||
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
||||
self.quant_format = quant_format
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
|
||||
weight = state_dict.pop(weight_key)
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(scale_key, None)
|
||||
if scale is not None:
|
||||
scale = scale.float()
|
||||
manually_loaded_keys.append(scale_key)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.num_embeddings, self.embedding_dim),
|
||||
)
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
||||
requires_grad=False)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for k in manually_loaded_keys:
|
||||
if k in missing_keys:
|
||||
missing_keys.remove(k)
|
||||
else:
|
||||
if layer_conf is not None:
|
||||
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight') or self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
|
||||
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||
weight = self.weight
|
||||
|
||||
# Optimized path: lookup in fp8, dequantize only the selected rows.
|
||||
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
|
||||
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
|
||||
if isinstance(qdata, QuantizedTensor):
|
||||
scale = qdata._params.scale
|
||||
qdata = qdata._qdata
|
||||
else:
|
||||
scale = None
|
||||
|
||||
x = torch.nn.functional.embedding(
|
||||
input, qdata, self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
uncast_bias_weight(self, qdata, None, offload_stream)
|
||||
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
|
||||
x = x.to(dtype=target_dtype)
|
||||
if scale is not None and scale != 1.0:
|
||||
x = x * scale.to(dtype=target_dtype)
|
||||
return x
|
||||
|
||||
# Fallback for non-quantized or weight_function (LoRA) case
|
||||
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
|
||||
|
||||
return MixedPrecisionOps
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
|
||||
@ -3,6 +3,7 @@ import comfy.model_management
|
||||
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
|
||||
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if weight is None:
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||
|
||||
17
comfy/sd.py
17
comfy/sd.py
@ -65,6 +65,7 @@ import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.gemma4
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -1271,6 +1272,9 @@ class TEModel(Enum):
|
||||
QWEN35_9B = 26
|
||||
QWEN35_27B = 27
|
||||
MINISTRAL_3_3B = 28
|
||||
GEMMA_4_E4B = 29
|
||||
GEMMA_4_E2B = 30
|
||||
GEMMA_4_31B = 31
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1296,6 +1300,12 @@ def detect_te_model(sd):
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
if 'model.layers.59.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_4_31B
|
||||
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
|
||||
return TEModel.GEMMA_4_E4B
|
||||
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
|
||||
return TEModel.GEMMA_4_E2B
|
||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_12B
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
@ -1435,6 +1445,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
|
||||
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
|
||||
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
|
||||
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
|
||||
clip_target.tokenizer = variant.tokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
|
||||
1298
comfy/text_encoders/gemma4.py
Normal file
1298
comfy/text_encoders/gemma4.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -521,7 +521,7 @@ class Attention(nn.Module):
|
||||
else:
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
if sliding_window is not None and xk.shape[2] > sliding_window:
|
||||
if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
|
||||
xk = xk[:, :, -sliding_window:]
|
||||
xv = xv[:, :, -sliding_window:]
|
||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||
@ -533,12 +533,12 @@ class Attention(nn.Module):
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
|
||||
super().__init__()
|
||||
ops = ops or nn
|
||||
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
intermediate_size = intermediate_size or config.intermediate_size
|
||||
self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
if config.mlp_activation == "silu":
|
||||
self.activation = torch.nn.functional.silu
|
||||
elif config.mlp_activation == "gelu_pytorch_tanh":
|
||||
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
|
||||
|
||||
return x, present_key_value
|
||||
|
||||
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
|
||||
class ScaledEmbedding(ops.Embedding):
|
||||
def forward(self, input_ids, out_dtype=None):
|
||||
return super().forward(input_ids, out_dtype=out_dtype) * scale
|
||||
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = ops.Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
|
||||
transformer = TransformerBlockGemma2
|
||||
self.normalize_in = True
|
||||
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
|
||||
else:
|
||||
transformer = TransformerBlock
|
||||
self.normalize_in = False
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
|
||||
self.config.rope_dims,
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
@ -850,7 +848,7 @@ class BaseGenerate:
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
@ -875,14 +873,16 @@ class BaseGenerate:
|
||||
pbar = comfy.utils.ProgressBar(max_length)
|
||||
|
||||
# Generation loop
|
||||
current_input_ids = initial_input_ids
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||
current_input_ids = next_token if initial_input_ids is not None else None
|
||||
pbar.update(1)
|
||||
|
||||
if token_id in stop_tokens:
|
||||
|
||||
@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||
|
||||
class DualLinearProjection(torch.nn.Module):
|
||||
|
||||
@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
embeds, _, _, _ = super().process_tokens(tokens, device)
|
||||
return embeds
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
|
||||
@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.normalize_in = False
|
||||
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
|
||||
"""Normalize image embeddings to match text embedding scale"""
|
||||
for info in embeds_info:
|
||||
if info.get("type") == "image":
|
||||
start_idx = info["index"]
|
||||
end_idx = start_idx + info["size"]
|
||||
embeds[:, start_idx:end_idx, :] /= scale_factor
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel):
|
||||
grain: Optional[float] = Field(None, description="Grain after AI model processing")
|
||||
grainSize: Optional[float] = Field(None, description="Size of generated grain")
|
||||
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
|
||||
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
|
||||
creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.")
|
||||
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
|
||||
prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)")
|
||||
sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness")
|
||||
realism: float | None = Field(None, description="ast-2 realism control")
|
||||
|
||||
|
||||
class OutputInformationVideo(BaseModel):
|
||||
@ -90,7 +93,7 @@ class Overrides(BaseModel):
|
||||
|
||||
class CreateVideoRequest(BaseModel):
|
||||
source: CreateVideoRequestSource = Field(...)
|
||||
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
|
||||
filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...)
|
||||
output: OutputInformationVideo = Field(...)
|
||||
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
|
||||
|
||||
|
||||
@ -36,11 +36,15 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
UPSCALER_MODELS_MAP = {
|
||||
"Astra 2": "ast-2",
|
||||
"Starlight (Astra) Fast": "slf-1",
|
||||
"Starlight (Astra) Creative": "slc-1",
|
||||
"Starlight Precise 2.5": "slp-2.5",
|
||||
}
|
||||
|
||||
AST2_MAX_FRAMES = 9000
|
||||
AST2_MAX_FRAMES_WITH_PROMPT = 450
|
||||
|
||||
|
||||
class TopazImageEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhance",
|
||||
display_name="Topaz Video Enhance",
|
||||
display_name="Topaz Video Enhance (Legacy)",
|
||||
category="api node/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Boolean.Input("upscaler_enabled", default=True),
|
||||
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
||||
IO.Combo.Input(
|
||||
"upscaler_model",
|
||||
options=[
|
||||
"Starlight (Astra) Fast",
|
||||
"Starlight (Astra) Creative",
|
||||
"Starlight Precise 2.5",
|
||||
],
|
||||
),
|
||||
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||
IO.Combo.Input(
|
||||
"upscaler_creativity",
|
||||
@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -457,12 +469,357 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
||||
|
||||
|
||||
class TopazVideoEnhanceV2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhanceV2",
|
||||
display_name="Topaz Video Enhance",
|
||||
category="api node/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.DynamicCombo.Input(
|
||||
"upscaler_model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Astra 2",
|
||||
[
|
||||
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||
IO.Float.Input(
|
||||
"creativity",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Creative strength of the upscale.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Optional descriptive (not instructive) scene prompt."
|
||||
f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"sharp",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Pre-enhance sharpness: "
|
||||
"0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"realism",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Pulls output toward photographic realism."
|
||||
"Leave at 0 for the model default.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Starlight (Astra) Fast",
|
||||
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Starlight (Astra) Creative",
|
||||
[
|
||||
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||
IO.Combo.Input(
|
||||
"creativity",
|
||||
options=["low", "middle", "high"],
|
||||
default="low",
|
||||
tooltip="Creative strength of the upscale.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Starlight Precise 2.5",
|
||||
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])],
|
||||
),
|
||||
IO.DynamicCombo.Option("Disabled", []),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"interpolation_model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Disabled", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"apo-8",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"interpolation_frame_rate",
|
||||
default=60,
|
||||
min=15,
|
||||
max=240,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Output frame rate.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"interpolation_slowmo",
|
||||
default=1,
|
||||
min=1,
|
||||
max=16,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Slow-motion factor applied to the input video. "
|
||||
"For example, 2 makes the output twice as slow and doubles the duration.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"interpolation_duplicate",
|
||||
default=False,
|
||||
tooltip="Analyze the input for duplicate frames and remove them.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"interpolation_duplicate_threshold",
|
||||
default=0.01,
|
||||
min=0.001,
|
||||
max=0.1,
|
||||
step=0.001,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Detection sensitivity for duplicate frames.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"dynamic_compression_level",
|
||||
options=["Low", "Mid", "High"],
|
||||
default="Low",
|
||||
tooltip="CQP level.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=[
|
||||
"upscaler_model",
|
||||
"upscaler_model.upscaler_resolution",
|
||||
"interpolation_model",
|
||||
]),
|
||||
expr="""
|
||||
(
|
||||
$model := $lookup(widgets, "upscaler_model");
|
||||
$res := $lookup(widgets, "upscaler_model.upscaler_resolution");
|
||||
$interp := $lookup(widgets, "interpolation_model");
|
||||
$is4k := $contains($res, "4k");
|
||||
$hasInterp := $interp != "disabled";
|
||||
$rates := {
|
||||
"starlight (astra) fast": {"hd": 0.43, "uhd": 0.85},
|
||||
"starlight precise 2.5": {"hd": 0.70, "uhd": 1.54},
|
||||
"astra 2": {"hd": 1.72, "uhd": 2.85},
|
||||
"starlight (astra) creative": {"hd": 2.25, "uhd": 3.99}
|
||||
};
|
||||
$surcharge := $is4k ? 0.28 : 0.14;
|
||||
$entry := $lookup($rates, $model);
|
||||
$base := $is4k ? $entry.uhd : $entry.hd;
|
||||
$hi := $base + ($hasInterp ? $surcharge : 0);
|
||||
$model = "disabled"
|
||||
? {"type":"text","text":"Interpolation only"}
|
||||
: ($hasInterp
|
||||
? {"type":"text","text":"~" & $string($base) & "–" & $string($hi) & " credits/src frame"}
|
||||
: {"type":"text","text":"~" & $string($base) & " credits/src frame"})
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
upscaler_model: dict,
|
||||
interpolation_model: dict,
|
||||
dynamic_compression_level: str = "Low",
|
||||
) -> IO.NodeOutput:
|
||||
upscaler_choice = upscaler_model["upscaler_model"]
|
||||
interpolation_choice = interpolation_model["interpolation_model"]
|
||||
if upscaler_choice == "Disabled" and interpolation_choice == "Disabled":
|
||||
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
|
||||
validate_container_format_is_mp4(video)
|
||||
src_width, src_height = video.get_dimensions()
|
||||
src_frame_rate = int(video.get_frame_rate())
|
||||
duration_sec = video.get_duration()
|
||||
src_video_stream = video.get_stream_source()
|
||||
target_width = src_width
|
||||
target_height = src_height
|
||||
target_frame_rate = src_frame_rate
|
||||
filters = []
|
||||
if upscaler_choice != "Disabled":
|
||||
if "1080p" in upscaler_model["upscaler_resolution"]:
|
||||
target_pixel_p = 1080
|
||||
max_long_side = 1920
|
||||
else:
|
||||
target_pixel_p = 2160
|
||||
max_long_side = 3840
|
||||
ar = src_width / src_height
|
||||
if src_width >= src_height:
|
||||
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
|
||||
target_height = target_pixel_p
|
||||
target_width = int(target_height * ar)
|
||||
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
|
||||
if target_width > max_long_side:
|
||||
target_width = max_long_side
|
||||
target_height = int(target_width / ar)
|
||||
else:
|
||||
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
|
||||
target_width = target_pixel_p
|
||||
target_height = int(target_width / ar)
|
||||
# Check if height exceeds standard bounds
|
||||
if target_height > max_long_side:
|
||||
target_height = max_long_side
|
||||
target_width = int(target_height * ar)
|
||||
if target_width % 2 != 0:
|
||||
target_width += 1
|
||||
if target_height % 2 != 0:
|
||||
target_height += 1
|
||||
model_id = UPSCALER_MODELS_MAP[upscaler_choice]
|
||||
if model_id == "slc-1":
|
||||
filters.append(
|
||||
VideoEnhancementFilter(
|
||||
model=model_id,
|
||||
creativity=upscaler_model["creativity"],
|
||||
isOptimizedMode=True,
|
||||
)
|
||||
)
|
||||
elif model_id == "ast-2":
|
||||
n_frames = video.get_frame_count()
|
||||
ast2_prompt = (upscaler_model["prompt"] or "").strip()
|
||||
if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT:
|
||||
raise ValueError(
|
||||
f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames "
|
||||
f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip."
|
||||
)
|
||||
if n_frames > AST2_MAX_FRAMES:
|
||||
raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.")
|
||||
realism = upscaler_model["realism"]
|
||||
filters.append(
|
||||
VideoEnhancementFilter(
|
||||
model=model_id,
|
||||
creativity=upscaler_model["creativity"],
|
||||
prompt=(ast2_prompt or None),
|
||||
sharp=upscaler_model["sharp"],
|
||||
realism=(realism if realism > 0 else None),
|
||||
)
|
||||
)
|
||||
else:
|
||||
filters.append(VideoEnhancementFilter(model=model_id))
|
||||
if interpolation_choice != "Disabled":
|
||||
target_frame_rate = interpolation_model["interpolation_frame_rate"]
|
||||
filters.append(
|
||||
VideoFrameInterpolationFilter(
|
||||
model=interpolation_choice,
|
||||
slowmo=interpolation_model["interpolation_slowmo"],
|
||||
fps=interpolation_model["interpolation_frame_rate"],
|
||||
duplicate=interpolation_model["interpolation_duplicate"],
|
||||
duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"],
|
||||
),
|
||||
)
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
|
||||
response_model=CreateVideoResponse,
|
||||
data=CreateVideoRequest(
|
||||
source=CreateVideoRequestSource(
|
||||
container="mp4",
|
||||
size=get_fs_object_size(src_video_stream),
|
||||
duration=int(duration_sec),
|
||||
frameCount=video.get_frame_count(),
|
||||
frameRate=src_frame_rate,
|
||||
resolution=Resolution(width=src_width, height=src_height),
|
||||
),
|
||||
filters=filters,
|
||||
output=OutputInformationVideo(
|
||||
resolution=Resolution(width=target_width, height=target_height),
|
||||
frameRate=target_frame_rate,
|
||||
audioCodec="AAC",
|
||||
audioTransfer="Copy",
|
||||
dynamicCompressionLevel=dynamic_compression_level,
|
||||
),
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
upload_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=VideoAcceptResponse,
|
||||
wait_label="Preparing upload",
|
||||
final_label_on_success="Upload started",
|
||||
)
|
||||
if len(upload_res.urls) > 1:
|
||||
raise NotImplementedError(
|
||||
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
|
||||
)
|
||||
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
|
||||
if isinstance(src_video_stream, BytesIO):
|
||||
src_video_stream.seek(0)
|
||||
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
else:
|
||||
with builtins.open(src_video_stream, "rb") as video_file:
|
||||
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=VideoCompleteUploadResponse,
|
||||
data=VideoCompleteUploadRequest(
|
||||
uploadResults=[
|
||||
VideoCompleteUploadRequestPart(
|
||||
partNum=1,
|
||||
eTag=upload_etag,
|
||||
),
|
||||
],
|
||||
),
|
||||
wait_label="Finalizing upload",
|
||||
final_label_on_success="Upload completed",
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
|
||||
response_model=VideoStatusResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
|
||||
poll_interval=10.0,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
||||
|
||||
|
||||
class TopazExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TopazImageEnhance,
|
||||
TopazVideoEnhance,
|
||||
TopazVideoEnhanceV2,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||
batch_size = min(len(image), len(alpha))
|
||||
out_images = []
|
||||
|
||||
batch_size = max(len(image), len(alpha))
|
||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||
for i in range(batch_size):
|
||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||
|
||||
return io.NodeOutput(torch.stack(out_images))
|
||||
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
|
||||
image = comfy.utils.repeat_to_batch_size(image, batch_size)
|
||||
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
|
||||
|
||||
|
||||
class CompositingExtension(ComfyExtension):
|
||||
|
||||
@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode):
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Image.Input("image", optional=True),
|
||||
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
|
||||
io.Audio.Input("audio", optional=True),
|
||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||
@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
|
||||
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio)
|
||||
|
||||
# Get sampling parameters from dynamic combo
|
||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||
@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode):
|
||||
seed=seed
|
||||
)
|
||||
|
||||
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
|
||||
generated_text = clip.decode(generated_ids)
|
||||
|
||||
return io.NodeOutput(generated_text)
|
||||
|
||||
|
||||
@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
|
||||
if image is None:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
else:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio)
|
||||
|
||||
|
||||
class TextgenExtension(ComfyExtension):
|
||||
|
||||
10
execution.py
10
execution.py
@ -15,6 +15,7 @@ import torch
|
||||
from comfy.cli_args import args
|
||||
import comfy.memory_management
|
||||
import comfy.model_management
|
||||
import comfy.model_prefetch
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
from latent_preview import set_preview_method
|
||||
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if args.verbose == "DEBUG":
|
||||
comfy_aimdo.control.analyze()
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
comfy.model_prefetch.cleanup_prefetch_queues()
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
|
||||
if has_pending_tasks:
|
||||
@ -549,6 +551,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
register_output_assets = getattr(server, "register_output_assets", None)
|
||||
if register_output_assets is not None:
|
||||
user_id_key = getattr(server, "INTERNAL_USER_ID_KEY", "_comfy_user_id")
|
||||
register_output_assets(
|
||||
output_ui,
|
||||
prompt_id,
|
||||
extra_data.get(user_id_key, ""),
|
||||
)
|
||||
ui_outputs[unique_id] = {
|
||||
"meta": {
|
||||
"node_id": unique_id,
|
||||
|
||||
39
main.py
39
main.py
@ -12,7 +12,7 @@ from app.logger import setup_logger
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.services import register_output_files
|
||||
from app.assets.services import collect_output_absolute_paths, register_output_files
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
from utils.mime_types import init_mime_types
|
||||
@ -241,38 +241,6 @@ def cuda_malloc_warning():
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||
"""Extract absolute file paths for output items from a history result."""
|
||||
paths: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for node_output in history_result.get("outputs", {}).values():
|
||||
for items in node_output.values():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type not in ("output", "temp"):
|
||||
continue
|
||||
base_dir = folder_paths.get_directory_by_type(item_type)
|
||||
if base_dir is None:
|
||||
continue
|
||||
base_dir = os.path.abspath(base_dir)
|
||||
filename = item.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
abs_path = os.path.abspath(
|
||||
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||
)
|
||||
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
|
||||
continue
|
||||
if abs_path not in seen:
|
||||
seen.add(abs_path)
|
||||
paths.append(abs_path)
|
||||
return paths
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_ram = args.cache_ram
|
||||
@ -335,8 +303,9 @@ def prompt_worker(q, server_instance):
|
||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
paths = _collect_output_absolute_paths(e.history_result)
|
||||
register_output_files(paths, job_id=prompt_id)
|
||||
paths = collect_output_absolute_paths(e.history_result)
|
||||
owner_id = extra_data.get(server.INTERNAL_USER_ID_KEY, "")
|
||||
register_output_files(paths, job_id=prompt_id, owner_id=owner_id)
|
||||
|
||||
flags = q.get_flags()
|
||||
free_memory = flags.get("free_memory", False)
|
||||
|
||||
@ -86,6 +86,6 @@ def image_alpha_fix(destination, source):
|
||||
if destination.shape[-1] < source.shape[-1]:
|
||||
source = source[...,:destination.shape[-1]]
|
||||
elif destination.shape[-1] > source.shape[-1]:
|
||||
destination = torch.nn.functional.pad(destination, (0, 1))
|
||||
destination[..., -1] = 1.0
|
||||
source = torch.nn.functional.pad(source, (0, 1))
|
||||
source[..., -1] = 1.0
|
||||
return destination, source
|
||||
|
||||
29
nodes.py
29
nodes.py
@ -1694,26 +1694,27 @@ class LoadImage:
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "load_image"
|
||||
|
||||
def load_image(self, image):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
device = comfy.model_management.intermediate_device()
|
||||
|
||||
components = InputImpl.VideoFromFile(image_path).get_components()
|
||||
if components.images.shape[0] > 0:
|
||||
return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu"))
|
||||
return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device))
|
||||
|
||||
# This code is left here to handle animated webp which pyav does not support loading
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
output_images = []
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
image = i.convert("RGB")
|
||||
|
||||
if len(output_images) == 0:
|
||||
@ -1728,25 +1729,15 @@ class LoadImage:
|
||||
if 'A' in i.getbands():
|
||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
elif i.mode == 'P' and 'transparency' in i.info:
|
||||
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image.to(dtype=dtype))
|
||||
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
||||
|
||||
if img.format == "MPO":
|
||||
break # ignore all frames except the first one for MPO format
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
|
||||
if len(output_images) > 1:
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
else:
|
||||
output_image = output_images[0]
|
||||
output_mask = output_masks[0]
|
||||
|
||||
return (output_image, output_mask)
|
||||
return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype))
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.15
|
||||
comfyui-workflow-templates==0.9.66
|
||||
comfyui-workflow-templates==0.9.68
|
||||
comfyui-embedded-docs==0.4.4
|
||||
torch
|
||||
torchsde
|
||||
|
||||
206
server.py
206
server.py
@ -1,3 +1,4 @@
|
||||
import errno
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
@ -35,8 +36,15 @@ from app.frontend_management import FrontendManager, parse_version
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.api.routes import register_assets_routes
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
from app.assets.services.asset_management import resolve_hash_to_path
|
||||
from app.assets.services.ingest import (
|
||||
collect_output_absolute_paths,
|
||||
register_file_in_place,
|
||||
register_output_files,
|
||||
)
|
||||
from app.assets.services.asset_management import (
|
||||
is_file_visible_to_owner,
|
||||
resolve_hash_to_path,
|
||||
)
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@ -53,10 +61,83 @@ from middleware.cache_middleware import cache_control
|
||||
if args.enable_manager:
|
||||
import comfyui_manager
|
||||
|
||||
INTERNAL_USER_ID_KEY = "_comfy_user_id"
|
||||
|
||||
|
||||
def _remove_sensitive_from_queue(queue: list) -> list:
|
||||
"""Remove sensitive data (index 5) from queue item tuples."""
|
||||
return [item[:5] for item in queue]
|
||||
return [_scrub_prompt_tuple(item[:5]) for item in queue]
|
||||
|
||||
|
||||
def _scrub_prompt_tuple(prompt_tuple):
|
||||
"""Remove internal-only prompt metadata before returning queue data."""
|
||||
if not isinstance(prompt_tuple, (list, tuple)):
|
||||
return prompt_tuple
|
||||
if len(prompt_tuple) <= 3 or not isinstance(prompt_tuple[3], dict):
|
||||
return prompt_tuple
|
||||
out = list(prompt_tuple)
|
||||
extra_data = dict(out[3])
|
||||
extra_data.pop(INTERNAL_USER_ID_KEY, None)
|
||||
out[3] = extra_data
|
||||
return tuple(out) if isinstance(prompt_tuple, tuple) else out
|
||||
|
||||
|
||||
def _scrub_history_for_response(history: dict) -> dict:
|
||||
"""Remove internal-only prompt metadata from history responses."""
|
||||
out = {}
|
||||
for prompt_id, item in history.items():
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
clean_item = dict(item)
|
||||
if "prompt" in clean_item:
|
||||
clean_item["prompt"] = _scrub_prompt_tuple(clean_item["prompt"])
|
||||
out[prompt_id] = clean_item
|
||||
return out
|
||||
|
||||
|
||||
def _prompt_tuple_owner_id(prompt_tuple) -> str | None:
|
||||
"""Return the stored owner id, or None for legacy prompts without one."""
|
||||
try:
|
||||
extra_data = prompt_tuple[3]
|
||||
except Exception:
|
||||
return "default"
|
||||
if not isinstance(extra_data, dict):
|
||||
return "default"
|
||||
if INTERNAL_USER_ID_KEY not in extra_data:
|
||||
return None
|
||||
return str(extra_data.get(INTERNAL_USER_ID_KEY) or "default")
|
||||
|
||||
|
||||
def _prompt_tuple_visible_to_user(prompt_tuple, owner_id: str) -> bool:
|
||||
"""Return whether a prompt tuple is visible to the requesting user."""
|
||||
prompt_owner_id = _prompt_tuple_owner_id(prompt_tuple)
|
||||
return prompt_owner_id is None or prompt_owner_id == str(owner_id or "default")
|
||||
|
||||
|
||||
def _filter_queue_for_user(queue: list, owner_id: str) -> list:
|
||||
"""Filter queue entries to those visible to the requesting user."""
|
||||
return [item for item in queue if _prompt_tuple_visible_to_user(item, owner_id)]
|
||||
|
||||
|
||||
def _filter_history_for_user(history: dict, owner_id: str) -> dict:
|
||||
"""Filter history entries to those visible to the requesting user."""
|
||||
return {
|
||||
prompt_id: item
|
||||
for prompt_id, item in history.items()
|
||||
if isinstance(item, dict)
|
||||
and _prompt_tuple_visible_to_user(item.get("prompt"), owner_id)
|
||||
}
|
||||
|
||||
|
||||
def _slice_history(history: dict, max_items: int | None, offset: int) -> dict:
|
||||
"""Return a stable paginated slice of a history mapping."""
|
||||
items = list(history.items())
|
||||
if offset < 0 and max_items is not None:
|
||||
offset = len(items) - max_items
|
||||
offset = max(offset, 0)
|
||||
if max_items is None:
|
||||
return dict(items[offset:])
|
||||
return dict(items[offset:offset + max_items])
|
||||
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
@ -381,7 +462,7 @@ class PromptServer():
|
||||
return a.hexdigest() == b.hexdigest()
|
||||
return False
|
||||
|
||||
def image_upload(post, image_save_function=None):
|
||||
def image_upload(post, image_save_function=None, owner_id=""):
|
||||
image = post.get("image")
|
||||
overwrite = post.get("overwrite")
|
||||
image_is_duplicate = False
|
||||
@ -430,7 +511,12 @@ class PromptServer():
|
||||
if args.enable_assets:
|
||||
try:
|
||||
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
|
||||
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
|
||||
result = register_file_in_place(
|
||||
abs_path=filepath,
|
||||
name=filename,
|
||||
tags=[tag],
|
||||
owner_id=owner_id,
|
||||
)
|
||||
resp["asset"] = {
|
||||
"id": result.ref.id,
|
||||
"name": result.ref.name,
|
||||
@ -449,12 +535,20 @@ class PromptServer():
|
||||
@routes.post("/upload/image")
|
||||
async def upload_image(request):
|
||||
post = await request.post()
|
||||
return image_upload(post)
|
||||
try:
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
except KeyError:
|
||||
return web.Response(status=403)
|
||||
return image_upload(post, owner_id=owner_id)
|
||||
|
||||
|
||||
@routes.post("/upload/mask")
|
||||
async def upload_mask(request):
|
||||
post = await request.post()
|
||||
try:
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
except KeyError:
|
||||
return web.Response(status=403)
|
||||
|
||||
def image_save_function(image, post, filepath):
|
||||
original_ref = json.loads(post.get("original_ref"))
|
||||
@ -496,7 +590,7 @@ class PromptServer():
|
||||
original_pil.putalpha(new_alpha)
|
||||
original_pil.save(filepath, compress_level=4, pnginfo=metadata)
|
||||
|
||||
return image_upload(post, image_save_function)
|
||||
return image_upload(post, image_save_function, owner_id=owner_id)
|
||||
|
||||
@routes.get("/view")
|
||||
async def view_image(request):
|
||||
@ -541,6 +635,14 @@ class PromptServer():
|
||||
file = os.path.join(output_dir, filename)
|
||||
|
||||
if os.path.isfile(file):
|
||||
if args.enable_assets and not asset_seeder.is_disabled():
|
||||
try:
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
except KeyError:
|
||||
return web.Response(status=403)
|
||||
if not is_file_visible_to_owner(file, owner_id=owner_id):
|
||||
return web.Response(status=403)
|
||||
|
||||
if 'preview' in request.rel_url.query:
|
||||
with Image.open(file) as img:
|
||||
preview_info = request.rel_url.query['preview'].split(';')
|
||||
@ -832,11 +934,20 @@ class PromptServer():
|
||||
status=400
|
||||
)
|
||||
|
||||
try:
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
except KeyError:
|
||||
return web.Response(status=403)
|
||||
|
||||
running, queued = self.prompt_queue.get_current_queue_volatile()
|
||||
history = self.prompt_queue.get_history()
|
||||
|
||||
running = _filter_queue_for_user(running, owner_id)
|
||||
queued = _filter_queue_for_user(queued, owner_id)
|
||||
history = _filter_history_for_user(history, owner_id)
|
||||
running = _remove_sensitive_from_queue(running)
|
||||
queued = _remove_sensitive_from_queue(queued)
|
||||
history = _scrub_history_for_response(history)
|
||||
|
||||
jobs, total = get_all_jobs(
|
||||
running, queued, history,
|
||||
@ -870,11 +981,20 @@ class PromptServer():
|
||||
status=400
|
||||
)
|
||||
|
||||
try:
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
except KeyError:
|
||||
return web.Response(status=403)
|
||||
|
||||
running, queued = self.prompt_queue.get_current_queue_volatile()
|
||||
history = self.prompt_queue.get_history(prompt_id=job_id)
|
||||
|
||||
running = _filter_queue_for_user(running, owner_id)
|
||||
queued = _filter_queue_for_user(queued, owner_id)
|
||||
history = _filter_history_for_user(history, owner_id)
|
||||
running = _remove_sensitive_from_queue(running)
|
||||
queued = _remove_sensitive_from_queue(queued)
|
||||
history = _scrub_history_for_response(history)
|
||||
|
||||
job = get_job(job_id, running, queued, history)
|
||||
if job is None:
|
||||
@ -897,24 +1017,55 @@ class PromptServer():
|
||||
else:
|
||||
offset = -1
|
||||
|
||||
return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset))
|
||||
try:
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
except KeyError:
|
||||
return web.Response(status=403)
|
||||
|
||||
history = self.prompt_queue.get_history()
|
||||
history = _filter_history_for_user(history, owner_id)
|
||||
history = _slice_history(history, max_items=max_items, offset=offset)
|
||||
history = _scrub_history_for_response(history)
|
||||
return web.json_response(history)
|
||||
|
||||
@routes.get("/history/{prompt_id}")
|
||||
async def get_history_prompt_id(request):
|
||||
prompt_id = request.match_info.get("prompt_id", None)
|
||||
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
|
||||
try:
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
except KeyError:
|
||||
return web.Response(status=403)
|
||||
|
||||
history = self.prompt_queue.get_history(prompt_id=prompt_id)
|
||||
history = _filter_history_for_user(history, owner_id)
|
||||
history = _scrub_history_for_response(history)
|
||||
return web.json_response(history)
|
||||
|
||||
@routes.get("/queue")
|
||||
async def get_queue(request):
|
||||
try:
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
except KeyError:
|
||||
return web.Response(status=403)
|
||||
|
||||
queue_info = {}
|
||||
current_queue = self.prompt_queue.get_current_queue_volatile()
|
||||
queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0])
|
||||
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
|
||||
queue_info['queue_running'] = _remove_sensitive_from_queue(
|
||||
_filter_queue_for_user(current_queue[0], owner_id)
|
||||
)
|
||||
queue_info['queue_pending'] = _remove_sensitive_from_queue(
|
||||
_filter_queue_for_user(current_queue[1], owner_id)
|
||||
)
|
||||
return web.json_response(queue_info)
|
||||
|
||||
@routes.post("/prompt")
|
||||
async def post_prompt(request):
|
||||
logging.info("got prompt")
|
||||
try:
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
except KeyError:
|
||||
return web.Response(status=403)
|
||||
|
||||
json_data = await request.json()
|
||||
json_data = self.trigger_on_prompt(json_data)
|
||||
|
||||
@ -945,6 +1096,7 @@ class PromptServer():
|
||||
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
extra_data[INTERNAL_USER_ID_KEY] = owner_id
|
||||
if valid[0]:
|
||||
outputs_to_execute = valid[2]
|
||||
sensitive = {}
|
||||
@ -1026,17 +1178,37 @@ class PromptServer():
|
||||
|
||||
@routes.post("/history")
|
||||
async def post_history(request):
|
||||
try:
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
except KeyError:
|
||||
return web.Response(status=403)
|
||||
|
||||
json_data = await request.json()
|
||||
if "clear" in json_data:
|
||||
if json_data["clear"]:
|
||||
self.prompt_queue.wipe_history()
|
||||
history = self.prompt_queue.get_history()
|
||||
history = _filter_history_for_user(history, owner_id)
|
||||
for history_id in history.keys():
|
||||
self.prompt_queue.delete_history_item(history_id)
|
||||
if "delete" in json_data:
|
||||
to_delete = json_data['delete']
|
||||
for id_to_delete in to_delete:
|
||||
self.prompt_queue.delete_history_item(id_to_delete)
|
||||
history = self.prompt_queue.get_history(prompt_id=id_to_delete)
|
||||
if _filter_history_for_user(history, owner_id):
|
||||
self.prompt_queue.delete_history_item(id_to_delete)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
def register_output_assets(self, output_ui, prompt_id: str, owner_id: str):
|
||||
if not args.enable_assets or asset_seeder.is_disabled():
|
||||
return
|
||||
try:
|
||||
paths = collect_output_absolute_paths(output_ui)
|
||||
if paths:
|
||||
register_output_files(paths, job_id=prompt_id, owner_id=owner_id)
|
||||
except Exception:
|
||||
logging.warning("Failed to register node output assets", exc_info=True)
|
||||
|
||||
async def setup(self):
|
||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||
@ -1245,7 +1417,13 @@ class PromptServer():
|
||||
address = addr[0]
|
||||
port = addr[1]
|
||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||
await site.start()
|
||||
try:
|
||||
await site.start()
|
||||
except OSError as e:
|
||||
if e.errno == errno.EADDRINUSE:
|
||||
logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.")
|
||||
raise SystemExit(1)
|
||||
raise
|
||||
|
||||
if not hasattr(self, 'address'):
|
||||
self.address = address #TODO: remove this
|
||||
|
||||
Loading…
Reference in New Issue
Block a user