From a4522017c518d1f0c3c5d2a803a2d31265da5cd4 Mon Sep 17 00:00:00 2001 From: Tavi Halperin Date: Thu, 26 Feb 2026 08:25:23 +0200 Subject: [PATCH 01/81] feat: per-guide attention strength control in self-attention (#12518) Implements per-guide attention attenuation via log-space additive bias in self-attention. Each guide reference tracks its own strength and optional spatial mask in conditioning metadata (guide_attention_entries). --- comfy/ldm/lightricks/av_model.py | 9 +- comfy/ldm/lightricks/model.py | 264 ++++++++++++++++++++++++++++++- comfy/model_base.py | 44 ++++++ comfy_extras/nodes_lt.py | 47 +++++- 4 files changed, 352 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 2b080aaeb..553fd5b38 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -218,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module): def forward( self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, - v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, + v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -234,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module): vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2))) norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa del vshift_msa, vscale_msa - attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) + attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options) del norm_vx # video cross-attention vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] @@ -726,7 +726,7 @@ class LTXAVModel(LTXVModel): return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] def _process_transformer_blocks( - self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs + self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs ): vx = x[0] ax = x[1] @@ -770,6 +770,7 @@ class LTXAVModel(LTXVModel): v_cross_gate_timestep=args["v_cross_gate_timestep"], a_cross_gate_timestep=args["a_cross_gate_timestep"], transformer_options=args["transformer_options"], + self_attention_mask=args.get("self_attention_mask"), ) return out @@ -790,6 +791,7 @@ class LTXAVModel(LTXVModel): "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "transformer_options": transformer_options, + "self_attention_mask": self_attention_mask, }, {"original_block": block_wrap}, ) @@ -811,6 +813,7 @@ class LTXAVModel(LTXVModel): v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, transformer_options=transformer_options, + self_attention_mask=self_attention_mask, ) return [vx, ax] diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index d61e19d6e..60d760d29 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from enum import Enum import functools +import logging import math from typing import Dict, Optional, Tuple @@ -14,6 +15,8 @@ import comfy.ldm.common_dit from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords +logger = logging.getLogger(__name__) + def _log_base(x, base): return np.log(x) / np.log(base) @@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) attn1_input = comfy.ldm.common_dit.rms_norm(x) attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) - attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) + attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) x.addcmul_(attn1_input, gate_msa) del attn1_input @@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC): """Process input data. Must be implemented by subclasses.""" pass + def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): + """Build self-attention mask for per-guide attention attenuation. + + Base implementation returns None (no attenuation). Subclasses that + support guide-based attention control should override this. + """ + return None + @abstractmethod - def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs): """Process transformer blocks. Must be implemented by subclasses.""" pass @@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC): attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) + # Build self-attention mask for per-guide attenuation + self_attention_mask = self._build_guide_self_attention_mask( + x, transformer_options, merged_args + ) + # Process transformer blocks x = self._process_transformer_blocks( - x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args + x, context, attention_mask, timestep, pe, + transformer_options=transformer_options, + self_attention_mask=self_attention_mask, + **merged_args, ) # Process output @@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel): pixel_coords = pixel_coords[:, :, grid_mask, ...] kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] + + # Compute per-guide surviving token counts from guide_attention_entries. + # Each entry tracks one guide reference; they are appended in order and + # their pre_filter_counts partition the kf_grid_mask. + guide_entries = kwargs.get("guide_attention_entries", None) + if guide_entries: + total_pfc = sum(e["pre_filter_count"] for e in guide_entries) + if total_pfc != len(kf_grid_mask): + raise ValueError( + f"guide pre_filter_counts ({total_pfc}) != " + f"keyframe grid mask length ({len(kf_grid_mask)})" + ) + resolved_entries = [] + offset = 0 + for entry in guide_entries: + pfc = entry["pre_filter_count"] + entry_mask = kf_grid_mask[offset:offset + pfc] + surviving = int(entry_mask.sum().item()) + resolved_entries.append({ + **entry, + "surviving_count": surviving, + }) + offset += pfc + additional_args["resolved_guide_entries"] = resolved_entries + keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs + # Total surviving guide tokens (all guides) + additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] + x = self.patchify_proj(x) return x, pixel_coords, additional_args - def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): + def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): + """Build self-attention mask for per-guide attention attenuation. + + Reads resolved_guide_entries from merged_args (computed in _process_input) + to build a log-space additive bias mask that attenuates noisy ↔ guide + attention for each guide reference independently. + + Returns None if no attenuation is needed (all strengths == 1.0 and no + spatial masks, or no guide tokens). + """ + if isinstance(x, list): + # AV model: x = [vx, ax]; use vx for token count and device + total_tokens = x[0].shape[1] + device = x[0].device + dtype = x[0].dtype + else: + total_tokens = x.shape[1] + device = x.device + dtype = x.dtype + + num_guide_tokens = merged_args.get("num_guide_tokens", 0) + if num_guide_tokens == 0: + return None + + resolved_entries = merged_args.get("resolved_guide_entries", None) + if not resolved_entries: + return None + + # Check if any attenuation is actually needed + needs_attenuation = any( + e["strength"] < 1.0 or e.get("pixel_mask") is not None + for e in resolved_entries + ) + if not needs_attenuation: + return None + + # Build per-guide-token weights for all tracked guide tokens. + # Guides are appended in order at the end of the sequence. + guide_start = total_tokens - num_guide_tokens + all_weights = [] + total_tracked = 0 + + for entry in resolved_entries: + surviving = entry["surviving_count"] + if surviving == 0: + continue + + strength = entry["strength"] + pixel_mask = entry.get("pixel_mask") + latent_shape = entry.get("latent_shape") + + if pixel_mask is not None and latent_shape is not None: + f_lat, h_lat, w_lat = latent_shape + per_token = self._downsample_mask_to_latent( + pixel_mask.to(device=device, dtype=dtype), + f_lat, h_lat, w_lat, + ) + # per_token shape: (B, f_lat*h_lat*w_lat). + # Collapse batch dim — the mask is assumed identical across the + # batch; validate and take the first element to get (1, tokens). + if per_token.shape[0] > 1: + ref = per_token[0] + for bi in range(1, per_token.shape[0]): + if not torch.equal(ref, per_token[bi]): + logger.warning( + "pixel_mask differs across batch elements; " + "using first element only." + ) + break + per_token = per_token[:1] + # `surviving` is the post-grid_mask token count. + # Clamp to surviving to handle any mismatch safely. + n_weights = min(per_token.shape[1], surviving) + weights = per_token[:, :n_weights] * strength # (1, n_weights) + else: + weights = torch.full( + (1, surviving), strength, device=device, dtype=dtype + ) + + all_weights.append(weights) + total_tracked += weights.shape[1] + + if not all_weights: + return None + + # Concatenate per-token weights for all tracked guides + tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) + + # Check if any weight is actually < 1.0 (otherwise no attenuation needed) + if (tracked_weights >= 1.0).all(): + return None + + # Build the mask: guide tokens are at the end of the sequence. + # Tracked guides come first (in order), untracked follow. + return self._build_self_attention_mask( + total_tokens, num_guide_tokens, total_tracked, + tracked_weights, guide_start, device, dtype, + ) + + @staticmethod + def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat): + """Downsample a pixel-space mask to per-token latent weights. + + Args: + mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1]. + f_lat: Number of latent frames (pre-dilation original count). + h_lat: Latent height (pre-dilation original height). + w_lat: Latent width (pre-dilation original width). + + Returns: + (B, F_lat * H_lat * W_lat) flattened per-token weights. + """ + b = mask.shape[0] + f_pix = mask.shape[2] + + # Spatial downsampling: area interpolation per frame + spatial_down = torch.nn.functional.interpolate( + rearrange(mask, "b 1 f h w -> (b f) 1 h w"), + size=(h_lat, w_lat), + mode="area", + ) + spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b) + + # Temporal downsampling: first pixel frame maps to first latent frame, + # remaining pixel frames are averaged in groups for causal temporal structure. + first_frame = spatial_down[:, :, :1, :, :] + if f_pix > 1 and f_lat > 1: + remaining_pix = f_pix - 1 + remaining_lat = f_lat - 1 + t = remaining_pix // remaining_lat + if t < 1: + # Fewer pixel frames than latent frames — upsample by repeating + # the available pixel frames via nearest interpolation. + rest_flat = rearrange( + spatial_down[:, :, 1:, :, :], + "b 1 f h w -> (b h w) 1 f", + ) + rest_up = torch.nn.functional.interpolate( + rest_flat, size=remaining_lat, mode="nearest", + ) + rest = rearrange( + rest_up, "(b h w) 1 f -> b 1 f h w", + b=b, h=h_lat, w=w_lat, + ) + else: + # Trim trailing pixel frames that don't fill a complete group + usable = remaining_lat * t + rest = rearrange( + spatial_down[:, :, 1:1 + usable, :, :], + "b 1 (f t) h w -> b 1 f t h w", + t=t, + ) + rest = rest.mean(dim=3) + latent_mask = torch.cat([first_frame, rest], dim=2) + elif f_lat > 1: + # Single pixel frame but multiple latent frames — repeat the + # single frame across all latent frames. + latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1) + else: + latent_mask = first_frame + + return rearrange(latent_mask, "b 1 f h w -> b (f h w)") + + @staticmethod + def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count, + tracked_weights, guide_start, device, dtype): + """Build a log-space additive self-attention bias mask. + + Attenuates attention between noisy tokens and tracked guide tokens. + Untracked guide tokens (at the end of the guide portion) keep full attention. + + Args: + total_tokens: Total sequence length. + num_guide_tokens: Total guide tokens (all guides) at end of sequence. + tracked_count: Number of tracked guide tokens (first in the guide portion). + tracked_weights: (1, tracked_count) tensor, values in [0, 1]. + guide_start: Index where guide tokens begin in the sequence. + device: Target device. + dtype: Target dtype. + + Returns: + (1, 1, total_tokens, total_tokens) additive bias mask. + 0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked. + """ + finfo = torch.finfo(dtype) + mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype) + tracked_end = guide_start + tracked_count + + # Convert weights to log-space bias + w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count) + log_w = torch.full_like(w, finfo.min) + positive_mask = w > 0 + if positive_mask.any(): + log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny)) + + # noisy → tracked guides: each noisy row gets the same per-guide weight + mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1) + # tracked guides → noisy: each guide row broadcasts its weight across noisy cols + mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1) + + return mask + + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel): def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask")) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel): timestep=timestep, pe=pe, transformer_options=transformer_options, + self_attention_mask=self_attention_mask, ) return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 4e2096d4b..04695c079 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -65,6 +65,42 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher + +class _CONDGuideEntries(comfy.conds.CONDConstant): + """CONDConstant subclass that safely compares guide_attention_entries. + + guide_attention_entries may contain ``pixel_mask`` tensors. The default + ``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError`` + on tensors. This subclass performs a structural comparison instead. + """ + + def can_concat(self, other): + if not isinstance(other, _CONDGuideEntries): + return False + a, b = self.cond, other.cond + if len(a) != len(b): + return False + for ea, eb in zip(a, b): + if ea["pre_filter_count"] != eb["pre_filter_count"]: + return False + if ea["strength"] != eb["strength"]: + return False + if ea.get("latent_shape") != eb.get("latent_shape"): + return False + a_has = ea.get("pixel_mask") is not None + b_has = eb.get("pixel_mask") is not None + if a_has != b_has: + return False + if a_has: + pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"] + if pm_a is not pm_b: + if (pm_a.shape != pm_b.shape + or pm_a.device != pm_b.device + or pm_a.dtype != pm_b.dtype + or not torch.equal(pm_a, pm_b)): + return False + return True + class ModelType(Enum): EPS = 1 V_PREDICTION = 2 @@ -974,6 +1010,10 @@ class LTXV(BaseModel): if keyframe_idxs is not None: out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) + return out def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): @@ -1026,6 +1066,10 @@ class LTXAV(BaseModel): if latent_shapes is not None: out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) + return out def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 1eeeec011..32fe921ff 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode): generate = execute # TODO: remove +def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0): + """Append a guide_attention_entry to both positive and negative conditioning. + + Each entry tracks one guide reference for per-reference attention control. + Entries are derived independently from each conditioning to avoid cross-contamination. + """ + new_entry = { + "pre_filter_count": pre_filter_count, + "strength": strength, + "pixel_mask": None, + "latent_shape": latent_shape, + } + results = [] + for cond in (positive, negative): + # Read existing entries from this specific conditioning + existing = [] + for t in cond: + found = t[1].get("guide_attention_entries", None) + if found is not None: + existing = found + break + # Shallow copy and append (no deepcopy needed — entries contain + # only scalars and None for pixel_mask at this call site). + entries = [*existing, new_entry] + results.append(node_helpers.conditioning_set_values( + cond, {"guide_attention_entries": entries} + )) + return results[0], results[1] + + def conditioning_get_any_value(conditioning, key, default=None): for t in conditioning: if key in t[1]: @@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode): scale_factors, ) + # Track this guide for per-reference attention control. + pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] + guide_latent_shape = list(t.shape[2:]) # [F, H, W] + positive, negative = _append_guide_attention_entry( + positive, negative, pre_filter_count, guide_latent_shape, strength=strength, + ) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) generate = execute # TODO: remove @@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode): latent_image = latent_image[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes] - positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) - negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) + positive = node_helpers.conditioning_set_values(positive, { + "keyframe_idxs": None, + "guide_attention_entries": None, + }) + negative = node_helpers.conditioning_set_values(negative, { + "keyframe_idxs": None, + "guide_attention_entries": None, + }) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) From 8a4d85c708435b47d0570637fdf1e89199702c48 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 25 Feb 2026 22:30:31 -0800 Subject: [PATCH 02/81] Cleanups to the last PR. (#12646) --- comfy/conds.py | 21 ++++++++++++++++++++- comfy/model_base.py | 40 ++-------------------------------------- 2 files changed, 22 insertions(+), 39 deletions(-) diff --git a/comfy/conds.py b/comfy/conds.py index 5af3e93ea..55d8cdd78 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -4,6 +4,25 @@ import comfy.utils import logging +def is_equal(x, y): + if torch.is_tensor(x) and torch.is_tensor(y): + return torch.equal(x, y) + elif isinstance(x, dict) and isinstance(y, dict): + if x.keys() != y.keys(): + return False + return all(is_equal(x[k], y[k]) for k in x) + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + if type(x) is not type(y) or len(x) != len(y): + return False + return all(is_equal(a, b) for a, b in zip(x, y)) + else: + try: + return x == y + except Exception: + logging.warning("comparison issue with COND") + return False + + class CONDRegular: def __init__(self, cond): self.cond = cond @@ -84,7 +103,7 @@ class CONDConstant(CONDRegular): return self._copy_with(self.cond) def can_concat(self, other): - if self.cond != other.cond: + if not is_equal(self.cond, other.cond): return False return True diff --git a/comfy/model_base.py b/comfy/model_base.py index 04695c079..8f852e3c6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -65,42 +65,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher - -class _CONDGuideEntries(comfy.conds.CONDConstant): - """CONDConstant subclass that safely compares guide_attention_entries. - - guide_attention_entries may contain ``pixel_mask`` tensors. The default - ``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError`` - on tensors. This subclass performs a structural comparison instead. - """ - - def can_concat(self, other): - if not isinstance(other, _CONDGuideEntries): - return False - a, b = self.cond, other.cond - if len(a) != len(b): - return False - for ea, eb in zip(a, b): - if ea["pre_filter_count"] != eb["pre_filter_count"]: - return False - if ea["strength"] != eb["strength"]: - return False - if ea.get("latent_shape") != eb.get("latent_shape"): - return False - a_has = ea.get("pixel_mask") is not None - b_has = eb.get("pixel_mask") is not None - if a_has != b_has: - return False - if a_has: - pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"] - if pm_a is not pm_b: - if (pm_a.shape != pm_b.shape - or pm_a.device != pm_b.device - or pm_a.dtype != pm_b.dtype - or not torch.equal(pm_a, pm_b)): - return False - return True - class ModelType(Enum): EPS = 1 V_PREDICTION = 2 @@ -1012,7 +976,7 @@ class LTXV(BaseModel): guide_attention_entries = kwargs.get("guide_attention_entries", None) if guide_attention_entries is not None: - out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) return out @@ -1068,7 +1032,7 @@ class LTXAV(BaseModel): guide_attention_entries = kwargs.get("guide_attention_entries", None) if guide_attention_entries is not None: - out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) return out From 74b5a337dcc4d6b276e4af45aa8a654c82569072 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Thu, 26 Feb 2026 01:00:32 -0800 Subject: [PATCH 03/81] fix: move essentials_category to correct replacement nodes (#12568) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move essentials_category from deprecated/incorrect nodes to their replacements: - ImageBatch → BatchImagesNode (ImageBatch is deprecated) - Blur → removed (should use subgraph blueprint) - GetVideoComponents → Video Slice Amp-Thread-ID: https://ampcode.com/threads/T-019c8340-4da2-723b-a09f-83895c5bbda5 --- comfy_extras/nodes_post_processing.py | 2 +- comfy_extras/nodes_video.py | 2 +- nodes.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index c67245e7a..4a0f7141a 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -79,7 +79,6 @@ class Blur(io.ComfyNode): node_id="ImageBlur", display_name="Image Blur", category="image/postprocessing", - essentials_category="Image Tools", inputs=[ io.Image.Input("image"), io.Int.Input("blur_radius", default=1, min=1, max=31, step=1), @@ -568,6 +567,7 @@ class BatchImagesNode(io.ComfyNode): node_id="BatchImagesNode", display_name="Batch Images", category="image", + essentials_category="Image Tools", search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], inputs=[ io.Autogrow.Input("images", template=autogrow_template) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index db7c171a2..5c096c232 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -147,7 +147,6 @@ class GetVideoComponents(io.ComfyNode): search_aliases=["extract frames", "split video", "video to images", "demux"], display_name="Get Video Components", category="image/video", - essentials_category="Video Tools", description="Extracts all components from a video: frames, audio, and framerate.", inputs=[ io.Video.Input("video", tooltip="The video to extract components from."), @@ -218,6 +217,7 @@ class VideoSlice(io.ComfyNode): "start time", ], category="image/video", + essentials_category="Video Tools", inputs=[ io.Video.Input("video"), io.Float.Input( diff --git a/nodes.py b/nodes.py index e2fc20d53..bff073e30 100644 --- a/nodes.py +++ b/nodes.py @@ -1925,7 +1925,6 @@ class ImageInvert: class ImageBatch: SEARCH_ALIASES = ["combine images", "merge images", "stack images"] - ESSENTIALS_CATEGORY = "Image Tools" @classmethod def INPUT_TYPES(s): From 38ca94599f7444f55589308d1cf611fb77f6ca16 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 26 Feb 2026 11:07:35 +0000 Subject: [PATCH 04/81] pyopengl-accelerate can cause object to be numpy ints instead of bare ints, which the glDeleteTextures function does not accept, explicitly cast to int (#12650) --- comfy_extras/nodes_glsl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_glsl.py b/comfy_extras/nodes_glsl.py index 75ffb6d80..6d210b307 100644 --- a/comfy_extras/nodes_glsl.py +++ b/comfy_extras/nodes_glsl.py @@ -717,11 +717,11 @@ def _render_shader_batch( gl.glUseProgram(0) for tex in input_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) for tex in output_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) for tex in ping_pong_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) if fbo is not None: gl.glDeleteFramebuffers(1, [fbo]) for pp_fbo in ping_pong_fbos: From 420e900f692f72a4e0108594a80a3465c036bebe Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 26 Feb 2026 12:19:38 -0800 Subject: [PATCH 05/81] main: load aimdo earlier (#12655) Some custom node packs are naughty, and violate the dont-load-torch-on-load rule. This causes aimdo to lose preference on its allocator hook on linux. Go super early on the aimdo first-stage init before custom nodes are mentioned at all. --- main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 39e605deb..3fe8f0589 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,10 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags +import comfy_aimdo.control + +if enables_dynamic_vram(): + comfy_aimdo.control.init() if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. @@ -173,10 +177,6 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") -import comfy_aimdo.control - -if enables_dynamic_vram(): - comfy_aimdo.control.init() import comfy.utils From fd41ec97cc2e457f322afdf136fd2f2b2454a240 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 26 Feb 2026 22:52:10 +0200 Subject: [PATCH 06/81] feat(api-nodes): add NanoBanana2 (#12660) --- comfy_api_nodes/apis/gemini.py | 6 + comfy_api_nodes/nodes_bytedance.py | 2 +- comfy_api_nodes/nodes_gemini.py | 203 +++++++++++++++++++++++++++-- 3 files changed, 196 insertions(+), 15 deletions(-) diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index 3304d7e76..639035fef 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -127,9 +127,15 @@ class GeminiImageConfig(BaseModel): imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions) +class GeminiThinkingConfig(BaseModel): + includeThoughts: bool | None = Field(None) + thinkingLevel: str = Field(...) + + class GeminiImageGenerationConfig(GeminiGenerationConfig): responseModalities: list[str] | None = Field(None) imageConfig: GeminiImageConfig | None = Field(None) + thinkingConfig: GeminiThinkingConfig | None = Field(None) class GeminiImageGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index cfc604aa3..6dbd5984e 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -186,7 +186,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ByteDanceSeedreamNode", - display_name="ByteDance Seedream 5.0", + display_name="ByteDance Seedream 4.5 & 5.0", category="api node/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index b69285be5..3fe804e0b 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -29,6 +29,7 @@ from comfy_api_nodes.apis.gemini import ( GeminiRole, GeminiSystemInstructionContent, GeminiTextPart, + GeminiThinkingConfig, Modality, ) from comfy_api_nodes.util import ( @@ -55,6 +56,21 @@ GEMINI_IMAGE_SYS_PROMPT = ( "Prioritize generating the visual representation above any text, formatting, or conversational requests." ) +GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution"]), + expr=""" + ( + $m := widgets.model; + $r := widgets.resolution; + $isFlash := $contains($m, "nano banana 2"); + $flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123}; + $proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24}; + $prices := $isFlash ? $flashPrices : $proPrices; + {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} + ) + """, +) + class GeminiModel(str, Enum): """ @@ -229,6 +245,10 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 120.0 + elif response.modelVersion == "gemini-3.1-flash-image-preview": + input_tokens_price = 0.5 + output_text_tokens_price = 3.0 + output_image_tokens_price = 60.0 else: return None final_price = response.usageMetadata.promptTokenCount * input_tokens_price @@ -686,7 +706,7 @@ class GeminiImage2(IO.ComfyNode): ), IO.Combo.Input( "model", - options=["gemini-3-pro-image-preview"], + options=["gemini-3-pro-image-preview", "Nano Banana 2 (Gemini 3.1 Flash Image)"], ), IO.Int.Input( "seed", @@ -750,19 +770,7 @@ class GeminiImage2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), - expr=""" - ( - $r := widgets.resolution; - ($contains($r,"1k") or $contains($r,"2k")) - ? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}} - : $contains($r,"4k") - ? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}} - : {"type":"text","text":"Token-based"} - ) - """, - ), + price_badge=GEMINI_IMAGE_2_PRICE_BADGE, ) @classmethod @@ -779,6 +787,10 @@ class GeminiImage2(IO.ComfyNode): system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) + if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": + model = "gemini-3.1-flash-image-preview" + if response_modalities == "IMAGE+TEXT": + raise ValueError("IMAGE+TEXT is not currently available for the Nano Banana 2 model.") parts: list[GeminiPart] = [GeminiPart(text=prompt)] if images is not None: @@ -815,6 +827,168 @@ class GeminiImage2(IO.ComfyNode): return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) +class GeminiNanoBanana2(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiNanoBanana2", + display_name="Nano Banana 2", + category="api node/image/Gemini", + description="Generate or edit images synchronously via Google Vertex API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt describing the image to generate or the edits to apply. " + "Include any constraints, styles, or details the model should follow.", + default="", + ), + IO.Combo.Input( + "model", + options=["Nano Banana 2 (Gemini 3.1 Flash Image)"], + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "4:5", + "5:4", + "9:16", + "16:9", + "21:9", + # "1:4", + # "4:1", + # "8:1", + # "1:8", + ], + default="auto", + tooltip="If set to 'auto', matches your input image's aspect ratio; " + "if no image is provided, a 16:9 square is usually generated.", + ), + IO.Combo.Input( + "resolution", + options=[ + # "512px", + "1K", + "2K", + "4K", + ], + tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE"], + advanced=True, + ), + IO.Combo.Input( + "thinking_level", + options=["MINIMAL", "HIGH"], + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional reference image(s). " + "To include multiple images, use the Batch Images node (up to 14).", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + advanced=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=GEMINI_IMAGE_2_PRICE_BADGE, + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + aspect_ratio: str, + resolution: str, + response_modalities: str, + thinking_level: str, + images: Input.Image | None = None, + files: list[GeminiPart] | None = None, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": + model = "gemini-3.1-flash-image-preview" + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + if images is not None: + if get_number_of_images(images) > 14: + raise ValueError("The current maximum number of supported images is 14.") + parts.extend(await create_image_parts(cls, images)) + if files is not None: + parts.extend(files) + + image_config = GeminiImageConfig(imageSize=resolution) + if aspect_ratio != "auto": + image_config.aspectRatio = aspect_ratio + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent(role=GeminiRole.user, parts=parts), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), + imageConfig=image_config, + thinkingConfig=GeminiThinkingConfig(thinkingLevel=thinking_level), + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) + + class GeminiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -822,6 +996,7 @@ class GeminiExtension(ComfyExtension): GeminiNode, GeminiImage, GeminiImage2, + GeminiNanoBanana2, GeminiInputFiles, ] From 88d05fe483a9f420a69d681e615422930404292b Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 27 Feb 2026 04:52:45 +0800 Subject: [PATCH 07/81] chore: update workflow templates to v0.9.4 (#12664) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 88a056e5c..b5b292980 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.3 +comfyui-workflow-templates==0.9.4 comfyui-embedded-docs==0.4.3 torch torchsde From 3dd10a59c00248d00f0cb0ab794ff1bb9fb00a5f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 26 Feb 2026 15:59:22 -0500 Subject: [PATCH 08/81] ComfyUI v0.15.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 6dbda1a87..6a35c6de3 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.15.0" +__version__ = "0.15.1" diff --git a/pyproject.toml b/pyproject.toml index eaf558740..1b2318273 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.15.0" +version = "0.15.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 3811780e4f73f9dbace01a85e6d97502406f8ccb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 26 Feb 2026 14:12:29 -0800 Subject: [PATCH 09/81] Portable with cu128 isn't useful anymore. (#12666) Users should either use the cu126 one or the regular one (cu130 at the moment) The cu128 portable is still included in the latest github release but I will stop including it as soon as it becomes slightly annoying to deal with. This might happen as soon as next week. --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index d97c6181d..56b7966cf 100644 --- a/README.md +++ b/README.md @@ -189,8 +189,6 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat [Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) -[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z). - [Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). #### How do I share models between another UI and ComfyUI? From b233dbe0bc179847b81680e0b59c493a8dc8d9a6 Mon Sep 17 00:00:00 2001 From: fappaz Date: Fri, 27 Feb 2026 12:19:19 +1300 Subject: [PATCH 10/81] feat(ace-step): add ACE-Step 1.5 lycoris key alias mapping for LoKR #12638 (#12665) --- comfy/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/lora.py b/comfy/lora.py index 279cf38bb..f36ddb046 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -337,6 +337,7 @@ def model_lora_keys_unet(model, key_map={}): if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"): key_lora = k[len("diffusion_model.decoder."):-len(".weight")] key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format return key_map From 08b26ed7c2fe43417058f4c6c5934de3cebf3f20 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 26 Feb 2026 15:59:24 -0800 Subject: [PATCH 11/81] bug_report template: Push harder for logs (#12657) We get a lot od bug reports without logs, especially for performance issues. --- .github/ISSUE_TEMPLATE/bug-report.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 6556677e0..95cc48f88 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -16,7 +16,7 @@ body: ## Very Important - Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored. + Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored. - type: checkboxes id: custom-nodes-test attributes: From c7f7d52b684f661d911f1747bb6954978fa1d1b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 27 Feb 2026 02:59:05 +0200 Subject: [PATCH 12/81] feat: Support SDPose-OOD (#12661) --- .../modules/diffusionmodules/openaimodel.py | 6 + comfy/ldm/modules/sdpose.py | 130 +++ comfy/model_detection.py | 6 +- comfy/supported_models.py | 3 +- comfy_api/latest/_io.py | 5 +- comfy_extras/nodes_sdpose.py | 740 ++++++++++++++++++ nodes.py | 1 + 7 files changed, 888 insertions(+), 3 deletions(-) create mode 100644 comfy/ldm/modules/sdpose.py create mode 100644 comfy_extras/nodes_sdpose.py diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c8d53cac..295310df6 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -18,6 +18,8 @@ import comfy.patcher_extension import comfy.ops ops = comfy.ops.disable_weight_init +from ..sdpose import HeatmapHead + class TimestepBlock(nn.Module): """ Any module where forward() takes timestep embeddings as a second argument. @@ -441,6 +443,7 @@ class UNetModel(nn.Module): disable_temporal_crossattention=False, max_ddpm_temb_period=10000, attn_precision=None, + heatmap_head=False, device=None, operations=ops, ): @@ -827,6 +830,9 @@ class UNetModel(nn.Module): #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) + if heatmap_head: + self.heatmap_head = HeatmapHead(device=device, dtype=self.dtype, operations=operations) + def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, diff --git a/comfy/ldm/modules/sdpose.py b/comfy/ldm/modules/sdpose.py new file mode 100644 index 000000000..d67b60b76 --- /dev/null +++ b/comfy/ldm/modules/sdpose.py @@ -0,0 +1,130 @@ +import torch +import numpy as np +from scipy.ndimage import gaussian_filter + +class HeatmapHead(torch.nn.Module): + def __init__( + self, + in_channels=640, + out_channels=133, + input_size=(768, 1024), + heatmap_scale=4, + deconv_out_channels=(640,), + deconv_kernel_sizes=(4,), + conv_out_channels=(640,), + conv_kernel_sizes=(1,), + final_layer_kernel_size=1, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale) + self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32) + + # Deconv layers + if deconv_out_channels: + deconv_layers = [] + for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes): + if kernel_size == 4: + padding, output_padding = 1, 0 + elif kernel_size == 3: + padding, output_padding = 1, 1 + elif kernel_size == 2: + padding, output_padding = 0, 0 + else: + raise ValueError(f'Unsupported kernel size {kernel_size}') + + deconv_layers.extend([ + operations.ConvTranspose2d(in_channels, out_ch, kernel_size, + stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype), + torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype), + torch.nn.SiLU(inplace=True) + ]) + in_channels = out_ch + self.deconv_layers = torch.nn.Sequential(*deconv_layers) + else: + self.deconv_layers = torch.nn.Identity() + + # Conv layers + if conv_out_channels: + conv_layers = [] + for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes): + padding = (kernel_size - 1) // 2 + conv_layers.extend([ + operations.Conv2d(in_channels, out_ch, kernel_size, + stride=1, padding=padding, device=device, dtype=dtype), + torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype), + torch.nn.SiLU(inplace=True) + ]) + in_channels = out_ch + self.conv_layers = torch.nn.Sequential(*conv_layers) + else: + self.conv_layers = torch.nn.Identity() + + self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype) + + def forward(self, x): # Decode heatmaps to keypoints + heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x))) + heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W) + B, K, H, W = heatmaps_np.shape + + batch_keypoints = [] + batch_scores = [] + + for b in range(B): + hm = heatmaps_np[b].copy() # (K, H, W) + + # --- vectorised argmax --- + flat = hm.reshape(K, -1) + idx = np.argmax(flat, axis=1) + scores = flat[np.arange(K), idx].copy() + y_locs, x_locs = np.unravel_index(idx, (H, W)) + keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space + invalid = scores <= 0. + keypoints[invalid] = -1 + + # --- DARK sub-pixel refinement (UDP) --- + # 1. Gaussian blur with max-preserving normalisation + border = 5 # (kernel-1)//2 for kernel=11 + for k in range(K): + origin_max = np.max(hm[k]) + dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32) + dr[border:-border, border:-border] = hm[k].copy() + dr = gaussian_filter(dr, sigma=2.0) + hm[k] = dr[border:-border, border:-border].copy() + cur_max = np.max(hm[k]) + if cur_max > 0: + hm[k] *= origin_max / cur_max + # 2. Log-space for Taylor expansion + np.clip(hm, 1e-3, 50., hm) + np.log(hm, hm) + # 3. Hessian-based Newton step + hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten() + index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2) + index += (W + 2) * (H + 2) * np.arange(0, K) + index = index.astype(int).reshape(-1, 1) + i_ = hm_pad[index] + ix1 = hm_pad[index + 1] + iy1 = hm_pad[index + W + 2] + ix1y1 = hm_pad[index + W + 3] + ix1_y1_ = hm_pad[index - W - 3] + ix1_ = hm_pad[index - 1] + iy1_ = hm_pad[index - 2 - W] + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1) + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1) + + # --- restore to input image space --- + keypoints = keypoints * self.scale_factor + keypoints[invalid] = -1 + + batch_keypoints.append(keypoints) + batch_scores.append(scores) + + return batch_keypoints, batch_scores diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 030ae6980..b4b51b200 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -795,6 +795,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config["use_temporal_resblock"] = False unet_config["use_temporal_attention"] = False + heatmap_key = '{}heatmap_head.conv_layers.0.weight'.format(key_prefix) + if heatmap_key in state_dict_keys: + unet_config["heatmap_head"] = True + return unet_config def model_config_from_unet_config(unet_config, state_dict=None): @@ -1015,7 +1019,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], - 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8, + 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 'use_temporal_attention': False, 'use_temporal_resblock': False} diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6d08ff0a5..1bb7b7011 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -525,7 +525,8 @@ class LotusD(SD20): } unet_extra_config = { - "num_classes": 'sequential' + "num_classes": 'sequential', + "num_head_channels": 64, } def get_model(self, state_dict, prefix="", device=None): diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 025727071..189d7d9bc 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1224,9 +1224,10 @@ class BoundingBox(ComfyTypeIO): class Input(WidgetInput): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, - socketless: bool=True, default: dict=None, component: str=None): + socketless: bool=True, default: dict=None, component: str=None, force_input: bool=None): super().__init__(id, display_name, optional, tooltip, None, default, socketless) self.component = component + self.force_input = force_input if default is None: self.default = {"x": 0, "y": 0, "width": 512, "height": 512} @@ -1234,6 +1235,8 @@ class BoundingBox(ComfyTypeIO): d = super().as_dict() if self.component: d["component"] = self.component + if self.force_input is not None: + d["forceInput"] = self.force_input return d diff --git a/comfy_extras/nodes_sdpose.py b/comfy_extras/nodes_sdpose.py new file mode 100644 index 000000000..71441848e --- /dev/null +++ b/comfy_extras/nodes_sdpose.py @@ -0,0 +1,740 @@ +import torch +import comfy.utils +import numpy as np +import math +import colorsys +from tqdm import tqdm +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_extras.nodes_lotus import LotusConditioning + + +def _preprocess_keypoints(kp_raw, sc_raw): + """Insert neck keypoint and remap from MMPose to OpenPose ordering. + + Returns (kp, sc) where kp has shape (134, 2) and sc has shape (134,). + Layout: + 0-17 body (18 kp, OpenPose order) + 18-23 feet (6 kp) + 24-91 face (68 kp) + 92-112 right hand (21 kp) + 113-133 left hand (21 kp) + """ + kp = np.array(kp_raw, dtype=np.float32) + sc = np.array(sc_raw, dtype=np.float32) + if len(kp) >= 17: + neck = (kp[5] + kp[6]) / 2 + neck_score = min(sc[5], sc[6]) if sc[5] > 0.3 and sc[6] > 0.3 else 0 + kp = np.insert(kp, 17, neck, axis=0) + sc = np.insert(sc, 17, neck_score) + mmpose_idx = np.array([17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]) + openpose_idx = np.array([ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]) + tmp_kp, tmp_sc = kp.copy(), sc.copy() + tmp_kp[openpose_idx] = kp[mmpose_idx] + tmp_sc[openpose_idx] = sc[mmpose_idx] + kp, sc = tmp_kp, tmp_sc + return kp, sc + + +def _to_openpose_frames(all_keypoints, all_scores, height, width): + """Convert raw keypoint lists to a list of OpenPose-style frame dicts. + + Each frame dict contains: + canvas_width, canvas_height, people: list of person dicts with keys: + pose_keypoints_2d - 18 body kp as flat [x,y,score,...] (absolute pixels) + foot_keypoints_2d - 6 foot kp as flat [x,y,score,...] (absolute pixels) + face_keypoints_2d - 70 face kp as flat [x,y,score,...] (absolute pixels) + indices 0-67: 68 face landmarks + index 68: right eye (body[14]) + index 69: left eye (body[15]) + hand_right_keypoints_2d - 21 right-hand kp (absolute pixels) + hand_left_keypoints_2d - 21 left-hand kp (absolute pixels) + """ + def _flatten(kp_slice, sc_slice): + return np.stack([kp_slice[:, 0], kp_slice[:, 1], sc_slice], axis=1).flatten().tolist() + + frames = [] + for img_idx in range(len(all_keypoints)): + people = [] + for kp_raw, sc_raw in zip(all_keypoints[img_idx], all_scores[img_idx]): + kp, sc = _preprocess_keypoints(kp_raw, sc_raw) + # 70 face kp = 68 face landmarks + REye (body[14]) + LEye (body[15]) + face_kp = np.concatenate([kp[24:92], kp[[14, 15]]], axis=0) + face_sc = np.concatenate([sc[24:92], sc[[14, 15]]], axis=0) + people.append({ + "pose_keypoints_2d": _flatten(kp[0:18], sc[0:18]), + "foot_keypoints_2d": _flatten(kp[18:24], sc[18:24]), + "face_keypoints_2d": _flatten(face_kp, face_sc), + "hand_right_keypoints_2d": _flatten(kp[92:113], sc[92:113]), + "hand_left_keypoints_2d": _flatten(kp[113:134], sc[113:134]), + }) + frames.append({"canvas_width": width, "canvas_height": height, "people": people}) + return frames + + +class KeypointDraw: + """ + Pose keypoint drawing class that supports both numpy and cv2 backends. + """ + def __init__(self): + try: + import cv2 + self.draw = cv2 + except ImportError: + self.draw = self + + # Hand connections (same for both hands) + self.hand_edges = [ + [0, 1], [1, 2], [2, 3], [3, 4], # thumb + [0, 5], [5, 6], [6, 7], [7, 8], # index + [0, 9], [9, 10], [10, 11], [11, 12], # middle + [0, 13], [13, 14], [14, 15], [15, 16], # ring + [0, 17], [17, 18], [18, 19], [19, 20], # pinky + ] + + # Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed) + self.body_limbSeq = [ + [2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], + [1, 16], [16, 18] + ] + + # Colors matching DWPose + self.colors = [ + [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], + [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], + [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85] + ] + + @staticmethod + def circle(canvas_np, center, radius, color, **kwargs): + """Draw a filled circle using NumPy vectorized operations.""" + cx, cy = center + h, w = canvas_np.shape[:2] + + radius_int = int(np.ceil(radius)) + + y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1) + x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1) + + if y_max <= y_min or x_max <= x_min: + return + + y, x = np.ogrid[y_min:y_max, x_min:x_max] + mask = (x - cx)**2 + (y - cy)**2 <= radius**2 + canvas_np[y_min:y_max, x_min:x_max][mask] = color + + @staticmethod + def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs): + """Draw line using Bresenham's algorithm with NumPy operations.""" + x0, y0, x1, y1 = *pt1, *pt2 + h, w = canvas_np.shape[:2] + dx, dy = abs(x1 - x0), abs(y1 - y0) + sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1) + err, x, y, line_points = dx - dy, x0, y0, [] + + while True: + line_points.append((x, y)) + if x == x1 and y == y1: + break + e2 = 2 * err + if e2 > -dy: + err, x = err - dy, x + sx + if e2 < dx: + err, y = err + dx, y + sy + + if thickness > 1: + radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5)) + for px, py in line_points: + y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1) + if y_max > y_min and x_max > x_min: + yy, xx = np.ogrid[y_min:y_max, x_min:x_max] + canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color + else: + line_points = np.array(line_points) + valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w) + if (valid_points := line_points[valid]).size: + canvas_np[valid_points[:, 1], valid_points[:, 0]] = color + + @staticmethod + def fillConvexPoly(canvas_np, pts, color, **kwargs): + """Fill polygon using vectorized scanline algorithm.""" + if len(pts) < 3: + return + pts = np.array(pts, dtype=np.int32) + h, w = canvas_np.shape[:2] + y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1) + if y_max <= y_min or x_max <= x_min: + return + yy, xx = np.mgrid[y_min:y_max, x_min:x_max] + mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool) + + for i in range(len(pts)): + p1, p2 = pts[i], pts[(i + 1) % len(pts)] + y1, y2 = p1[1], p2[1] + if y1 == y2: + continue + if y1 > y2: + p1, p2, y1, y2 = p2, p1, p2[1], p1[1] + if not (edge_mask := (yy >= y1) & (yy < y2)).any(): + continue + mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1)) + + canvas_np[y_min:y_max, x_min:x_max][mask] = color + + @staticmethod + def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs): + """Python implementation of cv2.ellipse2Poly.""" + axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output + angle = angle % 360 + if arc_start > arc_end: + arc_start, arc_end = arc_end, arc_start + while arc_start < 0: + arc_start, arc_end = arc_start + 360, arc_end + 360 + while arc_end > 360: + arc_end, arc_start = arc_end - 360, arc_start - 360 + if arc_end - arc_start > 360: + arc_start, arc_end = 0, 360 + + angle_rad = math.radians(angle) + alpha, beta = math.cos(angle_rad), math.sin(angle_rad) + pts = [] + for i in range(arc_start, arc_end + delta, delta): + theta_rad = math.radians(min(i, arc_end)) + x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad) + pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))]) + + unique_pts, prev_pt = [], (float('inf'), float('inf')) + for pt in pts: + if (pt_tuple := tuple(pt)) != prev_pt: + unique_pts.append(pt) + prev_pt = pt_tuple + + return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]] + + def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3, + draw_body=True, draw_feet=True, draw_face=True, draw_hands=True, stick_width=4, face_point_size=3): + """ + Draw wholebody keypoints (134 keypoints after processing) in DWPose style. + + Expected keypoint format (after neck insertion and remapping): + - Body: 0-17 (18 keypoints in OpenPose format, neck at index 1) + - Foot: 18-23 (6 keypoints) + - Face: 24-91 (68 landmarks) + - Right hand: 92-112 (21 keypoints) + - Left hand: 113-133 (21 keypoints) + + Args: + canvas: The canvas to draw on (numpy array) + keypoints: Array of keypoint coordinates + scores: Optional confidence scores for each keypoint + threshold: Minimum confidence threshold for drawing keypoints + + Returns: + canvas: The canvas with keypoints drawn + """ + H, W, C = canvas.shape + + # Draw body limbs + if draw_body and len(keypoints) >= 18: + for i, limb in enumerate(self.body_limbSeq): + # Convert from 1-indexed to 0-indexed + idx1, idx2 = limb[0] - 1, limb[1] - 1 + + if idx1 >= 18 or idx2 >= 18: + continue + + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + Y = [keypoints[idx1][0], keypoints[idx2][0]] + X = [keypoints[idx1][1], keypoints[idx2][1]] + mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2 + length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) + + if length < 1: + continue + + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + + polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1) + + self.draw.fillConvexPoly(canvas, polygon, self.colors[i % len(self.colors)]) + + # Draw body keypoints + if draw_body and len(keypoints) >= 18: + for i in range(18): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1) + + # Draw foot keypoints (18-23, 6 keypoints) + if draw_feet and len(keypoints) >= 24: + for i in range(18, 24): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1) + + # Draw right hand (92-112) + if draw_hands and len(keypoints) >= 113: + eps = 0.01 + for ie, edge in enumerate(self.hand_edges): + idx1, idx2 = 92 + edge[0], 92 + edge[1] + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) + x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) + + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: + # HSV to RGB conversion for rainbow colors + r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) + color = (int(r * 255), int(g * 255), int(b * 255)) + self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2) + + # Draw right hand keypoints + for i in range(92, 113): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + + # Draw left hand (113-133) + if draw_hands and len(keypoints) >= 134: + eps = 0.01 + for ie, edge in enumerate(self.hand_edges): + idx1, idx2 = 113 + edge[0], 113 + edge[1] + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) + x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) + + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: + # HSV to RGB conversion for rainbow colors + r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) + color = (int(r * 255), int(g * 255), int(b * 255)) + self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2) + + # Draw left hand keypoints + for i in range(113, 134): + if scores is not None and i < len(scores) and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + + # Draw face keypoints (24-91) - white dots only, no lines + if draw_face and len(keypoints) >= 92: + eps = 0.01 + for i in range(24, 92): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1) + + return canvas + +class SDPoseDrawKeypoints(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseDrawKeypoints", + category="image/preprocessors", + search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "pose"], + inputs=[ + io.Custom("POSE_KEYPOINT").Input("keypoints"), + io.Boolean.Input("draw_body", default=True), + io.Boolean.Input("draw_hands", default=True), + io.Boolean.Input("draw_face", default=True), + io.Boolean.Input("draw_feet", default=False), + io.Int.Input("stick_width", default=4, min=1, max=10, step=1), + io.Int.Input("face_point_size", default=3, min=1, max=10, step=1), + io.Float.Input("score_threshold", default=0.3, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, keypoints, draw_body, draw_hands, draw_face, draw_feet, stick_width, face_point_size, score_threshold) -> io.NodeOutput: + if not keypoints: + return io.NodeOutput(torch.zeros((1, 64, 64, 3), dtype=torch.float32)) + height = keypoints[0]["canvas_height"] + width = keypoints[0]["canvas_width"] + + def _parse(flat, n): + arr = np.array(flat, dtype=np.float32).reshape(n, 3) + return arr[:, :2], arr[:, 2] + + def _zeros(n): + return np.zeros((n, 2), dtype=np.float32), np.zeros(n, dtype=np.float32) + + pose_outputs = [] + drawer = KeypointDraw() + + for frame in tqdm(keypoints, desc="Drawing keypoints on frames"): + canvas = np.zeros((height, width, 3), dtype=np.uint8) + for person in frame["people"]: + body_kp, body_sc = _parse(person["pose_keypoints_2d"], 18) + foot_raw = person.get("foot_keypoints_2d") + foot_kp, foot_sc = _parse(foot_raw, 6) if foot_raw else _zeros(6) + face_kp, face_sc = _parse(person["face_keypoints_2d"], 70) + face_kp, face_sc = face_kp[:68], face_sc[:68] # drop appended eye kp; body already draws them + rhand_kp, rhand_sc = _parse(person["hand_right_keypoints_2d"], 21) + lhand_kp, lhand_sc = _parse(person["hand_left_keypoints_2d"], 21) + + kp = np.concatenate([body_kp, foot_kp, face_kp, rhand_kp, lhand_kp], axis=0) + sc = np.concatenate([body_sc, foot_sc, face_sc, rhand_sc, lhand_sc], axis=0) + + canvas = drawer.draw_wholebody_keypoints( + canvas, kp, sc, + threshold=score_threshold, + draw_body=draw_body, draw_feet=draw_feet, + draw_face=draw_face, draw_hands=draw_hands, + stick_width=stick_width, face_point_size=face_point_size, + ) + pose_outputs.append(canvas) + + pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0) + final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0 + return io.NodeOutput(final_pose_output) + +class SDPoseKeypointExtractor(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseKeypointExtractor", + category="image/preprocessors", + search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "sdpose"], + description="Extract pose keypoints from images using the SDPose model: https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints", + inputs=[ + io.Model.Input("model"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Int.Input("batch_size", default=16, min=1, max=10000, step=1), + io.BoundingBox.Input("bboxes", optional=True, force_input=True, tooltip="Optional bounding boxes for more accurate detections. Required for multi-person detection."), + ], + outputs=[ + io.Custom("POSE_KEYPOINT").Output("keypoints", tooltip="Keypoints in OpenPose frame format (canvas_width, canvas_height, people)"), + ], + ) + + @classmethod + def execute(cls, model, vae, image, batch_size, bboxes=None) -> io.NodeOutput: + + height, width = image.shape[-3], image.shape[-2] + context = LotusConditioning().execute().result[0] + + # Use output_block_patch to capture the last 640-channel feature + def output_patch(h, hsp, transformer_options): + nonlocal captured_feat + if h.shape[1] == 640: # Capture the features for wholebody + captured_feat = h.clone() + return h, hsp + + model_clone = model.clone() + model_clone.model_options["transformer_options"] = {"patches": {"output_block_patch": [output_patch]}} + + if not hasattr(model.model.diffusion_model, 'heatmap_head'): + raise ValueError("The provided model does not have a heatmap_head. Please use SDPose model from here https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints.") + + head = model.model.diffusion_model.heatmap_head + total_images = image.shape[0] + captured_feat = None + + model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768 + model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024 + + def _run_on_latent(latent_batch): + """Run one forward pass and return (keypoints_list, scores_list) for the batch.""" + nonlocal captured_feat + captured_feat = None + _ = comfy.sample.sample( + model_clone, + noise=torch.zeros_like(latent_batch), + steps=1, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=context, negative=context, + latent_image=latent_batch, disable_noise=True, disable_pbar=True, + ) + return head(captured_feat) # keypoints_batch, scores_batch + + # all_keypoints / all_scores are lists-of-lists: + # outer index = input image index + # inner index = detected person (one per bbox, or one for full-image) + all_keypoints = [] # shape: [n_images][n_persons] + all_scores = [] # shape: [n_images][n_persons] + pbar = comfy.utils.ProgressBar(total_images) + + if bboxes is not None: + if not isinstance(bboxes, list): + bboxes = [[bboxes]] + elif len(bboxes) == 0: + bboxes = [None] * total_images + # --- bbox-crop mode: one forward pass per crop ------------------------- + for img_idx in tqdm(range(total_images), desc="Extracting keypoints from crops"): + img = image[img_idx:img_idx + 1] # (1, H, W, C) + # Broadcasting: if fewer bbox lists than images, repeat the last one. + img_bboxes = bboxes[min(img_idx, len(bboxes) - 1)] if bboxes else None + + img_keypoints = [] + img_scores = [] + + if img_bboxes: + for bbox in img_bboxes: + x1 = max(0, int(bbox["x"])) + y1 = max(0, int(bbox["y"])) + x2 = min(width, int(bbox["x"] + bbox["width"])) + y2 = min(height, int(bbox["y"] + bbox["height"])) + + if x2 <= x1 or y2 <= y1: + continue + + crop_h_px, crop_w_px = y2 - y1, x2 - x1 + crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C) + + # scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size. + scale = min(model_h / crop_h_px, model_w / crop_w_px) + scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale)) + pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2 + + crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW + scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled") + padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device) + padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled + crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC + + latent_crop = vae.encode(crop_resized) + kp_batch, sc_batch = _run_on_latent(latent_crop) + kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space + + # remove padding offset, undo scale, offset to full-image coordinates. + kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32) + kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1 + kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1 + + img_keypoints.append(kp) + img_scores.append(sc) + else: + # No bboxes for this image – run on the full image + latent_img = vae.encode(img) + kp_batch, sc_batch = _run_on_latent(latent_img) + img_keypoints.append(kp_batch[0]) + img_scores.append(sc_batch[0]) + + all_keypoints.append(img_keypoints) + all_scores.append(img_scores) + pbar.update(1) + + else: # full-image mode, batched + tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints") + for batch_start in range(0, total_images, batch_size): + batch_end = min(batch_start + batch_size, total_images) + latent_batch = vae.encode(image[batch_start:batch_end]) + + kp_batch, sc_batch = _run_on_latent(latent_batch) + + for kp, sc in zip(kp_batch, sc_batch): + all_keypoints.append([kp]) + all_scores.append([sc]) + tqdm_pbar.update(1) + + pbar.update(batch_end - batch_start) + + openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width) + return io.NodeOutput(openpose_frames) + + +def get_face_bboxes(kp2ds, scale, image_shape): + h, w = image_shape + kp2ds_face = kp2ds.copy()[1:] * (w, h) + + min_x, min_y = np.min(kp2ds_face, axis=0) + max_x, max_y = np.max(kp2ds_face, axis=0) + + initial_width = max_x - min_x + initial_height = max_y - min_y + + if initial_width <= 0 or initial_height <= 0: + return [0, 0, 0, 0] + + initial_area = initial_width * initial_height + + expanded_area = initial_area * scale + + new_width = np.sqrt(expanded_area * (initial_width / initial_height)) + new_height = np.sqrt(expanded_area * (initial_height / initial_width)) + + delta_width = (new_width - initial_width) / 2 + delta_height = (new_height - initial_height) / 4 + + expanded_min_x = max(min_x - delta_width, 0) + expanded_max_x = min(max_x + delta_width, w) + expanded_min_y = max(min_y - 3 * delta_height, 0) + expanded_max_y = min(max_y + delta_height, h) + + return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)] + +class SDPoseFaceBBoxes(io.ComfyNode): + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseFaceBBoxes", + category="image/preprocessors", + search_aliases=["face bbox", "face bounding box", "pose", "keypoints"], + inputs=[ + io.Custom("POSE_KEYPOINT").Input("keypoints"), + io.Float.Input("scale", default=1.5, min=1.0, max=10.0, step=0.1, tooltip="Multiplier for the bounding box area around each detected face."), + io.Boolean.Input("force_square", default=True, tooltip="Expand the shorter bbox axis so the crop region is always square."), + ], + outputs=[ + io.BoundingBox.Output("bboxes", tooltip="Face bounding boxes per frame, compatible with SDPoseKeypointExtractor bboxes input."), + ], + ) + + @classmethod + def execute(cls, keypoints, scale, force_square) -> io.NodeOutput: + all_bboxes = [] + for frame in keypoints: + h = frame["canvas_height"] + w = frame["canvas_width"] + frame_bboxes = [] + for person in frame["people"]: + face_flat = person.get("face_keypoints_2d", []) + if not face_flat: + continue + # Parse absolute-pixel face keypoints (70 kp: 68 landmarks + REye + LEye) + face_arr = np.array(face_flat, dtype=np.float32).reshape(-1, 3) + face_xy = face_arr[:, :2] # (70, 2) in absolute pixels + + kp_norm = face_xy / np.array([w, h], dtype=np.float32) + kp_padded = np.vstack([np.zeros((1, 2), dtype=np.float32), kp_norm]) # (71, 2) + + x1, x2, y1, y2 = get_face_bboxes(kp_padded, scale, (h, w)) + if x2 > x1 and y2 > y1: + if force_square: + bw, bh = x2 - x1, y2 - y1 + if bw != bh: + side = max(bw, bh) + cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 + half = side // 2 + x1 = max(0, cx - half) + y1 = max(0, cy - half) + x2 = min(w, x1 + side) + y2 = min(h, y1 + side) + # Re-anchor if clamped + x1 = max(0, x2 - side) + y1 = max(0, y2 - side) + frame_bboxes.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1}) + + all_bboxes.append(frame_bboxes) + + return io.NodeOutput(all_bboxes) + + +class CropByBBoxes(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CropByBBoxes", + category="image/preprocessors", + search_aliases=["crop", "face crop", "bbox crop", "pose", "bounding box"], + description="Crop and resize regions from the input image batch based on provided bounding boxes.", + inputs=[ + io.Image.Input("image"), + io.BoundingBox.Input("bboxes", force_input=True), + io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."), + io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."), + io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."), + ], + outputs=[ + io.Image.Output(tooltip="All crops stacked into a single image batch."), + ], + ) + + @classmethod + def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput: + total_frames = image.shape[0] + img_h = image.shape[1] + img_w = image.shape[2] + num_ch = image.shape[3] + + if not isinstance(bboxes, list): + bboxes = [[bboxes]] + elif len(bboxes) == 0: + return io.NodeOutput(image) + + crops = [] + + for frame_idx in range(total_frames): + frame_bboxes = bboxes[min(frame_idx, len(bboxes) - 1)] + if not frame_bboxes: + continue + + frame_chw = image[frame_idx].permute(2, 0, 1).unsqueeze(0) # BHWC → BCHW (1, C, H, W) + + # Union all bboxes for this frame into a single crop region + x1 = min(int(b["x"]) for b in frame_bboxes) + y1 = min(int(b["y"]) for b in frame_bboxes) + x2 = max(int(b["x"] + b["width"]) for b in frame_bboxes) + y2 = max(int(b["y"] + b["height"]) for b in frame_bboxes) + + if padding > 0: + x1 = max(0, x1 - padding) + y1 = max(0, y1 - padding) + x2 = min(img_w, x2 + padding) + y2 = min(img_h, y2 + padding) + + x1, x2 = max(0, x1), min(img_w, x2) + y1, y2 = max(0, y1), min(img_h, y2) + + # Fallback for empty/degenerate crops + if x2 <= x1 or y2 <= y1: + fallback_size = int(min(img_h, img_w) * 0.3) + fb_x1 = max(0, (img_w - fallback_size) // 2) + fb_y1 = max(0, int(img_h * 0.1)) + fb_x2 = min(img_w, fb_x1 + fallback_size) + fb_y2 = min(img_h, fb_y1 + fallback_size) + if fb_x2 <= fb_x1 or fb_y2 <= fb_y1: + crops.append(torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)) + continue + x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2 + + crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w) + resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled") + crops.append(resized) + + if not crops: + return io.NodeOutput(image) + + out_images = torch.cat(crops, dim=0).permute(0, 2, 3, 1) # (N, H, W, C) + return io.NodeOutput(out_images) + + +class SDPoseExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SDPoseKeypointExtractor, + SDPoseDrawKeypoints, + SDPoseFaceBBoxes, + CropByBBoxes, + ] + +async def comfy_entrypoint() -> SDPoseExtension: + return SDPoseExtension() diff --git a/nodes.py b/nodes.py index bff073e30..0222ec629 100644 --- a/nodes.py +++ b/nodes.py @@ -2447,6 +2447,7 @@ async def init_builtin_extra_nodes(): "nodes_toolkit.py", "nodes_replacements.py", "nodes_nag.py", + "nodes_sdpose.py", ] import_failed = [] From 35e9fce7756604050f07a05d090e697b81322c44 Mon Sep 17 00:00:00 2001 From: vickytsang Date: Thu, 26 Feb 2026 17:16:12 -0800 Subject: [PATCH 13/81] Enable Pytorch Attention for gfx950 (#12641) --- comfy/model_management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 1fe56a62b..f73613f17 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -350,7 +350,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN' try: if is_amd(): - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0] if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD @@ -378,7 +378,7 @@ try: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): if any((a in arch) for a in ["gfx1200", "gfx1201"]): From 0a7f8e11b6f280b1b574f5dd642e0b46b8f0e045 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:13:24 +0000 Subject: [PATCH 14/81] fix torch.cat requiring inputs to all be same dimensions (#12673) --- comfy_extras/nodes_glsl.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_glsl.py b/comfy_extras/nodes_glsl.py index 6d210b307..2a59a9285 100644 --- a/comfy_extras/nodes_glsl.py +++ b/comfy_extras/nodes_glsl.py @@ -865,14 +865,15 @@ class GLSLShader(io.ComfyNode): cls, image_list: list[torch.Tensor], output_batch: torch.Tensor ) -> dict[str, list]: """Build UI output with input and output images for client-side shader execution.""" - combined_inputs = torch.cat(image_list, dim=0) - input_images_ui = ui.ImageSaveHelper.save_images( - combined_inputs, - filename_prefix="GLSLShader_input", - folder_type=io.FolderType.temp, - cls=None, - compress_level=1, - ) + input_images_ui = [] + for img in image_list: + input_images_ui.extend(ui.ImageSaveHelper.save_images( + img, + filename_prefix="GLSLShader_input", + folder_type=io.FolderType.temp, + cls=None, + compress_level=1, + )) output_images_ui = ui.ImageSaveHelper.save_images( output_batch, From 1f1ec377ce9d4c525d1615099524231756a69e5e Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 27 Feb 2026 09:13:57 -0800 Subject: [PATCH 15/81] feat: add ResolutionSelector node for aspect ratio and megapixel-based resolution calculation (#12199) Amp-Thread-ID: https://ampcode.com/threads/T-019c179e-cd8c-768f-ae66-207c7a53c01d Co-authored-by: Jedrzej Kosinski --- comfy_extras/nodes_resolution.py | 82 ++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 83 insertions(+) create mode 100644 comfy_extras/nodes_resolution.py diff --git a/comfy_extras/nodes_resolution.py b/comfy_extras/nodes_resolution.py new file mode 100644 index 000000000..d94156433 --- /dev/null +++ b/comfy_extras/nodes_resolution.py @@ -0,0 +1,82 @@ +from __future__ import annotations +import math +from enum import Enum +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class AspectRatio(str, Enum): + SQUARE = "1:1 (Square)" + PHOTO_H = "3:2 (Photo)" + STANDARD_H = "4:3 (Standard)" + WIDESCREEN_H = "16:9 (Widescreen)" + ULTRAWIDE_H = "21:9 (Ultrawide)" + PHOTO_V = "2:3 (Portrait Photo)" + STANDARD_V = "3:4 (Portrait Standard)" + WIDESCREEN_V = "9:16 (Portrait Widescreen)" + + +ASPECT_RATIOS: dict[str, tuple[int, int]] = { + "1:1 (Square)": (1, 1), + "3:2 (Photo)": (3, 2), + "4:3 (Standard)": (4, 3), + "16:9 (Widescreen)": (16, 9), + "21:9 (Ultrawide)": (21, 9), + "2:3 (Portrait Photo)": (2, 3), + "3:4 (Portrait Standard)": (3, 4), + "9:16 (Portrait Widescreen)": (9, 16), +} + + +class ResolutionSelector(io.ComfyNode): + """Calculate width and height from aspect ratio and megapixel target.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ResolutionSelector", + display_name="Resolution Selector", + category="utils", + description="Calculate width and height from aspect ratio and megapixel target. Useful for setting up Empty Latent Image dimensions.", + inputs=[ + io.Combo.Input( + "aspect_ratio", + options=AspectRatio, + default=AspectRatio.SQUARE, + tooltip="The aspect ratio for the output dimensions.", + ), + io.Float.Input( + "megapixels", + default=1.0, + min=0.1, + max=16.0, + step=0.1, + tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.", + ), + ], + outputs=[ + io.Int.Output("width", tooltip="Calculated width in pixels (multiple of 8)."), + io.Int.Output("height", tooltip="Calculated height in pixels (multiple of 8)."), + ], + ) + + @classmethod + def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput: + w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] + total_pixels = megapixels * 1024 * 1024 + scale = math.sqrt(total_pixels / (w_ratio * h_ratio)) + width = round(w_ratio * scale / 8) * 8 + height = round(h_ratio * scale / 8) * 8 + return io.NodeOutput(width, height) + + +class ResolutionExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ResolutionSelector, + ] + + +async def comfy_entrypoint() -> ResolutionExtension: + return ResolutionExtension() diff --git a/nodes.py b/nodes.py index 0222ec629..bf6ce5736 100644 --- a/nodes.py +++ b/nodes.py @@ -2435,6 +2435,7 @@ async def init_builtin_extra_nodes(): "nodes_audio_encoder.py", "nodes_rope.py", "nodes_logic.py", + "nodes_resolution.py", "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", From 25ec3d96a323c8455c6ee69e43bdd7a5599d3cc0 Mon Sep 17 00:00:00 2001 From: "Reiner \"Tiles\" Prokein" Date: Sat, 28 Feb 2026 01:03:45 +0100 Subject: [PATCH 16/81] Class WanVAE, def encode, feat_map is using self.decoder instead of self.encoder (#12682) --- comfy/ldm/wan/vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 7903c7690..71f73c64e 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -485,7 +485,7 @@ class WanVAE(nn.Module): iter_ = 1 + (t - 1) // 4 feat_map = None if iter_ > 1: - feat_map = [None] * count_conv3d(self.decoder) + feat_map = [None] * count_conv3d(self.encoder) ## 对encode输入的x,按时间拆分为1、4、4、4.... for i in range(iter_): conv_idx = [0] From e721e24136b5480c396bf0e37a114f6e4083482b Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:05:51 -0800 Subject: [PATCH 17/81] ops: implement lora requanting for non QuantizedTensor fp8 (#12668) Allow non QuantizedTensor layer to set want_requant to get the post lora calculation stochastic cast down to the original input dtype. This is then used by the legacy fp8 Linear implementation to set the compute_dtype to the preferred lora dtype but then want_requant it back down to fp8. This fixes the issue with --fast fp8_matrix_mult is combined with --fast dynamic_vram which doing a lora on an fp8_ non QT model. --- comfy/ops.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 98fec1e1d..6ee6075fb 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -167,17 +167,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu x = to_dequant(x, dtype) if not resident and lowvram_fn is not None: x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) - #FIXME: this is not accurate, we need to be sensitive to the compute dtype x = lowvram_fn(x) - if (isinstance(orig, QuantizedTensor) and - (want_requant and len(fns) == 0 or update_weight)): + if (want_requant and len(fns) == 0 or update_weight): seed = comfy.utils.string_to_seed(s.seed_key) - y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) - if want_requant and len(fns) == 0: - #The layer actually wants our freshly saved QT - x = y - elif update_weight: - y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key)) + if isinstance(orig, QuantizedTensor): + y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + else: + y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed) + if want_requant and len(fns) == 0: + x = y if update_weight: orig.copy_(y) for f in fns: @@ -617,7 +615,8 @@ def fp8_linear(self, input): if input.ndim != 2: return None - w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device) + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True) scale_weight = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32) From 94f1a1cc9df69cbc75fe6d0f78a4de5d1d857d9d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 27 Feb 2026 17:16:24 -0800 Subject: [PATCH 18/81] Limit overlap in image tile and combine nodes to prevent issues. (#12688) --- comfy_extras/nodes_images.py | 34 +++------------------------------- 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 727d7d09d..4c57bb5cb 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -706,8 +706,8 @@ class SplitImageToTileList(IO.ComfyNode): @staticmethod def get_grid_coords(width, height, tile_width, tile_height, overlap): coords = [] - stride_x = max(1, tile_width - overlap) - stride_y = max(1, tile_height - overlap) + stride_x = round(max(tile_width * 0.25, tile_width - overlap)) + stride_y = round(max(tile_width * 0.25, tile_height - overlap)) y = 0 while y < height: @@ -764,34 +764,6 @@ class ImageMergeTileList(IO.ComfyNode): ], ) - @staticmethod - def get_grid_coords(width, height, tile_width, tile_height, overlap): - coords = [] - stride_x = max(1, tile_width - overlap) - stride_y = max(1, tile_height - overlap) - - y = 0 - while y < height: - x = 0 - y_end = min(y + tile_height, height) - y_start = max(0, y_end - tile_height) - - while x < width: - x_end = min(x + tile_width, width) - x_start = max(0, x_end - tile_width) - - coords.append((x_start, y_start, x_end, y_end)) - - if x_end >= width: - break - x += stride_x - - if y_end >= height: - break - y += stride_y - - return coords - @classmethod def execute(cls, image_list, final_width, final_height, overlap): w = final_width[0] @@ -804,7 +776,7 @@ class ImageMergeTileList(IO.ComfyNode): device = first_tile.device dtype = first_tile.dtype - coords = cls.get_grid_coords(w, h, t_w, t_h, ovlp) + coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp) canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype) weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype) From ac4412d0fa2b9df8469fb6018e0036c47332397a Mon Sep 17 00:00:00 2001 From: Talmaj Date: Sat, 28 Feb 2026 05:04:34 +0100 Subject: [PATCH 19/81] Native LongCat-Image implementation (#12597) --- comfy/model_base.py | 19 ++ comfy/model_detection.py | 2 + comfy/sd.py | 5 + comfy/supported_models.py | 34 +++- comfy/text_encoders/longcat_image.py | 184 ++++++++++++++++++ nodes.py | 2 +- tests-unit/comfy_test/model_detection_test.py | 112 +++++++++++ 7 files changed, 356 insertions(+), 2 deletions(-) create mode 100644 comfy/text_encoders/longcat_image.py create mode 100644 tests-unit/comfy_test/model_detection_test.py diff --git a/comfy/model_base.py b/comfy/model_base.py index 8f852e3c6..85cd30bae 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -925,6 +925,25 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class LongCatImage(Flux): + def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + transformer_options = transformer_options.copy() + rope_opts = transformer_options.get("rope_options", {}) + rope_opts = dict(rope_opts) + rope_opts.setdefault("shift_t", 1.0) + rope_opts.setdefault("shift_y", 512.0) + rope_opts.setdefault("shift_x", 512.0) + transformer_options["rope_options"] = rope_opts + return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) + + def encode_adm(self, **kwargs): + return None + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out.pop('guidance', None) + return out + class Flux2(Flux): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b4b51b200..8a1d8ea4d 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -279,6 +279,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"]) if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model dit_config["txt_ids_dims"] = [1, 2] + if dit_config.get("context_in_dim") == 3584 and dit_config["vec_in_dim"] is None: # LongCat-Image + dit_config["txt_ids_dims"] = [1, 2] return dit_config diff --git a/comfy/sd.py b/comfy/sd.py index de119eb8e..7713d4678 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -60,6 +60,7 @@ import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.newbie import comfy.text_encoders.anima import comfy.text_encoders.ace15 +import comfy.text_encoders.longcat_image import comfy.model_patcher import comfy.lora @@ -1160,6 +1161,7 @@ class CLIPType(Enum): KANDINSKY5_IMAGE = 23 NEWBIE = 24 FLUX2 = 25 + LONGCAT_IMAGE = 26 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1372,6 +1374,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip if clip_type == CLIPType.HUNYUAN_IMAGE: clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + elif clip_type == CLIPType.LONGCAT_IMAGE: + clip_target.clip = comfy.text_encoders.longcat_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.longcat_image.LongCatImageTokenizer else: clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1bb7b7011..473fbbfd4 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -25,6 +25,7 @@ import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image import comfy.text_encoders.anima import comfy.text_encoders.ace15 +import comfy.text_encoders.longcat_image from . import supported_models_base from . import latent_formats @@ -1678,6 +1679,37 @@ class ACEStep15(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +class LongCatImage(supported_models_base.BASE): + unet_config = { + "image_model": "flux", + "guidance_embed": False, + "vec_in_dim": None, + "context_in_dim": 3584, + "txt_ids_dims": [1, 2], + } + + sampling_settings = { + } + + unet_extra_config = {} + latent_format = latent_formats.Flux + + memory_usage_factor = 2.5 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LongCatImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/longcat_image.py b/comfy/text_encoders/longcat_image.py new file mode 100644 index 000000000..882d80901 --- /dev/null +++ b/comfy/text_encoders/longcat_image.py @@ -0,0 +1,184 @@ +import re +import numbers +import torch +from comfy import sd1_clip +from comfy.text_encoders.qwen_image import Qwen25_7BVLITokenizer, Qwen25_7BVLIModel +import logging + +logger = logging.getLogger(__name__) + +QUOTE_PAIRS = [("'", "'"), ('"', '"'), ("\u2018", "\u2019"), ("\u201c", "\u201d")] +QUOTE_PATTERN = "|".join( + [ + re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) + for q1, q2 in QUOTE_PAIRS + ] +) +WORD_INTERNAL_QUOTE_RE = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + + +def split_quotation(prompt): + matches = WORD_INTERNAL_QUOTE_RE.findall(prompt) + mapping = [] + for i, word_src in enumerate(set(matches)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping.append((word_src, word_tgt)) + + parts = re.split(f"({QUOTE_PATTERN})", prompt) + result = [] + for part in parts: + for word_src, word_tgt in mapping: + part = part.replace(word_tgt, word_src) + if not part: + continue + is_quoted = bool(re.match(QUOTE_PATTERN, part)) + result.append((part, is_quoted)) + return result + + +class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_length = 512 + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + parts = split_quotation(text) + all_tokens = [] + for part_text, is_quoted in parts: + if is_quoted: + for char in part_text: + ids = self.tokenizer(char, add_special_tokens=False)["input_ids"] + all_tokens.extend(ids) + else: + ids = self.tokenizer(part_text, add_special_tokens=False)["input_ids"] + all_tokens.extend(ids) + + if len(all_tokens) > self.max_length: + all_tokens = all_tokens[: self.max_length] + logger.warning(f"Truncated prompt to {self.max_length} tokens") + + output = [(t, 1.0) for t in all_tokens] + # Pad to max length + self.pad_tokens(output, self.max_length - len(output)) + return [output] + + +class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, + name="qwen25_7b", + tokenizer=LongCatImageBaseTokenizer, + ) + self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" + self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + skip_template = False + if text.startswith("<|im_start|>"): + skip_template = True + if text.startswith("<|start_header_id|>"): + skip_template = True + if text == "": + text = " " + + base_tok = getattr(self, "qwen25_7b") + if skip_template: + tokens = super().tokenize_with_weights( + text, return_word_ids=return_word_ids, disable_weights=True, **kwargs + ) + else: + prefix_ids = base_tok.tokenizer( + self.longcat_template_prefix, add_special_tokens=False + )["input_ids"] + suffix_ids = base_tok.tokenizer( + self.longcat_template_suffix, add_special_tokens=False + )["input_ids"] + + prompt_tokens = base_tok.tokenize_with_weights( + text, return_word_ids=return_word_ids, **kwargs + ) + prompt_pairs = prompt_tokens[0] + + prefix_pairs = [(t, 1.0) for t in prefix_ids] + suffix_pairs = [(t, 1.0) for t in suffix_ids] + + combined = prefix_pairs + prompt_pairs + suffix_pairs + tokens = {"qwen25_7b": [combined]} + + return tokens + + +class LongCatImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__( + device=device, + dtype=dtype, + name="qwen25_7b", + clip_model=Qwen25_7BVLIModel, + model_options=model_options, + ) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen25_7b"][0] + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: + if tok_pairs[template_end + 2][0] == 198: + template_end += 3 + + if template_end == -1: + template_end = 0 + + suffix_start = None + for i in range(len(tok_pairs) - 1, -1, -1): + elem = tok_pairs[i][0] + if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral): + if elem == 151645: + suffix_start = i + break + + out = out[:, template_end:] + + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") + + if suffix_start is not None: + suffix_len = len(tok_pairs) - suffix_start + if suffix_len > 0 and out.shape[1] > suffix_len: + out = out[:, :-suffix_len] + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len] + if extra["attention_mask"].sum() == torch.numel( + extra["attention_mask"] + ): + extra.pop("attention_mask") + + return out, pooled, extra + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class LongCatImageTEModel_(LongCatImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + + return LongCatImageTEModel_ diff --git a/nodes.py b/nodes.py index bf6ce5736..5be9b16f9 100644 --- a/nodes.py +++ b/nodes.py @@ -976,7 +976,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py new file mode 100644 index 000000000..2551a417b --- /dev/null +++ b/tests-unit/comfy_test/model_detection_test.py @@ -0,0 +1,112 @@ +import torch + +from comfy.model_detection import detect_unet_config, model_config_from_unet_config +import comfy.supported_models + + +def _make_longcat_comfyui_sd(): + """Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights.""" + sd = {} + H = 32 # Reduce hidden state dimension to reduce memory usage + C_IN = 16 + C_CTX = 3584 + + sd["img_in.weight"] = torch.empty(H, C_IN * 4) + sd["img_in.bias"] = torch.empty(H) + sd["txt_in.weight"] = torch.empty(H, C_CTX) + sd["txt_in.bias"] = torch.empty(H) + + sd["time_in.in_layer.weight"] = torch.empty(H, 256) + sd["time_in.in_layer.bias"] = torch.empty(H) + sd["time_in.out_layer.weight"] = torch.empty(H, H) + sd["time_in.out_layer.bias"] = torch.empty(H) + + sd["final_layer.adaLN_modulation.1.weight"] = torch.empty(2 * H, H) + sd["final_layer.adaLN_modulation.1.bias"] = torch.empty(2 * H) + sd["final_layer.linear.weight"] = torch.empty(C_IN * 4, H) + sd["final_layer.linear.bias"] = torch.empty(C_IN * 4) + + for i in range(19): + sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128) + sd[f"double_blocks.{i}.img_attn.qkv.weight"] = torch.empty(3 * H, H) + sd[f"double_blocks.{i}.img_mod.lin.weight"] = torch.empty(H, H) + for i in range(38): + sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H) + + return sd + + +def _make_flux_schnell_comfyui_sd(): + """Minimal ComfyUI-format state dict for standard Flux Schnell.""" + sd = {} + H = 32 # Reduce hidden state dimension to reduce memory usage + C_IN = 16 + + sd["img_in.weight"] = torch.empty(H, C_IN * 4) + sd["img_in.bias"] = torch.empty(H) + sd["txt_in.weight"] = torch.empty(H, 4096) + sd["txt_in.bias"] = torch.empty(H) + + sd["double_blocks.0.img_attn.norm.key_norm.weight"] = torch.empty(128) + sd["double_blocks.0.img_attn.qkv.weight"] = torch.empty(3 * H, H) + sd["double_blocks.0.img_mod.lin.weight"] = torch.empty(H, H) + + for i in range(19): + sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128) + for i in range(38): + sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H) + + return sd + + +class TestModelDetection: + """Verify that first-match model detection selects the correct model + based on list ordering and unet_config specificity.""" + + def test_longcat_before_schnell_in_models_list(self): + """LongCatImage must appear before FluxSchnell in the models list.""" + models = comfy.supported_models.models + longcat_idx = next(i for i, m in enumerate(models) if m.__name__ == "LongCatImage") + schnell_idx = next(i for i, m in enumerate(models) if m.__name__ == "FluxSchnell") + assert longcat_idx < schnell_idx, ( + f"LongCatImage (index {longcat_idx}) must come before " + f"FluxSchnell (index {schnell_idx}) in the models list" + ) + + def test_longcat_comfyui_detected_as_longcat(self): + sd = _make_longcat_comfyui_sd() + unet_config = detect_unet_config(sd, "") + assert unet_config is not None + assert unet_config["image_model"] == "flux" + assert unet_config["context_in_dim"] == 3584 + assert unet_config["vec_in_dim"] is None + assert unet_config["guidance_embed"] is False + assert unet_config["txt_ids_dims"] == [1, 2] + + model_config = model_config_from_unet_config(unet_config, sd) + assert model_config is not None + assert type(model_config).__name__ == "LongCatImage" + + def test_longcat_comfyui_keys_pass_through_unchanged(self): + """Pre-converted weights should not be transformed by process_unet_state_dict.""" + sd = _make_longcat_comfyui_sd() + unet_config = detect_unet_config(sd, "") + model_config = model_config_from_unet_config(unet_config, sd) + + processed = model_config.process_unet_state_dict(dict(sd)) + assert "img_in.weight" in processed + assert "txt_in.weight" in processed + assert "time_in.in_layer.weight" in processed + assert "final_layer.linear.weight" in processed + + def test_flux_schnell_comfyui_detected_as_flux_schnell(self): + sd = _make_flux_schnell_comfyui_sd() + unet_config = detect_unet_config(sd, "") + assert unet_config is not None + assert unet_config["image_model"] == "flux" + assert unet_config["context_in_dim"] == 4096 + assert unet_config["txt_ids_dims"] == [] + + model_config = model_config_from_unet_config(unet_config, sd) + assert model_config is not None + assert type(model_config).__name__ == "FluxSchnell" From 9d0e114ee380d3eac8aeb00260a9df1212b6046a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 27 Feb 2026 20:34:58 -0800 Subject: [PATCH 20/81] PyOpenGL-accelerate is not necessary. (#12692) --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b5b292980..1b2bd0ae6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,5 +31,4 @@ spandrel pydantic~=2.0 pydantic-settings~=2.0 PyOpenGL -PyOpenGL-accelerate glfw From 80d49441e5e255f8d91d2f335f930e74ba85cbe8 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 27 Feb 2026 20:53:46 -0800 Subject: [PATCH 21/81] refactor: use AspectRatio enum members as ASPECT_RATIOS dict keys (#12689) Amp-Thread-ID: https://ampcode.com/threads/T-019ca1cb-0150-7549-8b1b-6713060d3408 Co-authored-by: Jedrzej Kosinski --- comfy_extras/nodes_resolution.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/comfy_extras/nodes_resolution.py b/comfy_extras/nodes_resolution.py index d94156433..520b4067e 100644 --- a/comfy_extras/nodes_resolution.py +++ b/comfy_extras/nodes_resolution.py @@ -16,15 +16,15 @@ class AspectRatio(str, Enum): WIDESCREEN_V = "9:16 (Portrait Widescreen)" -ASPECT_RATIOS: dict[str, tuple[int, int]] = { - "1:1 (Square)": (1, 1), - "3:2 (Photo)": (3, 2), - "4:3 (Standard)": (4, 3), - "16:9 (Widescreen)": (16, 9), - "21:9 (Ultrawide)": (21, 9), - "2:3 (Portrait Photo)": (2, 3), - "3:4 (Portrait Standard)": (3, 4), - "9:16 (Portrait Widescreen)": (9, 16), +ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = { + AspectRatio.SQUARE: (1, 1), + AspectRatio.PHOTO_H: (3, 2), + AspectRatio.STANDARD_H: (4, 3), + AspectRatio.WIDESCREEN_H: (16, 9), + AspectRatio.ULTRAWIDE_H: (21, 9), + AspectRatio.PHOTO_V: (2, 3), + AspectRatio.STANDARD_V: (3, 4), + AspectRatio.WIDESCREEN_V: (9, 16), } @@ -55,8 +55,12 @@ class ResolutionSelector(io.ComfyNode): ), ], outputs=[ - io.Int.Output("width", tooltip="Calculated width in pixels (multiple of 8)."), - io.Int.Output("height", tooltip="Calculated height in pixels (multiple of 8)."), + io.Int.Output( + "width", tooltip="Calculated width in pixels (multiple of 8)." + ), + io.Int.Output( + "height", tooltip="Calculated height in pixels (multiple of 8)." + ), ], ) From 95e1059661f7a1584b5f84a6ece72ed8d8992b73 Mon Sep 17 00:00:00 2001 From: fappaz Date: Sat, 28 Feb 2026 19:18:40 +1300 Subject: [PATCH 22/81] fix(ace15): handle missing lm_metadata in memory estimation during checkpoint export #12669 (#12686) --- comfy/text_encoders/ace15.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index f135d74c1..853f021ae 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -328,14 +328,14 @@ class ACE15TEModel(torch.nn.Module): return getattr(self, self.lm_model).load_sd(sd) def memory_estimation_function(self, token_weight_pairs, device=None): - lm_metadata = token_weight_pairs["lm_metadata"] + lm_metadata = token_weight_pairs.get("lm_metadata", {}) constant = self.constant if comfy.model_management.should_use_bf16(device): constant *= 0.5 token_weight_pairs = token_weight_pairs.get("lm_prompt", []) num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - num_tokens += lm_metadata['min_tokens'] + num_tokens += lm_metadata.get("min_tokens", 0) return num_tokens * constant * 1024 * 1024 def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"): From 1f6744162f606cce895f2d9818207ddecbce5932 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 28 Feb 2026 23:49:12 +0200 Subject: [PATCH 23/81] feat: Support SCAIL WanVideo model (#12614) --- comfy/ldm/wan/model.py | 115 ++++++++++++++++++++++++++++++++++++++ comfy/model_base.py | 38 +++++++++++++ comfy/model_detection.py | 2 + comfy/supported_models.py | 12 +++- comfy_extras/nodes_wan.py | 58 +++++++++++++++++++ node_helpers.py | 31 ++++++++++ 6 files changed, 255 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index ea123acb4..b2287dba9 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + +class SCAILWanModel(WanModel): + def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs): + super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs) + + self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) + + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs): + + if reference_latent is not None: + x = torch.cat((reference_latent, x), dim=2) + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes + x = x.flatten(2).transpose(1, 2) + + scail_pose_seq_len = 0 + if pose_latents is not None: + scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype) + scail_x = scail_x.flatten(2).transpose(1, 2) + scail_pose_seq_len = scail_x.shape[1] + x = torch.cat([x, scail_x], dim=1) + del scail_x + + # time embeddings + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.cat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + if scail_pose_seq_len > 0: + x = x[:, :-scail_pose_seq_len] + + # unpatchify + x = self.unpatchify(x, grid_sizes) + + if reference_latent is not None: + x = x[:, :, reference_latent.shape[2]:] + + return x + + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}): + main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) + + if pose_latents is None: + return main_freqs + + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + + F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] + + # if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames + h_scale = h / H_pose + w_scale = w / W_pose + + # 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code + h_shift = (h_scale - 1) / 2 + w_shift = (w_scale - 1) / 2 + pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}} + pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options) + + return torch.cat([main_freqs, pose_freqs], dim=1) + + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + if pose_latents is not None: + pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size) + + t_len = t + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = x.shape[2] + + reference_latent = None + if "reference_latent" in kwargs: + reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) + t_len += reference_latent.shape[2] + + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] diff --git a/comfy/model_base.py b/comfy/model_base.py index 85cd30bae..a1c690b9b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1502,6 +1502,44 @@ class WAN21_FlowRVS(WAN21): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) self.image_to_video = image_to_video +class WAN21_SCAIL(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel) + self.memory_usage_factor_conds = ("reference_latent", "pose_latents") + self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + ref_latent = self.process_latent_in(reference_latents[-1]) + ref_mask = torch.ones_like(ref_latent[:, :4]) + ref_latent = torch.cat([ref_latent, ref_mask], dim=1) + out['reference_latent'] = comfy.conds.CONDRegular(ref_latent) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + pose_latents = self.process_latent_in(pose_latents) + pose_mask = torch.ones_like(pose_latents[:, :4]) + pose_latents = torch.cat([pose_latents, pose_mask], dim=1) + out['pose_latents'] = comfy.conds.CONDRegular(pose_latents) + + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]] + + return out + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8a1d8ea4d..3faa950ca 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -498,6 +498,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "humo" elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "animate" + elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "scail" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 473fbbfd4..4f63e8327 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1268,6 +1268,16 @@ class WAN21_FlowRVS(WAN21_T2V): out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device) return out +class WAN21_SCAIL(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "scail", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1710,6 +1720,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index effa994d1..e50bfcd2c 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1456,6 +1456,63 @@ class WanInfiniteTalkToVideo(io.ComfyNode): return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) +class WanSCAILToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSCAILToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("reference_image", optional=True), + io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), + io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), + io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."), + io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + ref_latent = None + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(reference_image[:, :, :, :3]) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength + positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1476,6 +1533,7 @@ class WanExtension(ComfyExtension): WanAnimateToVideo, Wan22ImageToVideoLatent, WanInfiniteTalkToVideo, + WanSCAILToVideo, ] async def comfy_entrypoint() -> WanExtension: diff --git a/node_helpers.py b/node_helpers.py index 4ff960ef8..d3d834516 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -1,5 +1,6 @@ import hashlib import torch +import logging from comfy.cli_args import args @@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False): return c +def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0): + """ + Apply values to conditioning only during [start_percent, end_percent], keeping the + original conditioning active outside that range. Respects existing per-entry ranges. + """ + if start_percent > end_percent: + logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})") + return conditioning + + EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep + c = [] + for t in conditioning: + cond_start = t[1].get("start_percent", 0.0) + cond_end = t[1].get("end_percent", 1.0) + intersect_start = max(start_percent, cond_start) + intersect_end = min(end_percent, cond_end) + + if intersect_start >= intersect_end: # no overlap: emit unchanged + c.append(t) + continue + + if intersect_start > cond_start: # part before the requested range + c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS})) + + c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end})) + + if intersect_end < cond_end: # part after the requested range + c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end})) + return c + def pillow(fn, arg): prev_value = None try: From 5f41584e960d3ad90f6581278e57f7b52e771db4 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 28 Feb 2026 13:50:18 -0800 Subject: [PATCH 24/81] Disable dynamic_vram when weight hooks applied (#12653) * sd: add support for clip model reconstruction * nodes: SetClipHooks: Demote the dynamic model patcher * mp: Make dynamic_disable more robust The backup need to not be cloned. In addition add a delegate object to ModelPatcherDynamic so that non-cloning code can do ModelPatcherDynamic demotion * sampler_helpers: Demote to non-dynamic model patcher when hooking * code rabbit review comments --- comfy/model_patcher.py | 29 ++++++++++++++++++++-------- comfy/sampler_helpers.py | 12 ++++++++++++ comfy/samplers.py | 2 ++ comfy/sd.py | 38 +++++++++++++++++++++++++++---------- comfy_extras/nodes_hooks.py | 2 +- 5 files changed, 64 insertions(+), 19 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1c9ba8096..3fc76d9db 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -308,15 +308,22 @@ class ModelPatcher: def get_free_memory(self, device): return comfy.model_management.get_free_memory(device) - def clone(self, disable_dynamic=False): + def get_clone_model_override(self): + return self.model, (self.backup, self.object_patches_backup, self.pinned) + + def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ - model = self.model if self.is_dynamic() and disable_dynamic: class_ = ModelPatcher - temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) - model = temp_model_patcher.model + if model_override is None: + if self.cached_patcher_init is None: + raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.") + temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) + model_override = temp_model_patcher.get_clone_model_override() + if model_override is None: + model_override = self.get_clone_model_override() - n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) + n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -325,13 +332,12 @@ class ModelPatcher: n.object_patches = self.object_patches.copy() n.weight_wrapper_patches = self.weight_wrapper_patches.copy() n.model_options = comfy.utils.deepcopy_list_dict(self.model_options) - n.backup = self.backup - n.object_patches_backup = self.object_patches_backup n.parent = self - n.pinned = self.pinned n.force_cast_weights = self.force_cast_weights + n.backup, n.object_patches_backup, n.pinned = model_override[1] + # attachments n.attachments = {} for k in self.attachments: @@ -1435,6 +1441,7 @@ class ModelPatcherDynamic(ModelPatcher): del self.model.model_loaded_weight_memory if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} + self.non_dynamic_delegate_model = None assert load_device is not None def is_dynamic(self): @@ -1669,4 +1676,10 @@ class ModelPatcherDynamic(ModelPatcher): def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: pass + def get_non_dynamic_delegate(self): + model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model) + self.non_dynamic_delegate_model = model_patcher.get_clone_model_override() + return model_patcher + + CoreModelPatcher = ModelPatcher diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1f75f2ba7..bbba09e26 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -66,6 +66,18 @@ def convert_cond(cond): out.append(temp) return out +def cond_has_hooks(cond): + for c in cond: + temp = c[1] + if "hooks" in temp: + return True + if "control" in temp: + control = temp["control"] + extra_hooks = control.get_extra_hooks() + if len(extra_hooks) > 0: + return True + return False + def get_additional_models(conds, dtype): """loads additional models in conditioning""" cnets: list[ControlBase] = [] diff --git a/comfy/samplers.py b/comfy/samplers.py index 8b9782956..8be449ef7 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -946,6 +946,8 @@ class CFGGuider: def inner_set_conds(self, conds): for k in conds: + if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]): + self.model_patcher = self.model_patcher.get_non_dynamic_delegate() self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) def __call__(self, *args, **kwargs): diff --git a/comfy/sd.py b/comfy/sd.py index 7713d4678..a9ad7c2d2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -204,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False): if no_init: return params = target.params.copy() @@ -233,7 +233,8 @@ class CLIP: model_management.archive_model_dtypes(self.cond_stage_model) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -267,9 +268,9 @@ class CLIP: logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) self.tokenizer_options = {} - def clone(self): + def clone(self, disable_dynamic=False): n = CLIP(no_init=True) - n.patcher = self.patcher.clone() + n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic) n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx @@ -1164,14 +1165,21 @@ class CLIPType(Enum): LONGCAT_IMAGE = 26 -def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): + +def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): + clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic) + return clip.patcher + +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) if model_options.get("custom_operations", None) is None: sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) clip_data.append(sd) - return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) + clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic) + clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options)) + return clip class TEModel(Enum): @@ -1276,7 +1284,7 @@ def llama_detect(clip_data): return {} -def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): +def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): clip_data = state_dicts class EmptyClass: @@ -1496,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic) return clip def load_gligen(ckpt_path): @@ -1541,8 +1549,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) - if output_model: + if output_model and out[0] is not None: out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) + if output_clip and out[1] is not None: + out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options)) return out def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): @@ -1553,6 +1563,14 @@ def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, disable_dynamic=disable_dynamic) return model +def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + _, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False, + embedding_directory=embedding_directory, output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic) + return clip.patcher + def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False): clip = None clipvision = None @@ -1638,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic) else: logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index be7d600cd..056369e86 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -248,7 +248,7 @@ class SetClipHooks: def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): if hooks is not None: - clip = clip.clone() + clip = clip.clone(disable_dynamic=True) if apply_to_conds: clip.apply_hooks_to_conds = hooks clip.patcher.forced_hooks = hooks.clone() From 48bb0bd18aa90bba0eac7b4c1a1400c4f7110046 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 28 Feb 2026 13:52:30 -0800 Subject: [PATCH 25/81] cli_args: Default comfy to DynamicVram mode (#12658) --- comfy/cli_args.py | 4 ++-- main.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 63daca861..13079c7bc 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -146,6 +146,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") +parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") @@ -159,7 +160,6 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" - DynamicVRAM = "dynamic_vram" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) @@ -260,4 +260,4 @@ else: args.fast = set(args.fast) def enables_dynamic_vram(): - return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only + return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu diff --git a/main.py b/main.py index 3fe8f0589..a0545d9b3 100644 --- a/main.py +++ b/main.py @@ -192,7 +192,7 @@ import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher -if enables_dynamic_vram(): +if enables_dynamic_vram() and comfy.model_management.is_nvidia(): if comfy.model_management.torch_version_numeric < (2, 8): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): From 17106cb124fcfa0b75ea24993c65aa024059fc8d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 28 Feb 2026 19:21:32 -0800 Subject: [PATCH 26/81] Move parsing of requirements logic to function. (#12701) --- app/frontend_management.py | 42 ++------------------ tests-unit/app_test/frontend_manager_test.py | 6 +++ utils/install_util.py | 33 +++++++++++++++ 3 files changed, 42 insertions(+), 39 deletions(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index bdaa85812..f753ef0de 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -17,7 +17,7 @@ from importlib.metadata import version import requests from typing_extensions import NotRequired -from utils.install_util import get_missing_requirements_message, requirements_path +from utils.install_util import get_missing_requirements_message, get_required_packages_versions from comfy.cli_args import DEFAULT_VERSION_STRING import app.logger @@ -45,25 +45,7 @@ def get_installed_frontend_version(): def get_required_frontend_version(): - """Get the required frontend version from requirements.txt.""" - try: - with open(requirements_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line.startswith("comfyui-frontend-package=="): - version_str = line.split("==")[-1] - if not is_valid_version(version_str): - logging.error(f"Invalid version format in requirements.txt: {version_str}") - return None - return version_str - logging.error("comfyui-frontend-package not found in requirements.txt") - return None - except FileNotFoundError: - logging.error("requirements.txt not found. Cannot determine required frontend version.") - return None - except Exception as e: - logging.error(f"Error reading requirements.txt: {e}") - return None + return get_required_packages_versions().get("comfyui-frontend-package", None) def check_frontend_version(): @@ -217,25 +199,7 @@ class FrontendManager: @classmethod def get_required_templates_version(cls) -> str: - """Get the required workflow templates version from requirements.txt.""" - try: - with open(requirements_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line.startswith("comfyui-workflow-templates=="): - version_str = line.split("==")[-1] - if not is_valid_version(version_str): - logging.error(f"Invalid templates version format in requirements.txt: {version_str}") - return None - return version_str - logging.error("comfyui-workflow-templates not found in requirements.txt") - return None - except FileNotFoundError: - logging.error("requirements.txt not found. Cannot determine required templates version.") - return None - except Exception as e: - logging.error(f"Error reading requirements.txt: {e}") - return None + return get_required_packages_versions().get("comfyui-workflow-templates", None) @classmethod def default_frontend_path(cls) -> str: diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py index 643f04e72..1d5a84b47 100644 --- a/tests-unit/app_test/frontend_manager_test.py +++ b/tests-unit/app_test/frontend_manager_test.py @@ -49,6 +49,12 @@ def mock_provider(mock_releases): return provider +@pytest.fixture(autouse=True) +def clear_cache(): + import utils.install_util + utils.install_util.PACKAGE_VERSIONS = {} + + def test_get_release(mock_provider, mock_releases): version = "1.0.0" release = mock_provider.get_release(version) diff --git a/utils/install_util.py b/utils/install_util.py index 0f59bcf91..34489aec5 100644 --- a/utils/install_util.py +++ b/utils/install_util.py @@ -1,5 +1,7 @@ from pathlib import Path import sys +import logging +import re # The path to the requirements.txt file requirements_path = Path(__file__).parents[1] / "requirements.txt" @@ -16,3 +18,34 @@ Please install the updated requirements.txt file by running: {sys.executable} {extra}-m pip install -r {requirements_path} If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem. """.strip() + + +def is_valid_version(version: str) -> bool: + """Validate if a string is a valid semantic version (X.Y.Z format).""" + pattern = r"^(\d+)\.(\d+)\.(\d+)$" + return bool(re.match(pattern, version)) + + +PACKAGE_VERSIONS = {} +def get_required_packages_versions(): + if len(PACKAGE_VERSIONS) > 0: + return PACKAGE_VERSIONS.copy() + out = PACKAGE_VERSIONS + try: + with open(requirements_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip().replace(">=", "==") + s = line.split("==") + if len(s) == 2: + version_str = s[-1] + if not is_valid_version(version_str): + logging.error(f"Invalid version format in requirements.txt: {version_str}") + continue + out[s[0]] = version_str + return out.copy() + except FileNotFoundError: + logging.error("requirements.txt not found.") + return None + except Exception as e: + logging.error(f"Error reading requirements.txt: {e}") + return None From 1080bd442a7509d29bfe0b29cac9222de406c994 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 28 Feb 2026 19:23:28 -0800 Subject: [PATCH 27/81] Disable dynamic vram on wsl. (#12706) --- comfy/model_management.py | 8 ++++++++ main.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index f73613f17..86f840ada 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -180,6 +180,14 @@ def is_ixuca(): return True return False +def is_wsl(): + version = platform.uname().release + if version.endswith("-Microsoft"): + return True + elif version.endswith("microsoft-standard-WSL2"): + return True + return False + def get_torch_device(): global directml_enabled global cpu_state diff --git a/main.py b/main.py index a0545d9b3..af701f8df 100644 --- a/main.py +++ b/main.py @@ -192,7 +192,7 @@ import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher -if enables_dynamic_vram() and comfy.model_management.is_nvidia(): +if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl(): if comfy.model_management.torch_version_numeric < (2, 8): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): From d159142615e0a1a7ae4eb711a6ae9f66a5f2d76e Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 28 Feb 2026 20:59:24 -0800 Subject: [PATCH 28/81] refactor: rename Mahiro CFG to Similarity-Adaptive Guidance (#12172) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: rename Mahiro CFG to Similarity-Adaptive Guidance Rename the display name to better describe what the node does: adaptively blends guidance based on cosine similarity between positive and negative conditions. Amp-Thread-ID: https://ampcode.com/threads/T-019c0d36-8b43-745f-b7b2-e35b53f17fa1 Co-authored-by: Amp * feat: add search aliases for old mahiro name Amp-Thread-ID: https://ampcode.com/threads/T-019c0d36-8b43-745f-b7b2-e35b53f17fa1 * rename: Similarity-Adaptive Guidance → Positive-Biased Guidance (per reviewer) - display_name changed to 'Positive-Biased Guidance' to avoid SAG acronym collision - search_aliases expanded: mahiro, mahiro cfg, similarity-adaptive guidance, positive-biased cfg - ruff format applied --------- Co-authored-by: Amp Co-authored-by: Jedrzej Kosinski --- comfy_extras/nodes_mahiro.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 6459ca8c1..a25226e6d 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Mahiro", - display_name="Mahiro CFG", + display_name="Positive-Biased Guidance", category="_for_testing", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", inputs=[ @@ -20,27 +20,35 @@ class Mahiro(io.ComfyNode): io.Model.Output(display_name="patched_model"), ], is_experimental=True, + search_aliases=[ + "mahiro", + "mahiro cfg", + "similarity-adaptive guidance", + "positive-biased cfg", + ], ) @classmethod def execute(cls, model) -> io.NodeOutput: m = model.clone() + def mahiro_normd(args): - scale: float = args['cond_scale'] - cond_p: torch.Tensor = args['cond_denoised'] - uncond_p: torch.Tensor = args['uncond_denoised'] - #naive leap + scale: float = args["cond_scale"] + cond_p: torch.Tensor = args["cond_denoised"] + uncond_p: torch.Tensor = args["uncond_denoised"] + # naive leap leap = cond_p * scale - #sim with uncond leap + # sim with uncond leap u_leap = uncond_p * scale cfg = args["denoised"] merge = (leap + cfg) / 2 normu = torch.sqrt(u_leap.abs()) * u_leap.sign() normm = torch.sqrt(merge.abs()) * merge.sign() sim = F.cosine_similarity(normu, normm).mean() - simsc = 2 * (sim+1) - wm = (simsc*cfg + (4-simsc)*leap) / 4 + simsc = 2 * (sim + 1) + wm = (simsc * cfg + (4 - simsc) * leap) / 4 return wm + m.set_model_sampler_post_cfg_function(mahiro_normd) return io.NodeOutput(m) From 850e8b42ff67cec295edb686c4b85dc7811f5e7f Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 28 Feb 2026 21:38:19 -0800 Subject: [PATCH 29/81] feat: add text preview support to jobs API (#12169) * feat: add text preview support to jobs API Amp-Thread-ID: https://ampcode.com/threads/T-019c0be0-9fc6-71ac-853a-7c7cc846b375 Co-authored-by: Amp * test: update tests to expect text as previewable media type Amp-Thread-ID: https://ampcode.com/threads/T-019c0be0-9fc6-71ac-853a-7c7cc846b375 --------- --- comfy_execution/jobs.py | 53 ++++++++++++++++++++++++++++++------ tests/execution/test_jobs.py | 6 ++-- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 370014fb6..fcd7ef735 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -20,7 +20,7 @@ class JobStatus: # Media types that can be previewed in the frontend -PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'}) +PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'}) # 3D file extensions for preview fallback (no dedicated media_type exists) THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'}) @@ -75,6 +75,23 @@ def normalize_outputs(outputs: dict) -> dict: normalized[node_id] = normalized_node return normalized +# Text preview truncation limit (1024 characters) to prevent preview_output bloat +TEXT_PREVIEW_MAX_LENGTH = 1024 + + +def _create_text_preview(value: str) -> dict: + """Create a text preview dict with optional truncation. + + Returns: + dict with 'content' and optionally 'truncated' flag + """ + if len(value) <= TEXT_PREVIEW_MAX_LENGTH: + return {'content': value} + return { + 'content': value[:TEXT_PREVIEW_MAX_LENGTH], + 'truncated': True + } + def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: """Extract create_time and workflow_id from extra_data. @@ -221,23 +238,43 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]: continue for item in items: - normalized = normalize_output_item(item) - if normalized is None: - continue + if not isinstance(item, dict): + # Handle text outputs (non-dict items like strings or tuples) + normalized = normalize_output_item(item) + if normalized is None: + # Not a 3D file string — check for text preview + if media_type == 'text': + count += 1 + if preview_output is None: + if isinstance(item, tuple): + text_value = item[0] if item else '' + else: + text_value = str(item) + text_preview = _create_text_preview(text_value) + enriched = { + **text_preview, + 'nodeId': node_id, + 'mediaType': media_type + } + if fallback_preview is None: + fallback_preview = enriched + continue + # normalize_output_item returned a dict (e.g. 3D file) + item = normalized count += 1 if preview_output is not None: continue - if isinstance(normalized, dict) and is_previewable(media_type, normalized): + if is_previewable(media_type, item): enriched = { - **normalized, + **item, 'nodeId': node_id, } - if 'mediaType' not in normalized: + if 'mediaType' not in item: enriched['mediaType'] = media_type - if normalized.get('type') == 'output': + if item.get('type') == 'output': preview_output = enriched elif fallback_preview is None: fallback_preview = enriched diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 83c36fe48..814af5c13 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -38,13 +38,13 @@ class TestIsPreviewable: """Unit tests for is_previewable()""" def test_previewable_media_types(self): - """Images, video, audio, 3d media types should be previewable.""" - for media_type in ['images', 'video', 'audio', '3d']: + """Images, video, audio, 3d, text media types should be previewable.""" + for media_type in ['images', 'video', 'audio', '3d', 'text']: assert is_previewable(media_type, {}) is True def test_non_previewable_media_types(self): """Other media types should not be previewable.""" - for media_type in ['latents', 'text', 'metadata', 'files']: + for media_type in ['latents', 'metadata', 'files']: assert is_previewable(media_type, {}) is False def test_3d_extensions_previewable(self): From 4d79f4f0280da6c0a0e37123b9c80f24e2403536 Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Sun, 1 Mar 2026 10:38:30 -0700 Subject: [PATCH 30/81] fix: handle substep sigmas in context window set_step (#12719) Multi-step samplers (eg. dpmpp_2s_ancestral) call the model at intermediate sigma values not present in the schedule. This caused set_step to crash with "No sample_sigmas matched current timestep" when context windows were enabled. The fix is to keep self._step from the last exact match when a substep sigma is encountered, since substeps are still logically part of their parent step and should use the same context windows. Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com> --- comfy/context_windows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 2f82d51da..b54f7f39a 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC): mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: - raise Exception("No sample_sigmas matched current timestep; something went wrong.") + return # substep from multi-step sampler: keep self._step from the last full step self._step = int(matches[0].item()) def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: From c0d472e5b9b256d9e802ecac703bb6a8ca5f9eb8 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Mar 2026 11:14:56 -0800 Subject: [PATCH 31/81] comfy-aimdo 0.2.3 (#12720) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1b2bd0ae6..35fa3f18f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.2 +comfy-aimdo>=0.2.3 requests #non essential dependencies: From 602f6bd82c1f8b31d1b10b5f9ae4aa9637772ad5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 1 Mar 2026 12:28:39 -0800 Subject: [PATCH 32/81] Make --disable-smart-memory disable dynamic vram. (#12722) --- comfy/cli_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 13079c7bc..bfb61c825 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -260,4 +260,4 @@ else: args.fast = set(args.fast) def enables_dynamic_vram(): - return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu + return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu and not args.disable_smart_memory From dfbf99a06172a5c54002d80abf3e74c0d82c10b9 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Mar 2026 19:18:56 -0800 Subject: [PATCH 33/81] model_mangament: make dynamic --disable-smart-memory work (#12724) This was previously considering the pool of dynamic models as one giant entity for the sake of smart memory, but that isnt really the useful or what a user would reasonably expect. Make Dynamic VRAM properly purge its models just like the old --disable-smart-memory but conditioning the dynamic-for-dynamic bypass on smart memory. Re-enable dynamic smart memory. --- comfy/cli_args.py | 2 +- comfy/model_management.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index bfb61c825..13079c7bc 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -260,4 +260,4 @@ else: args.fast = set(args.fast) def enables_dynamic_vram(): - return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu and not args.disable_smart_memory + return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu diff --git a/comfy/model_management.py b/comfy/model_management.py index 86f840ada..c817d43b5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -639,12 +639,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ if not DISABLE_SMART_MEMORY: memory_to_free = memory_required - get_free_memory(device) ram_to_free = ram_required - get_free_ram() - - if current_loaded_models[i].model.is_dynamic() and for_dynamic: - #don't actually unload dynamic models for the sake of other dynamic models - #as that works on-demand. - memory_required -= current_loaded_models[i].model.loaded_size() - memory_to_free = 0 + if current_loaded_models[i].model.is_dynamic() and for_dynamic: + #don't actually unload dynamic models for the sake of other dynamic models + #as that works on-demand. + memory_required -= current_loaded_models[i].model.loaded_size() + memory_to_free = 0 if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) From 7175c11a4ed41278c9cb9e6961b8d8776ef69f00 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:21:41 -0800 Subject: [PATCH 34/81] comfy aimdo 0.2.4 (#12727) Comfy Aimdo 0.2.4 fixes a VRAM buffer alignment issue that happens in someworkflows where action is able to bypass the pytorch allocator and go straight to the cuda hook. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 35fa3f18f..71019c16f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.3 +comfy-aimdo>=0.2.4 requests #non essential dependencies: From afb54219fac341fa8614fdab090fe8096d0aec1e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 2 Mar 2026 09:24:33 +0200 Subject: [PATCH 35/81] feat(api-nodes): allow to use "IMAGE+TEXT" in NanoBanana2 (#12729) --- comfy_api_nodes/nodes_gemini.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 3fe804e0b..d83d2fc15 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -789,8 +789,6 @@ class GeminiImage2(IO.ComfyNode): validate_string(prompt, strip_whitespace=True, min_length=1) if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": model = "gemini-3.1-flash-image-preview" - if response_modalities == "IMAGE+TEXT": - raise ValueError("IMAGE+TEXT is not currently available for the Nano Banana 2 model.") parts: list[GeminiPart] = [GeminiPart(text=prompt)] if images is not None: @@ -895,7 +893,7 @@ class GeminiNanoBanana2(IO.ComfyNode): ), IO.Combo.Input( "response_modalities", - options=["IMAGE"], + options=["IMAGE", "IMAGE+TEXT"], advanced=True, ), IO.Combo.Input( @@ -925,6 +923,7 @@ class GeminiNanoBanana2(IO.ComfyNode): ], outputs=[ IO.Image.Output(), + IO.String.Output(), ], hidden=[ IO.Hidden.auth_token_comfy_org, From f1f8996e1562c3753666d1c568b2ff629edb9e36 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 3 Mar 2026 01:13:42 +0800 Subject: [PATCH 36/81] chore: update workflow templates to v0.9.5 (#12732) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 71019c16f..608b0cfa6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.4 +comfyui-workflow-templates==0.9.5 comfyui-embedded-docs==0.4.3 torch torchsde From 57dd6c1aadf500d90f635a8d3c15418c0d6d6ecd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 2 Mar 2026 15:54:18 -0800 Subject: [PATCH 37/81] Support loading zeta chroma weights properly. (#12734) --- comfy/model_detection.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3faa950ca..9f4a26e61 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -423,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" return dit_config - if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 + if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 dit_config = {} dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 @@ -533,8 +533,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config - if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1 - + if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1 dit_config = {} dit_config["image_model"] = "hunyuan3d2_1" dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] @@ -1055,6 +1054,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix) + elif 'noise_refiner.0.attention.norm_k.weight' in state_dict: + n_layers = count_blocks(state_dict, 'layers.{}.') + dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0] + sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix) + for k in state_dict: # For zeta chroma + if k not in sd_map: + sd_map[k] = k elif 'x_embedder.weight' in state_dict: #Flux depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') From 9ebee0a2179b361a24c20838c1848d7988320636 Mon Sep 17 00:00:00 2001 From: Lodestone Date: Tue, 3 Mar 2026 07:43:47 +0700 Subject: [PATCH 38/81] Feat: z-image pixel space (model still training atm) (#12709) * draft zeta (z-image pixel space) * revert gitignore * model loaded and able to run however vector direction still wrong tho * flip the vector direction to original again this time * Move wrongly positioned Z image pixel space class * inherit Radiance LatentFormat class * Fix parameters in classes for Zeta x0 dino * remove arbitrary nn.init instances * Remove unused import of lru_cache --------- Co-authored-by: silveroxides --- comfy/latent_formats.py | 7 + comfy/ldm/lumina/model.py | 265 ++++++++++++++++++++++++++++++++++++++ comfy/model_base.py | 5 + comfy/model_detection.py | 23 ++++ comfy/supported_models.py | 16 ++- 5 files changed, 315 insertions(+), 1 deletion(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f59999af6..6a57bca1c 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat): def process_out(self, latent): return latent + + +class ZImagePixelSpace(ChromaRadiance): + """Pixel-space latent format for ZImage DCT variant. + No VAE encoding/decoding — the model operates directly on RGB pixels. + """ + pass diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 77d1abc97..9e432d5c0 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension import comfy.utils +from comfy.ldm.chroma_radiance.layers import NerfEmbedder def invert_slices(slices, length): @@ -858,3 +859,267 @@ class NextDiT(nn.Module): img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] return -img + +############################################################################# +# Pixel Space Decoder Components # +############################################################################# + +def _modulate_shift_scale(x, shift, scale): + return x * (1 + scale) + shift + + +class PixelResBlock(nn.Module): + """ + Residual block with AdaLN modulation, zero-initialised so it starts as + an identity at the beginning of training. + """ + + def __init__(self, channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device) + self.mlp = nn.Sequential( + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1) + h = _modulate_shift_scale(self.in_ln(x), shift, scale) + h = self.mlp(h) + return x + gate * h + + +class DCTFinalLayer(nn.Module): + """Zero-initialised output projection (adopted from DiT).""" + + def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.norm_final(x)) + + +class SimpleMLPAdaLN(nn.Module): + """ + Small MLP decoder head for the pixel-space variant. + + Takes per-patch pixel values and a per-patch conditioning vector from the + transformer backbone and predicts the denoised pixel values. + + x : [B*N, P^2, C] – noisy pixel values per patch position + c : [B*N, dim] – backbone hidden state per patch (conditioning) + → [B*N, P^2, C] + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + z_channels: int, + num_res_blocks: int, + max_freqs: int = 8, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + + # Project backbone hidden state → per-patch conditioning + self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device) + + # Input projection with DCT positional encoding + self.input_embedder = NerfEmbedder( + in_channels=in_channels, + hidden_size_input=model_channels, + max_freqs=max_freqs, + dtype=dtype, + device=device, + operations=operations, + ) + + # Residual blocks + self.res_blocks = nn.ModuleList([ + PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks) + ]) + + # Output projection + self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + # x: [B*N, 1, P^2*C], c: [B*N, dim] + original_dtype = x.dtype + weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype) + x = self.input_embedder(x) # [B*N, 1, model_channels] + y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels] + x = x.to(weight_dtype) + for block in self.res_blocks: + x = block(x, y) + return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C] + + +############################################################################# +# NextDiT – Pixel Space # +############################################################################# + +class NextDiTPixelSpace(NextDiT): + """ + Pixel-space variant of NextDiT. + + Identical transformer backbone to NextDiT, but the output head is replaced + with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values + per patch rather than a single affine projection. + + Key differences vs NextDiT: + • ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead. + • ``_forward`` stores the raw patchified pixel values before the backbone + embedding and feeds them to ``dec_net`` together with the per-patch + backbone hidden states. + • Supports optional x0 prediction via ``use_x0``. + """ + + def __init__( + self, + # decoder-specific + decoder_hidden_size: int = 3840, + decoder_num_res_blocks: int = 4, + decoder_max_freqs: int = 8, + decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels) + use_x0: bool = False, + # all NextDiT args forwarded unchanged + **kwargs, + ): + super().__init__(**kwargs) + + # Remove the latent-space final layer – not used in pixel space + del self.final_layer + + patch_size = kwargs.get("patch_size", 2) + in_channels = kwargs.get("in_channels", 4) + dim = kwargs.get("dim", 4096) + + # decoder_in_channels is the full flattened patch: patch_size^2 * in_channels + dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels + + self.dec_net = SimpleMLPAdaLN( + in_channels=dec_in_ch, + model_channels=decoder_hidden_size, + out_channels=dec_in_ch, + z_channels=dim, + num_res_blocks=decoder_num_res_blocks, + max_freqs=decoder_max_freqs, + dtype=kwargs.get("dtype"), + device=kwargs.get("device"), + operations=kwargs.get("operations"), + ) + + if use_x0: + self.register_buffer("__x0__", torch.tensor([])) + + # ------------------------------------------------------------------ + # Forward — mirrors NextDiT._forward exactly, replacing final_layer + # with the pixel-space dec_net decoder. + # ------------------------------------------------------------------ + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) + + t = 1.0 - timesteps + cap_feats = context + cap_mask = attention_mask + bs, c, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) + adaln_input = t + + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + + # ---- capture raw pixel patches before patchify_and_embed embeds them ---- + pH = pW = self.patch_size + B, C, H, W = x.shape + pixel_patches = ( + x.view(B, C, H // pH, pH, W // pW, pW) + .permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C] + .flatten(3) # [B, Ht, Wt, pH*pW*C] + .flatten(1, 2) # [B, N, pH*pW*C] + ) + N = pixel_patches.shape[1] + # decoder sees one token per patch: [B*N, 1, P^2*C] + pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C) + + patches = transformer_options.get("patches", {}) + x_is_tensor = isinstance(x, torch.Tensor) + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed( + x, cap_feats, cap_mask, adaln_input, num_tokens, + ref_latents=ref_latents, ref_contexts=ref_contexts, + siglip_feats=siglip_feats, transformer_options=transformer_options + ) + freqs_cis = freqs_cis.to(img.device) + + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" + img_input = img + for i, layer in enumerate(self.layers): + transformer_options["block_index"] = i + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] + + # ---- pixel-space decoder (replaces final_layer + unpatchify) ---- + # img may have padding tokens beyond N; only the first N are real image patches + img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim] + decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim] + + output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C] + output = output.reshape(B, N, -1) # [B, N, P^2*C] + + # prepend zero cap placeholder so unpatchify indexing works unchanged + cap_placeholder = torch.zeros( + B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype + ) + img_out = self.unpatchify( + torch.cat([cap_placeholder, output], dim=1), + img_size, cap_size, return_tensor=x_is_tensor + )[:, :, :h, :w] + + return -img_out + + def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + # _forward returns neg_x0 = -x0 (negated decoder output). + # + # Reference inference (working_inference_reference.py): + # out = _forward(img, t) # = -x0 + # pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual] + # img += (t_prev - t_curr) * pred # Euler step + # + # ComfyUI's Euler sampler does the same: + # x_next = x + (sigma_next - sigma) * model_output + # So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t + neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) + + return (x - neg_x0) / timesteps.view(-1, 1, 1, 1) diff --git a/comfy/model_base.py b/comfy/model_base.py index a1c690b9b..1e01e9edc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1263,6 +1263,11 @@ class Lumina2(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class ZImagePixelSpace(Lumina2): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) + self.memory_usage_factor_conds = ("ref_latents",) + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 9f4a26e61..6eace4628 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -464,6 +464,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if sig_weight is not None: dit_config["siglip_feat_dim"] = sig_weight.shape[0] + dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix) + if dec_cond_key in state_dict_keys: # pixel-space variant + dit_config["image_model"] = "zimage_pixel" + # patch_size and in_channels are derived from x_embedder: + # x_embedder: Linear(patch_size * patch_size * in_channels, dim) + # The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim. + x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] + dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0] + # patch_size: infer from decoder final layer output matching x_embedder input + # in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2) + embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)] + dec_in_ch = dec_out # decoder in == decoder out (same pixel space) + dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3) + dit_config["in_channels"] = 3 + dit_config["decoder_in_channels"] = dec_in_ch + dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0] + dit_config["decoder_num_res_blocks"] = count_blocks( + state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.' + ) + dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5) + if '{}__x0__'.format(key_prefix) in state_dict_keys: + dit_config["use_x0"] = True + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4f63e8327..c0d3f387f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1118,6 +1118,20 @@ class ZImage(Lumina2): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect)) +class ZImagePixelSpace(ZImage): + unet_config = { + "image_model": "zimage_pixel", + } + + # Pixel-space model: no spatial compression, operates on raw RGB patches. + latent_format = latent_formats.ZImagePixelSpace + + # Much lower memory than latent-space models (no VAE, small patches). + memory_usage_factor = 0.05 # TODO: figure out the optimal value for this. + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ZImagePixelSpace(self, device=device) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -1720,6 +1734,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] From dff0a4a15887383c90a031e3fd48ebc41f6928e7 Mon Sep 17 00:00:00 2001 From: xeinherjer <112741359+xeinherjer-dev@users.noreply.github.com> Date: Tue, 3 Mar 2026 10:17:51 +0900 Subject: [PATCH 39/81] Fix VAEDecodeAudioTiled ignoring tile_size input (#12735) (#12738) --- comfy_extras/nodes_audio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 43df0512f..5d8d9bf6f 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -96,7 +96,7 @@ class VAEEncodeAudio(IO.ComfyNode): def vae_decode_audio(vae, samples, tile=None, overlap=None): if tile is not None: - audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1) + audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1) else: audio = vae.decode(samples["samples"]).movedim(-1, 1) From 09bcbddfcf804634f008f53c1827b7ba9a3956ec Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Mar 2026 08:50:33 -0800 Subject: [PATCH 40/81] ModelPatcherDynamic: Force load all non-comfy weights (#12739) * model_management: Remove non-comfy dynamic _v caster * Force pre-load non-comfy weights to GPU in ModelPatcherDynamic Non-comfy weights may expect to be pre-cast to the target device without in-model casting. Previously they were allocated in the vbar with _v which required the _v fault path in cast_to. Instead, back up the original CPU weight and move it directly to GPU at load time. --- comfy/model_management.py | 40 ------------------------------- comfy/model_patcher.py | 50 ++++++++++++++++----------------------- 2 files changed, 21 insertions(+), 69 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c817d43b5..0e0e96672 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -32,9 +32,6 @@ import comfy.memory_management import comfy.utils import comfy.quant_ops -import comfy_aimdo.torch -import comfy_aimdo.model_vbar - class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -1206,43 +1203,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): - if hasattr(weight, "_v"): - #Unexpected usage patterns. There is no reason these don't work but they - #have no testing and no callers do this. - assert r is None - assert stream is None - - cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ]) - - if dtype is None: - dtype = weight._model_dtype - - signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) - if signature is not None: - if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): - v_tensor = weight._v_tensor - else: - raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] - weight._v_tensor = v_tensor - weight._v_signature = signature - #Send it over - v_tensor.copy_(weight, non_blocking=non_blocking) - return v_tensor.to(dtype=dtype) - - r = torch.empty_like(weight, dtype=dtype, device=device) - - if weight.dtype != r.dtype and weight.dtype != weight._model_dtype: - #Offloaded casting could skip this, however it would make the quantizations - #inconsistent between loaded and offloaded weights. So force the double casting - #that would happen in regular flow to make offload deterministic. - cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device) - cast_buffer.copy_(weight, non_blocking=non_blocking) - weight = cast_buffer - r.copy_(weight, non_blocking=non_blocking) - - return r - if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3fc76d9db..e380e406b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1435,10 +1435,6 @@ class ModelPatcherDynamic(ModelPatcher): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): super().__init__(model, load_device, offload_device, size, weight_inplace_update) - #this is now way more dynamic and we dont support the same base model for both Dynamic - #and non-dynamic patchers. - if hasattr(self.model, "model_loaded_weight_memory"): - del self.model.model_loaded_weight_memory if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} self.non_dynamic_delegate_model = None @@ -1461,9 +1457,7 @@ class ModelPatcherDynamic(ModelPatcher): def loaded_size(self): vbar = self._vbar_get() - if vbar is None: - return 0 - return vbar.loaded_size() + return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory def get_free_memory(self, device): #NOTE: on high condition / batch counts, estimate should have already vacated @@ -1504,6 +1498,7 @@ class ModelPatcherDynamic(ModelPatcher): num_patches = 0 allocated_size = 0 + self.model.model_loaded_weight_memory = 0 with self.use_ejected(): self.unpatch_hooks() @@ -1512,10 +1507,6 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - #We force reserve VRAM for the non comfy-weight so we dont have to deal - #with pin and unpin syncrhonization which can be expensive for small weights - #with a high layer rate (e.g. autoregressive LLMs). - #prioritize the non-comfy weights (note the order reverse). loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) loading.sort(reverse=True) @@ -1558,6 +1549,9 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) self.patch_weight_to_device(key, device_to=device_to) + weight, _, _ = get_key_weight(self.model, key) + if weight is not None: + self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True @@ -1583,21 +1577,15 @@ class ModelPatcherDynamic(ModelPatcher): for param in params: key = key_param_name_to_key(n, param) weight, _, _ = get_key_weight(self.model, key) - weight.seed_key = key - set_dirty(weight, dirty) - geometry = weight - model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype - geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) - weight_size = geometry.numel() * geometry.element_size() - if vbar is not None and not hasattr(weight, "_v"): - weight._v = vbar.alloc(weight_size) - weight._model_dtype = model_dtype - allocated_size += weight_size - vbar.set_watermark_limit(allocated_size) + if key not in self.backup: + self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) + comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to)) + self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() move_weight_functions(m, device_to) - logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") self.model.device = device_to self.model.current_weight_patches_uuid = self.patches_uuid @@ -1613,7 +1601,16 @@ class ModelPatcherDynamic(ModelPatcher): assert self.load_device != torch.device("cpu") vbar = self._vbar_get() - return 0 if vbar is None else vbar.free_memory(memory_to_free) + freed = 0 if vbar is None else vbar.free_memory(memory_to_free) + + if freed < memory_to_free: + for key in list(self.backup.keys()): + bk = self.backup.pop(key) + comfy.utils.set_attr_param(self.model, key, bk.weight) + freed += self.model.model_loaded_weight_memory + self.model.model_loaded_weight_memory = 0 + + return freed def partially_unload_ram(self, ram_to_unload): loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) @@ -1640,11 +1637,6 @@ class ModelPatcherDynamic(ModelPatcher): for m in self.model.modules(): move_weight_functions(m, device_to) - keys = list(self.backup.keys()) - for k in keys: - bk = self.backup[k] - comfy.utils.set_attr_param(self.model, k, bk.weight) - def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above with self.use_ejected(skip_and_inject_on_exit_only=True): From 174fd6759deee5ea73e4cde4ba2936e8d62d8d66 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Mar 2026 08:51:15 -0800 Subject: [PATCH 41/81] main: Load aimdo after logger is setup (#12743) This was too early. Aimdo can use the logger in error paths and this causes a rogue default init if aimdo has something to log. --- main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index af701f8df..0f58d57b8 100644 --- a/main.py +++ b/main.py @@ -16,11 +16,6 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags -import comfy_aimdo.control - -if enables_dynamic_vram(): - comfy_aimdo.control.init() - if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' @@ -28,6 +23,11 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +import comfy_aimdo.control + +if enables_dynamic_vram(): + comfy_aimdo.control.init() + if os.name == "nt": os.environ['MIMALLOC_PURGE_DELAY'] = '0' From f719a9d928049f85b07b8ecc2259fba4832d37bb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Mar 2026 14:35:22 -0800 Subject: [PATCH 42/81] Adjust memory usage factor of zeta model. (#12746) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c0d3f387f..07feb31b3 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1127,7 +1127,7 @@ class ZImagePixelSpace(ZImage): latent_format = latent_formats.ZImagePixelSpace # Much lower memory than latent-space models (no VAE, small patches). - memory_usage_factor = 0.05 # TODO: figure out the optimal value for this. + memory_usage_factor = 0.03 # TODO: figure out the optimal value for this. def get_model(self, state_dict, prefix="", device=None): return model_base.ZImagePixelSpace(self, device=device) From b6ddc590ed8dafd50df8aad1e626b78276a690c0 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Tue, 3 Mar 2026 19:58:53 -0500 Subject: [PATCH 43/81] CURVE type (#12581) * CURVE type * fix: update typed wrapper unwrap keys to __type__ and __value__ * code improve * code improve --- comfy_api/latest/_io.py | 14 ++++++++++++++ execution.py | 14 ++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 189d7d9bc..050031dc0 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1240,6 +1240,19 @@ class BoundingBox(ComfyTypeIO): return d +@comfytype(io_type="CURVE") +class Curve(ComfyTypeIO): + CurvePoint = tuple[float, float] + Type = list[CurvePoint] + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = [(0.0, 0.0), (1.0, 1.0)] + + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -2226,5 +2239,6 @@ __all__ = [ "PriceBadgeDepends", "PriceBadge", "BoundingBox", + "Curve", "NodeReplace", ] diff --git a/execution.py b/execution.py index 75b021892..7ccdbf93e 100644 --- a/execution.py +++ b/execution.py @@ -876,12 +876,14 @@ async def validate_inputs(prompt_id, prompt, item, validated): continue else: try: - # Unwraps values wrapped in __value__ key. This is used to pass - # list widget value to execution, as by default list value is - # reserved to represent the connection between nodes. - if isinstance(val, dict) and "__value__" in val: - val = val["__value__"] - inputs[x] = val + # Unwraps values wrapped in __value__ key or typed wrapper. + # This is used to pass list widget values to execution, + # as by default list value is reserved to represent the + # connection between nodes. + if isinstance(val, dict): + if "__value__" in val: + val = val["__value__"] + inputs[x] = val if input_type == "INT": val = int(val) From ac6513e142f881202c40eacc5e337982b777ccd0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Mar 2026 18:19:40 -0800 Subject: [PATCH 44/81] DynamicVram: Add casting / fix torch Buffer weights (#12749) * respect model dtype in non-comfy caster * utils: factor out parent and name functionality of set_attr * utils: implement set_attr_buffer for torch buffers * ModelPatcherDynamic: Implement torch Buffer loading If there is a buffer in dynamic - force load it. --- comfy/model_management.py | 2 ++ comfy/model_patcher.py | 22 ++++++++++++++++++---- comfy/utils.py | 19 +++++++++++++++---- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0e0e96672..0f5966371 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -796,6 +796,8 @@ def archive_model_dtypes(model): for name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + for buf_name, buf in module.named_buffers(recurse=False): + setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype) def cleanup_models(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e380e406b..70f78a089 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -241,6 +241,7 @@ class ModelPatcher: self.patches = {} self.backup = {} + self.backup_buffers = {} self.object_patches = {} self.object_patches_backup = {} self.weight_wrapper_patches = {} @@ -309,7 +310,7 @@ class ModelPatcher: return comfy.model_management.get_free_memory(device) def get_clone_model_override(self): - return self.model, (self.backup, self.object_patches_backup, self.pinned) + return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ @@ -336,7 +337,7 @@ class ModelPatcher: n.force_cast_weights = self.force_cast_weights - n.backup, n.object_patches_backup, n.pinned = model_override[1] + n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1] # attachments n.attachments = {} @@ -1579,11 +1580,22 @@ class ModelPatcherDynamic(ModelPatcher): weight, _, _ = get_key_weight(self.model, key) if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) - comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to)) - self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() + model_dtype = getattr(m, param + "_comfy_model_dtype", None) + casted_weight = weight.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_param(self.model, key, casted_weight) + self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size() move_weight_functions(m, device_to) + for key, buf in self.model.named_buffers(recurse=True): + if key not in self.backup_buffers: + self.backup_buffers[key] = buf + module, buf_name = comfy.utils.resolve_attr(self.model, key) + model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None) + casted_buf = buf.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_buffer(self.model, key, casted_buf) + self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size() + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") @@ -1607,6 +1619,8 @@ class ModelPatcherDynamic(ModelPatcher): for key in list(self.backup.keys()): bk = self.backup.pop(key) comfy.utils.set_attr_param(self.model, key, bk.weight) + for key in list(self.backup_buffers.keys()): + comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key)) freed += self.model.model_loaded_weight_memory self.model.model_loaded_weight_memory = 0 diff --git a/comfy/utils.py b/comfy/utils.py index 0769cef44..6e1d14419 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): ATTR_UNSET={} -def set_attr(obj, attr, value): +def resolve_attr(obj, attr): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1], ATTR_UNSET) + return obj, attrs[-1] + +def set_attr(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) if value is ATTR_UNSET: - delattr(obj, attrs[-1]) + delattr(obj, name) else: - setattr(obj, attrs[-1], value) + setattr(obj, name, value) return prev def set_attr_param(obj, attr, value): return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) +def set_attr_buffer(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) + persistent = name not in getattr(obj, "_non_persistent_buffers_set", set()) + obj.register_buffer(name, value, persistent=persistent) + return prev + def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".") From eb011733b6e4d8a9f7b67a1787d817bfc8c0a5b4 Mon Sep 17 00:00:00 2001 From: Arthur R Longbottom Date: Tue, 3 Mar 2026 21:29:00 -0800 Subject: [PATCH 45/81] Fix VideoFromComponents.save_to crash when writing to BytesIO (#12683) * Fix VideoFromComponents.save_to crash when writing to BytesIO When `get_container_format()` or `get_stream_source()` is called on a tensor-based video (VideoFromComponents), it calls `save_to(BytesIO())`. Since BytesIO has no file extension, `av.open` can't infer the output format and throws `ValueError: Could not determine output format`. The sibling class `VideoFromFile` already handles this correctly via `get_open_write_kwargs()`, which detects BytesIO and sets the format explicitly. `VideoFromComponents` just never got the same treatment. This surfaces when any downstream node validates the container format of a tensor-based video, like TopazVideoEnhance or any node that calls `validate_container_format_is_mp4()`. Three-line fix in `comfy_api/latest/_input_impl/video_types.py`. * Add docstring to save_to to satisfy CI coverage check --- comfy_api/latest/_input_impl/video_types.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index a3d48c87f..58a37c9e8 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput): codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, ): + """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: @@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput): extra_kwargs = {} if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value + elif isinstance(path, io.BytesIO): + # BytesIO has no file extension, so av.open can't infer the format. + # Default to mp4 since that's the only supported format anyway. + extra_kwargs["format"] = "mp4" with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams if metadata is not None: From d531e3fb2a885d675d5b6d3a496b4af5d9757af1 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:47:44 -0800 Subject: [PATCH 46/81] model_patcher: Improve dynamic offload heuristic (#12759) Define a threshold below which a weight loading takes priority. This actually makes the offload consistent with non-dynamic, because what happens, is when non-dynamic fills ints to_load list, it will fill-up any left-over pieces that could fix large weights with small weights and load them, even though they were lower priority. This actually improves performance because the timy weights dont cost any VRAM and arent worth the control overhead of the DMA etc. --- comfy/model_patcher.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 70f78a089..168ce8430 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -699,7 +699,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self, prio_comfy_cast_weights=False, default_device=None): + def _load_list(self, for_dynamic=False, default_device=None): loading = [] for n, m in self.model.named_modules(): default = False @@ -727,8 +727,13 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () - loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) + # Dynamic: small weights (<64KB) first, then larger weights prioritized by size. + # Non-dynamic: prioritize by module offload cost. + if for_dynamic: + sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem) + else: + sort_criteria = (module_offload_mem,) + loading.append(sort_criteria + (module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -1508,11 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) - loading.sort(reverse=True) + loading = self._load_list(for_dynamic=True, default_device=device_to) + loading.sort() for x in loading: - _, _, _, n, m, params = x + *_, module_mem, n, m, params = x def set_dirty(item, dirty): if dirty or not hasattr(item, "_v_signature"): @@ -1627,9 +1632,9 @@ class ModelPatcherDynamic(ModelPatcher): return freed def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) + loading = self._load_list(for_dynamic=True, default_device=self.offload_device) for x in loading: - _, _, _, _, m, _ = x + *_, m, _ = x ram_to_unload -= comfy.pinned_memory.unpin_memory(m) if ram_to_unload <= 0: return From 9b85cf955858b0aca6b7b30c30b404470ea0c964 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:49:13 -0800 Subject: [PATCH 47/81] Comfy Aimdo 0.2.5 + Fix offload performance in DynamicVram (#12754) * ops: dont unpin nothing This was calling into aimdo in the none case (offloaded weight). Whats worse, is aimdo syncs for unpinning an offloaded weight, as that is the corner case of a weight getting evicted by its own use which does require a sync. But this was heppening every offloaded weight causing slowdown. * mp: fix get_free_memory policy The ModelPatcherDynamic get_free_memory was deducting the model from to try and estimate the conceptual free memory with doing any offloading. This is kind of what the old memory_memory_required was estimating in ModelPatcher load logic, however in practical reality, between over-estimates and padding, the loader usually underloaded models enough such that sampling could send CFG +/- through together even when partially loaded. So don't regress from the status quo and instead go all in on the idea that offloading is less of an issue than debatching. Tell the sampler it can use everything. --- comfy/model_patcher.py | 14 +++++++------- comfy/ops.py | 4 ++-- requirements.txt | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 168ce8430..7e5ad7aa4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -307,7 +307,13 @@ class ModelPatcher: return self.model.lowvram_patch_counter def get_free_memory(self, device): - return comfy.model_management.get_free_memory(device) + #Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In + #the vast majority of setups a little bit of offloading on the giant model more + #than pays for CFG. So return everything both torch and Aimdo could give us + aimdo_mem = 0 + if comfy.memory_management.aimdo_enabled: + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) @@ -1465,12 +1471,6 @@ class ModelPatcherDynamic(ModelPatcher): vbar = self._vbar_get() return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory - def get_free_memory(self, device): - #NOTE: on high condition / batch counts, estimate should have already vacated - #all non-dynamic models so this is safe even if its not 100% true that this - #would all be avaiable for inference use. - return comfy.model_management.get_total_memory(device) - self.model_size() - #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. def pin_weight_to_device(self, key): diff --git a/comfy/ops.py b/comfy/ops.py index 6ee6075fb..8275dd0a5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -269,8 +269,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream): return os, weight_a, bias_a = offload_stream device=None - #FIXME: This is not good RTTI - if not isinstance(weight_a, torch.Tensor): + #FIXME: This is really bad RTTI + if weight_a is not None and not isinstance(weight_a, torch.Tensor): comfy_aimdo.model_vbar.vbar_unpin(s._v) device = weight_a if os is None: diff --git a/requirements.txt b/requirements.txt index 608b0cfa6..110568cd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.4 +comfy-aimdo>=0.2.5 requests #non essential dependencies: From 0a7446ade4bbeecfaf36e9a70eeabbeb0f6e59ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:59:56 +0200 Subject: [PATCH 48/81] Pass tokens when loading text gen model for text generation (#12755) Co-authored-by: Jedrzej Kosinski --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index a9ad7c2d2..8bcd09582 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -428,7 +428,7 @@ class CLIP: def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): self.cond_stage_model.reset_clip_options() - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) From 8811db52db5d0aea49c1dbedd733a6b9304b83a9 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 12:12:37 -0800 Subject: [PATCH 49/81] comfy-aimdo 0.2.6 (#12764) Comfy Aimdo 0.2.6 fixes a GPU virtual address leak. This would manfiest as an error after a number of workflow runs. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 110568cd3..dae46d873 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.5 +comfy-aimdo>=0.2.6 requests #non essential dependencies: From ac4a943ff364885166def5d418582db971554caf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:33:14 -0800 Subject: [PATCH 50/81] Initial load device should be cpu when using dynamic vram. (#12766) --- comfy/model_management.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0f5966371..809600815 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -830,11 +830,14 @@ def unet_offload_device(): return torch.device("cpu") def unet_inital_load_device(parameters, dtype): + cpu_dev = torch.device("cpu") + if comfy.memory_management.aimdo_enabled: + return cpu_dev + torch_dev = get_torch_device() if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: return torch_dev - cpu_dev = torch.device("cpu") if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM: return cpu_dev @@ -842,7 +845,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled: + if mem_dev > mem_cpu and model_size < mem_dev: return torch_dev else: return cpu_dev @@ -945,6 +948,9 @@ def text_encoder_device(): return torch.device("cpu") def text_encoder_initial_device(load_device, offload_device, model_size=0): + if comfy.memory_management.aimdo_enabled: + return offload_device + if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device From 43c64b6308f93c331f057e12799bad0a68be5117 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 17:06:20 -0800 Subject: [PATCH 51/81] Support the LTXAV 2.3 model. (#12773) --- comfy/ldm/lightricks/av_model.py | 185 ++++++- comfy/ldm/lightricks/embeddings_connector.py | 4 + comfy/ldm/lightricks/model.py | 186 ++++++- comfy/ldm/lightricks/vae/audio_vae.py | 7 +- .../vae/causal_audio_autoencoder.py | 67 +-- .../vae/causal_video_autoencoder.py | 48 +- comfy/ldm/lightricks/vocoders/vocoder.py | 523 +++++++++++++++++- comfy/model_base.py | 2 +- comfy/sd.py | 2 +- comfy/text_encoders/lt.py | 68 ++- 10 files changed, 959 insertions(+), 133 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 553fd5b38..08d686b7b 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -2,11 +2,16 @@ from typing import Tuple import torch import torch.nn as nn from comfy.ldm.lightricks.model import ( + ADALN_BASE_PARAMS_COUNT, + ADALN_CROSS_ATTN_PARAMS_COUNT, CrossAttention, FeedForward, AdaLayerNormSingle, PixArtAlphaTextProjection, + NormSingleLinearTextProjection, LTXVModel, + apply_cross_attention_adaln, + compute_prompt_timestep, ) from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module): v_context_dim=None, a_context_dim=None, attn_precision=None, + apply_gated_attention=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=v_dim, @@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=vd_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=ad_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module): heads=v_heads, dim_head=vd_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module): a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations ) - self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype)) self.audio_scale_shift_table = nn.Parameter( - torch.empty(6, a_dim, device=device, dtype=dtype) + torch.empty(num_ada_params, a_dim, device=device, dtype=dtype) ) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype)) + self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype)) + self.scale_shift_table_a2v_ca_audio = nn.Parameter( torch.empty(5, a_dim, device=device, dtype=dtype) ) @@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module): return (*scale_shift_ada_values, *gate_ada_values) + def _apply_text_cross_attention( + self, x, context, attn, scale_shift_table, prompt_scale_shift_table, + timestep, prompt_timestep, attention_mask, transformer_options, + ): + """Apply text cross-attention, with optional ADaLN modulation.""" + if self.cross_attention_adaln: + shift_q, scale_q, gate = self.get_ada_values( + scale_shift_table, x.shape[0], timestep, slice(6, 9) + ) + return apply_cross_attention_adaln( + x, context, attn, shift_q, scale_q, gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask, transformer_options, + ) + return attn( + comfy.ldm.common_dit.rms_norm(x), context=context, + mask=attention_mask, transformer_options=transformer_options, + ) + def forward( self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None, + v_prompt_timestep=None, a_prompt_timestep=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module): vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] vx.addcmul_(attn1_out, vgate_msa) del vgate_msa, attn1_out - vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options)) + vx.add_(self._apply_text_cross_attention( + vx, v_context, self.attn2, self.scale_shift_table, + getattr(self, 'prompt_scale_shift_table', None), + v_timestep, v_prompt_timestep, attention_mask, transformer_options,) + ) # audio if run_ax: @@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module): agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0] ax.addcmul_(attn1_out, agate_msa) del agate_msa, attn1_out - ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options)) + ax.add_(self._apply_text_cross_attention( + ax, a_context, self.audio_attn2, self.audio_scale_shift_table, + getattr(self, 'audio_prompt_scale_shift_table', None), + a_timestep, a_prompt_timestep, attention_mask, transformer_options,) + ) # video - audio cross attention. if run_a2v or run_v2a: @@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel): use_middle_indices_grid=False, timestep_scale_multiplier=1000.0, av_ca_timestep_scale_multiplier=1.0, + apply_gated_attention=False, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel): self.audio_attention_head_dim = audio_attention_head_dim self.audio_num_attention_heads = audio_num_attention_heads self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.apply_gated_attention = apply_gated_attention # Calculate audio dimensions self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim @@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel): ) # Audio-specific AdaLN + audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.audio_adaln_single = AdaLayerNormSingle( self.audio_inner_dim, + embedding_coefficient=audio_embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations, ) + if self.cross_attention_adaln: + self.audio_prompt_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=2, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_prompt_adaln_single = None + num_scale_shift_values = 4 self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( self.inner_dim, @@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel): ) # Audio caption projection - self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.audio_inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.audio_caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_caption_projection = lambda a: a + else: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + connector_split_rope = kwargs.get("rope_type", "split") == "split" + connector_gated_attention = kwargs.get("connector_apply_gated_attention", False) + attention_head_dim = kwargs.get("connector_attention_head_dim", 128) + num_attention_heads = kwargs.get("connector_num_attention_heads", 30) + num_layers = kwargs.get("connector_num_layers", 2) self.audio_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim), + num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads), + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) self.video_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) - def preprocess_text_embeds(self, context): - if context.shape[-1] == self.caption_channels * 2: - return context - out_vid = self.video_embeddings_connector(context)[0] - out_audio = self.audio_embeddings_connector(context)[0] + def preprocess_text_embeds(self, context, unprocessed=False): + # LTXv2 fully processed context has dimension of self.caption_channels * 2 + # LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim + if not unprocessed: + if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2): + return context + if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim: + context_vid = context[:, :, :self.cross_attention_dim] + context_audio = context[:, :, self.cross_attention_dim:] + else: + context_vid = context + context_audio = context + if self.caption_proj_before_connector: + context_vid = self.caption_projection(context_vid) + context_audio = self.audio_caption_projection(context_audio) + out_vid = self.video_embeddings_connector(context_vid)[0] + out_audio = self.audio_embeddings_connector(context_audio)[0] return torch.concat((out_vid, out_audio), dim=-1) def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel): ad_head=self.audio_attention_head_dim, v_context_dim=self.cross_attention_dim, a_context_dim=self.audio_cross_attention_dim, + apply_gated_attention=self.apply_gated_attention, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel): v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) + v_prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + # Prepare audio timestep a_timestep = kwargs.get("a_timestep") if a_timestep is not None: @@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel): # Cross-attention timesteps - compress these too av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( - a_timestep_flat, + timestep.max().expand_as(a_timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( - timestep_flat, + a_timestep.max().expand_as(timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( - timestep_flat * av_ca_factor, + a_timestep.max().expand_as(timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( - a_timestep_flat * av_ca_factor, + timestep.max().expand_as(a_timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel): # Audio timesteps a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) + + a_prompt_timestep = compute_prompt_timestep( + self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype + ) else: a_timestep = timestep_scaled a_embedded_timestep = kwargs.get("embedded_timestep") cross_av_timestep_ss = [] + a_prompt_timestep = None - return [v_timestep, a_timestep, cross_av_timestep_ss], [ + return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [ v_embedded_timestep, a_embedded_timestep, - ] + ], None def _prepare_context(self, context, batch_size, x, attention_mask=None): vx = x[0] ax = x[1] + video_dim = vx.shape[-1] + audio_dim = ax.shape[-1] + + v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim + a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim + v_context, a_context = torch.split( - context, int(context.shape[-1] / 2), len(context.shape) - 1 + context, [v_context_dim, a_context_dim], len(context.shape) - 1 ) v_context, attention_mask = super()._prepare_context( v_context, batch_size, vx, attention_mask ) - if self.audio_caption_projection is not None: + if self.caption_proj_before_connector is False: a_context = self.audio_caption_projection(a_context) - a_context = a_context.view(batch_size, -1, ax.shape[-1]) + a_context = a_context.view(batch_size, -1, audio_dim) return [v_context, a_context], attention_mask @@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel): av_ca_v2a_gate_noise_timestep, ) = timestep[2] + v_prompt_timestep = timestep[3] + a_prompt_timestep = timestep[4] + """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=args["a_cross_gate_timestep"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), + v_prompt_timestep=args.get("v_prompt_timestep"), + a_prompt_timestep=args.get("a_prompt_timestep"), ) return out @@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel): "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, + "v_prompt_timestep": v_prompt_timestep, + "a_prompt_timestep": a_prompt_timestep, }, {"original_block": block_wrap}, ) @@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + v_prompt_timestep=v_prompt_timestep, + a_prompt_timestep=a_prompt_timestep, ) return [vx, ax] diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index 33adb9671..2811080be 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module): d_head, context_dim=None, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module): heads=n_heads, dim_head=d_head, context_dim=None, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module): positional_embedding_max_pos=[4096], causal_temporal_positioning=False, num_learnable_registers: Optional[int] = 128, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module): num_attention_heads, attention_head_dim, context_dim=cross_attention_dim, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 60d760d29..bfbc08357 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module): return hidden_states +class NormSingleLinearTextProjection(nn.Module): + """Text projection for 20B models - single linear with RMSNorm (no activation).""" + + def __init__( + self, in_features, hidden_size, dtype=None, device=None, operations=None + ): + super().__init__() + if operations is None: + operations = comfy.ops.disable_weight_init + self.in_norm = operations.RMSNorm( + in_features, eps=1e-6, elementwise_affine=False + ) + self.linear_1 = operations.Linear( + in_features, hidden_size, bias=True, dtype=dtype, device=device + ) + self.hidden_size = hidden_size + self.in_features = in_features + + def forward(self, caption): + caption = self.in_norm(caption) + caption = caption * (self.hidden_size / self.in_features) ** 0.5 + return self.linear_1(caption) + + class GELU_approx(nn.Module): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): super().__init__() @@ -343,6 +367,7 @@ class CrossAttention(nn.Module): dim_head=64, dropout=0.0, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -362,6 +387,12 @@ class CrossAttention(nn.Module): self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) + # Optional per-head gating + if apply_gated_attention: + self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device) + else: + self.to_gate_logits = None + self.to_out = nn.Sequential( operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -383,16 +414,30 @@ class CrossAttention(nn.Module): out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) + + # Apply per-head gating if enabled + if self.to_gate_logits is not None: + gate_logits = self.to_gate_logits(x) # (B, T, H) + b, t, _ = out.shape + out = out.view(b, t, self.heads, self.dim_head) + gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity + out = out * gates.unsqueeze(-1) + out = out.view(b, t, self.heads * self.dim_head) + return self.to_out(out) +# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate) +ADALN_BASE_PARAMS_COUNT = 6 +ADALN_CROSS_ATTN_PARAMS_COUNT = 9 class BasicTransformerBlock(nn.Module): def __init__( - self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None ): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, @@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module): operations=operations, ) - self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype)) - attn1_input = comfy.ldm.common_dit.rms_norm(x) - attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) - attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) - x.addcmul_(attn1_input, gate_msa) - del attn1_input + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2) - x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa + + if self.cross_attention_adaln: + shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2) + x += apply_cross_attention_adaln( + x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca, + self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options, + ) + else: + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) @@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module): return x +def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype): + """Compute a single global prompt timestep for cross-attention ADaLN. + + Uses the max across tokens (matching JAX max_per_segment) and broadcasts + over text tokens. Returns None when *adaln_module* is None. + """ + if adaln_module is None: + return None + ts_input = ( + timestep_scaled.max(dim=1, keepdim=True).values.flatten() + if timestep_scaled.dim() > 1 + else timestep_scaled.flatten() + ) + prompt_ts, _ = adaln_module( + ts_input, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1]) + + +def apply_cross_attention_adaln( + x, context, attn, q_shift, q_scale, q_gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask=None, transformer_options={}, +): + """Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV). + + Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so + that both regular tensors and CompressedTimestep are supported. + """ + batch_size = x.shape[0] + shift_kv, scale_kv = ( + prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) + ).unbind(dim=2) + attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift + encoder_hidden_states = context * (1 + scale_kv) + shift_kv + return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate + def get_fractional_positions(indices_grid, max_pos): n_pos_dims = indices_grid.shape[1] assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' @@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC): vae_scale_factors: tuple = (8, 32, 32), use_middle_indices_grid=False, timestep_scale_multiplier = 1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, + caption_projection_first_linear=True, dtype=None, device=None, operations=None, @@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC): self.causal_temporal_positioning = causal_temporal_positioning self.operations = operations self.timestep_scale_multiplier = timestep_scale_multiplier + self.caption_proj_before_connector = caption_proj_before_connector + self.cross_attention_adaln = cross_attention_adaln + self.caption_projection_first_linear = caption_projection_first_linear # Common dimensions self.inner_dim = num_attention_heads * attention_head_dim @@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC): self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device ) + embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) - self.caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.cross_attention_adaln: + self.prompt_adaln_single = AdaLayerNormSingle( + self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + ) + else: + self.prompt_adaln_single = None + + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.caption_projection = lambda a: a + else: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) @abstractmethod def _init_model_components(self, device, dtype, **kwargs): @@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC): if grid_mask is not None: timestep = timestep[:, grid_mask] - timestep = timestep * self.timestep_scale_multiplier + timestep_scaled = timestep * self.timestep_scale_multiplier timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), + timestep_scaled.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC): timestep = timestep.view(batch_size, -1, timestep.shape[-1]) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) - return timestep, embedded_timestep + prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + + return timestep, embedded_timestep, prompt_timestep def _prepare_context(self, context, batch_size, x, attention_mask=None): """Prepare context for transformer blocks.""" - if self.caption_projection is not None: + if self.caption_proj_before_connector is False: context = self.caption_projection(context) - context = context.view(batch_size, -1, x.shape[-1]) + context = context.view(batch_size, -1, x.shape[-1]) return context, attention_mask def _precompute_freqs_cis( @@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC): merged_args.update(additional_args) # Prepare timestep and context - timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + merged_args["prompt_timestep"] = prompt_timestep context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) # Prepare attention mask and positional embeddings @@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel): causal_temporal_positioning=False, vae_scale_factors=(8, 32, 32), use_middle_indices_grid=False, - timestep_scale_multiplier = 1000.0, + timestep_scale_multiplier=1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel): def _init_model_components(self, device, dtype, **kwargs): """Initialize LTXV-specific components.""" - # No additional components needed for LTXV beyond base class pass def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel): self.num_attention_heads, self.attention_head_dim, context_dim=self.cross_attention_dim, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prompt_timestep = kwargs.get("prompt_timestep", None) for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask")) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep")) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel): pe=pe, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + prompt_timestep=prompt_timestep, ) return x diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index 55a074661..fa0a00748 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( CausalityAxis, CausalAudioAutoencoder, ) -from comfy.ldm.lightricks.vocoders.vocoder import Vocoder +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE LATENT_DOWNSAMPLE_FACTOR = 4 @@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module): vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) - self.vocoder = Vocoder(config=component_config.vocoder) + if "bwe" in component_config.vocoder: + self.vocoder = VocoderWithBWE(config=component_config.vocoder) + else: + self.vocoder = Vocoder(config=component_config.vocoder) self.autoencoder.load_state_dict(vae_sd, strict=False) self.vocoder.load_state_dict(vocoder_sd, strict=False) diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py index f12b9bb53..b556b128f 100644 --- a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module): super().__init__() if config is None: - config = self._guess_config() + config = self.get_default_config() - # Extract encoder and decoder configs from the new format model_config = config.get("model", {}).get("params", {}) - variables_config = config.get("variables", {}) - self.sampling_rate = variables_config.get( - "sampling_rate", - model_config.get("sampling_rate", config.get("sampling_rate", 16000)), + self.sampling_rate = model_config.get( + "sampling_rate", config.get("sampling_rate", 16000) ) encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) decoder_config = model_config.get("decoder", encoder_config) # Load mel spectrogram parameters self.mel_bins = encoder_config.get("mel_bins", 64) - self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) - self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) # Store causality configuration at VAE level (not just in encoder internals) - causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value) self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) self.is_causal = self.causality_axis == CausalityAxis.HEIGHT @@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module): self.per_channel_statistics = processor() - def _guess_config(self): - encoder_config = { - # Required parameters - based on ltx-video-av-1679000 model metadata - "ch": 128, - "out_ch": 8, - "ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8] - "num_res_blocks": 2, - "attn_resolutions": [], # Based on metadata: empty list, no attention - "dropout": 0.0, - "resamp_with_conv": True, - "in_channels": 2, # stereo - "resolution": 256, - "z_channels": 8, + def get_default_config(self): + ddconfig = { "double_z": True, - "attn_type": "vanilla", - "mid_block_add_attention": False, # Based on metadata: false + "mel_bins": 64, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 2, + "out_ch": 2, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + "mid_block_add_attention": False, "norm_type": "pixel", - "causality_axis": "height", # Based on metadata - "mel_bins": 64, # Based on metadata: mel_bins = 64 - } - - decoder_config = { - # Inherits encoder config, can override specific params - **encoder_config, - "out_ch": 2, # Stereo audio output (2 channels) - "give_pre_end": False, - "tanh_out": False, + "causality_axis": "height", } config = { - "_class_name": "CausalAudioAutoencoder", - "sampling_rate": 16000, "model": { "params": { - "encoder": encoder_config, - "decoder": decoder_config, + "ddconfig": ddconfig, + "sampling_rate": 16000, } }, + "preprocessing": { + "stft": { + "filter_length": 1024, + "hop_length": 160, + }, + }, } return config diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index cbfdf412d..5b57dfc5e 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +def in_meta_context(): + return torch.device("meta") == torch.empty(0).device + def mark_conv3d_ended(module): tid = threading.get_ident() for _, m in module.named_modules(): @@ -350,6 +353,10 @@ class Decoder(nn.Module): output_channel = output_channel * block_params.get("multiplier", 2) if block_name == "compress_all": output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_space": + output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_time": + output_channel = output_channel * block_params.get("multiplier", 1) self.conv_in = make_conv_nd( dims, @@ -395,17 +402,21 @@ class Decoder(nn.Module): spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(2, 1, 1), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(1, 2, 2), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all": @@ -455,6 +466,15 @@ class Decoder(nn.Module): output_channel * 2, 0, operations=ops, ) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) + else: + self.register_buffer( + "last_scale_shift_table", + torch.tensor( + [0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(2, output_channel), + persistent=False, + ) # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: @@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module): self.scale_shift_table = nn.Parameter( torch.randn(4, in_channels) / in_channels**0.5 ) + else: + self.register_buffer( + "scale_shift_table", + torch.tensor( + [0.0, 0.0, 0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(4, in_channels), + persistent=False, + ) self.temporal_cache_state={} @@ -1012,9 +1041,6 @@ class processor(nn.Module): super().__init__() self.register_buffer("std-of-means", torch.empty(128)) self.register_buffer("mean-of-means", torch.empty(128)) - self.register_buffer("mean-of-stds", torch.empty(128)) - self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128)) - self.register_buffer("channel", torch.empty(128)) def un_normalize(self, x): return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) @@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module): super().__init__() if config is None: - config = self.guess_config(version) + config = self.get_default_config(version) + self.config = config self.timestep_conditioning = config.get("timestep_conditioning", False) + self.decode_noise_scale = config.get("decode_noise_scale", 0.025) + self.decode_timestep = config.get("decode_timestep", 0.05) double_z = config.get("double_z", True) latent_log_var = config.get( "latent_log_var", "per_channel" if double_z else "none" @@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module): latent_log_var=latent_log_var, norm_layer=config.get("norm_layer", "group_norm"), spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + base_channels=config.get("encoder_base_channels", 128), ) self.decoder = Decoder( @@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module): in_channels=config["latent_channels"], out_channels=config.get("out_channels", 3), blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), + base_channels=config.get("decoder_base_channels", 128), patch_size=config.get("patch_size", 1), norm_layer=config.get("norm_layer", "group_norm"), causal=config.get("causal_decoder", False), @@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module): self.per_channel_statistics = processor() - def guess_config(self, version): + def get_default_config(self, version): if version == 0: config = { "_class_name": "CausalVideoAutoencoder", @@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module): means, logvar = torch.chunk(self.encoder(x), 2, dim=1) return self.per_channel_statistics.normalize(means) - def decode(self, x, timestep=0.05, noise_scale=0.025): + def decode(self, x): if self.timestep_conditioning: #TODO: seed - x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x - return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep) - + x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x + return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index b1f15f2c5..6c4028aa8 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -3,6 +3,7 @@ import torch.nn.functional as F import torch.nn as nn import comfy.ops import numpy as np +import math ops = comfy.ops.disable_weight_init @@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) +# --------------------------------------------------------------------------- +# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2 +# Adopted from https://github.com/NVIDIA/BigVGAN +# --------------------------------------------------------------------------- + + +def _sinc(x: torch.Tensor): + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride=1, + padding=True, + padding_mode="replicate", + kernel_size=12, + ): + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + def forward(self, x): + _, C, _ = x.shape + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"): + super().__init__() + self.ratio = ratio + self.stride = ratio + + if window_type == "hann": + # Hann-windowed sinc filter — identical to torchaudio.functional.resample + # with its default parameters (rolloff=0.99, lowpass_filter_width=6). + # Uses replicate boundary padding, matching the reference resampler exactly. + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + t = (torch.arange(self.kernel_size) / ratio - width) * rolloff + t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2 + filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser-windowed sinc filter (BigVGAN default). + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + + self.register_buffer("filter", filter, persistent=persistent) + + def forward(self, x): + _, C, _ = x.shape + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + return self.lowpass(x) + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio=2, + down_ratio=2, + up_kernel_size=12, + down_kernel_size=12, + ): + super().__init__() + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +# --------------------------------------------------------------------------- +# BigVGAN v2 activations (Snake / SnakeBeta) +# --------------------------------------------------------------------------- + + +class Snake(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = self.alpha.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + a = torch.exp(a) + return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) + + +class SnakeBeta(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.beta = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.beta.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = self.alpha.unsqueeze(0).unsqueeze(-1) + b = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + a = torch.exp(a) + b = torch.exp(b) + return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2) + + +# --------------------------------------------------------------------------- +# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity) +# --------------------------------------------------------------------------- + + +class AMPBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"): + super().__init__() + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + self.acts1 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs1))] + ) + self.acts2 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs2))] + ) + + def forward(self, x): + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = x + xt + return x + + +# --------------------------------------------------------------------------- +# HiFi-GAN residual blocks +# --------------------------------------------------------------------------- + + class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() @@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module): """ Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1"). """ def __init__(self, config=None): @@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module): config = self.get_default_config() resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) - upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2]) - upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]) + upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4]) resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) upsample_initial_channel = config.get("upsample_initial_channel", 1024) stereo = config.get("stereo", True) - resblock = config.get("resblock", "1") + activation = config.get("activation", "snake") + use_bias_at_final = config.get("use_bias_at_final", True) + + # "output_sample_rate" is not present in recent checkpoint configs. + # When absent (None), AudioVAE.output_sample_rate computes it as: + # sample_rate * vocoder.upsample_factor / mel_hop_length + # where upsample_factor = product of all upsample stride lengths, + # and mel_hop_length is loaded from the autoencoder config at + # preprocessing.stft.hop_length (see CausalAudioAutoencoder). self.output_sample_rate = config.get("output_sample_rate") + self.resblock = config.get("resblock", "1") + self.use_tanh_at_final = config.get("use_tanh_at_final", True) + self.apply_final_activation = config.get("apply_final_activation", True) self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) - resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + if self.resblock == "1": + resblock_cls = ResBlock1 + elif self.resblock == "2": + resblock_cls = ResBlock2 + elif self.resblock == "AMP1": + resblock_cls = AMPBlock1 + else: + raise ValueError(f"Unknown resblock type: {self.resblock}") self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): @@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module): self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock_class(ch, k, d)) + for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): + if self.resblock == "AMP1": + self.resblocks.append(resblock_cls(ch, k, d, activation=activation)) + else: + self.resblocks.append(resblock_cls(ch, k, d)) out_channels = 2 if stereo else 1 - self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3) + if self.resblock == "AMP1": + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.act_post = Activation1d(act_cls(ch)) + else: + self.act_post = nn.LeakyReLU() + + self.conv_post = ops.Conv1d( + ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final + ) self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + def get_default_config(self): """Generate default configuration for the vocoder.""" config = { "resblock_kernel_sizes": [3, 7, 11], - "upsample_rates": [6, 5, 2, 2, 2], - "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "upsample_initial_channel": 1024, "stereo": True, "resblock": "1", + "activation": "snake", + "use_bias_at_final": True, + "use_tanh_at_final": True, } return config @@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module): assert x.shape[1] == 2, "Input must have 2 channels for stereo" x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) x = self.conv_pre(x) + for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) + if self.resblock != "AMP1": + x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) xs = None for j in range(self.num_kernels): @@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module): else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels - x = F.leaky_relu(x) + + x = self.act_post(x) x = self.conv_post(x) - x = torch.tanh(x) + + if self.apply_final_activation: + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, -1, 1) return x + + +class _STFTFn(nn.Module): + """Implements STFT as a convolution with precomputed DFT × Hann-window bases. + + The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal + Hann window are stored as buffers and loaded from the checkpoint. Using the exact + bfloat16 bases from training ensures the mel values fed to the BWE generator are + bit-identical to what it was trained on. + """ + + def __init__(self, filter_length: int, hop_length: int, win_length: int): + super().__init__() + self.hop_length = hop_length + self.win_length = win_length + n_freqs = filter_length // 2 + 1 + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + + def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Compute magnitude and phase spectrogram from a batch of waveforms. + + Applies causal (left-only) padding of win_length - hop_length samples so that + each output frame depends only on past and present input — no lookahead. + The STFT is computed by convolving the padded signal with forward_basis. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + Computed in float32 for numerical stability, then cast back to + the input dtype. + """ + if y.dim() == 2: + y = y.unsqueeze(1) # (B, 1, T) + left_pad = max(0, self.win_length - self.hop_length) # causal: left-only + y = F.pad(y, (left_pad, 0)) + spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real ** 2 + imag ** 2) + phase = torch.atan2(imag.float(), real.float()).to(real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. + + Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input + waveform and projecting the linear magnitude spectrum onto the mel filterbank. + + The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint + (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). + """ + + def __init__( + self, + filter_length: int, + hop_length: int, + win_length: int, + n_mel_channels: int, + sampling_rate: int, + mel_fmin: float, + mel_fmax: float, + ): + super().__init__() + self.stft_fn = _STFTFn(filter_length, hop_length, win_length) + + n_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) + + def mel_spectrogram( + self, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute log-mel spectrogram and auxiliary spectral quantities. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). + Computed as log(clamp(mel_basis @ magnitude, min=1e-5)). + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). + """ + magnitude, phase = self.stft_fn(y) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class VocoderWithBWE(torch.nn.Module): + """Vocoder with bandwidth extension (BWE) for higher sample rate output. + + Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples + to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. + """ + + def __init__(self, config): + super().__init__() + vocoder_config = config["vocoder"] + bwe_config = config["bwe"] + + self.vocoder = Vocoder(config=vocoder_config) + self.bwe_generator = Vocoder( + config={**bwe_config, "apply_final_activation": False} + ) + + self.input_sample_rate = bwe_config["input_sampling_rate"] + self.output_sample_rate = bwe_config["output_sampling_rate"] + self.hop_length = bwe_config["hop_length"] + + self.mel_stft = MelSTFT( + filter_length=bwe_config["n_fft"], + hop_length=bwe_config["hop_length"], + win_length=bwe_config["n_fft"], + n_mel_channels=bwe_config["num_mels"], + sampling_rate=bwe_config["input_sampling_rate"], + mel_fmin=0.0, + mel_fmax=bwe_config["input_sampling_rate"] / 2.0, + ) + self.resampler = UpSample1d( + ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"], + persistent=False, + window_type="hann", + ) + + def _compute_mel(self, audio): + """Compute log-mel spectrogram from waveform using causal STFT bases.""" + B, C, T = audio.shape + flat = audio.reshape(B * C, -1) # (B*C, T) + mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) + return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) + + def forward(self, mel_spec): + x = self.vocoder(mel_spec) + _, _, T_low = x.shape + T_out = T_low * self.output_sample_rate // self.input_sample_rate + + remainder = T_low % self.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + mel = self._compute_mel(x) + residual = self.bwe_generator(mel) + skip = self.resampler(x) + assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" + + return torch.clamp(residual + skip, -1, 1)[..., :T_out] diff --git a/comfy/model_base.py b/comfy/model_base.py index 1e01e9edc..d9d5a9293 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1021,7 +1021,7 @@ class LTXAV(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: if hasattr(self.diffusion_model, "preprocess_text_embeds"): - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference())) + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False)) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) diff --git a/comfy/sd.py b/comfy/sd.py index 8bcd09582..888ef1e77 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1467,7 +1467,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage elif clip_type == CLIPType.LTXV: - clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data)) + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.NEWBIE: diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e86ea9f4e..5e1273c6e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel): comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is +class DualLinearProjection(torch.nn.Module): + def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None): + super().__init__() + self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device) + self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device) + + def forward(self, x): + source_dim = x.shape[-1] + x = x.movedim(1, -1) + x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2) + + video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim)) + audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim)) + return torch.cat((video, audio), dim=-1) + class LTXAVTEModel(torch.nn.Module): - def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}): super().__init__() self.dtypes = set() self.dtypes.add(dtype) self.compat_mode = False + self.text_projection_type = text_projection_type self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) self.dtypes.add(dtype_llama) operations = self.gemma3_12b.operations # TODO - self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + + if self.text_projection_type == "single_linear": + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + elif self.text_projection_type == "dual_linear": + self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations) + def enable_compat_mode(self): # TODO: remove from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module): out_device = out.device if comfy.model_management.should_use_bf16(self.execution_device): out = out.to(device=self.execution_device, dtype=torch.bfloat16) - out = out.movedim(1, -1).to(self.execution_device) - out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) - out = out.reshape((out.shape[0], out.shape[1], -1)) - out = self.text_embedding_projection(out) - out = out.float() - if self.compat_mode: - out_vid = self.video_embeddings_connector(out)[0] - out_audio = self.audio_embeddings_connector(out)[0] - out = torch.concat((out_vid, out_audio), dim=-1) + if self.text_projection_type == "single_linear": + out = out.movedim(1, -1).to(self.execution_device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) - return out.to(out_device), pooled + if self.compat_mode: + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + extra = {} + else: + extra = {"unprocessed_ltxav_embeds": True} + elif self.text_projection_type == "dual_linear": + out = self.text_embedding_projection(out) + extra = {"unprocessed_ltxav_embeds": True} + + return out.to(device=out_device, dtype=torch.float), pooled, extra def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) @@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module): if "model.layers.47.self_attn.q_norm.weight" in sd: return self.gemma3_12b.load_sd(sd) else: - sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True) + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True) if len(sdo) == 0: sdo = sd @@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module): num_tokens = max(num_tokens, 642) return num_tokens * constant * 1024 * 1024 -def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): +def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"): class LTXAVTEModel_(LTXAVTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: @@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama - super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options) return LTXAVTEModel_ + +def sd_detect(state_dict_list, prefix=""): + for sd in state_dict_list: + if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd: + return {"text_projection_type": "dual_linear"} + if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd: + return {"text_projection_type": "single_linear"} + return {} + + def gemma3_te(dtype_llama=None, llama_quantization_metadata=None): class Gemma3_12BModel_(Gemma3_12BModel): def __init__(self, device="cpu", dtype=None, model_options={}): From f2ee7f2d367f98bb8a33bcb4a224bda441eb8a07 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 22:21:55 -0800 Subject: [PATCH 52/81] Fix cublas ops on dynamic vram. (#12776) --- comfy/ops.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 8275dd0a5..3e19cd1b6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -660,23 +660,29 @@ class fp8_ops(manual_cast): CUBLAS_IS_AVAILABLE = False try: - from cublas_ops import CublasLinear + from cublas_ops import CublasLinear, cublas_half_matmul CUBLAS_IS_AVAILABLE = True except ImportError: pass if CUBLAS_IS_AVAILABLE: - class cublas_ops(disable_weight_init): - class Linear(CublasLinear, disable_weight_init.Linear): + class cublas_ops(manual_cast): + class Linear(CublasLinear, manual_cast.Linear): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - return super().forward(input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) - + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) # ============================================================================== # Mixed Precision Operations From c5fe8ace68c432a262a5093bdd84b3ed70b9d283 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 5 Mar 2026 15:37:35 +0800 Subject: [PATCH 53/81] chore: update workflow templates to v0.9.6 (#12778) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index dae46d873..5f99407b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.5 +comfyui-workflow-templates==0.9.6 comfyui-embedded-docs==0.4.3 torch torchsde From 4941671b5a5c65fea48be922caa76b7f6a0a4595 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:39:51 -0800 Subject: [PATCH 54/81] Fix cuda getting initialized in cpu mode. (#12779) --- comfy/model_management.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 809600815..ee28ea107 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1666,12 +1666,16 @@ def lora_compute_dtype(device): return dtype def synchronize(): + if cpu_mode(): + return if is_intel_xpu(): torch.xpu.synchronize() elif torch.cuda.is_available(): torch.cuda.synchronize() def soft_empty_cache(force=False): + if cpu_mode(): + return global cpu_state if cpu_state == CPUState.MPS: torch.mps.empty_cache() From c8428541a6b6e4b1e0fbd685e9c846efcb60179e Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 5 Mar 2026 16:58:25 +0800 Subject: [PATCH 55/81] chore: update workflow templates to v0.9.7 (#12780) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5f99407b7..866818e08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.6 +comfyui-workflow-templates==0.9.7 comfyui-embedded-docs==0.4.3 torch torchsde From e04d0dbeb8266aa9262b5a4c3934ba4e4a371e37 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 04:06:29 -0500 Subject: [PATCH 56/81] ComfyUI v0.16.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 6a35c6de3..0aea18d3a 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.15.1" +__version__ = "0.16.0" diff --git a/pyproject.toml b/pyproject.toml index 1b2318273..f2133d99c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.15.1" +version = "0.16.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From bd21363563ce8e312c9271a0c64a0145335df8a9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:29:39 +0200 Subject: [PATCH 57/81] feat(api-nodes-xAI): updated models, pricing, added features (#12756) --- comfy_api_nodes/apis/grok.py | 14 ++++-- comfy_api_nodes/nodes_grok.py | 92 +++++++++++++++++++++++++++++------ 2 files changed, 87 insertions(+), 19 deletions(-) diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py index 8e3c79ab9..c56c8aecc 100644 --- a/comfy_api_nodes/apis/grok.py +++ b/comfy_api_nodes/apis/grok.py @@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel): aspect_ratio: str = Field(...) n: int = Field(...) seed: int = Field(...) - response_for: str = Field("url") + response_format: str = Field("url") + resolution: str = Field(...) class InputUrlObject(BaseModel): @@ -16,12 +17,13 @@ class InputUrlObject(BaseModel): class ImageEditRequest(BaseModel): model: str = Field(...) - image: InputUrlObject = Field(...) + images: list[InputUrlObject] = Field(...) prompt: str = Field(...) resolution: str = Field(...) n: int = Field(...) seed: int = Field(...) - response_for: str = Field("url") + response_format: str = Field("url") + aspect_ratio: str | None = Field(...) class VideoGenerationRequest(BaseModel): @@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel): revised_prompt: str | None = Field(None) +class UsageObject(BaseModel): + cost_in_usd_ticks: int | None = Field(None) + + class ImageGenerationResponse(BaseModel): data: list[ImageResponseObject] = Field(...) + usage: UsageObject | None = Field(None) class VideoGenerationResponse(BaseModel): @@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel): status: str | None = Field(None) video: VideoResponseObject | None = Field(None) model: str | None = Field(None) + usage: UsageObject | None = Field(None) diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index da15e97ea..0716d6239 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -27,6 +27,12 @@ from comfy_api_nodes.util import ( ) +def _extract_grok_price(response) -> float | None: + if response.usage and response.usage.cost_in_usd_ticks is not None: + return response.usage.cost_in_usd_ticks / 10_000_000_000 + return None + + class GrokImageNode(IO.ComfyNode): @classmethod @@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode): category="api node/image/Grok", description="Generate images using Grok based on a text prompt", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-image-beta"]), + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), IO.String.Input( "prompt", multiline=True, @@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), + IO.Combo.Input("resolution", options=["1K", "2K"], optional=True), ], outputs=[ IO.Image.Output(), @@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), - expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""", + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": $rate * widgets.number_of_images} + ) + """, ), ) @@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode): aspect_ratio: str, number_of_images: int, seed: int, + resolution: str = "1K", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) response = await sync_op( @@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode): aspect_ratio=aspect_ratio, n=number_of_images, seed=seed, + resolution=resolution.lower(), ), response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) @@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode): category="api node/image/Grok", description="Modify an existing image based on a text prompt", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-image-beta"]), - IO.Image.Input("image"), + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), + IO.Image.Input("image", display_name="images"), IO.String.Input( "prompt", multiline=True, tooltip="The text prompt used to generate the image", ), - IO.Combo.Input("resolution", options=["1K"]), + IO.Combo.Input("resolution", options=["1K", "2K"]), IO.Int.Input( "number_of_images", default=1, @@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "9:16", + "16:9", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", + ], + optional=True, + tooltip="Only allowed when multiple images are connected to the image input.", + ), ], outputs=[ IO.Image.Output(), @@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), - expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""", + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": 0.002 + $rate * widgets.number_of_images} + ) + """, ), ) @@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode): resolution: str, number_of_images: int, seed: int, + aspect_ratio: str = "auto", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) - if get_number_of_images(image) != 1: - raise ValueError("Only one input image is supported.") + if model == "grok-imagine-image-pro": + if get_number_of_images(image) > 1: + raise ValueError("The pro model supports only 1 input image.") + elif get_number_of_images(image) > 3: + raise ValueError("A maximum of 3 input images is supported.") + if aspect_ratio != "auto" and get_number_of_images(image) == 1: + raise ValueError( + "Custom aspect ratio is only allowed when multiple images are connected to the image input." + ) response = await sync_op( cls, ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), data=ImageEditRequest( model=model, - image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"), + images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image], prompt=prompt, resolution=resolution.lower(), n=number_of_images, seed=seed, + aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, ), response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) @@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode): category="api node/video/Grok", description="Generate video from a prompt or an image", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), IO.String.Input( "prompt", multiline=True, @@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]), + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]), expr=""" ( - $base := 0.181 * widgets.duration; + $rate := widgets.resolution = "720p" ? 0.07 : 0.05; + $base := $rate * widgets.duration; {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} ) """, @@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) @@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode): category="api node/video/Grok", description="Edit an existing video based on a text prompt.", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), IO.String.Input( "prompt", multiline=True, @@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""", + expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""", ), ) @@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) From 9cdfd7403bc46f75d12be16ba6041b8bcdd3f7fd Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 5 Mar 2026 17:12:38 +0200 Subject: [PATCH 58/81] feat(api-nodes): enable Kling 3.0 Motion Control (#12785) --- comfy_api_nodes/apis/kling.py | 1 + comfy_api_nodes/nodes_kling.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/comfy_api_nodes/apis/kling.py b/comfy_api_nodes/apis/kling.py index a5bd5f1d3..fe0f97cb3 100644 --- a/comfy_api_nodes/apis/kling.py +++ b/comfy_api_nodes/apis/kling.py @@ -148,3 +148,4 @@ class MotionControlRequest(BaseModel): keep_original_sound: str = Field(...) character_orientation: str = Field(...) mode: str = Field(..., description="'pro' or 'std'") + model_name: str = Field(...) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 74fa078ff..8963c335d 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -2747,6 +2747,7 @@ class MotionControl(IO.ComfyNode): "but the character orientation matches the reference image (camera/other details via prompt).", ), IO.Combo.Input("mode", options=["pro", "std"]), + IO.Combo.Input("model", options=["kling-v3", "kling-v2-6"], optional=True), ], outputs=[ IO.Video.Output(), @@ -2777,6 +2778,7 @@ class MotionControl(IO.ComfyNode): keep_original_sound: bool, character_orientation: str, mode: str, + model: str = "kling-v2-6", ) -> IO.NodeOutput: validate_string(prompt, max_length=2500) validate_image_dimensions(reference_image, min_width=340, min_height=340) @@ -2797,6 +2799,7 @@ class MotionControl(IO.ComfyNode): keep_original_sound="yes" if keep_original_sound else "no", character_orientation=character_orientation, mode=mode, + model_name=model, ), ) if response.code: From da29b797ce00b491c269e864cc3b8fceb279e530 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 5 Mar 2026 23:23:23 +0800 Subject: [PATCH 59/81] Update workflow templates to v0.9.8 (#12788) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 866818e08..3fd44e0cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.7 +comfyui-workflow-templates==0.9.8 comfyui-embedded-docs==0.4.3 torch torchsde From 6ef82a89b83a49247081dc57b154172573c9e313 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 10:38:33 -0500 Subject: [PATCH 60/81] ComfyUI v0.16.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 0aea18d3a..e58e0fb63 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.0" +__version__ = "0.16.1" diff --git a/pyproject.toml b/pyproject.toml index f2133d99c..199a90364 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.0" +version = "0.16.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 6481569ad4c3606bc50e9de39ce810651690ae79 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 5 Mar 2026 09:04:24 -0800 Subject: [PATCH 61/81] comfy-aimdo 0.2.7 (#12791) Comfy-aimdo 0.2.7 fixes a crash when a spurious cudaAsyncFree comes in and would cause an infinite stack overflow (via detours hooks). A lock is also introduced on the link list holding the free sections to avoid any possibility of threaded miscellaneous cuda allocations being the root cause. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3fd44e0cf..f7098b730 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.6 +comfy-aimdo>=0.2.7 requests #non essential dependencies: From 42e0e023eee6a19c1adb7bd3dc11c81ff6dcc9c8 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 5 Mar 2026 10:22:17 -0800 Subject: [PATCH 62/81] ops: Handle CPU weight in VBAR caster (#12792) This shouldn't happen but custom nodes gets there. Handle it as best we can. --- comfy/ops.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index 3e19cd1b6..06aa41d4f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -80,6 +80,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): + + #vbar doesn't support CPU weights, but some custom nodes have weird paths + #that might switch the layer to the CPU and expect it to work. We have to take + #a clone conservatively as we are mmapped and some SFT files are packed misaligned + #If you are a custom node author reading this, please move your layer to the GPU + #or declare your ModelPatcher as CPU in the first place. + if device is not None and device.type == "cpu": + weight = s.weight.to(dtype=dtype, copy=True) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + bias = None + if s.bias is not None: + bias = s.bias.to(dtype=bias_dtype, copy=True) + return weight, bias, (None, None, None) + offload_stream = None xfer_dest = None From 5073da57ad20a2abb921f79458e49a7f7d608740 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 6 Mar 2026 02:22:38 +0800 Subject: [PATCH 63/81] chore: update workflow templates to v0.9.10 (#12793) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f7098b730..9a674fac5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.8 +comfyui-workflow-templates==0.9.10 comfyui-embedded-docs==0.4.3 torch torchsde From 1c3b651c0a1539a374e3d29a3ce695b5844ac5fc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Mar 2026 10:35:56 -0800 Subject: [PATCH 64/81] Refactor. (#12794) --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 06aa41d4f..87b36b5c5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -86,7 +86,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu #a clone conservatively as we are mmapped and some SFT files are packed misaligned #If you are a custom node author reading this, please move your layer to the GPU #or declare your ModelPatcher as CPU in the first place. - if device is not None and device.type == "cpu": + if comfy.model_management.is_device_cpu(device): weight = s.weight.to(dtype=dtype, copy=True) if isinstance(weight, QuantizedTensor): weight = weight.dequantize() From 50549aa252903b936b2ed00b5de418c8b47f0841 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 13:41:06 -0500 Subject: [PATCH 65/81] ComfyUI v0.16.2 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index e58e0fb63..bc49f2218 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.1" +__version__ = "0.16.2" diff --git a/pyproject.toml b/pyproject.toml index 199a90364..73bfd1007 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.1" +version = "0.16.2" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 8befce5c7b84ff3451a6bd3bcbae1355ad322855 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 5 Mar 2026 22:37:25 +0200 Subject: [PATCH 66/81] Add manual cast to LTX2 vocoder conv_transpose1d (#12795) * Add manual cast to LTX2 vocoder * Update vocoder.py --- comfy/ldm/lightricks/vocoders/vocoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index 6c4028aa8..a0e03cada 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F import torch.nn as nn import comfy.ops +import comfy.model_management import numpy as np import math @@ -125,7 +126,7 @@ class UpSample1d(nn.Module): _, C, _ = x.shape x = F.pad(x, (self.pad, self.pad), mode="replicate") x = self.ratio * F.conv_transpose1d( - x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C ) x = x[..., self.pad_left : -self.pad_right] return x From 17b43c2b87eba43f0f071471b855e0ed659a2627 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Mar 2026 13:31:28 -0800 Subject: [PATCH 67/81] LTX audio vae novram fixes. (#12796) --- comfy/ldm/lightricks/vocoders/vocoder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index a0e03cada..2481d8bdd 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -82,7 +82,7 @@ class LowPassFilter1d(nn.Module): _, C, _ = x.shape if self.padding: x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) - return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) class UpSample1d(nn.Module): @@ -191,7 +191,7 @@ class Snake(nn.Module): self.eps = 1e-9 def forward(self, x): - a = self.alpha.unsqueeze(0).unsqueeze(-1) + a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) if self.alpha_logscale: a = torch.exp(a) return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) @@ -218,8 +218,8 @@ class SnakeBeta(nn.Module): self.eps = 1e-9 def forward(self, x): - a = self.alpha.unsqueeze(0).unsqueeze(-1) - b = self.beta.unsqueeze(0).unsqueeze(-1) + a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) + b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) if self.alpha_logscale: a = torch.exp(a) b = torch.exp(b) @@ -597,7 +597,7 @@ class _STFTFn(nn.Module): y = y.unsqueeze(1) # (B, 1, T) left_pad = max(0, self.win_length - self.hop_length) # causal: left-only y = F.pad(y, (left_pad, 0)) - spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) + spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0) n_freqs = spec.shape[1] // 2 real, imag = spec[:, :n_freqs], spec[:, n_freqs:] magnitude = torch.sqrt(real ** 2 + imag ** 2) @@ -648,7 +648,7 @@ class MelSTFT(nn.Module): """ magnitude, phase = self.stft_fn(y) energy = torch.norm(magnitude, dim=1) - mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude) log_mel = torch.log(torch.clamp(mel, min=1e-5)) return log_mel, magnitude, phase, energy From 58017e8726bdddae89704b1e0123bedc29994424 Mon Sep 17 00:00:00 2001 From: Tavi Halperin Date: Thu, 5 Mar 2026 23:51:20 +0200 Subject: [PATCH 68/81] feat: add causal_fix parameter to add_keyframe_index and append_keyframe (#12797) Allows explicit control over the causal_fix flag passed to latent_to_pixel_coords. Defaults to frame_idx == 0 when not specified, fixing the previous heuristic. --- comfy_extras/nodes_lt.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 32fe921ff..c05571143 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -253,10 +253,12 @@ class LTXVAddGuide(io.ComfyNode): return frame_idx, latent_idx @classmethod - def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1): + def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1, causal_fix=None): keyframe_idxs, _ = get_keyframe_idxs(cond) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) - pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 + if causal_fix is None: + causal_fix = frame_idx == 0 or guiding_latent.shape[2] == 1 + pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=causal_fix) pixel_coords[:, 0] += frame_idx # The following adjusts keyframe end positions for small grid IC-LoRA. @@ -278,12 +280,12 @@ class LTXVAddGuide(io.ComfyNode): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1): + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1, causal_fix=None): if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: raise ValueError("Adding guide to a combined AV latent is not supported.") - positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) - negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) + positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix) + negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix) if guide_mask is not None: target_h = max(noise_mask.shape[3], guide_mask.shape[3]) From 1c218282369a6cc80651d878fc51fa33d7bf34e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 17:25:49 -0500 Subject: [PATCH 69/81] ComfyUI v0.16.3 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index bc49f2218..5da21150b 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.2" +__version__ = "0.16.3" diff --git a/pyproject.toml b/pyproject.toml index 73bfd1007..6a83c5c63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.2" +version = "0.16.3" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From e544c65db91df5a070be69a0a9b922201fe79335 Mon Sep 17 00:00:00 2001 From: Dante Date: Fri, 6 Mar 2026 11:51:28 +0900 Subject: [PATCH 70/81] feat: add Math Expression node with simpleeval evaluation (#12687) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add EagerEval dataclass for frontend-side node evaluation Add EagerEval to the V3 API schema, enabling nodes to declare frontend-evaluated JSONata expressions. The frontend uses this to display computation results as badges without a backend round-trip. Co-Authored-By: Claude Opus 4.6 * feat: add Math Expression node with JSONata evaluation Add ComfyMathExpression node that evaluates JSONata expressions against dynamically-grown numeric inputs using Autogrow + MatchType. Sends input context via ui output so the frontend can re-evaluate when the expression changes without a backend round-trip. Co-Authored-By: Claude Opus 4.6 * feat: register nodes_math.py in extras_files loader list Co-Authored-By: Claude Opus 4.6 * fix: address CodeRabbit review feedback - Harden EagerEval.validate with type checks and strip() for empty strings - Add _positional_alias for spreadsheet-style names beyond z (aa, ab...) - Validate JSONata result is numeric before returning - Add jsonata to requirements.txt Co-Authored-By: Claude Opus 4.6 * refactor: remove EagerEval, scope PR to math node only Remove EagerEval dataclass from _io.py and eager_eval usage from nodes_math.py. Eager execution will be designed as a general-purpose system in a separate effort. Co-Authored-By: Claude Opus 4.6 * fix: use TemplateNames, cap inputs at 26, improve error message Address Kosinkadink review feedback: - Switch from Autogrow.TemplatePrefix to Autogrow.TemplateNames so input slots are named a-z, matching expression variables directly - Cap max inputs at 26 (a-z) instead of 100 - Simplify execute() by removing dual-mapping hack - Include expression and result value in error message Co-Authored-By: Claude Opus 4.6 * test: add unit tests for Math Expression node Add tests for _positional_alias (a-z mapping) and execute() covering arithmetic operations, float inputs, $sum(values), and error cases. Co-Authored-By: Claude Opus 4.6 * refactor: replace jsonata with simpleeval for math evaluation jsonata PyPI package has critical issues: no Python 3.12/3.13 wheels, no ARM/Apple Silicon wheels, abandoned (last commit 2023), C extension. Replace with simpleeval (pure Python, 3.4M downloads/month, MIT, AST-based security). Add math module functions (sqrt, ceil, floor, log, sin, cos, tan) and variadic sum() supporting both sum(values) and sum(a, b, c). Pin version to >=1.0,<2.0. Co-Authored-By: Claude Opus 4.6 * test: update tests for simpleeval migration Update JSONata syntax to Python syntax ($sum -> sum, $string -> str), add tests for math functions (sqrt, ceil, floor, sin, log10) and variadic sum(a, b, c). Co-Authored-By: Claude Opus 4.6 * refactor: replace MatchType with MultiType inputs and dual FLOAT/INT outputs Allow mixing INT and FLOAT connections on the same node by switching from MatchType (which forces all inputs to the same type) to MultiType. Output both FLOAT and INT so users can pick the type they need. Co-Authored-By: Claude Opus 4.6 * test: update tests for mixed INT/FLOAT inputs and dual outputs Add assertions for both FLOAT (result[0]) and INT (result[1]) outputs. Add test_mixed_int_float_inputs and test_mixed_resolution_scale to verify the primary use case of multiplying resolutions by a float factor. Co-Authored-By: Claude Opus 4.6 * feat: make expression input multiline and validate empty expression - Add multiline=True to expression input for better UX with longer expressions - Add empty expression validation with clear "Expression cannot be empty." message Co-Authored-By: Claude Opus 4.6 * test: add tests for empty expression validation Co-Authored-By: Claude Opus 4.6 * fix: address review feedback — safe pow, isfinite guard, test coverage - Wrap pow() with _safe_pow to prevent DoS via huge exponents (pow() bypasses simpleeval's safe_power guard on **) - Add math.isfinite() check to catch inf/nan before int() conversion - Add int/float converters to MATH_FUNCTIONS for explicit casting - Add "calculator" search alias - Replace _positional_alias helper with string.ascii_lowercase - Narrow test assertions and add error path + function coverage tests Co-Authored-By: Claude Opus 4.6 * Update requirements.txt --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: Jedrzej Kosinski Co-authored-by: Christian Byrne --- comfy_extras/nodes_math.py | 119 +++++++++++ nodes.py | 1 + requirements.txt | 1 + .../comfy_extras_test/nodes_math_test.py | 197 ++++++++++++++++++ 4 files changed, 318 insertions(+) create mode 100644 comfy_extras/nodes_math.py create mode 100644 tests-unit/comfy_extras_test/nodes_math_test.py diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py new file mode 100644 index 000000000..6417bacf1 --- /dev/null +++ b/comfy_extras/nodes_math.py @@ -0,0 +1,119 @@ +"""Math expression node using simpleeval for safe evaluation. + +Provides a ComfyMathExpression node that evaluates math expressions +against dynamically-grown numeric inputs. +""" + +from __future__ import annotations + +import math +import string + +from simpleeval import simple_eval +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +MAX_EXPONENT = 4000 + + +def _variadic_sum(*args): + """Support both sum(values) and sum(a, b, c).""" + if len(args) == 1 and hasattr(args[0], "__iter__"): + return sum(args[0]) + return sum(args) + + +def _safe_pow(base, exp): + """Wrap pow() with an exponent cap to prevent DoS via huge exponents. + + The ** operator is already guarded by simpleeval's safe_power, but + pow() as a callable bypasses that guard. + """ + if abs(exp) > MAX_EXPONENT: + raise ValueError(f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})") + return pow(base, exp) + + +MATH_FUNCTIONS = { + "sum": _variadic_sum, + "min": min, + "max": max, + "abs": abs, + "round": round, + "pow": _safe_pow, + "sqrt": math.sqrt, + "ceil": math.ceil, + "floor": math.floor, + "log": math.log, + "log2": math.log2, + "log10": math.log10, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "int": int, + "float": float, +} + + +class MathExpressionNode(io.ComfyNode): + """Evaluates a math expression against dynamically-grown inputs.""" + + @classmethod + def define_schema(cls) -> io.Schema: + autogrow = io.Autogrow.TemplateNames( + input=io.MultiType.Input("value", [io.Float, io.Int]), + names=list(string.ascii_lowercase), + min=1, + ) + return io.Schema( + node_id="ComfyMathExpression", + display_name="Math Expression", + category="math", + search_aliases=[ + "expression", "formula", "calculate", "calculator", + "eval", "math", + ], + inputs=[ + io.String.Input("expression", default="a + b", multiline=True), + io.Autogrow.Input("values", template=autogrow), + ], + outputs=[ + io.Float.Output(display_name="FLOAT"), + io.Int.Output(display_name="INT"), + ], + ) + + @classmethod + def execute( + cls, expression: str, values: io.Autogrow.Type + ) -> io.NodeOutput: + if not expression.strip(): + raise ValueError("Expression cannot be empty.") + + context: dict = dict(values) + context["values"] = list(values.values()) + + result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS) + # bool check must come first because bool is a subclass of int in Python + if isinstance(result, bool) or not isinstance(result, (int, float)): + raise ValueError( + f"Math Expression '{expression}' must evaluate to a numeric result, " + f"got {type(result).__name__}: {result!r}" + ) + if not math.isfinite(result): + raise ValueError( + f"Math Expression '{expression}' produced a non-finite result: {result}" + ) + return io.NodeOutput(float(result), int(result)) + + +class MathExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [MathExpressionNode] + + +async def comfy_entrypoint() -> MathExtension: + return MathExtension() diff --git a/nodes.py b/nodes.py index 5be9b16f9..0ef23b640 100644 --- a/nodes.py +++ b/nodes.py @@ -2449,6 +2449,7 @@ async def init_builtin_extra_nodes(): "nodes_replacements.py", "nodes_nag.py", "nodes_sdpose.py", + "nodes_math.py", ] import_failed = [] diff --git a/requirements.txt b/requirements.txt index 9a674fac5..7bf12247c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,7 @@ av>=14.2.0 comfy-kitchen>=0.2.7 comfy-aimdo>=0.2.7 requests +simpleeval>=1.0 #non essential dependencies: kornia>=0.7.1 diff --git a/tests-unit/comfy_extras_test/nodes_math_test.py b/tests-unit/comfy_extras_test/nodes_math_test.py new file mode 100644 index 000000000..fa4cdcac3 --- /dev/null +++ b/tests-unit/comfy_extras_test/nodes_math_test.py @@ -0,0 +1,197 @@ +import math + +import pytest +from collections import OrderedDict +from unittest.mock import patch, MagicMock + +mock_nodes = MagicMock() +mock_nodes.MAX_RESOLUTION = 16384 +mock_server = MagicMock() + +with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}): + from comfy_extras.nodes_math import MathExpressionNode + + +class TestMathExpressionExecute: + @staticmethod + def _exec(expression: str, **kwargs) -> object: + values = OrderedDict(kwargs) + return MathExpressionNode.execute(expression, values) + + def test_addition(self): + result = self._exec("a + b", a=3, b=4) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_subtraction(self): + result = self._exec("a - b", a=10, b=3) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_multiplication(self): + result = self._exec("a * b", a=3, b=5) + assert result[0] == 15.0 + assert result[1] == 15 + + def test_division(self): + result = self._exec("a / b", a=10, b=4) + assert result[0] == 2.5 + assert result[1] == 2 + + def test_single_input(self): + result = self._exec("a * 2", a=5) + assert result[0] == 10.0 + assert result[1] == 10 + + def test_three_inputs(self): + result = self._exec("a + b + c", a=1, b=2, c=3) + assert result[0] == 6.0 + assert result[1] == 6 + + def test_float_inputs(self): + result = self._exec("a + b", a=1.5, b=2.5) + assert result[0] == 4.0 + assert result[1] == 4 + + def test_mixed_int_float_inputs(self): + result = self._exec("a * b", a=1024, b=1.5) + assert result[0] == 1536.0 + assert result[1] == 1536 + + def test_mixed_resolution_scale(self): + result = self._exec("a * b", a=512, b=0.75) + assert result[0] == 384.0 + assert result[1] == 384 + + def test_sum_values_array(self): + result = self._exec("sum(values)", a=1, b=2, c=3) + assert result[0] == 6.0 + + def test_sum_variadic(self): + result = self._exec("sum(a, b, c)", a=1, b=2, c=3) + assert result[0] == 6.0 + + def test_min_values(self): + result = self._exec("min(values)", a=5, b=2, c=8) + assert result[0] == 2.0 + + def test_max_values(self): + result = self._exec("max(values)", a=5, b=2, c=8) + assert result[0] == 8.0 + + def test_abs_function(self): + result = self._exec("abs(a)", a=-7) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_sqrt(self): + result = self._exec("sqrt(a)", a=16) + assert result[0] == 4.0 + assert result[1] == 4 + + def test_ceil(self): + result = self._exec("ceil(a)", a=2.3) + assert result[0] == 3.0 + assert result[1] == 3 + + def test_floor(self): + result = self._exec("floor(a)", a=2.7) + assert result[0] == 2.0 + assert result[1] == 2 + + def test_sin(self): + result = self._exec("sin(a)", a=0) + assert result[0] == 0.0 + + def test_log10(self): + result = self._exec("log10(a)", a=100) + assert result[0] == 2.0 + assert result[1] == 2 + + def test_float_output_type(self): + result = self._exec("a + b", a=1, b=2) + assert isinstance(result[0], float) + + def test_int_output_type(self): + result = self._exec("a + b", a=1, b=2) + assert isinstance(result[1], int) + + def test_non_numeric_result_raises(self): + with pytest.raises(ValueError, match="must evaluate to a numeric result"): + self._exec("'hello'", a=42) + + def test_undefined_function_raises(self): + with pytest.raises(Exception, match="not defined"): + self._exec("str(a)", a=42) + + def test_boolean_result_raises(self): + with pytest.raises(ValueError, match="got bool"): + self._exec("a > b", a=5, b=3) + + def test_empty_expression_raises(self): + with pytest.raises(ValueError, match="Expression cannot be empty"): + self._exec("", a=1) + + def test_whitespace_only_expression_raises(self): + with pytest.raises(ValueError, match="Expression cannot be empty"): + self._exec(" ", a=1) + + # --- Missing function coverage (round, pow, log, log2, cos, tan) --- + + def test_round(self): + result = self._exec("round(a)", a=2.7) + assert result[0] == 3.0 + assert result[1] == 3 + + def test_round_with_ndigits(self): + result = self._exec("round(a, 2)", a=3.14159) + assert result[0] == pytest.approx(3.14) + + def test_pow(self): + result = self._exec("pow(a, b)", a=2, b=10) + assert result[0] == 1024.0 + assert result[1] == 1024 + + def test_log(self): + result = self._exec("log(a)", a=math.e) + assert result[0] == pytest.approx(1.0) + + def test_log2(self): + result = self._exec("log2(a)", a=8) + assert result[0] == pytest.approx(3.0) + + def test_cos(self): + result = self._exec("cos(a)", a=0) + assert result[0] == 1.0 + + def test_tan(self): + result = self._exec("tan(a)", a=0) + assert result[0] == 0.0 + + # --- int/float converter functions --- + + def test_int_converter(self): + result = self._exec("int(a / b)", a=7, b=2) + assert result[1] == 3 + + def test_float_converter(self): + result = self._exec("float(a)", a=5) + assert result[0] == 5.0 + + # --- Error path tests --- + + def test_division_by_zero_raises(self): + with pytest.raises(ZeroDivisionError): + self._exec("a / b", a=1, b=0) + + def test_sqrt_negative_raises(self): + with pytest.raises(ValueError, match="math domain error"): + self._exec("sqrt(a)", a=-1) + + def test_overflow_inf_raises(self): + with pytest.raises(ValueError, match="non-finite result"): + self._exec("a * b", a=1e308, b=10) + + def test_pow_huge_exponent_raises(self): + with pytest.raises(ValueError, match="Exponent .* exceeds maximum"): + self._exec("pow(a, b)", a=10, b=10000000) From 3b93d5d571cb3e018da65f822cd11b60202b11c2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 6 Mar 2026 11:04:48 +0200 Subject: [PATCH 71/81] feat(api-nodes): add TencentSmartTopology node (#12741) * feat(api-nodes): add TencentSmartTopology node * feat(api-nodes): enable TencentModelTo3DUV node * chore(Tencent endpoints): add "wait" to queued statuses --- comfy_api_nodes/apis/hunyuan3d.py | 16 ++++- comfy_api_nodes/nodes_hunyuan3d.py | 109 ++++++++++++++++++++++++++--- comfy_api_nodes/util/client.py | 2 +- 3 files changed, 114 insertions(+), 13 deletions(-) diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py index e84eba31e..dad9bc2fa 100644 --- a/comfy_api_nodes/apis/hunyuan3d.py +++ b/comfy_api_nodes/apis/hunyuan3d.py @@ -66,13 +66,17 @@ class To3DProTaskQueryRequest(BaseModel): JobId: str = Field(...) -class To3DUVFileInput(BaseModel): +class TaskFile3DInput(BaseModel): Type: str = Field(..., description="File type: GLB, OBJ, or FBX") Url: str = Field(...) class To3DUVTaskRequest(BaseModel): - File: To3DUVFileInput = Field(...) + File: TaskFile3DInput = Field(...) + + +class To3DPartTaskRequest(BaseModel): + File: TaskFile3DInput = Field(...) class TextureEditImageInfo(BaseModel): @@ -80,7 +84,13 @@ class TextureEditImageInfo(BaseModel): class TextureEditTaskRequest(BaseModel): - File3D: To3DUVFileInput = Field(...) + File3D: TaskFile3DInput = Field(...) Image: TextureEditImageInfo | None = Field(None) Prompt: str | None = Field(None) EnablePBR: bool | None = Field(None) + + +class SmartTopologyRequest(BaseModel): + File3D: TaskFile3DInput = Field(...) + PolygonType: str | None = Field(...) + FaceLevel: str | None = Field(...) diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index d1d9578ec..bd8bde997 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -5,18 +5,19 @@ from comfy_api_nodes.apis.hunyuan3d import ( Hunyuan3DViewImage, InputGenerateType, ResultFile3D, + SmartTopologyRequest, + TaskFile3DInput, TextureEditTaskRequest, + To3DPartTaskRequest, To3DProTaskCreateResponse, To3DProTaskQueryRequest, To3DProTaskRequest, To3DProTaskResultResponse, - To3DUVFileInput, To3DUVTaskRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_file_3d, - download_url_to_image_tensor, downscale_image_tensor_by_max_side, poll_op, sync_op, @@ -344,7 +345,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode): outputs=[ IO.File3DOBJ.Output(display_name="OBJ"), IO.File3DFBX.Output(display_name="FBX"), - IO.Image.Output(), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -375,7 +375,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode): ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"), response_model=To3DProTaskCreateResponse, data=To3DUVTaskRequest( - File=To3DUVFileInput( + File=TaskFile3DInput( Type=file_format.upper(), Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format), ) @@ -394,7 +394,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode): return IO.NodeOutput( await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), - await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url), ) @@ -463,7 +462,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode): ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"), response_model=To3DProTaskCreateResponse, data=TextureEditTaskRequest( - File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url), + File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), Prompt=prompt, EnablePBR=True, ), @@ -538,8 +537,8 @@ class Tencent3DPartNode(IO.ComfyNode): cls, ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"), response_model=To3DProTaskCreateResponse, - data=To3DUVTaskRequest( - File=To3DUVFileInput(Type=file_format.upper(), Url=model_url), + data=To3DPartTaskRequest( + File=TaskFile3DInput(Type=file_format.upper(), Url=model_url), ), is_rate_limited=_is_tencent_rate_limited, ) @@ -557,15 +556,107 @@ class Tencent3DPartNode(IO.ComfyNode): ) +class TencentSmartTopologyNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentSmartTopologyNode", + display_name="Hunyuan3D: Smart Topology", + category="api node/3d/Tencent", + description="Perform smart retopology on a 3D model. " + "Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.", + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DAny], + tooltip="Input 3D model (GLB or OBJ)", + ), + IO.Combo.Input( + "polygon_type", + options=["triangle", "quadrilateral"], + tooltip="Surface composition type.", + ), + IO.Combo.Input( + "face_level", + options=["medium", "high", "low"], + tooltip="Polygon reduction level.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.File3DOBJ.Output(display_name="OBJ"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge(expr='{"type":"usd","usd":1.0}'), + ) + + SUPPORTED_FORMATS = {"glb", "obj"} + + @classmethod + async def execute( + cls, + model_3d: Types.File3D, + polygon_type: str, + face_level: str, + seed: int, + ) -> IO.NodeOutput: + _ = seed + file_format = model_3d.format.lower() + if file_format not in cls.SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported file format: '{file_format}'. " f"Supported: {', '.join(sorted(cls.SUPPORTED_FORMATS))}." + ) + model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology", method="POST"), + response_model=To3DProTaskCreateResponse, + data=SmartTopologyRequest( + File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), + PolygonType=polygon_type, + FaceLevel=face_level, + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed: [{response.Error.Code}] {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + return IO.NodeOutput( + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), + ) + + class TencentHunyuan3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ TencentTextToModelNode, TencentImageToModelNode, - # TencentModelTo3DUVNode, + TencentModelTo3DUVNode, # Tencent3DTextureEditNode, Tencent3DPartNode, + TencentSmartTopologyNode, ] diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 94886af7b..79ffb77c1 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -83,7 +83,7 @@ class _PollUIState: _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] -QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"] async def sync_op( From 34e55f006156801a6b5988d046d9041cb681f12d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 6 Mar 2026 19:54:27 +0200 Subject: [PATCH 72/81] feat(api-nodes): add Gemini 3.1 Flash Lite model to LLM node (#12803) --- comfy_api_nodes/nodes_gemini.py | 47 +++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index d83d2fc15..8225ea67e 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -72,18 +72,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( ) -class GeminiModel(str, Enum): - """ - Gemini Model Names allowed by comfy-api - """ - - gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06" - gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" - gemini_2_5_pro = "gemini-2.5-pro" - gemini_2_5_flash = "gemini-2.5-flash" - gemini_3_0_pro = "gemini-3-pro-preview" - - class GeminiImageModel(str, Enum): """ Gemini Image Model Names allowed by comfy-api @@ -237,10 +225,14 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N input_tokens_price = 0.30 output_text_tokens_price = 2.50 output_image_tokens_price = 30.0 - elif response.modelVersion == "gemini-3-pro-preview": + elif response.modelVersion in ("gemini-3-pro-preview", "gemini-3.1-pro-preview"): input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 0.0 + elif response.modelVersion == "gemini-3.1-flash-lite-preview": + input_tokens_price = 0.25 + output_text_tokens_price = 1.50 + output_image_tokens_price = 0.0 elif response.modelVersion == "gemini-3-pro-image-preview": input_tokens_price = 2 output_text_tokens_price = 12.0 @@ -292,8 +284,16 @@ class GeminiNode(IO.ComfyNode): ), IO.Combo.Input( "model", - options=GeminiModel, - default=GeminiModel.gemini_2_5_pro, + options=[ + "gemini-2.5-pro-preview-05-06", + "gemini-2.5-flash-preview-04-17", + "gemini-2.5-pro", + "gemini-2.5-flash", + "gemini-3-pro-preview", + "gemini-3-1-pro", + "gemini-3-1-flash-lite", + ], + default="gemini-3-1-pro", tooltip="The Gemini model to use for generating responses.", ), IO.Int.Input( @@ -363,11 +363,16 @@ class GeminiNode(IO.ComfyNode): "usd": [0.00125, 0.01], "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } } - : $contains($m, "gemini-3-pro-preview") ? { + : ($contains($m, "gemini-3-pro-preview") or $contains($m, "gemini-3-1-pro")) ? { "type": "list_usd", "usd": [0.002, 0.012], "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } } + : $contains($m, "gemini-3-1-flash-lite") ? { + "type": "list_usd", + "usd": [0.00025, 0.0015], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } : {"type":"text", "text":"Token-based"} ) """, @@ -436,12 +441,14 @@ class GeminiNode(IO.ComfyNode): files: list[GeminiPart] | None = None, system_prompt: str = "", ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=False) + if model == "gemini-3-pro-preview": + model = "gemini-3.1-pro-preview" # model "gemini-3-pro-preview" will be soon deprecated by Google + elif model == "gemini-3-1-pro": + model = "gemini-3.1-pro-preview" + elif model == "gemini-3-1-flash-lite": + model = "gemini-3.1-flash-lite-preview" - # Create parts list with text prompt as the first part parts: list[GeminiPart] = [GeminiPart(text=prompt)] - - # Add other modal parts if images is not None: parts.extend(await create_image_parts(cls, images)) if audio is not None: From f466b066017b9ebe5df67decfcbd09f78c5c66fa Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:20:07 -0800 Subject: [PATCH 73/81] Fix fp16 audio encoder models (#12811) * mp: respect model_defined_dtypes in default caster This is needed for parametrizations when the dtype changes between sd and model. * audio_encoders: archive model dtypes Archive model dtypes to stop the state dict load override the dtypes defined by the core for compute etc. --- comfy/audio_encoders/audio_encoders.py | 1 + comfy/model_patcher.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 16998af94..0de7584b0 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -27,6 +27,7 @@ class AudioEncoderModel(): self.model.eval() self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 + comfy.model_management.archive_model_dtypes(self.model) def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7e5ad7aa4..745384271 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -715,8 +715,8 @@ class ModelPatcher: default = True # default random weights in non leaf modules break if default and default_device is not None: - for param in params.values(): - param.data = param.data.to(device=default_device) + for param_name, param in params.items(): + param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None)) if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0): module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem From d69d30819b91aa020d0bb888df2a5b917f83bb7e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 6 Mar 2026 16:11:16 -0800 Subject: [PATCH 74/81] Don't run TE on cpu when dynamic vram enabled. (#12815) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index ee28ea107..39b4aa483 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -939,7 +939,7 @@ def text_encoder_offload_device(): def text_encoder_device(): if args.gpu_only: return get_torch_device() - elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: + elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled: if should_use_fp16(prioritize_performance=False): return get_torch_device() else: From afc00f00553885eeb96ded329878fe732f6b9f7a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 6 Mar 2026 17:10:53 -0800 Subject: [PATCH 75/81] Fix requirements version. (#12817) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7bf12247c..26e2ecdec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,7 @@ av>=14.2.0 comfy-kitchen>=0.2.7 comfy-aimdo>=0.2.7 requests -simpleeval>=1.0 +simpleeval>=1.0.0 #non essential dependencies: kornia>=0.7.1 From 6ac8152fc80734b084d12865460e5e9a5d9a4e1b Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sat, 7 Mar 2026 15:54:09 +0800 Subject: [PATCH 76/81] chore: update workflow templates to v0.9.11 (#12821) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 26e2ecdec..dc9a9ded0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.10 +comfyui-workflow-templates==0.9.11 comfyui-embedded-docs==0.4.3 torch torchsde From bcf1a1fab1e9efe0d4999ea14e9c0318409e0000 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 7 Mar 2026 09:38:08 -0800 Subject: [PATCH 77/81] mm: reset_cast_buffers: sync compute stream before free (#12822) Sync the compute stream before freeing the cast buffers. This can cause use after free issues when the cast stream frees the buffer while the compute stream is behind enough to still needs a casted weight. --- comfy/model_management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 39b4aa483..07bc8ad67 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1148,6 +1148,7 @@ def reset_cast_buffers(): LARGEST_CASTED_WEIGHT = (None, 0) for offload_stream in STREAM_CAST_BUFFERS: offload_stream.synchronize() + synchronize() STREAM_CAST_BUFFERS.clear() soft_empty_cache() From a7a6335be538f55faa2abf7404c9b8e970847d1f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 7 Mar 2026 16:52:39 -0500 Subject: [PATCH 78/81] ComfyUI v0.16.4 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 5da21150b..2723d02e7 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.3" +__version__ = "0.16.4" diff --git a/pyproject.toml b/pyproject.toml index 6a83c5c63..753b219b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.3" +version = "0.16.4" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 29b24cb5177e9d5aa5b3d2e5869999efb4d538c7 Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Sat, 7 Mar 2026 17:37:25 -0800 Subject: [PATCH 79/81] refactor(assets): modular architecture + async two-phase scanner & background seeder (#12621) --- .../0002_merge_to_asset_references.py | 267 +++++ app/assets/api/routes.py | 810 ++++++++----- app/assets/api/schemas_in.py | 79 +- app/assets/api/schemas_out.py | 6 +- app/assets/api/upload.py | 171 +++ app/assets/database/bulk_ops.py | 204 ---- app/assets/database/models.py | 210 ++-- app/assets/database/queries.py | 976 ---------------- app/assets/database/queries/__init__.py | 121 ++ app/assets/database/queries/asset.py | 140 +++ .../database/queries/asset_reference.py | 1033 +++++++++++++++++ app/assets/database/queries/common.py | 54 + app/assets/database/queries/tags.py | 356 ++++++ app/assets/database/tags.py | 62 - app/assets/hashing.py | 75 -- app/assets/helpers.py | 319 +---- app/assets/manager.py | 516 -------- app/assets/scanner.py | 768 ++++++++---- app/assets/seeder.py | 794 +++++++++++++ app/assets/services/__init__.py | 87 ++ app/assets/services/asset_management.py | 309 +++++ app/assets/services/bulk_ingest.py | 280 +++++ app/assets/services/file_utils.py | 70 ++ app/assets/services/hashing.py | 95 ++ app/assets/services/ingest.py | 375 ++++++ app/assets/services/metadata_extract.py | 327 ++++++ app/assets/services/path_utils.py | 167 +++ app/assets/services/schemas.py | 109 ++ app/assets/services/tagging.py | 75 ++ app/database/db.py | 81 +- comfy/cli_args.py | 2 +- comfy_api/feature_flags.py | 1 + main.py | 35 +- requirements.txt | 2 + server.py | 19 +- tests-unit/assets_test/conftest.py | 16 +- tests-unit/assets_test/helpers.py | 28 + tests-unit/assets_test/queries/conftest.py | 20 + tests-unit/assets_test/queries/test_asset.py | 144 +++ .../assets_test/queries/test_asset_info.py | 517 +++++++++ .../assets_test/queries/test_cache_state.py | 499 ++++++++ .../assets_test/queries/test_metadata.py | 184 +++ tests-unit/assets_test/queries/test_tags.py | 366 ++++++ tests-unit/assets_test/services/__init__.py | 1 + tests-unit/assets_test/services/conftest.py | 54 + .../services/test_asset_management.py | 268 +++++ .../assets_test/services/test_bulk_ingest.py | 137 +++ .../assets_test/services/test_enrich.py | 207 ++++ .../assets_test/services/test_ingest.py | 229 ++++ .../assets_test/services/test_tagging.py | 197 ++++ .../assets_test/test_assets_missing_sync.py | 2 +- tests-unit/assets_test/test_crud.py | 138 ++- tests-unit/assets_test/test_downloads.py | 4 +- tests-unit/assets_test/test_file_utils.py | 121 ++ tests-unit/assets_test/test_list_filter.py | 40 +- .../assets_test/test_prune_orphaned_assets.py | 2 +- .../assets_test/test_sync_references.py | 482 ++++++++ .../{test_tags.py => test_tags_api.py} | 4 +- tests-unit/assets_test/test_uploads.py | 22 +- tests-unit/requirements.txt | 1 - tests-unit/seeder_test/test_seeder.py | 900 ++++++++++++++ utils/mime_types.py | 37 + 62 files changed, 10737 insertions(+), 2878 deletions(-) create mode 100644 alembic_db/versions/0002_merge_to_asset_references.py create mode 100644 app/assets/api/upload.py delete mode 100644 app/assets/database/bulk_ops.py delete mode 100644 app/assets/database/queries.py create mode 100644 app/assets/database/queries/__init__.py create mode 100644 app/assets/database/queries/asset.py create mode 100644 app/assets/database/queries/asset_reference.py create mode 100644 app/assets/database/queries/common.py create mode 100644 app/assets/database/queries/tags.py delete mode 100644 app/assets/database/tags.py delete mode 100644 app/assets/hashing.py delete mode 100644 app/assets/manager.py create mode 100644 app/assets/seeder.py create mode 100644 app/assets/services/__init__.py create mode 100644 app/assets/services/asset_management.py create mode 100644 app/assets/services/bulk_ingest.py create mode 100644 app/assets/services/file_utils.py create mode 100644 app/assets/services/hashing.py create mode 100644 app/assets/services/ingest.py create mode 100644 app/assets/services/metadata_extract.py create mode 100644 app/assets/services/path_utils.py create mode 100644 app/assets/services/schemas.py create mode 100644 app/assets/services/tagging.py create mode 100644 tests-unit/assets_test/helpers.py create mode 100644 tests-unit/assets_test/queries/conftest.py create mode 100644 tests-unit/assets_test/queries/test_asset.py create mode 100644 tests-unit/assets_test/queries/test_asset_info.py create mode 100644 tests-unit/assets_test/queries/test_cache_state.py create mode 100644 tests-unit/assets_test/queries/test_metadata.py create mode 100644 tests-unit/assets_test/queries/test_tags.py create mode 100644 tests-unit/assets_test/services/__init__.py create mode 100644 tests-unit/assets_test/services/conftest.py create mode 100644 tests-unit/assets_test/services/test_asset_management.py create mode 100644 tests-unit/assets_test/services/test_bulk_ingest.py create mode 100644 tests-unit/assets_test/services/test_enrich.py create mode 100644 tests-unit/assets_test/services/test_ingest.py create mode 100644 tests-unit/assets_test/services/test_tagging.py create mode 100644 tests-unit/assets_test/test_file_utils.py create mode 100644 tests-unit/assets_test/test_sync_references.py rename tests-unit/assets_test/{test_tags.py => test_tags_api.py} (98%) create mode 100644 tests-unit/seeder_test/test_seeder.py create mode 100644 utils/mime_types.py diff --git a/alembic_db/versions/0002_merge_to_asset_references.py b/alembic_db/versions/0002_merge_to_asset_references.py new file mode 100644 index 000000000..1ac1b980c --- /dev/null +++ b/alembic_db/versions/0002_merge_to_asset_references.py @@ -0,0 +1,267 @@ +""" +Merge AssetInfo and AssetCacheState into unified asset_references table. + +This migration drops old tables and creates the new unified schema. +All existing data is discarded. + +Revision ID: 0002_merge_to_asset_references +Revises: 0001_assets +Create Date: 2025-02-11 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0002_merge_to_asset_references" +down_revision = "0001_assets" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Drop old tables (order matters due to FK constraints) + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") + + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + # Truncate assets table (cascades handled by dropping dependent tables first) + op.execute("DELETE FROM assets") + + # Create asset_references table + op.create_table( + "asset_references", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column( + "asset_id", + sa.String(length=36), + sa.ForeignKey("assets.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("file_path", sa.Text(), nullable=True), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column( + "needs_verify", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column( + "is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false") + ), + sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column( + "preview_id", + sa.String(length=36), + sa.ForeignKey("assets.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=False), nullable=True), + sa.CheckConstraint( + "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" + ), + sa.CheckConstraint( + "enrichment_level >= 0 AND enrichment_level <= 2", + name="ck_ar_enrichment_level_range", + ), + ) + op.create_index( + "uq_asset_references_file_path", "asset_references", ["file_path"], unique=True + ) + op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"]) + op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"]) + op.create_index("ix_asset_references_name", "asset_references", ["name"]) + op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"]) + op.create_index( + "ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"] + ) + op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"]) + op.create_index( + "ix_asset_references_last_access_time", "asset_references", ["last_access_time"] + ) + op.create_index( + "ix_asset_references_owner_name", "asset_references", ["owner_id", "name"] + ) + op.create_index("ix_asset_references_deleted_at", "asset_references", ["deleted_at"]) + + # Create asset_reference_tags table + op.create_table( + "asset_reference_tags", + sa.Column( + "asset_reference_id", + sa.String(length=36), + sa.ForeignKey("asset_references.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "tag_name", + sa.String(length=512), + sa.ForeignKey("tags.name", ondelete="RESTRICT"), + nullable=False, + ), + sa.Column( + "origin", sa.String(length=32), nullable=False, server_default="manual" + ), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint( + "asset_reference_id", "tag_name", name="pk_asset_reference_tags" + ), + ) + op.create_index( + "ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"] + ) + op.create_index( + "ix_asset_reference_tags_asset_reference_id", + "asset_reference_tags", + ["asset_reference_id"], + ) + + # Create asset_reference_meta table + op.create_table( + "asset_reference_meta", + sa.Column( + "asset_reference_id", + sa.String(length=36), + sa.ForeignKey("asset_references.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint( + "asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta" + ), + ) + op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"]) + op.create_index( + "ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"] + ) + op.create_index( + "ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"] + ) + op.create_index( + "ix_asset_reference_meta_key_val_bool", + "asset_reference_meta", + ["key", "val_bool"], + ) + + +def downgrade() -> None: + """Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema. + + NOTE: Data is not recoverable. The upgrade discards all rows from the old + tables and truncates assets. After downgrade the old schema will be empty. + A filesystem rescan will repopulate data once the older code is running. + """ + # Drop new tables (order matters due to FK constraints) + op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta") + op.drop_table("asset_reference_meta") + + op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags") + op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags") + op.drop_table("asset_reference_tags") + + op.drop_index("ix_asset_references_deleted_at", table_name="asset_references") + op.drop_index("ix_asset_references_owner_name", table_name="asset_references") + op.drop_index("ix_asset_references_last_access_time", table_name="asset_references") + op.drop_index("ix_asset_references_created_at", table_name="asset_references") + op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references") + op.drop_index("ix_asset_references_is_missing", table_name="asset_references") + op.drop_index("ix_asset_references_name", table_name="asset_references") + op.drop_index("ix_asset_references_owner_id", table_name="asset_references") + op.drop_index("ix_asset_references_asset_id", table_name="asset_references") + op.drop_index("uq_asset_references_file_path", table_name="asset_references") + op.drop_table("asset_references") + + # Truncate assets (upgrade deleted all rows; downgrade starts fresh too) + op.execute("DELETE FROM assets") + + # Recreate old tables from 0001_assets schema + op.create_table( + "assets_info", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) + + op.create_table( + "asset_cache_state", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("file_path", sa.Text(), nullable=False), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) + + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 7676e50b4..40dee9f46 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -1,56 +1,144 @@ +import asyncio +import functools +import json import logging -import uuid -import urllib.parse import os -import contextlib -from aiohttp import web +import urllib.parse +import uuid +from typing import Any +from aiohttp import web from pydantic import ValidationError -import app.assets.manager as manager -from app import user_manager -from app.assets.api import schemas_in -from app.assets.helpers import get_query_dict -from app.assets.scanner import seed_assets - import folder_paths +from app import user_manager +from app.assets.api import schemas_in, schemas_out +from app.assets.api.schemas_in import ( + AssetValidationError, + UploadError, +) +from app.assets.helpers import validate_blake3_hash +from app.assets.api.upload import ( + delete_temp_file_if_exists, + parse_multipart_upload, +) +from app.assets.seeder import ScanInProgressError, asset_seeder +from app.assets.services import ( + DependencyMissingError, + HashMismatchError, + apply_tags, + asset_exists, + create_from_hash, + delete_asset_reference, + get_asset_detail, + list_assets_page, + list_tags, + remove_tags, + resolve_asset_for_download, + update_asset_metadata, + upload_from_temp_path, +) ROUTES = web.RouteTableDef() USER_MANAGER: user_manager.UserManager | None = None +_ASSETS_ENABLED = False + + +def _require_assets_feature_enabled(handler): + @functools.wraps(handler) + async def wrapper(request: web.Request) -> web.Response: + if not _ASSETS_ENABLED: + return _build_error_response( + 503, + "SERVICE_DISABLED", + "Assets system is disabled. Start the server with --enable-assets to use this feature.", + ) + return await handler(request) + + return wrapper + # UUID regex (canonical hyphenated form, case-insensitive) UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" -# Note to any custom node developers reading this code: -# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same. -def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: - global USER_MANAGER - USER_MANAGER = user_manager_instance +def get_query_dict(request: web.Request) -> dict[str, Any]: + """Gets a dictionary of query parameters from the request. + + request.query is a MultiMapping[str], needs to be converted to a dict + to be validated by Pydantic. + """ + query_dict = { + key: request.query.getall(key) + if len(request.query.getall(key)) > 1 + else request.query.get(key) + for key in request.query.keys() + } + return query_dict + + +# Note to any custom node developers reading this code: +# The assets system is not yet fully implemented, +# do not rely on the code in /app/assets remaining the same. + + +def register_assets_routes( + app: web.Application, + user_manager_instance: user_manager.UserManager | None = None, +) -> None: + global USER_MANAGER, _ASSETS_ENABLED + if user_manager_instance is not None: + USER_MANAGER = user_manager_instance + _ASSETS_ENABLED = True app.add_routes(ROUTES) -def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response: - return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) + +def disable_assets_routes() -> None: + """Disable asset routes at runtime (e.g. after DB init failure).""" + global _ASSETS_ENABLED + _ASSETS_ENABLED = False -def _validation_error_response(code: str, ve: ValidationError) -> web.Response: - return _error_response(400, code, "Validation failed.", {"errors": ve.json()}) +def _build_error_response( + status: int, code: str, message: str, details: dict | None = None +) -> web.Response: + return web.json_response( + {"error": {"code": code, "message": message, "details": details or {}}}, + status=status, + ) + + +def _build_validation_error_response(code: str, ve: ValidationError) -> web.Response: + errors = json.loads(ve.json()) + return _build_error_response(400, code, "Validation failed.", {"errors": errors}) + + +def _validate_sort_field(requested: str | None) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" @ROUTES.head("/api/assets/hash/{hash}") +@_require_assets_feature_enabled async def head_asset_by_hash(request: web.Request) -> web.Response: hash_str = request.match_info.get("hash", "").strip().lower() - if not hash_str or ":" not in hash_str: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - algo, digest = hash_str.split(":", 1) - if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - exists = manager.asset_exists(asset_hash=hash_str) + try: + hash_str = validate_blake3_hash(hash_str) + except ValueError: + return _build_error_response( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) + exists = asset_exists(hash_str) return web.Response(status=200 if exists else 404) @ROUTES.get("/api/assets") -async def list_assets(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def list_assets_route(request: web.Request) -> web.Response: """ GET request to list assets. """ @@ -58,78 +146,140 @@ async def list_assets(request: web.Request) -> web.Response: try: q = schemas_in.ListAssetsQuery.model_validate(query_dict) except ValidationError as ve: - return _validation_error_response("INVALID_QUERY", ve) + return _build_validation_error_response("INVALID_QUERY", ve) - payload = manager.list_assets( + sort = _validate_sort_field(q.sort) + order_candidate = (q.order or "desc").lower() + order = order_candidate if order_candidate in {"asc", "desc"} else "desc" + + result = list_assets_page( + owner_id=USER_MANAGER.get_request_user_id(request), include_tags=q.include_tags, exclude_tags=q.exclude_tags, name_contains=q.name_contains, metadata_filter=q.metadata_filter, limit=q.limit, offset=q.offset, - sort=q.sort, - order=q.order, - owner_id=USER_MANAGER.get_request_user_id(request), + sort=sort, + order=order, + ) + + summaries = [ + schemas_out.AssetSummary( + id=item.ref.id, + name=item.ref.name, + asset_hash=item.asset.hash if item.asset else None, + size=int(item.asset.size_bytes) if item.asset else None, + mime_type=item.asset.mime_type if item.asset else None, + tags=item.tags, + created_at=item.ref.created_at, + updated_at=item.ref.updated_at, + last_access_time=item.ref.last_access_time, + ) + for item in result.items + ] + + payload = schemas_out.AssetsList( + assets=summaries, + total=result.total, + has_more=(q.offset + len(summaries)) < result.total, ) return web.json_response(payload.model_dump(mode="json", exclude_none=True)) @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") -async def get_asset(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def get_asset_route(request: web.Request) -> web.Response: """ GET request to get an asset's info as JSON. """ - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - result = manager.get_asset( - asset_info_id=asset_info_id, + result = get_asset_detail( + reference_id=reference_id, owner_id=USER_MANAGER.get_request_user_id(request), ) + if not result: + return _build_error_response( + 404, + "ASSET_NOT_FOUND", + f"AssetReference {reference_id} not found", + {"id": reference_id}, + ) + + payload = schemas_out.AssetDetail( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash if result.asset else None, + size=int(result.asset.size_bytes) if result.asset else None, + mime_type=result.asset.mime_type if result.asset else None, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + ) except ValueError as e: - return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(e), {"id": reference_id} + ) except Exception: logging.exception( - "get_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "get_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") +@_require_assets_feature_enabled async def download_asset_content(request: web.Request) -> web.Response: - # question: do we need disposition? could we just stick with one of these? disposition = request.query.get("disposition", "attachment").lower().strip() if disposition not in {"inline", "attachment"}: disposition = "attachment" try: - abs_path, content_type, filename = manager.resolve_asset_content_for_download( - asset_info_id=str(uuid.UUID(request.match_info["id"])), + result = resolve_asset_for_download( + reference_id=str(uuid.UUID(request.match_info["id"])), owner_id=USER_MANAGER.get_request_user_id(request), ) + abs_path = result.abs_path + content_type = result.content_type + filename = result.download_name except ValueError as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve)) + return _build_error_response(404, "ASSET_NOT_FOUND", str(ve)) except NotImplementedError as nie: - return _error_response(501, "BACKEND_UNSUPPORTED", str(nie)) + return _build_error_response(501, "BACKEND_UNSUPPORTED", str(nie)) except FileNotFoundError: - return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") + return _build_error_response( + 404, "FILE_NOT_FOUND", "Underlying file not found on disk." + ) - quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'") - cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' + _DANGEROUS_MIME_TYPES = { + "text/html", "text/html-sandboxed", "application/xhtml+xml", + "text/javascript", "text/css", + } + if content_type in _DANGEROUS_MIME_TYPES: + content_type = "application/octet-stream" + + safe_name = (filename or "").replace("\r", "").replace("\n", "") + encoded = urllib.parse.quote(safe_name) + cd = f"{disposition}; filename*=UTF-8''{encoded}" file_size = os.path.getsize(abs_path) + size_mb = file_size / (1024 * 1024) logging.info( - "download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s", + "download_asset_content: path=%s, size=%d bytes (%.2f MB), type=%s, name=%s", abs_path, file_size, - file_size / (1024 * 1024), + size_mb, content_type, filename, ) - async def file_sender(): + async def stream_file_chunks(): chunk_size = 64 * 1024 with open(abs_path, "rb") as f: while True: @@ -139,26 +289,30 @@ async def download_asset_content(request: web.Request) -> web.Response: yield chunk return web.Response( - body=file_sender(), + body=stream_file_chunks(), content_type=content_type, headers={ "Content-Disposition": cd, "Content-Length": str(file_size), + "X-Content-Type-Options": "nosniff", }, ) @ROUTES.post("/api/assets/from-hash") -async def create_asset_from_hash(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def create_asset_from_hash_route(request: web.Request) -> web.Response: try: payload = await request.json() body = schemas_in.CreateFromHashBody.model_validate(payload) except ValidationError as ve: - return _validation_error_response("INVALID_BODY", ve) + return _build_validation_error_response("INVALID_BODY", ve) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) - result = manager.create_asset_from_hash( + result = create_from_hash( hash_str=body.hash, name=body.name, tags=body.tags, @@ -166,246 +320,209 @@ async def create_asset_from_hash(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), ) if result is None: - return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") - return web.json_response(result.model_dump(mode="json"), status=201) + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" + ) + + payload_out = schemas_out.AssetCreated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash, + size=int(result.asset.size_bytes), + mime_type=result.asset.mime_type, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + created_new=result.created_new, + ) + return web.json_response(payload_out.model_dump(mode="json"), status=201) @ROUTES.post("/api/assets") +@_require_assets_feature_enabled async def upload_asset(request: web.Request) -> web.Response: """Multipart/form-data endpoint for Asset uploads.""" - if not (request.content_type or "").lower().startswith("multipart/"): - return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.") - - reader = await request.multipart() - - file_present = False - file_client_name: str | None = None - tags_raw: list[str] = [] - provided_name: str | None = None - user_metadata_raw: str | None = None - provided_hash: str | None = None - provided_hash_exists: bool | None = None - - file_written = 0 - tmp_path: str | None = None - while True: - field = await reader.next() - if field is None: - break - - fname = getattr(field, "name", "") or "" - - if fname == "hash": - try: - s = ((await field.text()) or "").strip().lower() - except Exception: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - - if s: - if ":" not in s: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - provided_hash = f"{algo}:{digest}" - try: - provided_hash_exists = manager.asset_exists(asset_hash=provided_hash) - except Exception: - provided_hash_exists = None # do not fail the whole request here - - elif fname == "file": - file_present = True - file_client_name = (field.filename or "").strip() - - if provided_hash and provided_hash_exists is True: - # If client supplied a hash that we know exists, drain but do not write to disk - try: - while True: - chunk = await field.read_chunk(8 * 1024 * 1024) - if not chunk: - break - file_written += len(chunk) - except Exception: - return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.") - continue # Do not create temp file; we will create AssetInfo from the existing content - - # Otherwise, store to temp for hashing/ingest - uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") - unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) - os.makedirs(unique_dir, exist_ok=True) - tmp_path = os.path.join(unique_dir, ".upload.part") - - try: - with open(tmp_path, "wb") as f: - while True: - chunk = await field.read_chunk(8 * 1024 * 1024) - if not chunk: - break - f.write(chunk) - file_written += len(chunk) - except Exception: - try: - if os.path.exists(tmp_path or ""): - os.remove(tmp_path) - finally: - return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.") - elif fname == "tags": - tags_raw.append((await field.text()) or "") - elif fname == "name": - provided_name = (await field.text()) or None - elif fname == "user_metadata": - user_metadata_raw = (await field.text()) or None - - # If client did not send file, and we are not doing a from-hash fast path -> error - if not file_present and not (provided_hash and provided_hash_exists): - return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.") - - if file_present and file_written == 0 and not (provided_hash and provided_hash_exists): - # Empty upload is only acceptable if we are fast-pathing from existing hash - try: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - finally: - return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.") - try: - spec = schemas_in.UploadAssetSpec.model_validate({ - "tags": tags_raw, - "name": provided_name, - "user_metadata": user_metadata_raw, - "hash": provided_hash, - }) - except ValidationError as ve: - try: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - finally: - return _validation_error_response("INVALID_BODY", ve) - - # Validate models category against configured folders (consistent with previous behavior) - if spec.tags and spec.tags[0] == "models": - if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - return _error_response( - 400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'" - ) + parsed = await parse_multipart_upload(request, check_hash_exists=asset_exists) + except UploadError as e: + return _build_error_response(e.status, e.code, e.message) owner_id = USER_MANAGER.get_request_user_id(request) - # Fast path: if a valid provided hash exists, create AssetInfo without writing anything - if spec.hash and provided_hash_exists is True: - try: - result = manager.create_asset_from_hash( + try: + spec = schemas_in.UploadAssetSpec.model_validate( + { + "tags": parsed.tags_raw, + "name": parsed.provided_name, + "user_metadata": parsed.user_metadata_raw, + "hash": parsed.provided_hash, + } + ) + except ValidationError as ve: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response( + 400, "INVALID_BODY", f"Validation failed: {ve.json()}" + ) + + if spec.tags and spec.tags[0] == "models": + if ( + len(spec.tags) < 2 + or spec.tags[1] not in folder_paths.folder_names_and_paths + ): + delete_temp_file_if_exists(parsed.tmp_path) + category = spec.tags[1] if len(spec.tags) >= 2 else "" + return _build_error_response( + 400, "INVALID_BODY", f"unknown models category '{category}'" + ) + + try: + # Fast path: hash exists, create AssetReference without writing anything + if spec.hash and parsed.provided_hash_exists is True: + result = create_from_hash( hash_str=spec.hash, name=spec.name or (spec.hash.split(":", 1)[1]), tags=spec.tags, user_metadata=spec.user_metadata or {}, owner_id=owner_id, ) - except Exception: - logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id) - return _error_response(500, "INTERNAL", "Unexpected server error.") + if result is None: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist" + ) + delete_temp_file_if_exists(parsed.tmp_path) + else: + # Otherwise, we must have a temp file path to ingest + if not parsed.tmp_path or not os.path.exists(parsed.tmp_path): + return _build_error_response( + 400, + "MISSING_INPUT", + "Provided hash not found and no file uploaded.", + ) - if result is None: - return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist") - - # Drain temp if we accidentally saved (e.g., hash field came after file) - if tmp_path and os.path.exists(tmp_path): - with contextlib.suppress(Exception): - os.remove(tmp_path) - - status = 200 if (not result.created_new) else 201 - return web.json_response(result.model_dump(mode="json"), status=status) - - # Otherwise, we must have a temp file path to ingest - if not tmp_path or not os.path.exists(tmp_path): - # The only case we reach here without a temp file is: client sent a hash that does not exist and no file - return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.") - - try: - created = manager.upload_asset_from_temp_path( - spec, - temp_path=tmp_path, - client_filename=file_client_name, - owner_id=owner_id, - expected_asset_hash=spec.hash, - ) - status = 201 if created.created_new else 200 - return web.json_response(created.model_dump(mode="json"), status=status) - except ValueError as e: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - msg = str(e) - if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": - return _error_response( - 400, - "HASH_MISMATCH", - "Uploaded file hash does not match provided hash.", + result = upload_from_temp_path( + temp_path=parsed.tmp_path, + name=spec.name, + tags=spec.tags, + user_metadata=spec.user_metadata or {}, + client_filename=parsed.file_client_name, + owner_id=owner_id, + expected_hash=spec.hash, ) - return _error_response(400, "BAD_REQUEST", "Invalid inputs.") + except AssetValidationError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, e.code, str(e)) + except ValueError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, "BAD_REQUEST", str(e)) + except HashMismatchError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, "HASH_MISMATCH", str(e)) + except DependencyMissingError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(503, "DEPENDENCY_MISSING", e.message) except Exception: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id) - return _error_response(500, "INTERNAL", "Unexpected server error.") + delete_temp_file_if_exists(parsed.tmp_path) + logging.exception("upload_asset failed for owner_id=%s", owner_id) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + + payload = schemas_out.AssetCreated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash, + size=int(result.asset.size_bytes), + mime_type=result.asset.mime_type, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + created_new=result.created_new, + ) + status = 201 if result.created_new else 200 + return web.json_response(payload.model_dump(mode="json"), status=status) @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") -async def update_asset(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) +@_require_assets_feature_enabled +async def update_asset_route(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) try: body = schemas_in.UpdateAssetBody.model_validate(await request.json()) except ValidationError as ve: - return _validation_error_response("INVALID_BODY", ve) + return _build_validation_error_response("INVALID_BODY", ve) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.update_asset( - asset_info_id=asset_info_id, + result = update_asset_metadata( + reference_id=reference_id, name=body.name, user_metadata=body.user_metadata, owner_id=USER_MANAGER.get_request_user_id(request), ) - except (ValueError, PermissionError) as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + payload = schemas_out.AssetUpdated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash if result.asset else None, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + updated_at=result.ref.updated_at, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) + except ValueError as ve: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "update_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "update_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") -async def delete_asset(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) - delete_content = request.query.get("delete_content") - delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} +@_require_assets_feature_enabled +async def delete_asset_route(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) + delete_content_param = request.query.get("delete_content") + delete_content = ( + False + if delete_content_param is None + else delete_content_param.lower() not in {"0", "false", "no"} + ) try: - deleted = manager.delete_asset_reference( - asset_info_id=asset_info_id, + deleted = delete_asset_reference( + reference_id=reference_id, owner_id=USER_MANAGER.get_request_user_id(request), delete_content_if_orphan=delete_content, ) except Exception: logging.exception( - "delete_asset_reference failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "delete_asset_reference failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") if not deleted: - return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.") + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"AssetReference {reference_id} not found." + ) return web.Response(status=204) @ROUTES.get("/api/tags") +@_require_assets_feature_enabled async def get_tags(request: web.Request) -> web.Response: """ GET request to list all tags based on query parameters. @@ -415,12 +532,14 @@ async def get_tags(request: web.Request) -> web.Response: try: query = schemas_in.TagsListQuery.model_validate(query_map) except ValidationError as e: - return web.json_response( - {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}}, - status=400, + return _build_error_response( + 400, + "INVALID_QUERY", + "Invalid query parameters", + {"errors": json.loads(e.json())}, ) - result = manager.list_tags( + rows, total = list_tags( prefix=query.prefix, limit=query.limit, offset=query.offset, @@ -428,87 +547,212 @@ async def get_tags(request: web.Request) -> web.Response: include_zero=query.include_zero, owner_id=USER_MANAGER.get_request_user_id(request), ) - return web.json_response(result.model_dump(mode="json")) + + tags = [ + schemas_out.TagUsage(name=name, count=count, type=tag_type) + for (name, tag_type, count) in rows + ] + payload = schemas_out.TagsList( + tags=tags, total=total, has_more=(query.offset + len(tags)) < total + ) + return web.json_response(payload.model_dump(mode="json")) @ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled async def add_asset_tags(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - payload = await request.json() - data = schemas_in.TagsAdd.model_validate(payload) + json_payload = await request.json() + data = schemas_in.TagsAdd.model_validate(json_payload) except ValidationError as ve: - return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()}) + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags add.", + {"errors": ve.errors()}, + ) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.add_tags_to_asset( - asset_info_id=asset_info_id, + result = apply_tags( + reference_id=reference_id, tags=data.tags, origin="manual", owner_id=USER_MANAGER.get_request_user_id(request), ) - except (ValueError, PermissionError) as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + payload = schemas_out.TagsAdd( + added=result.added, + already_present=result.already_present, + total_tags=result.total_tags, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) + except ValueError as ve: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "add_tags_to_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "add_tags_to_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled async def delete_asset_tags(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - payload = await request.json() - data = schemas_in.TagsRemove.model_validate(payload) + json_payload = await request.json() + data = schemas_in.TagsRemove.model_validate(json_payload) except ValidationError as ve: - return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()}) + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags remove.", + {"errors": ve.errors()}, + ) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.remove_tags_from_asset( - asset_info_id=asset_info_id, + result = remove_tags( + reference_id=reference_id, tags=data.tags, owner_id=USER_MANAGER.get_request_user_id(request), ) + payload = schemas_out.TagsRemove( + removed=result.removed, + not_present=result.not_present, + total_tags=result.total_tags, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) except ValueError as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "remove_tags_from_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.post("/api/assets/seed") -async def seed_assets_endpoint(request: web.Request) -> web.Response: - """Trigger asset seeding for specified roots (models, input, output).""" +@_require_assets_feature_enabled +async def seed_assets(request: web.Request) -> web.Response: + """Trigger asset seeding for specified roots (models, input, output). + + Query params: + wait: If "true", block until scan completes (synchronous behavior for tests) + + Returns: + 202 Accepted if scan started + 409 Conflict if scan already running + 200 OK with final stats if wait=true + """ try: payload = await request.json() roots = payload.get("roots", ["models", "input", "output"]) except Exception: roots = ["models", "input", "output"] - valid_roots = [r for r in roots if r in ("models", "input", "output")] + valid_roots = tuple(r for r in roots if r in ("models", "input", "output")) if not valid_roots: - return _error_response(400, "INVALID_BODY", "No valid roots specified") + return _build_error_response(400, "INVALID_BODY", "No valid roots specified") + wait_param = request.query.get("wait", "").lower() + should_wait = wait_param in ("true", "1", "yes") + + started = asset_seeder.start(roots=valid_roots) + if not started: + return web.json_response({"status": "already_running"}, status=409) + + if should_wait: + await asyncio.to_thread(asset_seeder.wait) + status = asset_seeder.get_status() + return web.json_response( + { + "status": "completed", + "progress": { + "scanned": status.progress.scanned if status.progress else 0, + "total": status.progress.total if status.progress else 0, + "created": status.progress.created if status.progress else 0, + "skipped": status.progress.skipped if status.progress else 0, + }, + "errors": status.errors, + }, + status=200, + ) + + return web.json_response({"status": "started"}, status=202) + + +@ROUTES.get("/api/assets/seed/status") +@_require_assets_feature_enabled +async def get_seed_status(request: web.Request) -> web.Response: + """Get current scan status and progress.""" + status = asset_seeder.get_status() + return web.json_response( + { + "state": status.state.value, + "progress": { + "scanned": status.progress.scanned, + "total": status.progress.total, + "created": status.progress.created, + "skipped": status.progress.skipped, + } + if status.progress + else None, + "errors": status.errors, + }, + status=200, + ) + + +@ROUTES.post("/api/assets/seed/cancel") +@_require_assets_feature_enabled +async def cancel_seed(request: web.Request) -> web.Response: + """Request cancellation of in-progress scan.""" + cancelled = asset_seeder.cancel() + if cancelled: + return web.json_response({"status": "cancelling"}, status=200) + return web.json_response({"status": "idle"}, status=200) + + +@ROUTES.post("/api/assets/prune") +@_require_assets_feature_enabled +async def mark_missing_assets(request: web.Request) -> web.Response: + """Mark assets as missing when outside all known root prefixes. + + This is a non-destructive soft-delete operation. Assets and metadata + are preserved, but references are flagged as missing. They can be + restored if the file reappears in a future scan. + + Returns: + 200 OK with count of marked assets + 409 Conflict if a scan is currently running + """ try: - seed_assets(tuple(valid_roots)) - except Exception: - logging.exception("seed_assets failed for roots=%s", valid_roots) - return _error_response(500, "INTERNAL", "Seed operation failed") - - return web.json_response({"seeded": valid_roots}, status=200) + marked = asset_seeder.mark_missing_outside_prefixes() + except ScanInProgressError: + return web.json_response( + {"status": "scan_running", "marked": 0}, + status=409, + ) + return web.json_response({"status": "completed", "marked": marked}, status=200) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 6707ffb0c..d255c938e 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -1,6 +1,8 @@ import json +from dataclasses import dataclass from typing import Any, Literal +from app.assets.helpers import validate_blake3_hash from pydantic import ( BaseModel, ConfigDict, @@ -10,6 +12,41 @@ from pydantic import ( model_validator, ) + +class UploadError(Exception): + """Error during upload parsing with HTTP status and code.""" + + def __init__(self, status: int, code: str, message: str): + super().__init__(message) + self.status = status + self.code = code + self.message = message + + +class AssetValidationError(Exception): + """Validation error in asset processing (invalid tags, metadata, etc.).""" + + def __init__(self, code: str, message: str): + super().__init__(message) + self.code = code + self.message = message + + +@dataclass +class ParsedUpload: + """Result of parsing a multipart upload request.""" + + file_present: bool + file_written: int + file_client_name: str | None + tmp_path: str | None + tags_raw: list[str] + provided_name: str | None + user_metadata_raw: str | None + provided_hash: str | None + provided_hash_exists: bool | None + + class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) @@ -21,7 +58,9 @@ class ListAssetsQuery(BaseModel): limit: conint(ge=1, le=500) = 20 offset: conint(ge=0) = 0 - sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" + sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = ( + "created_at" + ) order: Literal["asc", "desc"] = "desc" @field_validator("include_tags", "exclude_tags", mode="before") @@ -61,7 +100,7 @@ class UpdateAssetBody(BaseModel): user_metadata: dict[str, Any] | None = None @model_validator(mode="after") - def _at_least_one(self): + def _validate_at_least_one_field(self): if self.name is None and self.user_metadata is None: raise ValueError("Provide at least one of: name, user_metadata.") return self @@ -78,19 +117,11 @@ class CreateFromHashBody(BaseModel): @field_validator("hash") @classmethod def _require_blake3(cls, v): - s = (v or "").strip().lower() - if ":" not in s: - raise ValueError("hash must be 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3": - raise ValueError("only canonical 'blake3:' is accepted here") - if not digest or any(c for c in digest if c not in "0123456789abcdef"): - raise ValueError("hash digest must be lowercase hex") - return s + return validate_blake3_hash(v or "") @field_validator("tags", mode="before") @classmethod - def _tags_norm(cls, v): + def _normalize_tags_field(cls, v): if v is None: return [] if isinstance(v, list): @@ -154,15 +185,16 @@ class TagsRemove(TagsAdd): class UploadAssetSpec(BaseModel): """Upload Asset operation. + - tags: ordered; first is root ('models'|'input'|'output'); - if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths + if root == 'models', second must be a valid category - name: display name - user_metadata: arbitrary JSON object (optional) - - hash: optional canonical 'blake3:' provided by the client for validation / fast-path + - hash: optional canonical 'blake3:' for validation / fast-path - Files created via this endpoint are stored on disk using the **content hash** as the filename stem - and the original extension is preserved when available. + Files are stored using the content hash as filename stem. """ + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) tags: list[str] = Field(..., min_length=1) @@ -175,17 +207,10 @@ class UploadAssetSpec(BaseModel): def _parse_hash(cls, v): if v is None: return None - s = str(v).strip().lower() + s = str(v).strip() if not s: return None - if ":" not in s: - raise ValueError("hash must be 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3": - raise ValueError("only canonical 'blake3:' is accepted here") - if not digest or any(c for c in digest if c not in "0123456789abcdef"): - raise ValueError("hash digest must be lowercase hex") - return f"{algo}:{digest}" + return validate_blake3_hash(s) @field_validator("tags", mode="before") @classmethod @@ -260,5 +285,7 @@ class UploadAssetSpec(BaseModel): raise ValueError("first tag must be one of: models, input, output") if root == "models": if len(self.tags) < 2: - raise ValueError("models uploads require a category tag as the second tag") + raise ValueError( + "models uploads require a category tag as the second tag" + ) return self diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index b6fb3da0c..f36447856 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -19,7 +19,7 @@ class AssetSummary(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("created_at", "updated_at", "last_access_time") - def _ser_dt(self, v: datetime | None, _info): + def _serialize_datetime(self, v: datetime | None, _info): return v.isoformat() if v else None @@ -40,7 +40,7 @@ class AssetUpdated(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("updated_at") - def _ser_updated(self, v: datetime | None, _info): + def _serialize_updated_at(self, v: datetime | None, _info): return v.isoformat() if v else None @@ -59,7 +59,7 @@ class AssetDetail(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("created_at", "last_access_time") - def _ser_dt(self, v: datetime | None, _info): + def _serialize_datetime(self, v: datetime | None, _info): return v.isoformat() if v else None diff --git a/app/assets/api/upload.py b/app/assets/api/upload.py new file mode 100644 index 000000000..721c12f4d --- /dev/null +++ b/app/assets/api/upload.py @@ -0,0 +1,171 @@ +import logging +import os +import uuid +from typing import Callable + +from aiohttp import web + +import folder_paths +from app.assets.api.schemas_in import ParsedUpload, UploadError +from app.assets.helpers import validate_blake3_hash + + +def normalize_and_validate_hash(s: str) -> str: + """Validate and normalize a hash string. + + Returns canonical 'blake3:' or raises UploadError. + """ + try: + return validate_blake3_hash(s) + except ValueError: + raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:'") + + +async def parse_multipart_upload( + request: web.Request, + check_hash_exists: Callable[[str], bool], +) -> ParsedUpload: + """ + Parse a multipart/form-data upload request. + + Args: + request: The aiohttp request + check_hash_exists: Callable(hash_str) -> bool to check if a hash exists + + Returns: + ParsedUpload with parsed fields and temp file path + + Raises: + UploadError: On validation or I/O errors + """ + if not (request.content_type or "").lower().startswith("multipart/"): + raise UploadError( + 415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads." + ) + + reader = await request.multipart() + + file_present = False + file_client_name: str | None = None + tags_raw: list[str] = [] + provided_name: str | None = None + user_metadata_raw: str | None = None + provided_hash: str | None = None + provided_hash_exists: bool | None = None + + file_written = 0 + tmp_path: str | None = None + + while True: + field = await reader.next() + if field is None: + break + + fname = getattr(field, "name", "") or "" + + if fname == "hash": + try: + s = ((await field.text()) or "").strip().lower() + except Exception: + raise UploadError( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) + + if s: + provided_hash = normalize_and_validate_hash(s) + try: + provided_hash_exists = check_hash_exists(provided_hash) + except Exception as e: + logging.exception( + "check_hash_exists failed for hash=%s: %s", provided_hash, e + ) + raise UploadError( + 500, + "HASH_CHECK_FAILED", + "Backend error while checking asset hash.", + ) + + elif fname == "file": + file_present = True + file_client_name = (field.filename or "").strip() + + if provided_hash and provided_hash_exists is True: + # Hash exists - drain file but don't write to disk + try: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + file_written += len(chunk) + except Exception: + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file." + ) + continue + + uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") + unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) + os.makedirs(unique_dir, exist_ok=True) + tmp_path = os.path.join(unique_dir, ".upload.part") + + try: + with open(tmp_path, "wb") as f: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + file_written += len(chunk) + except Exception: + delete_temp_file_if_exists(tmp_path) + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file." + ) + + elif fname == "tags": + tags_raw.append((await field.text()) or "") + elif fname == "name": + provided_name = (await field.text()) or None + elif fname == "user_metadata": + user_metadata_raw = (await field.text()) or None + + if not file_present and not (provided_hash and provided_hash_exists): + raise UploadError( + 400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'." + ) + + if ( + file_present + and file_written == 0 + and not (provided_hash and provided_hash_exists) + ): + delete_temp_file_if_exists(tmp_path) + raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.") + + return ParsedUpload( + file_present=file_present, + file_written=file_written, + file_client_name=file_client_name, + tmp_path=tmp_path, + tags_raw=tags_raw, + provided_name=provided_name, + user_metadata_raw=user_metadata_raw, + provided_hash=provided_hash, + provided_hash_exists=provided_hash_exists, + ) + + +def delete_temp_file_if_exists(tmp_path: str | None) -> None: + """Safely remove a temp file and its parent directory if empty.""" + if tmp_path: + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError as e: + logging.debug("Failed to delete temp file %s: %s", tmp_path, e) + try: + parent = os.path.dirname(tmp_path) + if parent and os.path.isdir(parent): + os.rmdir(parent) # only succeeds if empty + except OSError: + pass diff --git a/app/assets/database/bulk_ops.py b/app/assets/database/bulk_ops.py deleted file mode 100644 index c7b75290a..000000000 --- a/app/assets/database/bulk_ops.py +++ /dev/null @@ -1,204 +0,0 @@ -import os -import uuid -import sqlalchemy -from typing import Iterable -from sqlalchemy.orm import Session -from sqlalchemy.dialects import sqlite - -from app.assets.helpers import utcnow -from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta - -MAX_BIND_PARAMS = 800 - -def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]: - if not rows: - return [] - rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row)) - for i in range(0, len(rows), rows_per_stmt): - yield rows[i:i + rows_per_stmt] - -def _iter_chunks(seq, n: int): - for i in range(0, len(seq), n): - yield seq[i:i + n] - -def _rows_per_stmt(cols: int) -> int: - return max(1, MAX_BIND_PARAMS // max(1, cols)) - - -def seed_from_paths_batch( - session: Session, - *, - specs: list[dict], - owner_id: str = "", -) -> dict: - """Each spec is a dict with keys: - - abs_path: str - - size_bytes: int - - mtime_ns: int - - info_name: str - - tags: list[str] - - fname: Optional[str] - """ - if not specs: - return {"inserted_infos": 0, "won_states": 0, "lost_states": 0} - - now = utcnow() - asset_rows: list[dict] = [] - state_rows: list[dict] = [] - path_to_asset: dict[str, str] = {} - asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row - path_list: list[str] = [] - - for sp in specs: - ap = os.path.abspath(sp["abs_path"]) - aid = str(uuid.uuid4()) - iid = str(uuid.uuid4()) - path_list.append(ap) - path_to_asset[ap] = aid - - asset_rows.append( - { - "id": aid, - "hash": None, - "size_bytes": sp["size_bytes"], - "mime_type": None, - "created_at": now, - } - ) - state_rows.append( - { - "asset_id": aid, - "file_path": ap, - "mtime_ns": sp["mtime_ns"], - } - ) - asset_to_info[aid] = { - "id": iid, - "owner_id": owner_id, - "name": sp["info_name"], - "asset_id": aid, - "preview_id": None, - "user_metadata": {"filename": sp["fname"]} if sp["fname"] else None, - "created_at": now, - "updated_at": now, - "last_access_time": now, - "_tags": sp["tags"], - "_filename": sp["fname"], - } - - # insert all seed Assets (hash=NULL) - ins_asset = sqlite.insert(Asset) - for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)): - session.execute(ins_asset, chunk) - - # try to claim AssetCacheState (file_path) - # Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted - ins_state = ( - sqlite.insert(AssetCacheState) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)): - session.execute(ins_state, chunk) - - # Query to find which of our paths won (were actually inserted) - winners_by_path: set[str] = set() - for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS): - result = session.execute( - sqlalchemy.select(AssetCacheState.file_path) - .where(AssetCacheState.file_path.in_(chunk)) - .where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk])) - ) - winners_by_path.update(result.scalars().all()) - - all_paths_set = set(path_list) - losers_by_path = all_paths_set - winners_by_path - lost_assets = [path_to_asset[p] for p in losers_by_path] - if lost_assets: # losers get their Asset removed - for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS): - session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk))) - - if not winners_by_path: - return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} - - # insert AssetInfo only for winners - # Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted - winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] - ins_info = ( - sqlite.insert(AssetInfo) - .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) - ) - for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)): - session.execute(ins_info, chunk) - - # Query to find which info rows were actually inserted (by matching our generated IDs) - all_info_ids = [row["id"] for row in winner_info_rows] - inserted_info_ids: set[str] = set() - for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS): - result = session.execute( - sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk)) - ) - inserted_info_ids.update(result.scalars().all()) - - # build and insert tag + meta rows for the AssetInfo - tag_rows: list[dict] = [] - meta_rows: list[dict] = [] - if inserted_info_ids: - for row in winner_info_rows: - iid = row["id"] - if iid not in inserted_info_ids: - continue - for t in row["_tags"]: - tag_rows.append({ - "asset_info_id": iid, - "tag_name": t, - "origin": "automatic", - "added_at": now, - }) - if row["_filename"]: - meta_rows.append( - { - "asset_info_id": iid, - "key": "filename", - "ordinal": 0, - "val_str": row["_filename"], - "val_num": None, - "val_bool": None, - "val_json": None, - } - ) - - bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS) - return { - "inserted_infos": len(inserted_info_ids), - "won_states": len(winners_by_path), - "lost_states": len(losers_by_path), - } - - -def bulk_insert_tags_and_meta( - session: Session, - *, - tag_rows: list[dict], - meta_rows: list[dict], - max_bind_params: int, -) -> None: - """Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING. - - tag_rows keys: asset_info_id, tag_name, origin, added_at - - meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json - """ - if tag_rows: - ins_links = ( - sqlite.insert(AssetInfoTag) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params): - session.execute(ins_links, chunk) - if meta_rows: - ins_meta = ( - sqlite.insert(AssetInfoMeta) - .on_conflict_do_nothing( - index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] - ) - ) - for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params): - session.execute(ins_meta, chunk) diff --git a/app/assets/database/models.py b/app/assets/database/models.py index 3cd28f68b..03c1c1707 100644 --- a/app/assets/database/models.py +++ b/app/assets/database/models.py @@ -2,8 +2,8 @@ from __future__ import annotations import uuid from datetime import datetime - from typing import Any + from sqlalchemy import ( JSON, BigInteger, @@ -16,102 +16,102 @@ from sqlalchemy import ( Numeric, String, Text, - UniqueConstraint, ) from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship -from app.assets.helpers import utcnow -from app.database.models import to_dict, Base +from app.assets.helpers import get_utc_now +from app.database.models import Base class Asset(Base): __tablename__ = "assets" - id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) hash: Mapped[str | None] = mapped_column(String(256), nullable=True) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) mime_type: Mapped[str | None] = mapped_column(String(255)) created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow + DateTime(timezone=False), nullable=False, default=get_utc_now ) - infos: Mapped[list[AssetInfo]] = relationship( - "AssetInfo", + references: Mapped[list[AssetReference]] = relationship( + "AssetReference", back_populates="asset", - primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id), - foreign_keys=lambda: [AssetInfo.asset_id], + primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id), + foreign_keys=lambda: [AssetReference.asset_id], cascade="all,delete-orphan", passive_deletes=True, ) - preview_of: Mapped[list[AssetInfo]] = relationship( - "AssetInfo", + preview_of: Mapped[list[AssetReference]] = relationship( + "AssetReference", back_populates="preview_asset", - primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id), - foreign_keys=lambda: [AssetInfo.preview_id], + primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id), + foreign_keys=lambda: [AssetReference.preview_id], viewonly=True, ) - cache_states: Mapped[list[AssetCacheState]] = relationship( - back_populates="asset", - cascade="all, delete-orphan", - passive_deletes=True, - ) - __table_args__ = ( Index("uq_assets_hash", "hash", unique=True), Index("ix_assets_mime_type", "mime_type"), CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - return to_dict(self, include_none=include_none) - def __repr__(self) -> str: return f"" -class AssetCacheState(Base): - __tablename__ = "asset_cache_state" +class AssetReference(Base): + """Unified model combining file cache state and user-facing metadata. - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) - file_path: Mapped[str] = mapped_column(Text, nullable=False) - mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) - needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + Each row represents either: + - A filesystem reference (file_path is set) with cache state + - An API-created reference (file_path is NULL) without cache state + """ - asset: Mapped[Asset] = relationship(back_populates="cache_states") + __tablename__ = "asset_references" - __table_args__ = ( - Index("ix_asset_cache_state_file_path", "file_path"), - Index("ix_asset_cache_state_asset_id", "asset_id"), - CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), - UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) + asset_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - return to_dict(self, include_none=include_none) + # Cache state fields (from former AssetCacheState) + file_path: Mapped[str | None] = mapped_column(Text, nullable=True) + mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - def __repr__(self) -> str: - return f"" - - -class AssetInfo(Base): - __tablename__ = "assets_info" - - id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + # Info fields (from former AssetInfo) owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") name: Mapped[str] = mapped_column(String(512), nullable=False) - asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) - preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) - user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) - last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + preview_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="SET NULL") + ) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column( + JSON(none_as_null=True) + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + last_access_time: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + deleted_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=False), nullable=True, default=None + ) asset: Mapped[Asset] = relationship( "Asset", - back_populates="infos", + back_populates="references", foreign_keys=[asset_id], lazy="selectin", ) @@ -121,51 +121,59 @@ class AssetInfo(Base): foreign_keys=[preview_id], ) - metadata_entries: Mapped[list[AssetInfoMeta]] = relationship( - back_populates="asset_info", + metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship( + back_populates="asset_reference", cascade="all,delete-orphan", passive_deletes=True, ) - tag_links: Mapped[list[AssetInfoTag]] = relationship( - back_populates="asset_info", + tag_links: Mapped[list[AssetReferenceTag]] = relationship( + back_populates="asset_reference", cascade="all,delete-orphan", passive_deletes=True, - overlaps="tags,asset_infos", + overlaps="tags,asset_references", ) tags: Mapped[list[Tag]] = relationship( - secondary="asset_info_tags", - back_populates="asset_infos", + secondary="asset_reference_tags", + back_populates="asset_references", lazy="selectin", viewonly=True, - overlaps="tag_links,asset_info_links,asset_infos,tag", + overlaps="tag_links,asset_reference_links,asset_references,tag", ) __table_args__ = ( - UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), - Index("ix_assets_info_owner_name", "owner_id", "name"), - Index("ix_assets_info_owner_id", "owner_id"), - Index("ix_assets_info_asset_id", "asset_id"), - Index("ix_assets_info_name", "name"), - Index("ix_assets_info_created_at", "created_at"), - Index("ix_assets_info_last_access_time", "last_access_time"), + Index("uq_asset_references_file_path", "file_path", unique=True), + Index("ix_asset_references_asset_id", "asset_id"), + Index("ix_asset_references_owner_id", "owner_id"), + Index("ix_asset_references_name", "name"), + Index("ix_asset_references_is_missing", "is_missing"), + Index("ix_asset_references_enrichment_level", "enrichment_level"), + Index("ix_asset_references_created_at", "created_at"), + Index("ix_asset_references_last_access_time", "last_access_time"), + Index("ix_asset_references_deleted_at", "deleted_at"), + Index("ix_asset_references_owner_name", "owner_id", "name"), + CheckConstraint( + "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" + ), + CheckConstraint( + "enrichment_level >= 0 AND enrichment_level <= 2", + name="ck_ar_enrichment_level_range", + ), ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - data = to_dict(self, include_none=include_none) - data["tags"] = [t.name for t in self.tags] - return data - def __repr__(self) -> str: - return f"" + path_part = f" path={self.file_path!r}" if self.file_path else "" + return f"" -class AssetInfoMeta(Base): - __tablename__ = "asset_info_meta" +class AssetReferenceMeta(Base): + __tablename__ = "asset_reference_meta" - asset_info_id: Mapped[str] = mapped_column( - String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + asset_reference_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("asset_references.id", ondelete="CASCADE"), + primary_key=True, ) key: Mapped[str] = mapped_column(String(256), primary_key=True) ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) @@ -175,36 +183,40 @@ class AssetInfoMeta(Base): val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True) val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True) - asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries") + asset_reference: Mapped[AssetReference] = relationship( + back_populates="metadata_entries" + ) __table_args__ = ( - Index("ix_asset_info_meta_key", "key"), - Index("ix_asset_info_meta_key_val_str", "key", "val_str"), - Index("ix_asset_info_meta_key_val_num", "key", "val_num"), - Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"), + Index("ix_asset_reference_meta_key", "key"), + Index("ix_asset_reference_meta_key_val_str", "key", "val_str"), + Index("ix_asset_reference_meta_key_val_num", "key", "val_num"), + Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"), ) -class AssetInfoTag(Base): - __tablename__ = "asset_info_tags" +class AssetReferenceTag(Base): + __tablename__ = "asset_reference_tags" - asset_info_id: Mapped[str] = mapped_column( - String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + asset_reference_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("asset_references.id", ondelete="CASCADE"), + primary_key=True, ) tag_name: Mapped[str] = mapped_column( String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True ) origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") added_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow + DateTime(timezone=False), nullable=False, default=get_utc_now ) - asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links") - tag: Mapped[Tag] = relationship(back_populates="asset_info_links") + asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links") + tag: Mapped[Tag] = relationship(back_populates="asset_reference_links") __table_args__ = ( - Index("ix_asset_info_tags_tag_name", "tag_name"), - Index("ix_asset_info_tags_asset_info_id", "asset_info_id"), + Index("ix_asset_reference_tags_tag_name", "tag_name"), + Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"), ) @@ -214,20 +226,18 @@ class Tag(Base): name: Mapped[str] = mapped_column(String(512), primary_key=True) tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") - asset_info_links: Mapped[list[AssetInfoTag]] = relationship( + asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship( back_populates="tag", - overlaps="asset_infos,tags", + overlaps="asset_references,tags", ) - asset_infos: Mapped[list[AssetInfo]] = relationship( - secondary="asset_info_tags", + asset_references: Mapped[list[AssetReference]] = relationship( + secondary="asset_reference_tags", back_populates="tags", viewonly=True, - overlaps="asset_info_links,tag_links,tags,asset_info", + overlaps="asset_reference_links,tag_links,tags,asset_reference", ) - __table_args__ = ( - Index("ix_tags_tag_type", "tag_type"), - ) + __table_args__ = (Index("ix_tags_tag_type", "tag_type"),) def __repr__(self) -> str: return f"" diff --git a/app/assets/database/queries.py b/app/assets/database/queries.py deleted file mode 100644 index d6b33ec7b..000000000 --- a/app/assets/database/queries.py +++ /dev/null @@ -1,976 +0,0 @@ -import os -import logging -import sqlalchemy as sa -from collections import defaultdict -from datetime import datetime -from typing import Iterable, Any -from sqlalchemy import select, delete, exists, func -from sqlalchemy.dialects import sqlite -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session, contains_eager, noload -from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag -from app.assets.helpers import ( - compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow -) -from typing import Sequence - - -def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: - """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" - owner_id = (owner_id or "").strip() - if owner_id == "": - return AssetInfo.owner_id == "" - return AssetInfo.owner_id.in_(["", owner_id]) - - -def pick_best_live_path(states: Sequence[AssetCacheState]) -> str: - """ - Return the best on-disk path among cache states: - 1) Prefer a path that exists with needs_verify == False (already verified). - 2) Otherwise, pick the first path that exists. - 3) Otherwise return empty string. - """ - alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] - if not alive: - return "" - for s in alive: - if not getattr(s, "needs_verify", False): - return s.file_path - return alive[0].file_path - - -def apply_tag_filters( - stmt: sa.sql.Select, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, -) -> sa.sql.Select: - """include_tags: every tag must be present; exclude_tags: none may be present.""" - include_tags = normalize_tags(include_tags) - exclude_tags = normalize_tags(exclude_tags) - - if include_tags: - for tag_name in include_tags: - stmt = stmt.where( - exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name == tag_name) - ) - ) - - if exclude_tags: - stmt = stmt.where( - ~exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name.in_(exclude_tags)) - ) - ) - return stmt - - -def apply_metadata_filter( - stmt: sa.sql.Select, - metadata_filter: dict | None = None, -) -> sa.sql.Select: - """Apply filters using asset_info_meta projection table.""" - if not metadata_filter: - return stmt - - def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - return sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - *preds, - ) - - def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: - if value is None: - no_row_for_key = sa.not_( - sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - ) - ) - null_row = _exists_for_pred( - key, - AssetInfoMeta.val_json.is_(None), - AssetInfoMeta.val_str.is_(None), - AssetInfoMeta.val_num.is_(None), - AssetInfoMeta.val_bool.is_(None), - ) - return sa.or_(no_row_for_key, null_row) - - if isinstance(value, bool): - return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) - if isinstance(value, (int, float)): - from decimal import Decimal - num = value if isinstance(value, Decimal) else Decimal(str(value)) - return _exists_for_pred(key, AssetInfoMeta.val_num == num) - if isinstance(value, str): - return _exists_for_pred(key, AssetInfoMeta.val_str == value) - return _exists_for_pred(key, AssetInfoMeta.val_json == value) - - for k, v in metadata_filter.items(): - if isinstance(v, list): - ors = [_exists_clause_for_value(k, elem) for elem in v] - if ors: - stmt = stmt.where(sa.or_(*ors)) - else: - stmt = stmt.where(_exists_clause_for_value(k, v)) - return stmt - - -def asset_exists_by_hash( - session: Session, - *, - asset_hash: str, -) -> bool: - """ - Check if an asset with a given hash exists in database. - """ - row = ( - session.execute( - select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) - ) - ).first() - return row is not None - - -def asset_info_exists_for_asset_id( - session: Session, - *, - asset_id: str, -) -> bool: - q = ( - select(sa.literal(True)) - .select_from(AssetInfo) - .where(AssetInfo.asset_id == asset_id) - .limit(1) - ) - return (session.execute(q)).first() is not None - - -def get_asset_by_hash( - session: Session, - *, - asset_hash: str, -) -> Asset | None: - return ( - session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() - - -def get_asset_info_by_id( - session: Session, - *, - asset_info_id: str, -) -> AssetInfo | None: - return session.get(AssetInfo, asset_info_id) - - -def list_asset_infos_page( - session: Session, - owner_id: str = "", - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, - name_contains: str | None = None, - metadata_filter: dict | None = None, - limit: int = 20, - offset: int = 0, - sort: str = "created_at", - order: str = "desc", -) -> tuple[list[AssetInfo], dict[str, list[str]], int]: - base = ( - select(AssetInfo) - .join(Asset, Asset.id == AssetInfo.asset_id) - .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) - .where(visible_owner_clause(owner_id)) - ) - - if name_contains: - escaped, esc = escape_like_prefix(name_contains) - base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) - - base = apply_tag_filters(base, include_tags, exclude_tags) - base = apply_metadata_filter(base, metadata_filter) - - sort = (sort or "created_at").lower() - order = (order or "desc").lower() - sort_map = { - "name": AssetInfo.name, - "created_at": AssetInfo.created_at, - "updated_at": AssetInfo.updated_at, - "last_access_time": AssetInfo.last_access_time, - "size": Asset.size_bytes, - } - sort_col = sort_map.get(sort, AssetInfo.created_at) - sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() - - base = base.order_by(sort_exp).limit(limit).offset(offset) - - count_stmt = ( - select(sa.func.count()) - .select_from(AssetInfo) - .join(Asset, Asset.id == AssetInfo.asset_id) - .where(visible_owner_clause(owner_id)) - ) - if name_contains: - escaped, esc = escape_like_prefix(name_contains) - count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) - count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) - count_stmt = apply_metadata_filter(count_stmt, metadata_filter) - - total = int((session.execute(count_stmt)).scalar_one() or 0) - - infos = (session.execute(base)).unique().scalars().all() - - id_list: list[str] = [i.id for i in infos] - tag_map: dict[str, list[str]] = defaultdict(list) - if id_list: - rows = session.execute( - select(AssetInfoTag.asset_info_id, Tag.name) - .join(Tag, Tag.name == AssetInfoTag.tag_name) - .where(AssetInfoTag.asset_info_id.in_(id_list)) - .order_by(AssetInfoTag.added_at) - ) - for aid, tag_name in rows.all(): - tag_map[aid].append(tag_name) - - return infos, tag_map, total - - -def fetch_asset_info_asset_and_tags( - session: Session, - asset_info_id: str, - owner_id: str = "", -) -> tuple[AssetInfo, Asset, list[str]] | None: - stmt = ( - select(AssetInfo, Asset, Tag.name) - .join(Asset, Asset.id == AssetInfo.asset_id) - .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) - .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .options(noload(AssetInfo.tags)) - .order_by(Tag.name.asc()) - ) - - rows = (session.execute(stmt)).all() - if not rows: - return None - - first_info, first_asset, _ = rows[0] - tags: list[str] = [] - seen: set[str] = set() - for _info, _asset, tag_name in rows: - if tag_name and tag_name not in seen: - seen.add(tag_name) - tags.append(tag_name) - return first_info, first_asset, tags - - -def fetch_asset_info_and_asset( - session: Session, - *, - asset_info_id: str, - owner_id: str = "", -) -> tuple[AssetInfo, Asset] | None: - stmt = ( - select(AssetInfo, Asset) - .join(Asset, Asset.id == AssetInfo.asset_id) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .limit(1) - .options(noload(AssetInfo.tags)) - ) - row = session.execute(stmt) - pair = row.first() - if not pair: - return None - return pair[0], pair[1] - -def list_cache_states_by_asset_id( - session: Session, *, asset_id: str -) -> Sequence[AssetCacheState]: - return ( - session.execute( - select(AssetCacheState) - .where(AssetCacheState.asset_id == asset_id) - .order_by(AssetCacheState.id.asc()) - ) - ).scalars().all() - - -def touch_asset_info_by_id( - session: Session, - *, - asset_info_id: str, - ts: datetime | None = None, - only_if_newer: bool = True, -) -> None: - ts = ts or utcnow() - stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) - if only_if_newer: - stmt = stmt.where( - sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) - ) - session.execute(stmt.values(last_access_time=ts)) - - -def create_asset_info_for_existing_asset( - session: Session, - *, - asset_hash: str, - name: str, - user_metadata: dict | None = None, - tags: Sequence[str] | None = None, - tag_origin: str = "manual", - owner_id: str = "", -) -> AssetInfo: - """Create or return an existing AssetInfo for an Asset identified by asset_hash.""" - now = utcnow() - asset = get_asset_by_hash(session, asset_hash=asset_hash) - if not asset: - raise ValueError(f"Unknown asset hash {asset_hash}") - - info = AssetInfo( - owner_id=owner_id, - name=name, - asset_id=asset.id, - preview_id=None, - created_at=now, - updated_at=now, - last_access_time=now, - ) - try: - with session.begin_nested(): - session.add(info) - session.flush() - except IntegrityError: - existing = ( - session.execute( - select(AssetInfo) - .options(noload(AssetInfo.tags)) - .where( - AssetInfo.asset_id == asset.id, - AssetInfo.name == name, - AssetInfo.owner_id == owner_id, - ) - .limit(1) - ) - ).unique().scalars().first() - if not existing: - raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") - return existing - - # metadata["filename"] hack - new_meta = dict(user_metadata or {}) - computed_filename = None - try: - p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) - if p: - computed_filename = compute_relative_filename(p) - except Exception: - computed_filename = None - if computed_filename: - new_meta["filename"] = computed_filename - if new_meta: - replace_asset_info_metadata_projection( - session, - asset_info_id=info.id, - user_metadata=new_meta, - ) - - if tags is not None: - set_asset_info_tags( - session, - asset_info_id=info.id, - tags=tags, - origin=tag_origin, - ) - return info - - -def set_asset_info_tags( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", -) -> dict: - desired = normalize_tags(tags) - - current = set( - tag_name for (tag_name,) in ( - session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) - ).all() - ) - - to_add = [t for t in desired if t not in current] - to_remove = [t for t in current if t not in desired] - - if to_add: - ensure_tags_exist(session, to_add, tag_type="user") - session.add_all([ - AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) - for t in to_add - ]) - session.flush() - - if to_remove: - session.execute( - delete(AssetInfoTag) - .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) - ) - session.flush() - - return {"added": to_add, "removed": to_remove, "total": desired} - - -def replace_asset_info_metadata_projection( - session: Session, - *, - asset_info_id: str, - user_metadata: dict | None = None, -) -> None: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - info.user_metadata = user_metadata or {} - info.updated_at = utcnow() - session.flush() - - session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) - session.flush() - - if not user_metadata: - return - - rows: list[AssetInfoMeta] = [] - for k, v in user_metadata.items(): - for r in project_kv(k, v): - rows.append( - AssetInfoMeta( - asset_info_id=asset_info_id, - key=r["key"], - ordinal=int(r["ordinal"]), - val_str=r.get("val_str"), - val_num=r.get("val_num"), - val_bool=r.get("val_bool"), - val_json=r.get("val_json"), - ) - ) - if rows: - session.add_all(rows) - session.flush() - - -def ingest_fs_asset( - session: Session, - *, - asset_hash: str, - abs_path: str, - size_bytes: int, - mtime_ns: int, - mime_type: str | None = None, - info_name: str | None = None, - owner_id: str = "", - preview_id: str | None = None, - user_metadata: dict | None = None, - tags: Sequence[str] = (), - tag_origin: str = "manual", - require_existing_tags: bool = False, -) -> dict: - """ - Idempotently upsert: - - Asset by content hash (create if missing) - - AssetCacheState(file_path) pointing to asset_id - - Optionally AssetInfo + tag links and metadata projection - Returns flags and ids. - """ - locator = os.path.abspath(abs_path) - now = utcnow() - - if preview_id: - if not session.get(Asset, preview_id): - preview_id = None - - out: dict[str, Any] = { - "asset_created": False, - "asset_updated": False, - "state_created": False, - "state_updated": False, - "asset_info_id": None, - } - - # 1) Asset by hash - asset = ( - session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() - if not asset: - vals = { - "hash": asset_hash, - "size_bytes": int(size_bytes), - "mime_type": mime_type, - "created_at": now, - } - res = session.execute( - sqlite.insert(Asset) - .values(**vals) - .on_conflict_do_nothing(index_elements=[Asset.hash]) - ) - if int(res.rowcount or 0) > 0: - out["asset_created"] = True - asset = ( - session.execute( - select(Asset).where(Asset.hash == asset_hash).limit(1) - ) - ).scalars().first() - if not asset: - raise RuntimeError("Asset row not found after upsert.") - else: - changed = False - if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: - asset.size_bytes = int(size_bytes) - changed = True - if mime_type and asset.mime_type != mime_type: - asset.mime_type = mime_type - changed = True - if changed: - out["asset_updated"] = True - - # 2) AssetCacheState upsert by file_path (unique) - vals = { - "asset_id": asset.id, - "file_path": locator, - "mtime_ns": int(mtime_ns), - } - ins = ( - sqlite.insert(AssetCacheState) - .values(**vals) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - - res = session.execute(ins) - if int(res.rowcount or 0) > 0: - out["state_created"] = True - else: - upd = ( - sa.update(AssetCacheState) - .where(AssetCacheState.file_path == locator) - .where( - sa.or_( - AssetCacheState.asset_id != asset.id, - AssetCacheState.mtime_ns.is_(None), - AssetCacheState.mtime_ns != int(mtime_ns), - ) - ) - .values(asset_id=asset.id, mtime_ns=int(mtime_ns)) - ) - res2 = session.execute(upd) - if int(res2.rowcount or 0) > 0: - out["state_updated"] = True - - # 3) Optional AssetInfo + tags + metadata - if info_name: - try: - with session.begin_nested(): - info = AssetInfo( - owner_id=owner_id, - name=info_name, - asset_id=asset.id, - preview_id=preview_id, - created_at=now, - updated_at=now, - last_access_time=now, - ) - session.add(info) - session.flush() - out["asset_info_id"] = info.id - except IntegrityError: - pass - - existing_info = ( - session.execute( - select(AssetInfo) - .where( - AssetInfo.asset_id == asset.id, - AssetInfo.name == info_name, - (AssetInfo.owner_id == owner_id), - ) - .limit(1) - ) - ).unique().scalar_one_or_none() - if not existing_info: - raise RuntimeError("Failed to update or insert AssetInfo.") - - if preview_id and existing_info.preview_id != preview_id: - existing_info.preview_id = preview_id - - existing_info.updated_at = now - if existing_info.last_access_time < now: - existing_info.last_access_time = now - session.flush() - out["asset_info_id"] = existing_info.id - - norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] - if norm and out["asset_info_id"] is not None: - if not require_existing_tags: - ensure_tags_exist(session, norm, tag_type="user") - - existing_tag_names = set( - name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() - ) - missing = [t for t in norm if t not in existing_tag_names] - if missing and require_existing_tags: - raise ValueError(f"Unknown tags: {missing}") - - existing_links = set( - tag_name - for (tag_name,) in ( - session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) - ) - ).all() - ) - to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] - if to_add: - session.add_all( - [ - AssetInfoTag( - asset_info_id=out["asset_info_id"], - tag_name=t, - origin=tag_origin, - added_at=now, - ) - for t in to_add - ] - ) - session.flush() - - # metadata["filename"] hack - if out["asset_info_id"] is not None: - primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) - computed_filename = compute_relative_filename(primary_path) if primary_path else None - - current_meta = existing_info.user_metadata or {} - new_meta = dict(current_meta) - if user_metadata is not None: - for k, v in user_metadata.items(): - new_meta[k] = v - if computed_filename: - new_meta["filename"] = computed_filename - - if new_meta != current_meta: - replace_asset_info_metadata_projection( - session, - asset_info_id=out["asset_info_id"], - user_metadata=new_meta, - ) - - try: - remove_missing_tag_for_asset_id(session, asset_id=asset.id) - except Exception: - logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) - return out - - -def update_asset_info_full( - session: Session, - *, - asset_info_id: str, - name: str | None = None, - tags: Sequence[str] | None = None, - user_metadata: dict | None = None, - tag_origin: str = "manual", - asset_info_row: Any = None, -) -> AssetInfo: - if not asset_info_row: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - else: - info = asset_info_row - - touched = False - if name is not None and name != info.name: - info.name = name - touched = True - - computed_filename = None - try: - p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id)) - if p: - computed_filename = compute_relative_filename(p) - except Exception: - computed_filename = None - - if user_metadata is not None: - new_meta = dict(user_metadata) - if computed_filename: - new_meta["filename"] = computed_filename - replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - else: - if computed_filename: - current_meta = info.user_metadata or {} - if current_meta.get("filename") != computed_filename: - new_meta = dict(current_meta) - new_meta["filename"] = computed_filename - replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - - if tags is not None: - set_asset_info_tags( - session, - asset_info_id=asset_info_id, - tags=tags, - origin=tag_origin, - ) - touched = True - - if touched and user_metadata is None: - info.updated_at = utcnow() - session.flush() - - return info - - -def delete_asset_info_by_id( - session: Session, - *, - asset_info_id: str, - owner_id: str, -) -> bool: - stmt = sa.delete(AssetInfo).where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - return int((session.execute(stmt)).rowcount or 0) > 0 - - -def list_tags_with_usage( - session: Session, - prefix: str | None = None, - limit: int = 100, - offset: int = 0, - include_zero: bool = True, - order: str = "count_desc", - owner_id: str = "", -) -> tuple[list[tuple[str, str, int]], int]: - counts_sq = ( - select( - AssetInfoTag.tag_name.label("tag_name"), - func.count(AssetInfoTag.asset_info_id).label("cnt"), - ) - .select_from(AssetInfoTag) - .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) - .where(visible_owner_clause(owner_id)) - .group_by(AssetInfoTag.tag_name) - .subquery() - ) - - q = ( - select( - Tag.name, - Tag.tag_type, - func.coalesce(counts_sq.c.cnt, 0).label("count"), - ) - .select_from(Tag) - .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) - ) - - if prefix: - escaped, esc = escape_like_prefix(prefix.strip().lower()) - q = q.where(Tag.name.like(escaped + "%", escape=esc)) - - if not include_zero: - q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) - - if order == "name_asc": - q = q.order_by(Tag.name.asc()) - else: - q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) - - total_q = select(func.count()).select_from(Tag) - if prefix: - escaped, esc = escape_like_prefix(prefix.strip().lower()) - total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) - if not include_zero: - total_q = total_q.where( - Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) - ) - - rows = (session.execute(q.limit(limit).offset(offset))).all() - total = (session.execute(total_q)).scalar_one() - - rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] - return rows_norm, int(total or 0) - - -def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: - wanted = normalize_tags(list(names)) - if not wanted: - return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] - ins = ( - sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - session.execute(ins) - - -def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]: - return [ - tag_name for (tag_name,) in ( - session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - ] - - -def add_tags_to_asset_info( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", - create_if_missing: bool = True, - asset_info_row: Any = None, -) -> dict: - if not asset_info_row: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"added": [], "already_present": [], "total_tags": total} - - if create_if_missing: - ensure_tags_exist(session, norm, tag_type="user") - - current = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - want = set(norm) - to_add = sorted(want - current) - - if to_add: - with session.begin_nested() as nested: - try: - session.add_all( - [ - AssetInfoTag( - asset_info_id=asset_info_id, - tag_name=t, - origin=origin, - added_at=utcnow(), - ) - for t in to_add - ] - ) - session.flush() - except IntegrityError: - nested.rollback() - - after = set(get_asset_tags(session, asset_info_id=asset_info_id)) - return { - "added": sorted(((after - current) & want)), - "already_present": sorted(want & current), - "total_tags": sorted(after), - } - - -def remove_tags_from_asset_info( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], -) -> dict: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": [], "not_present": [], "total_tags": total} - - existing = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - to_remove = sorted(set(t for t in norm if t in existing)) - not_present = sorted(set(t for t in norm if t not in existing)) - - if to_remove: - session.execute( - delete(AssetInfoTag) - .where( - AssetInfoTag.asset_info_id == asset_info_id, - AssetInfoTag.tag_name.in_(to_remove), - ) - ) - session.flush() - - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": to_remove, "not_present": not_present, "total_tags": total} - - -def remove_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, -) -> None: - session.execute( - sa.delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), - AssetInfoTag.tag_name == "missing", - ) - ) - - -def set_asset_info_preview( - session: Session, - *, - asset_info_id: str, - preview_asset_id: str | None = None, -) -> None: - """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - if preview_asset_id is None: - info.preview_id = None - else: - # validate preview asset exists - if not session.get(Asset, preview_asset_id): - raise ValueError(f"Preview Asset {preview_asset_id} not found") - info.preview_id = preview_asset_id - - info.updated_at = utcnow() - session.flush() diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py new file mode 100644 index 000000000..7888d0645 --- /dev/null +++ b/app/assets/database/queries/__init__.py @@ -0,0 +1,121 @@ +from app.assets.database.queries.asset import ( + asset_exists_by_hash, + bulk_insert_assets, + get_asset_by_hash, + get_existing_asset_ids, + reassign_asset_references, + update_asset_hash_and_mime, + upsert_asset, +) +from app.assets.database.queries.asset_reference import ( + CacheStateRow, + UnenrichedReferenceRow, + bulk_insert_references_ignore_conflicts, + bulk_update_enrichment_level, + bulk_update_is_missing, + bulk_update_needs_verify, + convert_metadata_to_rows, + delete_assets_by_ids, + delete_orphaned_seed_asset, + delete_reference_by_id, + delete_references_by_ids, + fetch_reference_and_asset, + fetch_reference_asset_and_tags, + get_or_create_reference, + get_reference_by_file_path, + get_reference_by_id, + get_reference_with_owner_check, + get_reference_ids_by_ids, + get_references_by_paths_and_asset_ids, + get_references_for_prefixes, + get_unenriched_references, + get_unreferenced_unhashed_asset_ids, + insert_reference, + list_references_by_asset_id, + list_references_page, + mark_references_missing_outside_prefixes, + reference_exists_for_asset_id, + restore_references_by_paths, + set_reference_metadata, + set_reference_preview, + soft_delete_reference_by_id, + update_reference_access_time, + update_reference_name, + update_reference_timestamps, + update_reference_updated_at, + upsert_reference, +) +from app.assets.database.queries.tags import ( + AddTagsResult, + RemoveTagsResult, + SetTagsResult, + add_missing_tag_for_asset_id, + add_tags_to_reference, + bulk_insert_tags_and_meta, + ensure_tags_exist, + get_reference_tags, + list_tags_with_usage, + remove_missing_tag_for_asset_id, + remove_tags_from_reference, + set_reference_tags, + validate_tags_exist, +) + +__all__ = [ + "AddTagsResult", + "CacheStateRow", + "RemoveTagsResult", + "SetTagsResult", + "UnenrichedReferenceRow", + "add_missing_tag_for_asset_id", + "add_tags_to_reference", + "asset_exists_by_hash", + "bulk_insert_assets", + "bulk_insert_references_ignore_conflicts", + "bulk_insert_tags_and_meta", + "bulk_update_enrichment_level", + "bulk_update_is_missing", + "bulk_update_needs_verify", + "convert_metadata_to_rows", + "delete_assets_by_ids", + "delete_orphaned_seed_asset", + "delete_reference_by_id", + "delete_references_by_ids", + "ensure_tags_exist", + "fetch_reference_and_asset", + "fetch_reference_asset_and_tags", + "get_asset_by_hash", + "get_existing_asset_ids", + "get_or_create_reference", + "get_reference_by_file_path", + "get_reference_by_id", + "get_reference_with_owner_check", + "get_reference_ids_by_ids", + "get_reference_tags", + "get_references_by_paths_and_asset_ids", + "get_references_for_prefixes", + "get_unenriched_references", + "get_unreferenced_unhashed_asset_ids", + "insert_reference", + "list_references_by_asset_id", + "list_references_page", + "list_tags_with_usage", + "mark_references_missing_outside_prefixes", + "reassign_asset_references", + "reference_exists_for_asset_id", + "remove_missing_tag_for_asset_id", + "remove_tags_from_reference", + "restore_references_by_paths", + "set_reference_metadata", + "set_reference_preview", + "soft_delete_reference_by_id", + "set_reference_tags", + "update_asset_hash_and_mime", + "update_reference_access_time", + "update_reference_name", + "update_reference_timestamps", + "update_reference_updated_at", + "upsert_asset", + "upsert_reference", + "validate_tags_exist", +] diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py new file mode 100644 index 000000000..a21f5b68f --- /dev/null +++ b/app/assets/database/queries/asset.py @@ -0,0 +1,140 @@ +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.dialects import sqlite +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks + + +def asset_exists_by_hash( + session: Session, + asset_hash: str, +) -> bool: + """ + Check if an asset with a given hash exists in database. + """ + row = ( + session.execute( + select(sa.literal(True)) + .select_from(Asset) + .where(Asset.hash == asset_hash) + .limit(1) + ) + ).first() + return row is not None + + +def get_asset_by_hash( + session: Session, + asset_hash: str, +) -> Asset | None: + return ( + (session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))) + .scalars() + .first() + ) + + +def upsert_asset( + session: Session, + asset_hash: str, + size_bytes: int, + mime_type: str | None = None, +) -> tuple[Asset, bool, bool]: + """Upsert an Asset by hash. Returns (asset, created, updated).""" + vals = {"hash": asset_hash, "size_bytes": int(size_bytes)} + if mime_type: + vals["mime_type"] = mime_type + + ins = ( + sqlite.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + res = session.execute(ins) + created = int(res.rowcount or 0) > 0 + + asset = ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + .scalars() + .first() + ) + if not asset: + raise RuntimeError("Asset row not found after upsert.") + + updated = False + if not created: + changed = False + if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: + asset.size_bytes = int(size_bytes) + changed = True + if mime_type and asset.mime_type != mime_type: + asset.mime_type = mime_type + changed = True + if changed: + updated = True + + return asset, created, updated + + +def bulk_insert_assets( + session: Session, + rows: list[dict], +) -> None: + """Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash.""" + if not rows: + return + ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash]) + for chunk in iter_chunks(rows, calculate_rows_per_statement(5)): + session.execute(ins, chunk) + + +def get_existing_asset_ids( + session: Session, + asset_ids: list[str], +) -> set[str]: + """Return the subset of asset_ids that exist in the database.""" + if not asset_ids: + return set() + found: set[str] = set() + for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): + rows = session.execute( + select(Asset.id).where(Asset.id.in_(chunk)) + ).fetchall() + found.update(row[0] for row in rows) + return found + + +def update_asset_hash_and_mime( + session: Session, + asset_id: str, + asset_hash: str | None = None, + mime_type: str | None = None, +) -> bool: + """Update asset hash and/or mime_type. Returns True if asset was found.""" + asset = session.get(Asset, asset_id) + if not asset: + return False + if asset_hash is not None: + asset.hash = asset_hash + if mime_type is not None: + asset.mime_type = mime_type + return True + + +def reassign_asset_references( + session: Session, + from_asset_id: str, + to_asset_id: str, + reference_id: str, +) -> None: + """Reassign a reference from one asset to another. + + Used when merging a stub asset into an existing asset with the same hash. + """ + ref = session.get(AssetReference, reference_id) + if ref and ref.asset_id == from_asset_id: + ref.asset_id = to_asset_id + + session.flush() diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py new file mode 100644 index 000000000..6524791cc --- /dev/null +++ b/app/assets/database/queries/asset_reference.py @@ -0,0 +1,1033 @@ +"""Query functions for the unified AssetReference table. + +This module replaces the separate asset_info.py and cache_state.py query modules, +providing a unified interface for the merged asset_references table. +""" + +from collections import defaultdict +from datetime import datetime +from decimal import Decimal +from typing import NamedTuple, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, exists, select +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session, noload + +from app.assets.database.models import ( + Asset, + AssetReference, + AssetReferenceMeta, + AssetReferenceTag, + Tag, +) +from app.assets.database.queries.common import ( + MAX_BIND_PARAMS, + build_prefix_like_conditions, + build_visible_owner_clause, + calculate_rows_per_statement, + iter_chunks, +) +from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags + + +def _check_is_scalar(v): + if v is None: + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def _scalar_to_row(key: str, ordinal: int, value) -> dict: + """Convert a scalar value to a typed projection row.""" + if value is None: + return { + "key": key, + "ordinal": ordinal, + "val_str": None, + "val_num": None, + "val_bool": None, + "val_json": None, + } + if isinstance(value, bool): + return {"key": key, "ordinal": ordinal, "val_bool": bool(value)} + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return {"key": key, "ordinal": ordinal, "val_num": num} + if isinstance(value, str): + return {"key": key, "ordinal": ordinal, "val_str": value} + return {"key": key, "ordinal": ordinal, "val_json": value} + + +def convert_metadata_to_rows(key: str, value) -> list[dict]: + """Turn a metadata key/value into typed projection rows.""" + if value is None: + return [_scalar_to_row(key, 0, None)] + + if _check_is_scalar(value): + return [_scalar_to_row(key, 0, value)] + + if isinstance(value, list): + if all(_check_is_scalar(x) for x in value): + return [_scalar_to_row(key, i, x) for i, x in enumerate(value)] + return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)] + + return [{"key": key, "ordinal": 0, "val_json": value}] + + +def _apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def _apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None = None, +) -> sa.sql.Select: + """Apply filters using asset_reference_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + no_row_for_key = sa.not_( + sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + ) + ) + null_row = _exists_for_pred( + key, + AssetReferenceMeta.val_json.is_(None), + AssetReferenceMeta.val_str.is_(None), + AssetReferenceMeta.val_num.is_(None), + AssetReferenceMeta.val_bool.is_(None), + ) + return sa.or_(no_row_for_key, null_row) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value)) + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetReferenceMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetReferenceMeta.val_str == value) + return _exists_for_pred(key, AssetReferenceMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt + + +def get_reference_by_id( + session: Session, + reference_id: str, +) -> AssetReference | None: + return session.get(AssetReference, reference_id) + + +def get_reference_with_owner_check( + session: Session, + reference_id: str, + owner_id: str, +) -> AssetReference: + """Fetch a reference and verify ownership. + + Raises: + ValueError: if reference not found or soft-deleted + PermissionError: if owner_id doesn't match + """ + ref = get_reference_by_id(session, reference_id=reference_id) + if not ref or ref.deleted_at is not None: + raise ValueError(f"AssetReference {reference_id} not found") + if ref.owner_id and ref.owner_id != owner_id: + raise PermissionError("not owner") + return ref + + +def get_reference_by_file_path( + session: Session, + file_path: str, +) -> AssetReference | None: + """Get a reference by its file path.""" + return ( + session.execute( + select(AssetReference).where(AssetReference.file_path == file_path).limit(1) + ) + .scalars() + .first() + ) + + +def reference_exists_for_asset_id( + session: Session, + asset_id: str, +) -> bool: + q = ( + select(sa.literal(True)) + .select_from(AssetReference) + .where(AssetReference.asset_id == asset_id) + .where(AssetReference.deleted_at.is_(None)) + .limit(1) + ) + return session.execute(q).first() is not None + + +def insert_reference( + session: Session, + asset_id: str, + name: str, + owner_id: str = "", + file_path: str | None = None, + mtime_ns: int | None = None, + preview_id: str | None = None, +) -> AssetReference | None: + """Insert a new AssetReference. Returns None if unique constraint violated.""" + now = get_utc_now() + try: + with session.begin_nested(): + ref = AssetReference( + asset_id=asset_id, + name=name, + owner_id=owner_id, + file_path=file_path, + mtime_ns=mtime_ns, + preview_id=preview_id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + except IntegrityError: + return None + + +def get_or_create_reference( + session: Session, + asset_id: str, + name: str, + owner_id: str = "", + file_path: str | None = None, + mtime_ns: int | None = None, + preview_id: str | None = None, +) -> tuple[AssetReference, bool]: + """Get existing or create new AssetReference. + + For filesystem references (file_path is set), uniqueness is by file_path. + For API references (file_path is None), we look for matching + asset_id + owner_id + name. + + Returns (reference, created). + """ + ref = insert_reference( + session, + asset_id=asset_id, + name=name, + owner_id=owner_id, + file_path=file_path, + mtime_ns=mtime_ns, + preview_id=preview_id, + ) + if ref: + return ref, True + + # Find existing - priority to file_path match, then name match + if file_path: + existing = get_reference_by_file_path(session, file_path) + else: + existing = ( + session.execute( + select(AssetReference) + .where( + AssetReference.asset_id == asset_id, + AssetReference.name == name, + AssetReference.owner_id == owner_id, + AssetReference.file_path.is_(None), + ) + .limit(1) + ) + .unique() + .scalar_one_or_none() + ) + if not existing: + raise RuntimeError("Failed to find AssetReference after insert conflict.") + return existing, False + + +def update_reference_timestamps( + session: Session, + reference: AssetReference, + preview_id: str | None = None, +) -> None: + """Update timestamps and optionally preview_id on existing AssetReference.""" + now = get_utc_now() + if preview_id and reference.preview_id != preview_id: + reference.preview_id = preview_id + reference.updated_at = now + + +def list_references_page( + session: Session, + owner_id: str = "", + limit: int = 100, + offset: int = 0, + name_contains: str | None = None, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + metadata_filter: dict | None = None, + sort: str | None = None, + order: str | None = None, +) -> tuple[list[AssetReference], dict[str, list[str]], int]: + """List references with pagination, filtering, and sorting. + + Returns (references, tag_map, total_count). + """ + base = ( + select(AssetReference) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .options(noload(AssetReference.tags)) + ) + + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) + + base = _apply_tag_filters(base, include_tags, exclude_tags) + base = _apply_metadata_filter(base, metadata_filter) + + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetReference.name, + "created_at": AssetReference.created_at, + "updated_at": AssetReference.updated_at, + "last_access_time": AssetReference.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetReference.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + count_stmt = ( + select(sa.func.count()) + .select_from(AssetReference) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + ) + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + count_stmt = count_stmt.where( + AssetReference.name.ilike(f"%{escaped}%", escape=esc) + ) + count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) + + total = int(session.execute(count_stmt).scalar_one() or 0) + refs = session.execute(base).unique().scalars().all() + + id_list: list[str] = [r.id for r in refs] + tag_map: dict[str, list[str]] = defaultdict(list) + if id_list: + rows = session.execute( + select(AssetReferenceTag.asset_reference_id, Tag.name) + .join(Tag, Tag.name == AssetReferenceTag.tag_name) + .where(AssetReferenceTag.asset_reference_id.in_(id_list)) + .order_by(AssetReferenceTag.added_at) + ) + for ref_id, tag_name in rows.all(): + tag_map[ref_id].append(tag_name) + + return list(refs), tag_map, total + + +def fetch_reference_asset_and_tags( + session: Session, + reference_id: str, + owner_id: str = "", +) -> tuple[AssetReference, Asset, list[str]] | None: + stmt = ( + select(AssetReference, Asset, Tag.name) + .join(Asset, Asset.id == AssetReference.asset_id) + .join( + AssetReferenceTag, + AssetReferenceTag.asset_reference_id == AssetReference.id, + isouter=True, + ) + .join(Tag, Tag.name == AssetReferenceTag.tag_name, isouter=True) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .options(noload(AssetReference.tags)) + .order_by(Tag.name.asc()) + ) + + rows = session.execute(stmt).all() + if not rows: + return None + + first_ref, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _ref, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_ref, first_asset, tags + + +def fetch_reference_and_asset( + session: Session, + reference_id: str, + owner_id: str = "", +) -> tuple[AssetReference, Asset] | None: + stmt = ( + select(AssetReference, Asset) + .join(Asset, Asset.id == AssetReference.asset_id) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .limit(1) + .options(noload(AssetReference.tags)) + ) + pair = session.execute(stmt).first() + if not pair: + return None + return pair[0], pair[1] + + +def update_reference_access_time( + session: Session, + reference_id: str, + ts: datetime | None = None, + only_if_newer: bool = True, +) -> None: + ts = ts or get_utc_now() + stmt = sa.update(AssetReference).where(AssetReference.id == reference_id) + if only_if_newer: + stmt = stmt.where( + sa.or_( + AssetReference.last_access_time.is_(None), + AssetReference.last_access_time < ts, + ) + ) + session.execute(stmt.values(last_access_time=ts)) + + +def update_reference_name( + session: Session, + reference_id: str, + name: str, +) -> None: + """Update the name of an AssetReference.""" + now = get_utc_now() + session.execute( + sa.update(AssetReference) + .where(AssetReference.id == reference_id) + .values(name=name, updated_at=now) + ) + + +def update_reference_updated_at( + session: Session, + reference_id: str, + ts: datetime | None = None, +) -> None: + """Update the updated_at timestamp of an AssetReference.""" + ts = ts or get_utc_now() + session.execute( + sa.update(AssetReference) + .where(AssetReference.id == reference_id) + .values(updated_at=ts) + ) + + +def set_reference_metadata( + session: Session, + reference_id: str, + user_metadata: dict | None = None, +) -> None: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + ref.user_metadata = user_metadata or {} + ref.updated_at = get_utc_now() + session.flush() + + session.execute( + delete(AssetReferenceMeta).where( + AssetReferenceMeta.asset_reference_id == reference_id + ) + ) + session.flush() + + if not user_metadata: + return + + rows: list[AssetReferenceMeta] = [] + for k, v in user_metadata.items(): + for r in convert_metadata_to_rows(k, v): + rows.append( + AssetReferenceMeta( + asset_reference_id=reference_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + session.flush() + + +def delete_reference_by_id( + session: Session, + reference_id: str, + owner_id: str, +) -> bool: + stmt = sa.delete(AssetReference).where( + AssetReference.id == reference_id, + build_visible_owner_clause(owner_id), + ) + return int(session.execute(stmt).rowcount or 0) > 0 + + +def soft_delete_reference_by_id( + session: Session, + reference_id: str, + owner_id: str, +) -> bool: + """Mark a reference as soft-deleted by setting deleted_at timestamp. + + Returns True if the reference was found and marked deleted. + """ + now = get_utc_now() + stmt = ( + sa.update(AssetReference) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .values(deleted_at=now) + ) + return int(session.execute(stmt).rowcount or 0) > 0 + + +def set_reference_preview( + session: Session, + reference_id: str, + preview_asset_id: str | None = None, +) -> None: + """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + if preview_asset_id is None: + ref.preview_id = None + else: + if not session.get(Asset, preview_asset_id): + raise ValueError(f"Preview Asset {preview_asset_id} not found") + ref.preview_id = preview_asset_id + + ref.updated_at = get_utc_now() + session.flush() + + +class CacheStateRow(NamedTuple): + """Row from reference query with cache state data.""" + + reference_id: str + file_path: str + mtime_ns: int | None + needs_verify: bool + asset_id: str + asset_hash: str | None + size_bytes: int | None + + +def list_references_by_asset_id( + session: Session, + asset_id: str, +) -> Sequence[AssetReference]: + return ( + session.execute( + select(AssetReference) + .where(AssetReference.asset_id == asset_id) + .order_by(AssetReference.id.asc()) + ) + .scalars() + .all() + ) + + +def upsert_reference( + session: Session, + asset_id: str, + file_path: str, + name: str, + mtime_ns: int, + owner_id: str = "", +) -> tuple[bool, bool]: + """Upsert a reference by file_path. Returns (created, updated). + + Also restores references that were previously marked as missing. + """ + now = get_utc_now() + vals = { + "asset_id": asset_id, + "file_path": file_path, + "name": name, + "owner_id": owner_id, + "mtime_ns": int(mtime_ns), + "is_missing": False, + "created_at": now, + "updated_at": now, + "last_access_time": now, + } + ins = ( + sqlite.insert(AssetReference) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetReference.file_path]) + ) + res = session.execute(ins) + created = int(res.rowcount or 0) > 0 + + if created: + return True, False + + upd = ( + sa.update(AssetReference) + .where(AssetReference.file_path == file_path) + .where( + sa.or_( + AssetReference.asset_id != asset_id, + AssetReference.mtime_ns.is_(None), + AssetReference.mtime_ns != int(mtime_ns), + AssetReference.is_missing == True, # noqa: E712 + AssetReference.deleted_at.isnot(None), + ) + ) + .values( + asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False, + deleted_at=None, updated_at=now, + ) + ) + res2 = session.execute(upd) + updated = int(res2.rowcount or 0) > 0 + return False, updated + + +def mark_references_missing_outside_prefixes( + session: Session, + valid_prefixes: list[str], +) -> int: + """Mark references as missing when file_path doesn't match any valid prefix. + + Returns number of references marked as missing. + """ + if not valid_prefixes: + return 0 + + conds = build_prefix_like_conditions(valid_prefixes) + matches_valid_prefix = sa.or_(*conds) + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(~matches_valid_prefix) + .where(AssetReference.is_missing == False) # noqa: E712 + .values(is_missing=True) + ) + return result.rowcount + + +def restore_references_by_paths(session: Session, file_paths: list[str]) -> int: + """Restore references that were previously marked as missing. + + Returns number of references restored. + """ + if not file_paths: + return 0 + + total = 0 + for chunk in iter_chunks(file_paths, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.file_path.in_(chunk)) + .where(AssetReference.is_missing == True) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .values(is_missing=False) + ) + total += result.rowcount + return total + + +def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]: + """Get IDs of unhashed assets (hash=None) with no active references. + + An asset is considered unreferenced if it has no references, + or all its references are marked as missing. + + Returns list of asset IDs that are unreferenced. + """ + active_ref_exists = ( + sa.select(sa.literal(1)) + .where(AssetReference.asset_id == Asset.id) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .correlate(Asset) + .exists() + ) + unreferenced_subq = sa.select(Asset.id).where( + Asset.hash.is_(None), ~active_ref_exists + ) + return [row[0] for row in session.execute(unreferenced_subq).all()] + + +def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int: + """Delete assets and their references by ID. + + Returns number of assets deleted. + """ + if not asset_ids: + return 0 + total = 0 + for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): + session.execute( + sa.delete(AssetReference).where(AssetReference.asset_id.in_(chunk)) + ) + result = session.execute(sa.delete(Asset).where(Asset.id.in_(chunk))) + total += result.rowcount + return total + + +def get_references_for_prefixes( + session: Session, + prefixes: list[str], + *, + include_missing: bool = False, +) -> list[CacheStateRow]: + """Get all references with file paths matching any of the given prefixes. + + Args: + session: Database session + prefixes: List of absolute directory prefixes to match + include_missing: If False (default), exclude references marked as missing + + Returns: + List of cache state rows with joined asset data + """ + if not prefixes: + return [] + + conds = build_prefix_like_conditions(prefixes) + + query = ( + sa.select( + AssetReference.id, + AssetReference.file_path, + AssetReference.mtime_ns, + AssetReference.needs_verify, + AssetReference.asset_id, + Asset.hash, + Asset.size_bytes, + ) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(sa.or_(*conds)) + ) + + if not include_missing: + query = query.where(AssetReference.is_missing == False) # noqa: E712 + + rows = session.execute( + query.order_by(AssetReference.asset_id.asc(), AssetReference.id.asc()) + ).all() + + return [ + CacheStateRow( + reference_id=row[0], + file_path=row[1], + mtime_ns=row[2], + needs_verify=row[3], + asset_id=row[4], + asset_hash=row[5], + size_bytes=int(row[6]) if row[6] is not None else None, + ) + for row in rows + ] + + +def bulk_update_needs_verify( + session: Session, reference_ids: list[str], value: bool +) -> int: + """Set needs_verify flag for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(chunk)) + .values(needs_verify=value) + ) + total += result.rowcount + return total + + +def bulk_update_is_missing( + session: Session, reference_ids: list[str], value: bool +) -> int: + """Set is_missing flag for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(chunk)) + .values(is_missing=value) + ) + total += result.rowcount + return total + + +def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int: + """Delete references by their IDs. + + Returns: Number of rows deleted + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.delete(AssetReference).where(AssetReference.id.in_(chunk)) + ) + total += result.rowcount + return total + + +def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool: + """Delete a seed asset (hash is None) and its references. + + Returns: True if asset was deleted, False if not found or has a hash + """ + asset = session.get(Asset, asset_id) + if not asset: + return False + if asset.hash is not None: + return False + session.execute( + sa.delete(AssetReference).where(AssetReference.asset_id == asset_id) + ) + session.delete(asset) + return True + + +class UnenrichedReferenceRow(NamedTuple): + """Row for references needing enrichment.""" + + reference_id: str + asset_id: str + file_path: str + enrichment_level: int + + +def get_unenriched_references( + session: Session, + prefixes: list[str], + max_level: int = 0, + limit: int = 1000, +) -> list[UnenrichedReferenceRow]: + """Get references that need enrichment (enrichment_level <= max_level). + + Args: + session: Database session + prefixes: List of absolute directory prefixes to scan + max_level: Maximum enrichment level to include (0=stubs, 1=metadata done) + limit: Maximum number of rows to return + + Returns: + List of unenriched reference rows with file paths + """ + if not prefixes: + return [] + + conds = build_prefix_like_conditions(prefixes) + + query = ( + sa.select( + AssetReference.id, + AssetReference.asset_id, + AssetReference.file_path, + AssetReference.enrichment_level, + ) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(sa.or_(*conds)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.enrichment_level <= max_level) + .order_by(AssetReference.id.asc()) + .limit(limit) + ) + + rows = session.execute(query).all() + return [ + UnenrichedReferenceRow( + reference_id=row[0], + asset_id=row[1], + file_path=row[2], + enrichment_level=row[3], + ) + for row in rows + ] + + +def bulk_update_enrichment_level( + session: Session, + reference_ids: list[str], + level: int, +) -> int: + """Update enrichment level for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(reference_ids)) + .values(enrichment_level=level) + ) + return result.rowcount + + +def bulk_insert_references_ignore_conflicts( + session: Session, + rows: list[dict], +) -> None: + """Bulk insert reference rows with ON CONFLICT DO NOTHING on file_path. + + Each dict should have: id, asset_id, file_path, name, owner_id, mtime_ns, etc. + The is_missing field is automatically set to False for new inserts. + """ + if not rows: + return + enriched_rows = [{**row, "is_missing": False} for row in rows] + ins = sqlite.insert(AssetReference).on_conflict_do_nothing( + index_elements=[AssetReference.file_path] + ) + for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(14)): + session.execute(ins, chunk) + + +def get_references_by_paths_and_asset_ids( + session: Session, + path_to_asset: dict[str, str], +) -> set[str]: + """Query references to find paths where our asset_id won the insert. + + Args: + path_to_asset: Mapping of file_path -> asset_id we tried to insert + + Returns: + Set of file_paths where our asset_id is present + """ + if not path_to_asset: + return set() + + pairs = list(path_to_asset.items()) + winners: set[str] = set() + + # Each pair uses 2 bind params, so chunk at MAX_BIND_PARAMS // 2 + for chunk in iter_chunks(pairs, MAX_BIND_PARAMS // 2): + pairwise = sa.tuple_(AssetReference.file_path, AssetReference.asset_id).in_( + chunk + ) + result = session.execute( + select(AssetReference.file_path).where(pairwise) + ) + winners.update(result.scalars().all()) + + return winners + + +def get_reference_ids_by_ids( + session: Session, + reference_ids: list[str], +) -> set[str]: + """Query to find which reference IDs exist in the database.""" + if not reference_ids: + return set() + + found: set[str] = set() + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + select(AssetReference.id).where(AssetReference.id.in_(chunk)) + ) + found.update(result.scalars().all()) + return found diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py new file mode 100644 index 000000000..194c39a1e --- /dev/null +++ b/app/assets/database/queries/common.py @@ -0,0 +1,54 @@ +"""Shared utilities for database query modules.""" + +import os +from typing import Iterable + +import sqlalchemy as sa + +from app.assets.database.models import AssetReference +from app.assets.helpers import escape_sql_like_string + +MAX_BIND_PARAMS = 800 + + +def calculate_rows_per_statement(cols: int) -> int: + """Calculate how many rows can fit in one statement given column count.""" + return max(1, MAX_BIND_PARAMS // max(1, cols)) + + +def iter_chunks(seq, n: int): + """Yield successive n-sized chunks from seq.""" + for i in range(0, len(seq), n): + yield seq[i : i + n] + + +def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]: + """Yield chunks of rows sized to fit within bind param limits.""" + if not rows: + return + yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row)) + + +def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. + + Owner-less rows are visible to everyone. + """ + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetReference.owner_id == "" + return AssetReference.owner_id.in_(["", owner_id]) + + +def build_prefix_like_conditions( + prefixes: list[str], +) -> list[sa.sql.ColumnElement]: + """Build LIKE conditions for matching file paths under directory prefixes.""" + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_sql_like_string(base) + conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) + return conds diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py new file mode 100644 index 000000000..8b25fee67 --- /dev/null +++ b/app/assets/database/queries/tags.py @@ -0,0 +1,356 @@ +from dataclasses import dataclass +from typing import Iterable, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, func, select +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from app.assets.database.models import ( + AssetReference, + AssetReferenceMeta, + AssetReferenceTag, + Tag, +) +from app.assets.database.queries.common import ( + build_visible_owner_clause, + iter_row_chunks, +) +from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags + + +@dataclass(frozen=True) +class AddTagsResult: + added: list[str] + already_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class RemoveTagsResult: + removed: list[str] + not_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class SetTagsResult: + added: list[str] + removed: list[str] + total: list[str] + + +def validate_tags_exist(session: Session, tags: list[str]) -> None: + """Raise ValueError if any of the given tag names do not exist.""" + existing_tag_names = set( + name + for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all() + ) + missing = [t for t in tags if t not in existing_tag_names] + if missing: + raise ValueError(f"Unknown tags: {missing}") + + +def ensure_tags_exist( + session: Session, names: Iterable[str], tag_type: str = "user" +) -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + session.execute(ins) + + +def get_reference_tags(session: Session, reference_id: str) -> list[str]: + return [ + tag_name + for (tag_name,) in ( + session.execute( + select(AssetReferenceTag.tag_name).where( + AssetReferenceTag.asset_reference_id == reference_id + ) + ) + ).all() + ] + + +def set_reference_tags( + session: Session, + reference_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> SetTagsResult: + desired = normalize_tags(tags) + + current = set(get_reference_tags(session, reference_id)) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + ensure_tags_exist(session, to_add, tag_type="user") + session.add_all( + [ + AssetReferenceTag( + asset_reference_id=reference_id, + tag_name=t, + origin=origin, + added_at=get_utc_now(), + ) + for t in to_add + ] + ) + session.flush() + + if to_remove: + session.execute( + delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id == reference_id, + AssetReferenceTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + return SetTagsResult(added=to_add, removed=to_remove, total=desired) + + +def add_tags_to_reference( + session: Session, + reference_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + reference_row: AssetReference | None = None, +) -> AddTagsResult: + if not reference_row: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_reference_tags(session, reference_id=reference_id) + return AddTagsResult(added=[], already_present=[], total_tags=total) + + if create_if_missing: + ensure_tags_exist(session, norm, tag_type="user") + + current = set(get_reference_tags(session, reference_id)) + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetReferenceTag( + asset_reference_id=reference_id, + tag_name=t, + origin=origin, + added_at=get_utc_now(), + ) + for t in to_add + ] + ) + session.flush() + except IntegrityError: + nested.rollback() + + after = set(get_reference_tags(session, reference_id=reference_id)) + return AddTagsResult( + added=sorted(((after - current) & want)), + already_present=sorted(want & current), + total_tags=sorted(after), + ) + + +def remove_tags_from_reference( + session: Session, + reference_id: str, + tags: Sequence[str], +) -> RemoveTagsResult: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_reference_tags(session, reference_id=reference_id) + return RemoveTagsResult(removed=[], not_present=[], total_tags=total) + + existing = set(get_reference_tags(session, reference_id)) + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + session.execute( + delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id == reference_id, + AssetReferenceTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + total = get_reference_tags(session, reference_id=reference_id) + return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total) + + +def add_missing_tag_for_asset_id( + session: Session, + asset_id: str, + origin: str = "automatic", +) -> None: + select_rows = ( + sa.select( + AssetReference.id.label("asset_reference_id"), + sa.literal("missing").label("tag_name"), + sa.literal(origin).label("origin"), + sa.literal(get_utc_now()).label("added_at"), + ) + .where(AssetReference.asset_id == asset_id) + .where( + sa.not_( + sa.exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == "missing") + ) + ) + ) + ) + session.execute( + sqlite.insert(AssetReferenceTag) + .from_select( + ["asset_reference_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing( + index_elements=[ + AssetReferenceTag.asset_reference_id, + AssetReferenceTag.tag_name, + ] + ) + ) + + +def remove_missing_tag_for_asset_id( + session: Session, + asset_id: str, +) -> None: + session.execute( + sa.delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id.in_( + sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id) + ), + AssetReferenceTag.tag_name == "missing", + ) + ) + + +def list_tags_with_usage( + session: Session, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetReferenceTag.tag_name.label("tag_name"), + func.count(AssetReferenceTag.asset_reference_id).label("cnt"), + ) + .select_from(AssetReferenceTag) + .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.deleted_at.is_(None)) + .group_by(AssetReferenceTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + escaped, esc = escape_sql_like_string(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + escaped, esc = escape_sql_like_string(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if not include_zero: + visible_tags_sq = ( + select(AssetReferenceTag.tag_name) + .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.deleted_at.is_(None)) + .group_by(AssetReferenceTag.tag_name) + ) + total_q = total_q.where(Tag.name.in_(visible_tags_sq)) + + rows = (session.execute(q.limit(limit).offset(offset))).all() + total = (session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) + + +def bulk_insert_tags_and_meta( + session: Session, + tag_rows: list[dict], + meta_rows: list[dict], +) -> None: + """Batch insert into asset_reference_tags and asset_reference_meta. + + Uses ON CONFLICT DO NOTHING. + + Args: + session: Database session + tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at + meta_rows: Dicts with: asset_reference_id, key, ordinal, val_* + """ + if tag_rows: + ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing( + index_elements=[ + AssetReferenceTag.asset_reference_id, + AssetReferenceTag.tag_name, + ] + ) + for chunk in iter_row_chunks(tag_rows, cols_per_row=4): + session.execute(ins_tags, chunk) + + if meta_rows: + ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing( + index_elements=[ + AssetReferenceMeta.asset_reference_id, + AssetReferenceMeta.key, + AssetReferenceMeta.ordinal, + ] + ) + for chunk in iter_row_chunks(meta_rows, cols_per_row=7): + session.execute(ins_meta, chunk) diff --git a/app/assets/database/tags.py b/app/assets/database/tags.py deleted file mode 100644 index 3ab6497c2..000000000 --- a/app/assets/database/tags.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Iterable - -import sqlalchemy -from sqlalchemy.orm import Session -from sqlalchemy.dialects import sqlite - -from app.assets.helpers import normalize_tags, utcnow -from app.assets.database.models import Tag, AssetInfoTag, AssetInfo - - -def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: - wanted = normalize_tags(list(names)) - if not wanted: - return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] - ins = ( - sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - return session.execute(ins) - -def add_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, - origin: str = "automatic", -) -> None: - select_rows = ( - sqlalchemy.select( - AssetInfo.id.label("asset_info_id"), - sqlalchemy.literal("missing").label("tag_name"), - sqlalchemy.literal(origin).label("origin"), - sqlalchemy.literal(utcnow()).label("added_at"), - ) - .where(AssetInfo.asset_id == asset_id) - .where( - sqlalchemy.not_( - sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) - ) - ) - ) - session.execute( - sqlite.insert(AssetInfoTag) - .from_select( - ["asset_info_id", "tag_name", "origin", "added_at"], - select_rows, - ) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - -def remove_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, -) -> None: - session.execute( - sqlalchemy.delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), - AssetInfoTag.tag_name == "missing", - ) - ) diff --git a/app/assets/hashing.py b/app/assets/hashing.py deleted file mode 100644 index 4b72084b9..000000000 --- a/app/assets/hashing.py +++ /dev/null @@ -1,75 +0,0 @@ -from blake3 import blake3 -from typing import IO -import os -import asyncio - - -DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB - -# NOTE: this allows hashing different representations of a file-like object -def blake3_hash( - fp: str | IO[bytes], - chunk_size: int = DEFAULT_CHUNK, -) -> str: - """ - Returns a BLAKE3 hex digest for ``fp``, which may be: - - a filename (str/bytes) or PathLike - - an open binary file object - If ``fp`` is a file object, it must be opened in **binary** mode and support - ``read``, ``seek``, and ``tell``. The function will seek to the start before - reading and will attempt to restore the original position afterward. - """ - # duck typing to check if input is a file-like object - if hasattr(fp, "read"): - return _hash_file_obj(fp, chunk_size) - - with open(os.fspath(fp), "rb") as f: - return _hash_file_obj(f, chunk_size) - - -async def blake3_hash_async( - fp: str | IO[bytes], - chunk_size: int = DEFAULT_CHUNK, -) -> str: - """Async wrapper for ``blake3_hash_sync``. - Uses a worker thread so the event loop remains responsive. - """ - # If it is a path, open inside the worker thread to keep I/O off the loop. - if hasattr(fp, "read"): - return await asyncio.to_thread(blake3_hash, fp, chunk_size) - - def _worker() -> str: - with open(os.fspath(fp), "rb") as f: - return _hash_file_obj(f, chunk_size) - - return await asyncio.to_thread(_worker) - - -def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str: - """ - Hash an already-open binary file object by streaming in chunks. - - Seeks to the beginning before reading (if supported). - - Restores the original position afterward (if tell/seek are supported). - """ - if chunk_size <= 0: - chunk_size = DEFAULT_CHUNK - - # in case file object is already open and not at the beginning, track so can be restored after hashing - orig_pos = file_obj.tell() - - try: - # seek to the beginning before reading - if orig_pos != 0: - file_obj.seek(0) - - h = blake3() - while True: - chunk = file_obj.read(chunk_size) - if not chunk: - break - h.update(chunk) - return h.hexdigest() - finally: - # restore original position in file object, if needed - if orig_pos != 0: - file_obj.seek(orig_pos) diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 5030b123a..3798f3933 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -1,226 +1,42 @@ -import contextlib import os -from decimal import Decimal -from aiohttp import web from datetime import datetime, timezone -from pathlib import Path -from typing import Literal, Any - -import folder_paths +from typing import Sequence -RootType = Literal["models", "input", "output"] -ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") - -def get_query_dict(request: web.Request) -> dict[str, Any]: +def select_best_live_path(states: Sequence) -> str: """ - Gets a dictionary of query parameters from the request. - - 'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic. + Return the best on-disk path among cache states: + 1) Prefer a path that exists with needs_verify == False (already verified). + 2) Otherwise, pick the first path that exists. + 3) Otherwise return empty string. """ - query_dict = { - key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key) - for key in request.query.keys() - } - return query_dict + alive = [ + s + for s in states + if getattr(s, "file_path", None) and os.path.isfile(s.file_path) + ] + if not alive: + return "" + for s in alive: + if not getattr(s, "needs_verify", False): + return s.file_path + return alive[0].file_path -def list_tree(base_dir: str) -> list[str]: - out: list[str] = [] - base_abs = os.path.abspath(base_dir) - if not os.path.isdir(base_abs): - return out - for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - for name in filenames: - out.append(os.path.abspath(os.path.join(dirpath, name))) - return out -def prefixes_for_root(root: RootType) -> list[str]: - if root == "models": - bases: list[str] = [] - for _bucket, paths in get_comfy_models_folders(): - bases.extend(paths) - return [os.path.abspath(p) for p in bases] - if root == "input": - return [os.path.abspath(folder_paths.get_input_directory())] - if root == "output": - return [os.path.abspath(folder_paths.get_output_directory())] - return [] +def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]: + """Escapes %, _ and the escape char in a LIKE prefix. -def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]: - """Escapes %, _ and the escape char itself in a LIKE prefix. - Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like(). + Returns (escaped_prefix, escape_char). """ s = s.replace(escape, escape + escape) # escape the escape char first s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards return s, escape -def fast_asset_file_check( - *, - mtime_db: int | None, - size_db: int | None, - stat_result: os.stat_result, -) -> bool: - if mtime_db is None: - return False - actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)) - if int(mtime_db) != int(actual_mtime_ns): - return False - sz = int(size_db or 0) - if sz > 0: - return int(stat_result.st_size) == sz - return True -def utcnow() -> datetime: +def get_utc_now() -> datetime: """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC.""" return datetime.now(timezone.utc).replace(tzinfo=None) -def get_comfy_models_folders() -> list[tuple[str, list[str]]]: - """Build a list of (folder_name, base_paths[]) categories that are configured for model locations. - - We trust `folder_paths.folder_names_and_paths` and include a category if - *any* of its base paths lies under the Comfy `models_dir`. - """ - targets: list[tuple[str, list[str]]] = [] - models_root = os.path.abspath(folder_paths.models_dir) - for name, values in folder_paths.folder_names_and_paths.items(): - paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI - if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): - targets.append((name, paths)) - return targets - -def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: - """Validates and maps tags -> (base_dir, subdirs_for_fs)""" - root = tags[0] - if root == "models": - if len(tags) < 2: - raise ValueError("at least two tags required for model asset") - try: - bases = folder_paths.folder_names_and_paths[tags[1]][0] - except KeyError: - raise ValueError(f"unknown model category '{tags[1]}'") - if not bases: - raise ValueError(f"no base path configured for category '{tags[1]}'") - base_dir = os.path.abspath(bases[0]) - raw_subdirs = tags[2:] - else: - base_dir = os.path.abspath( - folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory() - ) - raw_subdirs = tags[1:] - for i in raw_subdirs: - if i in (".", ".."): - raise ValueError("invalid path component in tags") - - return base_dir, raw_subdirs if raw_subdirs else [] - -def ensure_within_base(candidate: str, base: str) -> None: - cand_abs = os.path.abspath(candidate) - base_abs = os.path.abspath(base) - try: - if os.path.commonpath([cand_abs, base_abs]) != base_abs: - raise ValueError("destination escapes base directory") - except Exception: - raise ValueError("invalid destination path") - -def compute_relative_filename(file_path: str) -> str | None: - """ - Return the model's path relative to the last well-known folder (the model category), - using forward slashes, eg: - /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" - /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" - - For non-model paths, returns None. - NOTE: this is a temporary helper, used only for initializing metadata["filename"] field. - """ - try: - root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path) - except ValueError: - return None - - p = Path(rel_path) - parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] - if not parts: - return None - - if root_category == "models": - # parts[0] is the category ("checkpoints", "vae", etc) – drop it - inside = parts[1:] if len(parts) > 1 else [parts[0]] - return "/".join(inside) - return "/".join(parts) # input/output: keep all parts - -def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]: - """Given an absolute or relative file path, determine which root category the path belongs to: - - 'input' if the file resides under `folder_paths.get_input_directory()` - - 'output' if the file resides under `folder_paths.get_output_directory()` - - 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()` - - Returns: - (root_category, relative_path_inside_that_root) - For 'models', the relative path is prefixed with the category name: - e.g. ('models', 'vae/test/sub/ae.safetensors') - - Raises: - ValueError: if the path does not belong to input, output, or configured model bases. - """ - fp_abs = os.path.abspath(file_path) - - def _is_within(child: str, parent: str) -> bool: - try: - return os.path.commonpath([child, parent]) == parent - except Exception: - return False - - def _rel(child: str, parent: str) -> str: - return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep) - - # 1) input - input_base = os.path.abspath(folder_paths.get_input_directory()) - if _is_within(fp_abs, input_base): - return "input", _rel(fp_abs, input_base) - - # 2) output - output_base = os.path.abspath(folder_paths.get_output_directory()) - if _is_within(fp_abs, output_base): - return "output", _rel(fp_abs, output_base) - - # 3) models (check deepest matching base to avoid ambiguity) - best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) - for bucket, bases in get_comfy_models_folders(): - for b in bases: - base_abs = os.path.abspath(b) - if not _is_within(fp_abs, base_abs): - continue - cand = (len(base_abs), bucket, _rel(fp_abs, base_abs)) - if best is None or cand[0] > best[0]: - best = cand - - if best is not None: - _, bucket, rel_inside = best - combined = os.path.join(bucket, rel_inside) - return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) - - raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}") - -def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: - """Return a tuple (name, tags) derived from a filesystem path. - - Semantics: - - Root category is determined by `get_relative_to_root_category_path_of_asset`. - - The returned `name` is the base filename with extension from the relative path. - - The returned `tags` are: - [root_category] + parent folders of the relative path (in order) - For 'models', this means: - file '/.../ModelsDir/vae/test_tag/ae.safetensors' - -> root_category='models', some_path='vae/test_tag/ae.safetensors' - -> name='ae.safetensors', tags=['models', 'vae', 'test_tag'] - - Raises: - ValueError: if the path does not belong to input, output, or configured model bases. - """ - root_category, some_path = get_relative_to_root_category_path_of_asset(file_path) - p = Path(some_path) - parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] - return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) def normalize_tags(tags: list[str] | None) -> list[str]: """ @@ -228,85 +44,22 @@ def normalize_tags(tags: list[str] | None) -> list[str]: - Stripping whitespace and converting to lowercase. - Removing duplicates. """ - return [t.strip().lower() for t in (tags or []) if (t or "").strip()] + return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip())) -def collect_models_files() -> list[str]: - out: list[str] = [] - for folder_name, bases in get_comfy_models_folders(): - rel_files = folder_paths.get_filename_list(folder_name) or [] - for rel_path in rel_files: - abs_path = folder_paths.get_full_path(folder_name, rel_path) - if not abs_path: - continue - abs_path = os.path.abspath(abs_path) - allowed = False - for b in bases: - base_abs = os.path.abspath(b) - with contextlib.suppress(Exception): - if os.path.commonpath([abs_path, base_abs]) == base_abs: - allowed = True - break - if allowed: - out.append(abs_path) - return out -def is_scalar(v): - if v is None: - return True - if isinstance(v, bool): - return True - if isinstance(v, (int, float, Decimal, str)): - return True - return False +def validate_blake3_hash(s: str) -> str: + """Validate and normalize a blake3 hash string. -def project_kv(key: str, value): + Returns canonical 'blake3:' or raises ValueError. """ - Turn a metadata key/value into typed projection rows. - Returns list[dict] with keys: - key, ordinal, and one of val_str / val_num / val_bool / val_json (others None) - """ - rows: list[dict] = [] - - def _null_row(ordinal: int) -> dict: - return { - "key": key, "ordinal": ordinal, - "val_str": None, "val_num": None, "val_bool": None, "val_json": None - } - - if value is None: - rows.append(_null_row(0)) - return rows - - if is_scalar(value): - if isinstance(value, bool): - rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) - elif isinstance(value, (int, float, Decimal)): - num = value if isinstance(value, Decimal) else Decimal(str(value)) - rows.append({"key": key, "ordinal": 0, "val_num": num}) - elif isinstance(value, str): - rows.append({"key": key, "ordinal": 0, "val_str": value}) - else: - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows - - if isinstance(value, list): - if all(is_scalar(x) for x in value): - for i, x in enumerate(value): - if x is None: - rows.append(_null_row(i)) - elif isinstance(x, bool): - rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) - elif isinstance(x, (int, float, Decimal)): - num = x if isinstance(x, Decimal) else Decimal(str(x)) - rows.append({"key": key, "ordinal": i, "val_num": num}) - elif isinstance(x, str): - rows.append({"key": key, "ordinal": i, "val_str": x}) - else: - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - for i, x in enumerate(value): - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows + s = s.strip().lower() + if not s or ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if ( + algo != "blake3" + or len(digest) != 64 + or any(c for c in digest if c not in "0123456789abcdef") + ): + raise ValueError("hash must be 'blake3:'") + return f"{algo}:{digest}" diff --git a/app/assets/manager.py b/app/assets/manager.py deleted file mode 100644 index a68c8c8ae..000000000 --- a/app/assets/manager.py +++ /dev/null @@ -1,516 +0,0 @@ -import os -import mimetypes -import contextlib -from typing import Sequence - -from app.database.db import create_session -from app.assets.api import schemas_out, schemas_in -from app.assets.database.queries import ( - asset_exists_by_hash, - asset_info_exists_for_asset_id, - get_asset_by_hash, - get_asset_info_by_id, - fetch_asset_info_asset_and_tags, - fetch_asset_info_and_asset, - create_asset_info_for_existing_asset, - touch_asset_info_by_id, - update_asset_info_full, - delete_asset_info_by_id, - list_cache_states_by_asset_id, - list_asset_infos_page, - list_tags_with_usage, - get_asset_tags, - add_tags_to_asset_info, - remove_tags_from_asset_info, - pick_best_live_path, - ingest_fs_asset, - set_asset_info_preview, -) -from app.assets.helpers import resolve_destination_from_tags, ensure_within_base -from app.assets.database.models import Asset - - -def _safe_sort_field(requested: str | None) -> str: - if not requested: - return "created_at" - v = requested.lower() - if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: - return v - return "created_at" - - -def _get_size_mtime_ns(path: str) -> tuple[int, int]: - st = os.stat(path, follow_symlinks=True) - return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) - - -def _safe_filename(name: str | None, fallback: str) -> str: - n = os.path.basename((name or "").strip() or fallback) - if n: - return n - return fallback - - -def asset_exists(*, asset_hash: str) -> bool: - """ - Check if an asset with a given hash exists in database. - """ - with create_session() as session: - return asset_exists_by_hash(session, asset_hash=asset_hash) - - -def list_assets( - *, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, - name_contains: str | None = None, - metadata_filter: dict | None = None, - limit: int = 20, - offset: int = 0, - sort: str = "created_at", - order: str = "desc", - owner_id: str = "", -) -> schemas_out.AssetsList: - sort = _safe_sort_field(sort) - order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() - - with create_session() as session: - infos, tag_map, total = list_asset_infos_page( - session, - owner_id=owner_id, - include_tags=include_tags, - exclude_tags=exclude_tags, - name_contains=name_contains, - metadata_filter=metadata_filter, - limit=limit, - offset=offset, - sort=sort, - order=order, - ) - - summaries: list[schemas_out.AssetSummary] = [] - for info in infos: - asset = info.asset - tags = tag_map.get(info.id, []) - summaries.append( - schemas_out.AssetSummary( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset else None, - mime_type=asset.mime_type if asset else None, - tags=tags, - created_at=info.created_at, - updated_at=info.updated_at, - last_access_time=info.last_access_time, - ) - ) - - return schemas_out.AssetsList( - assets=summaries, - total=total, - has_more=(offset + len(summaries)) < total, - ) - - -def get_asset( - *, - asset_info_id: str, - owner_id: str = "", -) -> schemas_out.AssetDetail: - with create_session() as session: - res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not res: - raise ValueError(f"AssetInfo {asset_info_id} not found") - info, asset, tag_names = res - preview_id = info.preview_id - - return schemas_out.AssetDetail( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, - mime_type=asset.mime_type if asset else None, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - ) - - -def resolve_asset_content_for_download( - *, - asset_info_id: str, - owner_id: str = "", -) -> tuple[str, str, str]: - with create_session() as session: - pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not pair: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - info, asset = pair - states = list_cache_states_by_asset_id(session, asset_id=asset.id) - abs_path = pick_best_live_path(states) - if not abs_path: - raise FileNotFoundError - - touch_asset_info_by_id(session, asset_info_id=asset_info_id) - session.commit() - - ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" - download_name = info.name or os.path.basename(abs_path) - return abs_path, ctype, download_name - - -def upload_asset_from_temp_path( - spec: schemas_in.UploadAssetSpec, - *, - temp_path: str, - client_filename: str | None = None, - owner_id: str = "", - expected_asset_hash: str | None = None, -) -> schemas_out.AssetCreated: - """ - Create new asset or update existing asset from a temporary file path. - """ - try: - # NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment - import app.assets.hashing as hashing - digest = hashing.blake3_hash(temp_path) - except Exception as e: - raise RuntimeError(f"failed to hash uploaded file: {e}") - asset_hash = "blake3:" + digest - - if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower(): - raise ValueError("HASH_MISMATCH") - - with create_session() as session: - existing = get_asset_by_hash(session, asset_hash=asset_hash) - if existing is not None: - with contextlib.suppress(Exception): - if temp_path and os.path.exists(temp_path): - os.remove(temp_path) - - display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) - info = create_asset_info_for_existing_asset( - session, - asset_hash=asset_hash, - name=display_name, - user_metadata=spec.user_metadata or {}, - tags=spec.tags or [], - tag_origin="manual", - owner_id=owner_id, - ) - tag_names = get_asset_tags(session, asset_info_id=info.id) - session.commit() - - return schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=existing.hash, - size=int(existing.size_bytes) if existing.size_bytes is not None else None, - mime_type=existing.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=False, - ) - - base_dir, subdirs = resolve_destination_from_tags(spec.tags) - dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir - os.makedirs(dest_dir, exist_ok=True) - - src_for_ext = (client_filename or spec.name or "").strip() - _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" - ext = _ext if 0 < len(_ext) <= 16 else "" - hashed_basename = f"{digest}{ext}" - dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) - ensure_within_base(dest_abs, base_dir) - - content_type = ( - mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] - or mimetypes.guess_type(hashed_basename, strict=False)[0] - or "application/octet-stream" - ) - - try: - os.replace(temp_path, dest_abs) - except Exception as e: - raise RuntimeError(f"failed to move uploaded file into place: {e}") - - try: - size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs) - except OSError as e: - raise RuntimeError(f"failed to stat destination file: {e}") - - with create_session() as session: - result = ingest_fs_asset( - session, - asset_hash=asset_hash, - abs_path=dest_abs, - size_bytes=size_bytes, - mtime_ns=mtime_ns, - mime_type=content_type, - info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest), - owner_id=owner_id, - preview_id=None, - user_metadata=spec.user_metadata or {}, - tags=spec.tags, - tag_origin="manual", - require_existing_tags=False, - ) - info_id = result["asset_info_id"] - if not info_id: - raise RuntimeError("failed to create asset metadata") - - pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id) - if not pair: - raise RuntimeError("inconsistent DB state after ingest") - info, asset = pair - tag_names = get_asset_tags(session, asset_info_id=info.id) - created_result = schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=asset.hash, - size=int(asset.size_bytes), - mime_type=asset.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=result["asset_created"], - ) - session.commit() - - return created_result - - -def update_asset( - *, - asset_info_id: str, - name: str | None = None, - tags: list[str] | None = None, - user_metadata: dict | None = None, - owner_id: str = "", -) -> schemas_out.AssetUpdated: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - info = update_asset_info_full( - session, - asset_info_id=asset_info_id, - name=name, - tags=tags, - user_metadata=user_metadata, - tag_origin="manual", - asset_info_row=info_row, - ) - - tag_names = get_asset_tags(session, asset_info_id=asset_info_id) - result = schemas_out.AssetUpdated( - id=info.id, - name=info.name, - asset_hash=info.asset.hash if info.asset else None, - tags=tag_names, - user_metadata=info.user_metadata or {}, - updated_at=info.updated_at, - ) - session.commit() - - return result - - -def set_asset_preview( - *, - asset_info_id: str, - preview_asset_id: str | None = None, - owner_id: str = "", -) -> schemas_out.AssetDetail: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - set_asset_info_preview( - session, - asset_info_id=asset_info_id, - preview_asset_id=preview_asset_id, - ) - - res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not res: - raise RuntimeError("State changed during preview update") - info, asset, tags = res - result = schemas_out.AssetDetail( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, - mime_type=asset.mime_type if asset else None, - tags=tags, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - ) - session.commit() - - return result - - -def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - asset_id = info_row.asset_id if info_row else None - deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not deleted: - session.commit() - return False - - if not delete_content_if_orphan or not asset_id: - session.commit() - return True - - still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id) - if still_exists: - session.commit() - return True - - states = list_cache_states_by_asset_id(session, asset_id=asset_id) - file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] - - asset_row = session.get(Asset, asset_id) - if asset_row is not None: - session.delete(asset_row) - - session.commit() - for p in file_paths: - with contextlib.suppress(Exception): - if p and os.path.isfile(p): - os.remove(p) - return True - - -def create_asset_from_hash( - *, - hash_str: str, - name: str, - tags: list[str] | None = None, - user_metadata: dict | None = None, - owner_id: str = "", -) -> schemas_out.AssetCreated | None: - canonical = hash_str.strip().lower() - with create_session() as session: - asset = get_asset_by_hash(session, asset_hash=canonical) - if not asset: - return None - - info = create_asset_info_for_existing_asset( - session, - asset_hash=canonical, - name=_safe_filename(name, fallback=canonical.split(":", 1)[1]), - user_metadata=user_metadata or {}, - tags=tags or [], - tag_origin="manual", - owner_id=owner_id, - ) - tag_names = get_asset_tags(session, asset_info_id=info.id) - result = schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=asset.hash, - size=int(asset.size_bytes), - mime_type=asset.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=False, - ) - session.commit() - - return result - - -def add_tags_to_asset( - *, - asset_info_id: str, - tags: list[str], - origin: str = "manual", - owner_id: str = "", -) -> schemas_out.TagsAdd: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - data = add_tags_to_asset_info( - session, - asset_info_id=asset_info_id, - tags=tags, - origin=origin, - create_if_missing=True, - asset_info_row=info_row, - ) - session.commit() - return schemas_out.TagsAdd(**data) - - -def remove_tags_from_asset( - *, - asset_info_id: str, - tags: list[str], - owner_id: str = "", -) -> schemas_out.TagsRemove: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - data = remove_tags_from_asset_info( - session, - asset_info_id=asset_info_id, - tags=tags, - ) - session.commit() - return schemas_out.TagsRemove(**data) - - -def list_tags( - prefix: str | None = None, - limit: int = 100, - offset: int = 0, - order: str = "count_desc", - include_zero: bool = True, - owner_id: str = "", -) -> schemas_out.TagsList: - limit = max(1, min(1000, limit)) - offset = max(0, offset) - - with create_session() as session: - rows, total = list_tags_with_usage( - session, - prefix=prefix, - limit=limit, - offset=offset, - include_zero=include_zero, - order=order, - owner_id=owner_id, - ) - - tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] - return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 0172a5c2f..e27ea5123 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -1,263 +1,567 @@ -import contextlib -import time import logging import os -import sqlalchemy +from pathlib import Path +from typing import Callable, Literal, TypedDict import folder_paths -from app.database.db import create_session, dependencies_available -from app.assets.helpers import ( - collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path, - list_tree,prefixes_for_root, escape_like_prefix, - RootType +from app.assets.database.queries import ( + add_missing_tag_for_asset_id, + bulk_update_enrichment_level, + bulk_update_is_missing, + bulk_update_needs_verify, + delete_orphaned_seed_asset, + delete_references_by_ids, + ensure_tags_exist, + get_asset_by_hash, + get_references_for_prefixes, + get_unenriched_references, + mark_references_missing_outside_prefixes, + reassign_asset_references, + remove_missing_tag_for_asset_id, + set_reference_metadata, + update_asset_hash_and_mime, ) -from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id -from app.assets.database.bulk_ops import seed_from_paths_batch -from app.assets.database.models import Asset, AssetCacheState, AssetInfo +from app.assets.services.bulk_ingest import ( + SeedAssetSpec, + batch_insert_seed_assets, +) +from app.assets.services.file_utils import ( + get_mtime_ns, + is_visible, + list_files_recursively, + verify_file_unchanged, +) +from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash +from app.assets.services.metadata_extract import extract_file_metadata +from app.assets.services.path_utils import ( + compute_relative_filename, + get_comfy_models_folders, + get_name_and_tags_from_asset_path, +) +from app.database.db import create_session -def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None: - """ - Scan the given roots and seed the assets into the database. - """ - if not dependencies_available(): - if enable_logging: - logging.warning("Database dependencies not available, skipping assets scan") - return - t_start = time.perf_counter() - created = 0 - skipped_existing = 0 - orphans_pruned = 0 - paths: list[str] = [] - try: - existing_paths: set[str] = set() - for r in roots: - try: - survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True) - if survivors: - existing_paths.update(survivors) - except Exception as e: - logging.exception("fast DB scan failed for %s: %s", r, e) +class _RefInfo(TypedDict): + ref_id: str + file_path: str + exists: bool + stat_unchanged: bool + needs_verify: bool - try: - orphans_pruned = _prune_orphaned_assets(roots) - except Exception as e: - logging.exception("orphan pruning failed: %s", e) - if "models" in roots: - paths.extend(collect_models_files()) - if "input" in roots: - paths.extend(list_tree(folder_paths.get_input_directory())) - if "output" in roots: - paths.extend(list_tree(folder_paths.get_output_directory())) +class _AssetAccumulator(TypedDict): + hash: str | None + size_db: int + refs: list[_RefInfo] - specs: list[dict] = [] - tag_pool: set[str] = set() - for p in paths: - abs_p = os.path.abspath(p) - if abs_p in existing_paths: - skipped_existing += 1 + +RootType = Literal["models", "input", "output"] + + +def get_prefixes_for_root(root: RootType) -> list[str]: + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + return [os.path.abspath(p) for p in bases] + if root == "input": + return [os.path.abspath(folder_paths.get_input_directory())] + if root == "output": + return [os.path.abspath(folder_paths.get_output_directory())] + return [] + + +def get_all_known_prefixes() -> list[str]: + """Get all known asset prefixes across all root types.""" + all_roots: tuple[RootType, ...] = ("models", "input", "output") + return [p for root in all_roots for p in get_prefixes_for_root(root)] + + +def collect_models_files() -> list[str]: + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + rel_files = folder_paths.get_filename_list(folder_name) or [] + for rel_path in rel_files: + if not all(is_visible(part) for part in Path(rel_path).parts): continue - try: - stat_p = os.stat(abs_p, follow_symlinks=False) - except OSError: + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: continue - # skip empty files - if not stat_p.st_size: - continue - name, tags = get_name_and_tags_from_asset_path(abs_p) - specs.append( - { - "abs_path": abs_p, - "size_bytes": stat_p.st_size, - "mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)), - "info_name": name, - "tags": tags, - "fname": compute_relative_filename(abs_p), - } - ) - for t in tags: - tag_pool.add(t) - # if no file specs, nothing to do - if not specs: - return - with create_session() as sess: - if tag_pool: - ensure_tags_exist(sess, tag_pool, tag_type="user") - - result = seed_from_paths_batch(sess, specs=specs, owner_id="") - created += result["inserted_infos"] - sess.commit() - finally: - if enable_logging: - logging.info( - "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)", - roots, - time.perf_counter() - t_start, - created, - skipped_existing, - orphans_pruned, - len(paths), - ) + abs_path = os.path.abspath(abs_path) + allowed = False + abs_p = Path(abs_path) + for b in bases: + if abs_p.is_relative_to(os.path.abspath(b)): + allowed = True + break + if allowed: + out.append(abs_path) + return out -def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int: - """Prune cache states outside configured prefixes, then delete orphaned seed assets.""" - all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)] - if not all_prefixes: - return 0 - - def make_prefix_condition(prefix: str): - base = prefix if prefix.endswith(os.sep) else prefix + os.sep - escaped, esc = escape_like_prefix(base) - return AssetCacheState.file_path.like(escaped + "%", escape=esc) - - matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes]) - - orphan_subq = ( - sqlalchemy.select(Asset.id) - .outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id) - .where(Asset.hash.is_(None), AssetCacheState.id.is_(None)) - ).scalar_subquery() - - with create_session() as sess: - sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix)) - sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq))) - result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq))) - sess.commit() - return result.rowcount - - -def _fast_db_consistency_pass( +def sync_references_with_filesystem( + session, root: RootType, - *, collect_existing_paths: bool = False, update_missing_tags: bool = False, ) -> set[str] | None: - """Fast DB+FS pass for a root: - - Toggle needs_verify per state using fast check - - For hashed assets with at least one fast-ok state in this root: delete stale missing states - - For seed assets with all states missing: delete Asset and its AssetInfos - - Optionally add/remove 'missing' tags based on fast-ok in this root - - Optionally return surviving absolute paths + """Reconcile asset references with filesystem for a root. + + - Toggle needs_verify per reference using mtime/size stat check + - For hashed assets with at least one stat-unchanged ref: delete stale missing refs + - For seed assets with all refs missing: delete Asset and its references + - Optionally add/remove 'missing' tags based on stat check in this root + - Optionally return surviving absolute paths + + Args: + session: Database session + root: Root type to scan + collect_existing_paths: If True, return set of surviving file paths + update_missing_tags: If True, update 'missing' tags based on file status + + Returns: + Set of surviving absolute paths if collect_existing_paths=True, else None """ - prefixes = prefixes_for_root(root) + prefixes = get_prefixes_for_root(root) if not prefixes: return set() if collect_existing_paths else None - conds = [] - for p in prefixes: - base = os.path.abspath(p) - if not base.endswith(os.sep): - base += os.sep - escaped, esc = escape_like_prefix(base) - conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + rows = get_references_for_prefixes( + session, prefixes, include_missing=update_missing_tags + ) + + by_asset: dict[str, _AssetAccumulator] = {} + for row in rows: + acc = by_asset.get(row.asset_id) + if acc is None: + acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []} + by_asset[row.asset_id] = acc + + stat_unchanged = False + try: + exists = True + stat_unchanged = verify_file_unchanged( + mtime_db=row.mtime_ns, + size_db=acc["size_db"], + stat_result=os.stat(row.file_path, follow_symlinks=True), + ) + except FileNotFoundError: + exists = False + except PermissionError: + exists = True + logging.debug("Permission denied accessing %s", row.file_path) + except OSError as e: + exists = False + logging.debug("OSError checking %s: %s", row.file_path, e) + + acc["refs"].append( + { + "ref_id": row.reference_id, + "file_path": row.file_path, + "exists": exists, + "stat_unchanged": stat_unchanged, + "needs_verify": row.needs_verify, + } + ) + + to_set_verify: list[str] = [] + to_clear_verify: list[str] = [] + stale_ref_ids: list[str] = [] + to_mark_missing: list[str] = [] + to_clear_missing: list[str] = [] + survivors: set[str] = set() + + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + refs = acc["refs"] + any_unchanged = any(r["stat_unchanged"] for r in refs) + all_missing = all(not r["exists"] for r in refs) + + for r in refs: + if not r["exists"]: + to_mark_missing.append(r["ref_id"]) + continue + if r["stat_unchanged"]: + to_clear_missing.append(r["ref_id"]) + if r["needs_verify"]: + to_clear_verify.append(r["ref_id"]) + if not r["stat_unchanged"] and not r["needs_verify"]: + to_set_verify.append(r["ref_id"]) + + if a_hash is None: + if refs and all_missing: + delete_orphaned_seed_asset(session, aid) + else: + for r in refs: + if r["exists"]: + survivors.add(os.path.abspath(r["file_path"])) + continue + + if any_unchanged: + for r in refs: + if not r["exists"]: + stale_ref_ids.append(r["ref_id"]) + if update_missing_tags: + try: + remove_missing_tag_for_asset_id(session, asset_id=aid) + except Exception as e: + logging.warning( + "Failed to remove missing tag for asset %s: %s", aid, e + ) + elif update_missing_tags: + try: + add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic") + except Exception as e: + logging.warning("Failed to add missing tag for asset %s: %s", aid, e) + + for r in refs: + if r["exists"]: + survivors.add(os.path.abspath(r["file_path"])) + + delete_references_by_ids(session, stale_ref_ids) + stale_set = set(stale_ref_ids) + to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set] + bulk_update_is_missing(session, to_mark_missing, value=True) + bulk_update_is_missing(session, to_clear_missing, value=False) + bulk_update_needs_verify(session, to_set_verify, value=True) + bulk_update_needs_verify(session, to_clear_verify, value=False) + + return survivors if collect_existing_paths else None + + +def sync_root_safely(root: RootType) -> set[str]: + """Sync a single root's references with the filesystem. + + Returns survivors (existing paths) or empty set on failure. + """ + try: + with create_session() as sess: + survivors = sync_references_with_filesystem( + sess, + root, + collect_existing_paths=True, + update_missing_tags=True, + ) + sess.commit() + return survivors or set() + except Exception as e: + logging.exception("fast DB scan failed for %s: %s", root, e) + return set() + + +def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int: + """Mark references as missing when outside the given prefixes. + + This is a non-destructive soft-delete. Returns count marked or 0 on failure. + """ + try: + with create_session() as sess: + count = mark_references_missing_outside_prefixes(sess, prefixes) + sess.commit() + return count + except Exception as e: + logging.exception("marking missing assets failed: %s", e) + return 0 + + +def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]: + """Collect all file paths for the given roots.""" + paths: list[str] = [] + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_files_recursively(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_files_recursively(folder_paths.get_output_directory())) + return paths + + +def build_asset_specs( + paths: list[str], + existing_paths: set[str], + enable_metadata_extraction: bool = True, + compute_hashes: bool = False, +) -> tuple[list[SeedAssetSpec], set[str], int]: + """Build asset specs from paths, returning (specs, tag_pool, skipped_count). + + Args: + paths: List of file paths to process + existing_paths: Set of paths that already exist in the database + enable_metadata_extraction: If True, extract tier 1 & 2 metadata + compute_hashes: If True, compute blake3 hashes (slow for large files) + """ + specs: list[SeedAssetSpec] = [] + tag_pool: set[str] = set() + skipped = 0 + + for p in paths: + abs_p = os.path.abspath(p) + if abs_p in existing_paths: + skipped += 1 + continue + try: + stat_p = os.stat(abs_p, follow_symlinks=True) + except OSError: + continue + if not stat_p.st_size: + continue + name, tags = get_name_and_tags_from_asset_path(abs_p) + rel_fname = compute_relative_filename(abs_p) + + # Extract metadata (tier 1: filesystem, tier 2: safetensors header) + metadata = None + if enable_metadata_extraction: + metadata = extract_file_metadata( + abs_p, + stat_result=stat_p, + relative_filename=rel_fname, + ) + + # Compute hash if requested + asset_hash: str | None = None + if compute_hashes: + try: + digest, _ = compute_blake3_hash(abs_p) + asset_hash = "blake3:" + digest + except Exception as e: + logging.warning("Failed to hash %s: %s", abs_p, e) + + mime_type = metadata.content_type if metadata else None + specs.append( + { + "abs_path": abs_p, + "size_bytes": stat_p.st_size, + "mtime_ns": get_mtime_ns(stat_p), + "info_name": name, + "tags": tags, + "fname": rel_fname, + "metadata": metadata, + "hash": asset_hash, + "mime_type": mime_type, + } + ) + tag_pool.update(tags) + + return specs, tag_pool, skipped + + + +def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: + """Insert asset specs into database, returning count of created refs.""" + if not specs: + return 0 + with create_session() as sess: + if tag_pool: + ensure_tags_exist(sess, tag_pool, tag_type="user") + result = batch_insert_seed_assets(sess, specs=specs, owner_id="") + sess.commit() + return result.inserted_refs + + +# Enrichment level constants +ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only +ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type) +ENRICHMENT_HASHED = 2 # Hash computed (blake3) + + +def get_unenriched_assets_for_roots( + roots: tuple[RootType, ...], + max_level: int = ENRICHMENT_STUB, + limit: int = 1000, +) -> list: + """Get assets that need enrichment for the given roots. + + Args: + roots: Tuple of root types to scan + max_level: Maximum enrichment level to include + limit: Maximum number of rows to return + + Returns: + List of UnenrichedReferenceRow + """ + prefixes: list[str] = [] + for root in roots: + prefixes.extend(get_prefixes_for_root(root)) + + if not prefixes: + return [] with create_session() as sess: - rows = ( - sess.execute( - sqlalchemy.select( - AssetCacheState.id, - AssetCacheState.file_path, - AssetCacheState.mtime_ns, - AssetCacheState.needs_verify, - AssetCacheState.asset_id, - Asset.hash, - Asset.size_bytes, - ) - .join(Asset, Asset.id == AssetCacheState.asset_id) - .where(sqlalchemy.or_(*conds)) - .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + return get_unenriched_references( + sess, prefixes, max_level=max_level, limit=limit + ) + + +def enrich_asset( + session, + file_path: str, + reference_id: str, + asset_id: str, + extract_metadata: bool = True, + compute_hash: bool = False, + interrupt_check: Callable[[], bool] | None = None, + hash_checkpoints: dict[str, HashCheckpoint] | None = None, +) -> int: + """Enrich a single asset with metadata and/or hash. + + Args: + session: Database session (caller manages lifecycle) + file_path: Absolute path to the file + reference_id: ID of the reference to update + asset_id: ID of the asset to update (for mime_type and hash) + extract_metadata: If True, extract safetensors header and mime type + compute_hash: If True, compute blake3 hash + interrupt_check: Optional non-blocking callable that returns True if + the operation should be interrupted (e.g. paused or cancelled) + hash_checkpoints: Optional dict for saving/restoring hash progress + across interruptions, keyed by file path + + Returns: + New enrichment level achieved + """ + new_level = ENRICHMENT_STUB + + try: + stat_p = os.stat(file_path, follow_symlinks=True) + except OSError: + return new_level + + rel_fname = compute_relative_filename(file_path) + mime_type: str | None = None + metadata = None + + if extract_metadata: + metadata = extract_file_metadata( + file_path, + stat_result=stat_p, + relative_filename=rel_fname, + ) + if metadata: + mime_type = metadata.content_type + new_level = ENRICHMENT_METADATA + + full_hash: str | None = None + if compute_hash: + try: + mtime_before = get_mtime_ns(stat_p) + size_before = stat_p.st_size + + # Restore checkpoint if available and file unchanged + checkpoint = None + if hash_checkpoints is not None: + checkpoint = hash_checkpoints.get(file_path) + if checkpoint is not None: + cur_stat = os.stat(file_path, follow_symlinks=True) + if (checkpoint.mtime_ns != get_mtime_ns(cur_stat) + or checkpoint.file_size != cur_stat.st_size): + checkpoint = None + hash_checkpoints.pop(file_path, None) + else: + mtime_before = get_mtime_ns(cur_stat) + + digest, new_checkpoint = compute_blake3_hash( + file_path, + interrupt_check=interrupt_check, + checkpoint=checkpoint, ) - ).all() - by_asset: dict[str, dict] = {} - for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: - acc = by_asset.get(aid) - if acc is None: - acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} - by_asset[aid] = acc + if digest is None: + # Interrupted — save checkpoint for later resumption + if hash_checkpoints is not None and new_checkpoint is not None: + new_checkpoint.mtime_ns = mtime_before + new_checkpoint.file_size = size_before + hash_checkpoints[file_path] = new_checkpoint + return new_level + + # Completed — clear any saved checkpoint + if hash_checkpoints is not None: + hash_checkpoints.pop(file_path, None) + + stat_after = os.stat(file_path, follow_symlinks=True) + mtime_after = get_mtime_ns(stat_after) + if mtime_before != mtime_after: + logging.warning("File modified during hashing, discarding hash: %s", file_path) + else: + full_hash = f"blake3:{digest}" + metadata_ok = not extract_metadata or metadata is not None + if metadata_ok: + new_level = ENRICHMENT_HASHED + except Exception as e: + logging.warning("Failed to hash %s: %s", file_path, e) + + if extract_metadata and metadata: + user_metadata = metadata.to_user_metadata() + set_reference_metadata(session, reference_id, user_metadata) + + if full_hash: + existing = get_asset_by_hash(session, full_hash) + if existing and existing.id != asset_id: + reassign_asset_references(session, asset_id, existing.id, reference_id) + delete_orphaned_seed_asset(session, asset_id) + if mime_type: + update_asset_hash_and_mime(session, existing.id, mime_type=mime_type) + else: + update_asset_hash_and_mime(session, asset_id, full_hash, mime_type) + elif mime_type: + update_asset_hash_and_mime(session, asset_id, mime_type=mime_type) + + bulk_update_enrichment_level(session, [reference_id], new_level) + session.commit() + + return new_level + + +def enrich_assets_batch( + rows: list, + extract_metadata: bool = True, + compute_hash: bool = False, + interrupt_check: Callable[[], bool] | None = None, + hash_checkpoints: dict[str, HashCheckpoint] | None = None, +) -> tuple[int, list[str]]: + """Enrich a batch of assets. + + Uses a single DB session for the entire batch, committing after each + individual asset to avoid long-held transactions while eliminating + per-asset session creation overhead. + + Args: + rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots + extract_metadata: If True, extract metadata for each asset + compute_hash: If True, compute hash for each asset + interrupt_check: Optional non-blocking callable that returns True if + the operation should be interrupted (e.g. paused or cancelled) + hash_checkpoints: Optional dict for saving/restoring hash progress + across interruptions, keyed by file path + + Returns: + Tuple of (enriched_count, failed_reference_ids) + """ + enriched = 0 + failed_ids: list[str] = [] + + with create_session() as sess: + for row in rows: + if interrupt_check is not None and interrupt_check(): + break - fast_ok = False try: - exists = True - fast_ok = fast_asset_file_check( - mtime_db=mtime_db, - size_db=acc["size_db"], - stat_result=os.stat(fp, follow_symlinks=True), + new_level = enrich_asset( + sess, + file_path=row.file_path, + reference_id=row.reference_id, + asset_id=row.asset_id, + extract_metadata=extract_metadata, + compute_hash=compute_hash, + interrupt_check=interrupt_check, + hash_checkpoints=hash_checkpoints, ) - except FileNotFoundError: - exists = False - except OSError: - exists = False - - acc["states"].append({ - "sid": sid, - "fp": fp, - "exists": exists, - "fast_ok": fast_ok, - "needs_verify": bool(needs_verify), - }) - - to_set_verify: list[int] = [] - to_clear_verify: list[int] = [] - stale_state_ids: list[int] = [] - survivors: set[str] = set() - - for aid, acc in by_asset.items(): - a_hash = acc["hash"] - states = acc["states"] - any_fast_ok = any(s["fast_ok"] for s in states) - all_missing = all(not s["exists"] for s in states) - - for s in states: - if not s["exists"]: - continue - if s["fast_ok"] and s["needs_verify"]: - to_clear_verify.append(s["sid"]) - if not s["fast_ok"] and not s["needs_verify"]: - to_set_verify.append(s["sid"]) - - if a_hash is None: - if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists - sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid)) - asset = sess.get(Asset, aid) - if asset: - sess.delete(asset) + if new_level > row.enrichment_level: + enriched += 1 else: - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["fp"])) - continue + failed_ids.append(row.reference_id) + except Exception as e: + logging.warning("Failed to enrich %s: %s", row.file_path, e) + sess.rollback() + failed_ids.append(row.reference_id) - if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records - for s in states: - if not s["exists"]: - stale_state_ids.append(s["sid"]) - if update_missing_tags: - with contextlib.suppress(Exception): - remove_missing_tag_for_asset_id(sess, asset_id=aid) - elif update_missing_tags: - with contextlib.suppress(Exception): - add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") - - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["fp"])) - - if stale_state_ids: - sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids))) - if to_set_verify: - sess.execute( - sqlalchemy.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_set_verify)) - .values(needs_verify=True) - ) - if to_clear_verify: - sess.execute( - sqlalchemy.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_clear_verify)) - .values(needs_verify=False) - ) - sess.commit() - return survivors if collect_existing_paths else None + return enriched, failed_ids diff --git a/app/assets/seeder.py b/app/assets/seeder.py new file mode 100644 index 000000000..029448464 --- /dev/null +++ b/app/assets/seeder.py @@ -0,0 +1,794 @@ +"""Background asset seeder with thread management and cancellation support.""" + +import logging +import os +import threading +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Callable + +from app.assets.scanner import ( + ENRICHMENT_METADATA, + ENRICHMENT_STUB, + RootType, + build_asset_specs, + collect_paths_for_roots, + enrich_assets_batch, + get_all_known_prefixes, + get_prefixes_for_root, + get_unenriched_assets_for_roots, + insert_asset_specs, + mark_missing_outside_prefixes_safely, + sync_root_safely, +) +from app.database.db import dependencies_available + + +class ScanInProgressError(Exception): + """Raised when an operation cannot proceed because a scan is running.""" + + +class State(Enum): + """Seeder state machine states.""" + + IDLE = "IDLE" + RUNNING = "RUNNING" + PAUSED = "PAUSED" + CANCELLING = "CANCELLING" + + +class ScanPhase(Enum): + """Scan phase options.""" + + FAST = "fast" # Phase 1: filesystem only (stubs) + ENRICH = "enrich" # Phase 2: metadata + hash + FULL = "full" # Both phases sequentially + + +@dataclass +class Progress: + """Progress information for a scan operation.""" + + scanned: int = 0 + total: int = 0 + created: int = 0 + skipped: int = 0 + + +@dataclass +class ScanStatus: + """Current status of the asset seeder.""" + + state: State + progress: Progress | None + errors: list[str] = field(default_factory=list) + + +ProgressCallback = Callable[[Progress], None] + + +class _AssetSeeder: + """Background asset scanning manager. + + Spawns ephemeral daemon threads for scanning. + Each scan creates a new thread that exits when complete. + Use the module-level ``asset_seeder`` instance. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = State.IDLE + self._progress: Progress | None = None + self._last_progress: Progress | None = None + self._errors: list[str] = [] + self._thread: threading.Thread | None = None + self._cancel_event = threading.Event() + self._run_gate = threading.Event() + self._run_gate.set() # Start unpaused (set = running, clear = paused) + self._roots: tuple[RootType, ...] = () + self._phase: ScanPhase = ScanPhase.FULL + self._compute_hashes: bool = False + self._prune_first: bool = False + self._progress_callback: ProgressCallback | None = None + self._disabled: bool = False + + def disable(self) -> None: + """Disable the asset seeder, preventing any scans from starting.""" + self._disabled = True + logging.info("Asset seeder disabled") + + def is_disabled(self) -> bool: + """Check if the asset seeder is disabled.""" + return self._disabled + + def start( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + phase: ScanPhase = ScanPhase.FULL, + progress_callback: ProgressCallback | None = None, + prune_first: bool = False, + compute_hashes: bool = False, + ) -> bool: + """Start a background scan for the given roots. + + Args: + roots: Tuple of root types to scan (models, input, output) + phase: Scan phase to run (FAST, ENRICH, or FULL for both) + progress_callback: Optional callback called with progress updates + prune_first: If True, prune orphaned assets before scanning + compute_hashes: If True, compute blake3 hashes (slow) + + Returns: + True if scan was started, False if already running + """ + if self._disabled: + logging.debug("Asset seeder is disabled, skipping start") + return False + logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value) + with self._lock: + if self._state != State.IDLE: + logging.info("Asset seeder already running, skipping start") + return False + self._state = State.RUNNING + self._progress = Progress() + self._errors = [] + self._roots = roots + self._phase = phase + self._prune_first = prune_first + self._compute_hashes = compute_hashes + self._progress_callback = progress_callback + self._cancel_event.clear() + self._run_gate.set() # Ensure unpaused when starting + self._thread = threading.Thread( + target=self._run_scan, + name="_AssetSeeder", + daemon=True, + ) + self._thread.start() + return True + + def start_fast( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + progress_callback: ProgressCallback | None = None, + prune_first: bool = False, + ) -> bool: + """Start a fast scan (phase 1 only) - creates stub records. + + Args: + roots: Tuple of root types to scan + progress_callback: Optional callback for progress updates + prune_first: If True, prune orphaned assets before scanning + + Returns: + True if scan was started, False if already running + """ + return self.start( + roots=roots, + phase=ScanPhase.FAST, + progress_callback=progress_callback, + prune_first=prune_first, + compute_hashes=False, + ) + + def start_enrich( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + progress_callback: ProgressCallback | None = None, + compute_hashes: bool = False, + ) -> bool: + """Start an enrichment scan (phase 2 only) - extracts metadata and hashes. + + Args: + roots: Tuple of root types to scan + progress_callback: Optional callback for progress updates + compute_hashes: If True, compute blake3 hashes + + Returns: + True if scan was started, False if already running + """ + return self.start( + roots=roots, + phase=ScanPhase.ENRICH, + progress_callback=progress_callback, + prune_first=False, + compute_hashes=compute_hashes, + ) + + def cancel(self) -> bool: + """Request cancellation of the current scan. + + Returns: + True if cancellation was requested, False if not running or paused + """ + with self._lock: + if self._state not in (State.RUNNING, State.PAUSED): + return False + logging.info("Asset seeder cancelling (was %s)", self._state.value) + self._state = State.CANCELLING + self._cancel_event.set() + self._run_gate.set() # Unblock if paused so thread can exit + return True + + def stop(self) -> bool: + """Stop the current scan (alias for cancel). + + Returns: + True if stop was requested, False if not running + """ + return self.cancel() + + def pause(self) -> bool: + """Pause the current scan. + + The scan will complete its current batch before pausing. + + Returns: + True if pause was requested, False if not running + """ + with self._lock: + if self._state != State.RUNNING: + return False + logging.info("Asset seeder pausing") + self._state = State.PAUSED + self._run_gate.clear() + return True + + def resume(self) -> bool: + """Resume a paused scan. + + This is a noop if the scan is not in the PAUSED state + + Returns: + True if resumed, False if not paused + """ + with self._lock: + if self._state != State.PAUSED: + return False + logging.info("Asset seeder resuming") + self._state = State.RUNNING + self._run_gate.set() + self._emit_event("assets.seed.resumed", {}) + return True + + def restart( + self, + roots: tuple[RootType, ...] | None = None, + phase: ScanPhase | None = None, + progress_callback: ProgressCallback | None = None, + prune_first: bool | None = None, + compute_hashes: bool | None = None, + timeout: float = 5.0, + ) -> bool: + """Cancel any running scan and start a new one. + + Args: + roots: Roots to scan (defaults to previous roots) + phase: Scan phase (defaults to previous phase) + progress_callback: Progress callback (defaults to previous) + prune_first: Prune before scan (defaults to previous) + compute_hashes: Compute hashes (defaults to previous) + timeout: Max seconds to wait for current scan to stop + + Returns: + True if new scan was started, False if failed to stop previous + """ + logging.info("Asset seeder restart requested") + with self._lock: + prev_roots = self._roots + prev_phase = self._phase + prev_callback = self._progress_callback + prev_prune = self._prune_first + prev_hashes = self._compute_hashes + + self.cancel() + if not self.wait(timeout=timeout): + return False + + cb = progress_callback if progress_callback is not None else prev_callback + return self.start( + roots=roots if roots is not None else prev_roots, + phase=phase if phase is not None else prev_phase, + progress_callback=cb, + prune_first=prune_first if prune_first is not None else prev_prune, + compute_hashes=( + compute_hashes if compute_hashes is not None else prev_hashes + ), + ) + + def wait(self, timeout: float | None = None) -> bool: + """Wait for the current scan to complete. + + Args: + timeout: Maximum seconds to wait, or None for no timeout + + Returns: + True if scan completed, False if timeout expired or no scan running + """ + with self._lock: + thread = self._thread + if thread is None: + return True + thread.join(timeout=timeout) + return not thread.is_alive() + + def get_status(self) -> ScanStatus: + """Get the current status and progress of the seeder.""" + with self._lock: + src = self._progress or self._last_progress + return ScanStatus( + state=self._state, + progress=Progress( + scanned=src.scanned, + total=src.total, + created=src.created, + skipped=src.skipped, + ) + if src + else None, + errors=list(self._errors), + ) + + def shutdown(self, timeout: float = 5.0) -> None: + """Gracefully shutdown: cancel any running scan and wait for thread. + + Args: + timeout: Maximum seconds to wait for thread to exit + """ + self.cancel() + self.wait(timeout=timeout) + with self._lock: + self._thread = None + + def mark_missing_outside_prefixes(self) -> int: + """Mark references as missing when outside all known root prefixes. + + This is a non-destructive soft-delete operation. Assets and their + metadata are preserved, but references are flagged as missing. + They can be restored if the file reappears in a future scan. + + This operation is decoupled from scanning to prevent partial scans + from accidentally marking assets belonging to other roots. + + Should be called explicitly when cleanup is desired, typically after + a full scan of all roots or during maintenance. + + Returns: + Number of references marked as missing + + Raises: + ScanInProgressError: If a scan is currently running + """ + with self._lock: + if self._state != State.IDLE: + raise ScanInProgressError( + "Cannot mark missing assets while scan is running" + ) + self._state = State.RUNNING + + try: + if not dependencies_available(): + logging.warning( + "Database dependencies not available, skipping mark missing" + ) + return 0 + + all_prefixes = get_all_known_prefixes() + marked = mark_missing_outside_prefixes_safely(all_prefixes) + if marked > 0: + logging.info("Marked %d references as missing", marked) + return marked + finally: + with self._lock: + self._last_progress = self._progress + self._state = State.IDLE + self._progress = None + + def _is_cancelled(self) -> bool: + """Check if cancellation has been requested.""" + return self._cancel_event.is_set() + + def _is_paused_or_cancelled(self) -> bool: + """Non-blocking check: True if paused or cancelled. + + Use as interrupt_check for I/O-bound work (e.g. hashing) so that + file handles are released immediately on pause rather than held + open while blocked. The caller is responsible for blocking on + _check_pause_and_cancel() afterward. + """ + return not self._run_gate.is_set() or self._cancel_event.is_set() + + def _check_pause_and_cancel(self) -> bool: + """Block while paused, then check if cancelled. + + Call this at checkpoint locations in scan loops. It will: + 1. Block indefinitely while paused (until resume or cancel) + 2. Return True if cancelled, False to continue + + Returns: + True if scan should stop, False to continue + """ + if not self._run_gate.is_set(): + self._emit_event("assets.seed.paused", {}) + self._run_gate.wait() # Blocks if paused + return self._is_cancelled() + + def _emit_event(self, event_type: str, data: dict) -> None: + """Emit a WebSocket event if server is available.""" + try: + from server import PromptServer + + if hasattr(PromptServer, "instance") and PromptServer.instance: + PromptServer.instance.send_sync(event_type, data) + except Exception: + pass + + def _update_progress( + self, + scanned: int | None = None, + total: int | None = None, + created: int | None = None, + skipped: int | None = None, + ) -> None: + """Update progress counters (thread-safe).""" + callback: ProgressCallback | None = None + progress: Progress | None = None + + with self._lock: + if self._progress is None: + return + if scanned is not None: + self._progress.scanned = scanned + if total is not None: + self._progress.total = total + if created is not None: + self._progress.created = created + if skipped is not None: + self._progress.skipped = skipped + if self._progress_callback: + callback = self._progress_callback + progress = Progress( + scanned=self._progress.scanned, + total=self._progress.total, + created=self._progress.created, + skipped=self._progress.skipped, + ) + + if callback and progress: + try: + callback(progress) + except Exception: + pass + + _MAX_ERRORS = 200 + + def _add_error(self, message: str) -> None: + """Add an error message (thread-safe), capped at _MAX_ERRORS.""" + with self._lock: + if len(self._errors) < self._MAX_ERRORS: + self._errors.append(message) + + def _log_scan_config(self, roots: tuple[RootType, ...]) -> None: + """Log the directories that will be scanned.""" + import folder_paths + + for root in roots: + if root == "models": + logging.info( + "Asset scan [models] directory: %s", + os.path.abspath(folder_paths.models_dir), + ) + else: + prefixes = get_prefixes_for_root(root) + if prefixes: + logging.info("Asset scan [%s] directories: %s", root, prefixes) + + def _run_scan(self) -> None: + """Main scan loop running in background thread.""" + t_start = time.perf_counter() + roots = self._roots + phase = self._phase + cancelled = False + total_created = 0 + total_enriched = 0 + skipped_existing = 0 + total_paths = 0 + + try: + if not dependencies_available(): + self._add_error("Database dependencies not available") + self._emit_event( + "assets.seed.error", + {"message": "Database dependencies not available"}, + ) + return + + if self._prune_first: + all_prefixes = get_all_known_prefixes() + marked = mark_missing_outside_prefixes_safely(all_prefixes) + if marked > 0: + logging.info("Marked %d refs as missing before scan", marked) + + if self._check_pause_and_cancel(): + logging.info("Asset scan cancelled after pruning phase") + cancelled = True + return + + self._log_scan_config(roots) + + # Phase 1: Fast scan (stub records) + if phase in (ScanPhase.FAST, ScanPhase.FULL): + created, skipped, paths = self._run_fast_phase(roots) + total_created, skipped_existing, total_paths = created, skipped, paths + + if self._check_pause_and_cancel(): + cancelled = True + return + + self._emit_event( + "assets.seed.fast_complete", + { + "roots": list(roots), + "created": total_created, + "skipped": skipped_existing, + "total": total_paths, + }, + ) + + # Phase 2: Enrichment scan (metadata + hashes) + if phase in (ScanPhase.ENRICH, ScanPhase.FULL): + if self._check_pause_and_cancel(): + cancelled = True + return + + enrich_cancelled, total_enriched = self._run_enrich_phase(roots) + + if enrich_cancelled: + cancelled = True + return + + self._emit_event( + "assets.seed.enrich_complete", + { + "roots": list(roots), + "enriched": total_enriched, + }, + ) + + elapsed = time.perf_counter() - t_start + logging.info( + "Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d", + roots, + phase.value, + elapsed, + total_created, + total_enriched, + skipped_existing, + ) + + self._emit_event( + "assets.seed.completed", + { + "phase": phase.value, + "total": total_paths, + "created": total_created, + "enriched": total_enriched, + "skipped": skipped_existing, + "elapsed": round(elapsed, 3), + }, + ) + + except Exception as e: + self._add_error(f"Scan failed: {e}") + logging.exception("Asset scan failed") + self._emit_event("assets.seed.error", {"message": str(e)}) + finally: + if cancelled: + self._emit_event( + "assets.seed.cancelled", + { + "scanned": self._progress.scanned if self._progress else 0, + "total": total_paths, + "created": total_created, + }, + ) + with self._lock: + self._last_progress = self._progress + self._state = State.IDLE + self._progress = None + + def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]: + """Run phase 1: fast scan to create stub records. + + Returns: + Tuple of (total_created, skipped_existing, total_paths) + """ + t_fast_start = time.perf_counter() + total_created = 0 + skipped_existing = 0 + + existing_paths: set[str] = set() + t_sync = time.perf_counter() + for r in roots: + if self._check_pause_and_cancel(): + return total_created, skipped_existing, 0 + existing_paths.update(sync_root_safely(r)) + logging.debug( + "Fast scan: sync_root phase took %.3fs (%d existing paths)", + time.perf_counter() - t_sync, + len(existing_paths), + ) + + if self._check_pause_and_cancel(): + return total_created, skipped_existing, 0 + + t_collect = time.perf_counter() + paths = collect_paths_for_roots(roots) + logging.debug( + "Fast scan: collect_paths took %.3fs (%d paths found)", + time.perf_counter() - t_collect, + len(paths), + ) + total_paths = len(paths) + self._update_progress(total=total_paths) + + self._emit_event( + "assets.seed.started", + {"roots": list(roots), "total": total_paths, "phase": "fast"}, + ) + + # Use stub specs (no metadata extraction, no hashing) + t_specs = time.perf_counter() + specs, tag_pool, skipped_existing = build_asset_specs( + paths, + existing_paths, + enable_metadata_extraction=False, + compute_hashes=False, + ) + logging.debug( + "Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)", + time.perf_counter() - t_specs, + len(specs), + skipped_existing, + ) + self._update_progress(skipped=skipped_existing) + + if self._check_pause_and_cancel(): + return total_created, skipped_existing, total_paths + + batch_size = 500 + last_progress_time = time.perf_counter() + progress_interval = 1.0 + + for i in range(0, len(specs), batch_size): + if self._check_pause_and_cancel(): + logging.info( + "Fast scan cancelled after %d/%d files (created=%d)", + i, + len(specs), + total_created, + ) + return total_created, skipped_existing, total_paths + + batch = specs[i : i + batch_size] + batch_tags = {t for spec in batch for t in spec["tags"]} + try: + created = insert_asset_specs(batch, batch_tags) + total_created += created + except Exception as e: + self._add_error(f"Batch insert failed at offset {i}: {e}") + logging.exception("Batch insert failed at offset %d", i) + + scanned = i + len(batch) + now = time.perf_counter() + self._update_progress(scanned=scanned, created=total_created) + + if now - last_progress_time >= progress_interval: + self._emit_event( + "assets.seed.progress", + { + "phase": "fast", + "scanned": scanned, + "total": len(specs), + "created": total_created, + }, + ) + last_progress_time = now + + self._update_progress(scanned=len(specs), created=total_created) + logging.info( + "Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)", + time.perf_counter() - t_fast_start, + total_created, + skipped_existing, + total_paths, + ) + return total_created, skipped_existing, total_paths + + def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]: + """Run phase 2: enrich existing records with metadata and hashes. + + Returns: + Tuple of (cancelled, total_enriched) + """ + total_enriched = 0 + batch_size = 100 + last_progress_time = time.perf_counter() + progress_interval = 1.0 + + # Get the target enrichment level based on compute_hashes + if not self._compute_hashes: + target_max_level = ENRICHMENT_STUB + else: + target_max_level = ENRICHMENT_METADATA + + self._emit_event( + "assets.seed.started", + {"roots": list(roots), "phase": "enrich"}, + ) + + skip_ids: set[str] = set() + consecutive_empty = 0 + max_consecutive_empty = 3 + + # Hash checkpoints survive across batches so interrupted hashes + # can be resumed without re-reading the entire file. + hash_checkpoints: dict[str, object] = {} + + while True: + if self._check_pause_and_cancel(): + logging.info("Enrich scan cancelled after %d assets", total_enriched) + return True, total_enriched + + # Fetch next batch of unenriched assets + unenriched = get_unenriched_assets_for_roots( + roots, + max_level=target_max_level, + limit=batch_size, + ) + + # Filter out previously failed references + if skip_ids: + unenriched = [r for r in unenriched if r.reference_id not in skip_ids] + + if not unenriched: + break + + enriched, failed_ids = enrich_assets_batch( + unenriched, + extract_metadata=True, + compute_hash=self._compute_hashes, + interrupt_check=self._is_paused_or_cancelled, + hash_checkpoints=hash_checkpoints, + ) + total_enriched += enriched + skip_ids.update(failed_ids) + + if enriched == 0: + consecutive_empty += 1 + if consecutive_empty >= max_consecutive_empty: + logging.warning( + "Enrich phase stopping: %d consecutive batches with no progress (%d skipped)", + consecutive_empty, + len(skip_ids), + ) + break + else: + consecutive_empty = 0 + + now = time.perf_counter() + if now - last_progress_time >= progress_interval: + self._emit_event( + "assets.seed.progress", + { + "phase": "enrich", + "enriched": total_enriched, + }, + ) + last_progress_time = now + + return False, total_enriched + + +asset_seeder = _AssetSeeder() diff --git a/app/assets/services/__init__.py b/app/assets/services/__init__.py new file mode 100644 index 000000000..11fcb4122 --- /dev/null +++ b/app/assets/services/__init__.py @@ -0,0 +1,87 @@ +from app.assets.services.asset_management import ( + asset_exists, + delete_asset_reference, + get_asset_by_hash, + get_asset_detail, + list_assets_page, + resolve_asset_for_download, + set_asset_preview, + update_asset_metadata, +) +from app.assets.services.bulk_ingest import ( + BulkInsertResult, + batch_insert_seed_assets, + cleanup_unreferenced_assets, +) +from app.assets.services.file_utils import ( + get_mtime_ns, + get_size_and_mtime_ns, + list_files_recursively, + verify_file_unchanged, +) +from app.assets.services.ingest import ( + DependencyMissingError, + HashMismatchError, + create_from_hash, + upload_from_temp_path, +) +from app.assets.database.queries import ( + AddTagsResult, + RemoveTagsResult, +) +from app.assets.services.schemas import ( + AssetData, + AssetDetailResult, + AssetSummaryData, + DownloadResolutionResult, + IngestResult, + ListAssetsResult, + ReferenceData, + RegisterAssetResult, + TagUsage, + UploadResult, + UserMetadata, +) +from app.assets.services.tagging import ( + apply_tags, + list_tags, + remove_tags, +) + +__all__ = [ + "AddTagsResult", + "AssetData", + "AssetDetailResult", + "AssetSummaryData", + "ReferenceData", + "BulkInsertResult", + "DependencyMissingError", + "DownloadResolutionResult", + "HashMismatchError", + "IngestResult", + "ListAssetsResult", + "RegisterAssetResult", + "RemoveTagsResult", + "TagUsage", + "UploadResult", + "UserMetadata", + "apply_tags", + "asset_exists", + "batch_insert_seed_assets", + "create_from_hash", + "delete_asset_reference", + "get_asset_by_hash", + "get_asset_detail", + "get_mtime_ns", + "get_size_and_mtime_ns", + "list_assets_page", + "list_files_recursively", + "list_tags", + "cleanup_unreferenced_assets", + "remove_tags", + "resolve_asset_for_download", + "set_asset_preview", + "update_asset_metadata", + "upload_from_temp_path", + "verify_file_unchanged", +] diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py new file mode 100644 index 000000000..3fe7115c8 --- /dev/null +++ b/app/assets/services/asset_management.py @@ -0,0 +1,309 @@ +import contextlib +import mimetypes +import os +from typing import Sequence + + +from app.assets.database.models import Asset +from app.assets.database.queries import ( + asset_exists_by_hash, + reference_exists_for_asset_id, + delete_reference_by_id, + fetch_reference_and_asset, + soft_delete_reference_by_id, + fetch_reference_asset_and_tags, + get_asset_by_hash as queries_get_asset_by_hash, + get_reference_by_id, + get_reference_with_owner_check, + list_references_page, + list_references_by_asset_id, + set_reference_metadata, + set_reference_preview, + set_reference_tags, + update_reference_access_time, + update_reference_name, + update_reference_updated_at, +) +from app.assets.helpers import select_best_live_path +from app.assets.services.path_utils import compute_relative_filename +from app.assets.services.schemas import ( + AssetData, + AssetDetailResult, + AssetSummaryData, + DownloadResolutionResult, + ListAssetsResult, + UserMetadata, + extract_asset_data, + extract_reference_data, +) +from app.database.db import create_session + + +def get_asset_detail( + reference_id: str, + owner_id: str = "", +) -> AssetDetailResult | None: + with create_session() as session: + result = fetch_reference_asset_and_tags( + session, + reference_id=reference_id, + owner_id=owner_id, + ) + if not result: + return None + + ref, asset, tags = result + return AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tags, + ) + + +def update_asset_metadata( + reference_id: str, + name: str | None = None, + tags: Sequence[str] | None = None, + user_metadata: UserMetadata = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> AssetDetailResult: + with create_session() as session: + ref = get_reference_with_owner_check(session, reference_id, owner_id) + + touched = False + if name is not None and name != ref.name: + update_reference_name(session, reference_id=reference_id, name=name) + touched = True + + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None + + new_meta: dict | None = None + if user_metadata is not None: + new_meta = dict(user_metadata) + elif computed_filename: + current_meta = ref.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + + if new_meta is not None: + if computed_filename: + new_meta["filename"] = computed_filename + set_reference_metadata( + session, reference_id=reference_id, user_metadata=new_meta + ) + touched = True + + if tags is not None: + set_reference_tags( + session, + reference_id=reference_id, + tags=tags, + origin=tag_origin, + ) + touched = True + + if touched and user_metadata is None: + update_reference_updated_at(session, reference_id=reference_id) + + result = fetch_reference_asset_and_tags( + session, + reference_id=reference_id, + owner_id=owner_id, + ) + if not result: + raise RuntimeError("State changed during update") + + ref, asset, tag_list = result + detail = AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_list, + ) + session.commit() + + return detail + + +def delete_asset_reference( + reference_id: str, + owner_id: str, + delete_content_if_orphan: bool = True, +) -> bool: + with create_session() as session: + if not delete_content_if_orphan: + # Soft delete: mark the reference as deleted but keep everything + deleted = soft_delete_reference_by_id( + session, reference_id=reference_id, owner_id=owner_id + ) + session.commit() + return deleted + + ref_row = get_reference_by_id(session, reference_id=reference_id) + asset_id = ref_row.asset_id if ref_row else None + file_path = ref_row.file_path if ref_row else None + + deleted = delete_reference_by_id( + session, reference_id=reference_id, owner_id=owner_id + ) + if not deleted: + session.commit() + return False + + if not asset_id: + session.commit() + return True + + still_exists = reference_exists_for_asset_id(session, asset_id=asset_id) + if still_exists: + session.commit() + return True + + # Orphaned asset - delete it and its files + refs = list_references_by_asset_id(session, asset_id=asset_id) + file_paths = [ + r.file_path for r in (refs or []) if getattr(r, "file_path", None) + ] + # Also include the just-deleted file path + if file_path: + file_paths.append(file_path) + + asset_row = session.get(Asset, asset_id) + if asset_row is not None: + session.delete(asset_row) + + session.commit() + + # Delete files after commit + for p in file_paths: + with contextlib.suppress(Exception): + if p and os.path.isfile(p): + os.remove(p) + + return True + + +def set_asset_preview( + reference_id: str, + preview_asset_id: str | None = None, + owner_id: str = "", +) -> AssetDetailResult: + with create_session() as session: + get_reference_with_owner_check(session, reference_id, owner_id) + + set_reference_preview( + session, + reference_id=reference_id, + preview_asset_id=preview_asset_id, + ) + + result = fetch_reference_asset_and_tags( + session, reference_id=reference_id, owner_id=owner_id + ) + if not result: + raise RuntimeError("State changed during preview update") + + ref, asset, tags = result + detail = AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tags, + ) + session.commit() + + return detail + + +def asset_exists(asset_hash: str) -> bool: + with create_session() as session: + return asset_exists_by_hash(session, asset_hash=asset_hash) + + +def get_asset_by_hash(asset_hash: str) -> AssetData | None: + with create_session() as session: + asset = queries_get_asset_by_hash(session, asset_hash=asset_hash) + return extract_asset_data(asset) + + +def list_assets_page( + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> ListAssetsResult: + with create_session() as session: + refs, tag_map, total = list_references_page( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + + items: list[AssetSummaryData] = [] + for ref in refs: + items.append( + AssetSummaryData( + ref=extract_reference_data(ref), + asset=extract_asset_data(ref.asset), + tags=tag_map.get(ref.id, []), + ) + ) + + return ListAssetsResult(items=items, total=total) + + +def resolve_asset_for_download( + reference_id: str, + owner_id: str = "", +) -> DownloadResolutionResult: + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise ValueError(f"AssetReference {reference_id} not found") + + ref, asset = pair + + # For references with file_path, use that directly + if ref.file_path and os.path.isfile(ref.file_path): + abs_path = ref.file_path + else: + # For API-created refs without file_path, find a path from other refs + refs = list_references_by_asset_id(session, asset_id=asset.id) + abs_path = select_best_live_path(refs) + if not abs_path: + raise FileNotFoundError( + f"No live path for AssetReference {reference_id} " + f"(asset id={asset.id}, name={ref.name})" + ) + + # Capture ORM attributes before commit (commit expires loaded objects) + ref_name = ref.name + asset_mime = asset.mime_type + + update_reference_access_time(session, reference_id=reference_id) + session.commit() + + ctype = ( + asset_mime + or mimetypes.guess_type(ref_name or abs_path)[0] + or "application/octet-stream" + ) + download_name = ref_name or os.path.basename(abs_path) + return DownloadResolutionResult( + abs_path=abs_path, + content_type=ctype, + download_name=download_name, + ) diff --git a/app/assets/services/bulk_ingest.py b/app/assets/services/bulk_ingest.py new file mode 100644 index 000000000..54e72730c --- /dev/null +++ b/app/assets/services/bulk_ingest.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import os +import uuid +from dataclasses import dataclass +from datetime import datetime +from typing import TYPE_CHECKING, Any, TypedDict + +from sqlalchemy.orm import Session + +from app.assets.database.queries import ( + bulk_insert_assets, + bulk_insert_references_ignore_conflicts, + bulk_insert_tags_and_meta, + delete_assets_by_ids, + get_existing_asset_ids, + get_reference_ids_by_ids, + get_references_by_paths_and_asset_ids, + get_unreferenced_unhashed_asset_ids, + restore_references_by_paths, +) +from app.assets.helpers import get_utc_now + +if TYPE_CHECKING: + from app.assets.services.metadata_extract import ExtractedMetadata + + +class SeedAssetSpec(TypedDict): + """Spec for seeding an asset from filesystem.""" + + abs_path: str + size_bytes: int + mtime_ns: int + info_name: str + tags: list[str] + fname: str + metadata: ExtractedMetadata | None + hash: str | None + mime_type: str | None + + +class AssetRow(TypedDict): + """Row data for inserting an Asset.""" + + id: str + hash: str | None + size_bytes: int + mime_type: str | None + created_at: datetime + + +class ReferenceRow(TypedDict): + """Row data for inserting an AssetReference.""" + + id: str + asset_id: str + file_path: str + mtime_ns: int + owner_id: str + name: str + preview_id: str | None + user_metadata: dict[str, Any] | None + created_at: datetime + updated_at: datetime + last_access_time: datetime + + +class TagRow(TypedDict): + """Row data for inserting a Tag.""" + + asset_reference_id: str + tag_name: str + origin: str + added_at: datetime + + +class MetadataRow(TypedDict): + """Row data for inserting asset metadata.""" + + asset_reference_id: str + key: str + ordinal: int + val_str: str | None + val_num: float | None + val_bool: bool | None + val_json: dict[str, Any] | None + + +@dataclass +class BulkInsertResult: + """Result of bulk asset insertion.""" + + inserted_refs: int + won_paths: int + lost_paths: int + + +def batch_insert_seed_assets( + session: Session, + specs: list[SeedAssetSpec], + owner_id: str = "", +) -> BulkInsertResult: + """Seed assets from filesystem specs in batch. + + Each spec is a dict with keys: + - abs_path: str + - size_bytes: int + - mtime_ns: int + - info_name: str + - tags: list[str] + - fname: Optional[str] + + This function orchestrates: + 1. Insert seed Assets (hash=NULL) + 2. Claim references with ON CONFLICT DO NOTHING on file_path + 3. Query to find winners (paths where our asset_id was inserted) + 4. Delete Assets for losers (path already claimed by another asset) + 5. Insert tags and metadata for successfully inserted references + + Returns: + BulkInsertResult with inserted_refs, won_paths, lost_paths + """ + if not specs: + return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0) + + current_time = get_utc_now() + asset_rows: list[AssetRow] = [] + reference_rows: list[ReferenceRow] = [] + path_to_asset_id: dict[str, str] = {} + asset_id_to_ref_data: dict[str, dict] = {} + absolute_path_list: list[str] = [] + + for spec in specs: + absolute_path = os.path.abspath(spec["abs_path"]) + asset_id = str(uuid.uuid4()) + reference_id = str(uuid.uuid4()) + absolute_path_list.append(absolute_path) + path_to_asset_id[absolute_path] = asset_id + + mime_type = spec.get("mime_type") + asset_rows.append( + { + "id": asset_id, + "hash": spec.get("hash"), + "size_bytes": spec["size_bytes"], + "mime_type": mime_type, + "created_at": current_time, + } + ) + + # Build user_metadata from extracted metadata or fallback to filename + extracted_metadata = spec.get("metadata") + if extracted_metadata: + user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata() + elif spec["fname"]: + user_metadata = {"filename": spec["fname"]} + else: + user_metadata = None + + reference_rows.append( + { + "id": reference_id, + "asset_id": asset_id, + "file_path": absolute_path, + "mtime_ns": spec["mtime_ns"], + "owner_id": owner_id, + "name": spec["info_name"], + "preview_id": None, + "user_metadata": user_metadata, + "created_at": current_time, + "updated_at": current_time, + "last_access_time": current_time, + } + ) + + asset_id_to_ref_data[asset_id] = { + "reference_id": reference_id, + "tags": spec["tags"], + "filename": spec["fname"], + "extracted_metadata": extracted_metadata, + } + + bulk_insert_assets(session, asset_rows) + + # Filter reference rows to only those whose assets were actually inserted + # (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING) + inserted_asset_ids = get_existing_asset_ids( + session, [r["asset_id"] for r in reference_rows] + ) + reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids] + + bulk_insert_references_ignore_conflicts(session, reference_rows) + restore_references_by_paths(session, absolute_path_list) + winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id) + + inserted_paths = { + path + for path in absolute_path_list + if path_to_asset_id[path] in inserted_asset_ids + } + losing_paths = inserted_paths - winning_paths + lost_asset_ids = [path_to_asset_id[path] for path in losing_paths] + + if lost_asset_ids: + delete_assets_by_ids(session, lost_asset_ids) + + if not winning_paths: + return BulkInsertResult( + inserted_refs=0, + won_paths=0, + lost_paths=len(losing_paths), + ) + + # Get reference IDs for winners + winning_ref_ids = [ + asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"] + for path in winning_paths + ] + inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids) + + tag_rows: list[TagRow] = [] + metadata_rows: list[MetadataRow] = [] + + if inserted_ref_ids: + for path in winning_paths: + asset_id = path_to_asset_id[path] + ref_data = asset_id_to_ref_data[asset_id] + ref_id = ref_data["reference_id"] + + if ref_id not in inserted_ref_ids: + continue + + for tag in ref_data["tags"]: + tag_rows.append( + { + "asset_reference_id": ref_id, + "tag_name": tag, + "origin": "automatic", + "added_at": current_time, + } + ) + + # Use extracted metadata for meta rows if available + extracted_metadata = ref_data.get("extracted_metadata") + if extracted_metadata: + metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id)) + elif ref_data["filename"]: + # Fallback: just store filename + metadata_rows.append( + { + "asset_reference_id": ref_id, + "key": "filename", + "ordinal": 0, + "val_str": ref_data["filename"], + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) + + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows) + + return BulkInsertResult( + inserted_refs=len(inserted_ref_ids), + won_paths=len(winning_paths), + lost_paths=len(losing_paths), + ) + + +def cleanup_unreferenced_assets(session: Session) -> int: + """Hard-delete unhashed assets with no active references. + + This is a destructive operation intended for explicit cleanup. + Only deletes assets where hash=None and all references are missing. + + Returns: + Number of assets deleted + """ + unreferenced_ids = get_unreferenced_unhashed_asset_ids(session) + return delete_assets_by_ids(session, unreferenced_ids) diff --git a/app/assets/services/file_utils.py b/app/assets/services/file_utils.py new file mode 100644 index 000000000..c47ebe460 --- /dev/null +++ b/app/assets/services/file_utils.py @@ -0,0 +1,70 @@ +import os + + +def get_mtime_ns(stat_result: os.stat_result) -> int: + """Extract mtime in nanoseconds from a stat result.""" + return getattr( + stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000) + ) + + +def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]: + """Get file size in bytes and mtime in nanoseconds.""" + st = os.stat(path, follow_symlinks=follow_symlinks) + return st.st_size, get_mtime_ns(st) + + +def verify_file_unchanged( + mtime_db: int | None, + size_db: int | None, + stat_result: os.stat_result, +) -> bool: + """Check if a file is unchanged based on mtime and size. + + Returns True if the file's mtime and size match the database values. + Returns False if mtime_db is None or values don't match. + + size_db=None means don't check size; 0 is a valid recorded size. + """ + if mtime_db is None: + return False + actual_mtime_ns = get_mtime_ns(stat_result) + if int(mtime_db) != int(actual_mtime_ns): + return False + if size_db is not None: + return int(stat_result.st_size) == int(size_db) + return True + + +def is_visible(name: str) -> bool: + """Return True if a file or directory name is visible (not hidden).""" + return not name.startswith(".") + + +def list_files_recursively(base_dir: str) -> list[str]: + """Recursively list all files in a directory, following symlinks.""" + out: list[str] = [] + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + return out + # Track seen real directory identities to prevent circular symlink loops + seen_dirs: set[tuple[int, int]] = set() + for dirpath, subdirs, filenames in os.walk( + base_abs, topdown=True, followlinks=True + ): + try: + st = os.stat(dirpath) + dir_id = (st.st_dev, st.st_ino) + except OSError: + subdirs.clear() + continue + if dir_id in seen_dirs: + subdirs.clear() + continue + seen_dirs.add(dir_id) + subdirs[:] = [d for d in subdirs if is_visible(d)] + for name in filenames: + if not is_visible(name): + continue + out.append(os.path.abspath(os.path.join(dirpath, name))) + return out diff --git a/app/assets/services/hashing.py b/app/assets/services/hashing.py new file mode 100644 index 000000000..92aee6402 --- /dev/null +++ b/app/assets/services/hashing.py @@ -0,0 +1,95 @@ +import io +import os +from contextlib import contextmanager +from dataclasses import dataclass +from typing import IO, Any, Callable, Iterator + +from blake3 import blake3 + +DEFAULT_CHUNK = 8 * 1024 * 1024 + +InterruptCheck = Callable[[], bool] + + +@dataclass +class HashCheckpoint: + """Saved state for resuming an interrupted hash computation.""" + + bytes_processed: int + hasher: Any # blake3 hasher instance + mtime_ns: int = 0 + file_size: int = 0 + + +@contextmanager +def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]: + """Yield (file_object, is_path) with appropriate setup/teardown.""" + if hasattr(fp, "read"): + seekable = getattr(fp, "seekable", lambda: False)() + orig_pos = None + if seekable: + try: + orig_pos = fp.tell() + if orig_pos != 0: + fp.seek(0) + except io.UnsupportedOperation: + orig_pos = None + try: + yield fp, False + finally: + if orig_pos is not None: + fp.seek(orig_pos) + else: + with open(os.fspath(fp), "rb") as f: + yield f, True + + +def compute_blake3_hash( + fp: str | IO[bytes], + chunk_size: int = DEFAULT_CHUNK, + interrupt_check: InterruptCheck | None = None, + checkpoint: HashCheckpoint | None = None, +) -> tuple[str | None, HashCheckpoint | None]: + """Compute BLAKE3 hash of a file, with optional checkpoint support. + + Args: + fp: File path or file-like object + chunk_size: Size of chunks to read at a time + interrupt_check: Optional callable that returns True if the operation + should be interrupted (e.g. paused or cancelled). Must be + non-blocking so file handles are released immediately. Checked + between chunk reads. + checkpoint: Optional checkpoint to resume from (file paths only) + + Returns: + Tuple of (hex_digest, None) on completion, or + (None, checkpoint) on interruption (file paths only), or + (None, None) on interruption of a file object + """ + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK + + with _open_for_hashing(fp) as (f, is_path): + if checkpoint is not None and is_path: + f.seek(checkpoint.bytes_processed) + h = checkpoint.hasher + bytes_processed = checkpoint.bytes_processed + else: + h = blake3() + bytes_processed = 0 + + while True: + if interrupt_check is not None and interrupt_check(): + if is_path: + return None, HashCheckpoint( + bytes_processed=bytes_processed, + hasher=h, + ) + return None, None + chunk = f.read(chunk_size) + if not chunk: + break + h.update(chunk) + bytes_processed += len(chunk) + + return h.hexdigest(), None diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py new file mode 100644 index 000000000..44d7aef36 --- /dev/null +++ b/app/assets/services/ingest.py @@ -0,0 +1,375 @@ +import contextlib +import logging +import mimetypes +import os +from typing import Any, Sequence + +from sqlalchemy.orm import Session + +import app.assets.services.hashing as hashing +from app.assets.database.queries import ( + add_tags_to_reference, + fetch_reference_and_asset, + get_asset_by_hash, + get_existing_asset_ids, + get_reference_by_file_path, + get_reference_tags, + get_or_create_reference, + remove_missing_tag_for_asset_id, + set_reference_metadata, + set_reference_tags, + upsert_asset, + upsert_reference, + validate_tags_exist, +) +from app.assets.helpers import normalize_tags +from app.assets.services.file_utils import get_size_and_mtime_ns +from app.assets.services.path_utils import ( + compute_relative_filename, + resolve_destination_from_tags, + validate_path_within_base, +) +from app.assets.services.schemas import ( + IngestResult, + RegisterAssetResult, + UploadResult, + UserMetadata, + extract_asset_data, + extract_reference_data, +) +from app.database.db import create_session + + +def _ingest_file_from_path( + abs_path: str, + asset_hash: str, + size_bytes: int, + mtime_ns: int, + mime_type: str | None = None, + info_name: str | None = None, + owner_id: str = "", + preview_id: str | None = None, + user_metadata: UserMetadata = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + require_existing_tags: bool = False, +) -> IngestResult: + locator = os.path.abspath(abs_path) + user_metadata = user_metadata or {} + + asset_created = False + asset_updated = False + ref_created = False + ref_updated = False + reference_id: str | None = None + + with create_session() as session: + if preview_id: + if preview_id not in get_existing_asset_ids(session, [preview_id]): + preview_id = None + + asset, asset_created, asset_updated = upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=size_bytes, + mime_type=mime_type, + ) + + ref_created, ref_updated = upsert_reference( + session, + asset_id=asset.id, + file_path=locator, + name=info_name or os.path.basename(locator), + mtime_ns=mtime_ns, + owner_id=owner_id, + ) + + # Get the reference we just created/updated + ref = get_reference_by_file_path(session, locator) + if ref: + reference_id = ref.id + + if preview_id and ref.preview_id != preview_id: + ref.preview_id = preview_id + + norm = normalize_tags(list(tags)) + if norm: + if require_existing_tags: + validate_tags_exist(session, norm) + add_tags_to_reference( + session, + reference_id=reference_id, + tags=norm, + origin=tag_origin, + create_if_missing=not require_existing_tags, + ) + + _update_metadata_with_filename( + session, + reference_id=reference_id, + file_path=ref.file_path, + current_metadata=ref.user_metadata, + user_metadata=user_metadata, + ) + + try: + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + except Exception: + logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) + + session.commit() + + return IngestResult( + asset_created=asset_created, + asset_updated=asset_updated, + ref_created=ref_created, + ref_updated=ref_updated, + reference_id=reference_id, + ) + + +def _register_existing_asset( + asset_hash: str, + name: str, + user_metadata: UserMetadata = None, + tags: list[str] | None = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> RegisterAssetResult: + user_metadata = user_metadata or {} + + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=asset_hash) + if not asset: + raise ValueError(f"No asset with hash {asset_hash}") + + ref, ref_created = get_or_create_reference( + session, + asset_id=asset.id, + owner_id=owner_id, + name=name, + ) + + if not ref_created: + tag_names = get_reference_tags(session, reference_id=ref.id) + result = RegisterAssetResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created=False, + ) + session.commit() + return result + + new_meta = dict(user_metadata) + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta: + set_reference_metadata( + session, + reference_id=ref.id, + user_metadata=new_meta, + ) + + if tags is not None: + set_reference_tags( + session, + reference_id=ref.id, + tags=tags, + origin=tag_origin, + ) + + tag_names = get_reference_tags(session, reference_id=ref.id) + session.refresh(ref) + result = RegisterAssetResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created=True, + ) + session.commit() + + return result + + + +def _update_metadata_with_filename( + session: Session, + reference_id: str, + file_path: str | None, + current_metadata: dict | None, + user_metadata: dict[str, Any], +) -> None: + computed_filename = compute_relative_filename(file_path) if file_path else None + + current_meta = current_metadata or {} + new_meta = dict(current_meta) + for k, v in user_metadata.items(): + new_meta[k] = v + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta != current_meta: + set_reference_metadata( + session, + reference_id=reference_id, + user_metadata=new_meta, + ) + + +def _sanitize_filename(name: str | None, fallback: str) -> str: + n = os.path.basename((name or "").strip() or fallback) + return n if n else fallback + + +class HashMismatchError(Exception): + pass + + +class DependencyMissingError(Exception): + def __init__(self, message: str): + self.message = message + super().__init__(message) + + +def upload_from_temp_path( + temp_path: str, + name: str | None = None, + tags: list[str] | None = None, + user_metadata: dict | None = None, + client_filename: str | None = None, + owner_id: str = "", + expected_hash: str | None = None, +) -> UploadResult: + try: + digest, _ = hashing.compute_blake3_hash(temp_path) + except ImportError as e: + raise DependencyMissingError(str(e)) + except Exception as e: + raise RuntimeError(f"failed to hash uploaded file: {e}") + asset_hash = "blake3:" + digest + + if expected_hash and asset_hash != expected_hash.strip().lower(): + raise HashMismatchError("Uploaded file hash does not match provided hash.") + + with create_session() as session: + existing = get_asset_by_hash(session, asset_hash=asset_hash) + + if existing is not None: + with contextlib.suppress(Exception): + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + display_name = _sanitize_filename(name or client_filename, fallback=digest) + result = _register_existing_asset( + asset_hash=asset_hash, + name=display_name, + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + return UploadResult( + ref=result.ref, + asset=result.asset, + tags=result.tags, + created_new=False, + ) + + if not tags: + raise ValueError("tags are required for new asset uploads") + base_dir, subdirs = resolve_destination_from_tags(tags) + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + src_for_ext = (client_filename or name or "").strip() + _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" + ext = _ext if 0 < len(_ext) <= 16 else "" + hashed_basename = f"{digest}{ext}" + dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) + validate_path_within_base(dest_abs, base_dir) + + content_type = ( + mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] + or mimetypes.guess_type(hashed_basename, strict=False)[0] + or "application/octet-stream" + ) + + try: + os.replace(temp_path, dest_abs) + except Exception as e: + raise RuntimeError(f"failed to move uploaded file into place: {e}") + + try: + size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs) + except OSError as e: + raise RuntimeError(f"failed to stat destination file: {e}") + + ingest_result = _ingest_file_from_path( + asset_hash=asset_hash, + abs_path=dest_abs, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=_sanitize_filename(name or client_filename, fallback=digest), + owner_id=owner_id, + preview_id=None, + user_metadata=user_metadata or {}, + tags=tags, + tag_origin="manual", + require_existing_tags=False, + ) + reference_id = ingest_result.reference_id + if not reference_id: + raise RuntimeError("failed to create asset reference") + + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + ref, asset = pair + tag_names = get_reference_tags(session, reference_id=ref.id) + + return UploadResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created_new=ingest_result.asset_created, + ) + + +def create_from_hash( + hash_str: str, + name: str, + tags: list[str] | None = None, + user_metadata: dict | None = None, + owner_id: str = "", +) -> UploadResult | None: + canonical = hash_str.strip().lower() + + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=canonical) + if not asset: + return None + + result = _register_existing_asset( + asset_hash=canonical, + name=_sanitize_filename( + name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical + ), + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + + return UploadResult( + ref=result.ref, + asset=result.asset, + tags=result.tags, + created_new=False, + ) diff --git a/app/assets/services/metadata_extract.py b/app/assets/services/metadata_extract.py new file mode 100644 index 000000000..a004929bc --- /dev/null +++ b/app/assets/services/metadata_extract.py @@ -0,0 +1,327 @@ +"""Metadata extraction for asset scanning. + +Tier 1: Filesystem metadata (zero parsing) +Tier 2: Safetensors header metadata (fast JSON read only) +""" + +from __future__ import annotations + +import json +import logging +import mimetypes +import os +import struct +from dataclasses import dataclass +from typing import Any + +from utils.mime_types import init_mime_types + +init_mime_types() + +# Supported safetensors extensions +SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"}) + +# Maximum safetensors header size to read (8MB) +MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024 + + +@dataclass +class ExtractedMetadata: + """Metadata extracted from a file during scanning.""" + + # Tier 1: Filesystem (always available) + filename: str = "" + file_path: str = "" # Full absolute path to the file + content_length: int = 0 + content_type: str | None = None + format: str = "" # file extension without dot + + # Tier 2: Safetensors header (if available) + base_model: str | None = None + trained_words: list[str] | None = None + air: str | None = None # CivitAI AIR identifier + has_preview_images: bool = False + + # Source provenance (populated if embedded in safetensors) + source_url: str | None = None + source_arn: str | None = None + repo_url: str | None = None + preview_url: str | None = None + source_hash: str | None = None + + # HuggingFace specific + repo_id: str | None = None + revision: str | None = None + filepath: str | None = None + resolve_url: str | None = None + + def to_user_metadata(self) -> dict[str, Any]: + """Convert to user_metadata dict for AssetReference.user_metadata JSON field.""" + data: dict[str, Any] = { + "filename": self.filename, + "content_length": self.content_length, + "format": self.format, + } + if self.file_path: + data["file_path"] = self.file_path + if self.content_type: + data["content_type"] = self.content_type + + # Tier 2 fields + if self.base_model: + data["base_model"] = self.base_model + if self.trained_words: + data["trained_words"] = self.trained_words + if self.air: + data["air"] = self.air + if self.has_preview_images: + data["has_preview_images"] = True + + # Source provenance + if self.source_url: + data["source_url"] = self.source_url + if self.source_arn: + data["source_arn"] = self.source_arn + if self.repo_url: + data["repo_url"] = self.repo_url + if self.preview_url: + data["preview_url"] = self.preview_url + if self.source_hash: + data["source_hash"] = self.source_hash + + # HuggingFace + if self.repo_id: + data["repo_id"] = self.repo_id + if self.revision: + data["revision"] = self.revision + if self.filepath: + data["filepath"] = self.filepath + if self.resolve_url: + data["resolve_url"] = self.resolve_url + + return data + + def to_meta_rows(self, reference_id: str) -> list[dict]: + """Convert to asset_reference_meta rows for typed/indexed querying.""" + rows: list[dict] = [] + + def add_str(key: str, val: str | None, ordinal: int = 0) -> None: + if val: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": ordinal, + "val_str": val[:2048] if len(val) > 2048 else val, + "val_num": None, + "val_bool": None, + "val_json": None, + }) + + def add_num(key: str, val: int | float | None) -> None: + if val is not None: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": 0, + "val_str": None, + "val_num": val, + "val_bool": None, + "val_json": None, + }) + + def add_bool(key: str, val: bool | None) -> None: + if val is not None: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": 0, + "val_str": None, + "val_num": None, + "val_bool": val, + "val_json": None, + }) + + # Tier 1 + add_str("filename", self.filename) + add_num("content_length", self.content_length) + add_str("content_type", self.content_type) + add_str("format", self.format) + + # Tier 2 + add_str("base_model", self.base_model) + add_str("air", self.air) + has_previews = self.has_preview_images if self.has_preview_images else None + add_bool("has_preview_images", has_previews) + + # trained_words as multiple rows with ordinals + if self.trained_words: + for i, word in enumerate(self.trained_words[:100]): # limit to 100 words + add_str("trained_words", word, ordinal=i) + + # Source provenance + add_str("source_url", self.source_url) + add_str("source_arn", self.source_arn) + add_str("repo_url", self.repo_url) + add_str("preview_url", self.preview_url) + add_str("source_hash", self.source_hash) + + # HuggingFace + add_str("repo_id", self.repo_id) + add_str("revision", self.revision) + add_str("filepath", self.filepath) + add_str("resolve_url", self.resolve_url) + + return rows + + +def _read_safetensors_header( + path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE +) -> dict[str, Any] | None: + """Read only the JSON header from a safetensors file. + + This is very fast - reads 8 bytes for header length, then the JSON header. + No tensor data is loaded. + + Args: + path: Absolute path to safetensors file + max_size: Maximum header size to read (default 8MB) + + Returns: + Parsed header dict or None if failed + """ + try: + with open(path, "rb") as f: + header_bytes = f.read(8) + if len(header_bytes) < 8: + return None + length_of_header = struct.unpack(" max_size: + return None + header_data = f.read(length_of_header) + if len(header_data) < length_of_header: + return None + return json.loads(header_data.decode("utf-8")) + except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error): + return None + + +def _extract_safetensors_metadata( + header: dict[str, Any], meta: ExtractedMetadata +) -> None: + """Extract metadata from safetensors header __metadata__ section. + + Modifies meta in-place. + """ + st_meta = header.get("__metadata__", {}) + if not isinstance(st_meta, dict): + return + + # Common model metadata + meta.base_model = ( + st_meta.get("ss_base_model_version") + or st_meta.get("modelspec.base_model") + or st_meta.get("base_model") + ) + + # Trained words / trigger words + trained_words = st_meta.get("ss_tag_frequency") + if trained_words and isinstance(trained_words, str): + try: + tag_freq = json.loads(trained_words) + # Extract unique tags from all datasets + all_tags: set[str] = set() + for dataset_tags in tag_freq.values(): + if isinstance(dataset_tags, dict): + all_tags.update(dataset_tags.keys()) + if all_tags: + meta.trained_words = sorted(all_tags)[:100] + except json.JSONDecodeError: + pass + + # Direct trained_words field (some formats) + if not meta.trained_words: + tw = st_meta.get("trained_words") + if isinstance(tw, str): + try: + parsed = json.loads(tw) + if isinstance(parsed, list): + meta.trained_words = [str(x) for x in parsed] + else: + meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()] + except json.JSONDecodeError: + meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()] + elif isinstance(tw, list): + meta.trained_words = [str(x) for x in tw] + + # CivitAI AIR + meta.air = st_meta.get("air") or st_meta.get("modelspec.air") + + # Preview images (ssmd_cover_images) + cover_images = st_meta.get("ssmd_cover_images") + if cover_images: + meta.has_preview_images = True + + # Source provenance fields + meta.source_url = st_meta.get("source_url") + meta.source_arn = st_meta.get("source_arn") + meta.repo_url = st_meta.get("repo_url") + meta.preview_url = st_meta.get("preview_url") + meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash") + + # HuggingFace fields + meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id") + meta.revision = st_meta.get("revision") or st_meta.get("hf_revision") + meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath") + meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url") + + +def extract_file_metadata( + abs_path: str, + stat_result: os.stat_result | None = None, + relative_filename: str | None = None, +) -> ExtractedMetadata: + """Extract metadata from a file using tier 1 and tier 2 methods. + + Tier 1: Filesystem metadata from path and stat + Tier 2: Safetensors header parsing if applicable + + Args: + abs_path: Absolute path to the file + stat_result: Optional pre-fetched stat result (saves a syscall) + relative_filename: Optional relative filename to use instead of basename + (e.g., "flux/123/model.safetensors" for model paths) + + Returns: + ExtractedMetadata with all available fields populated + """ + meta = ExtractedMetadata() + + # Tier 1: Filesystem metadata + meta.filename = relative_filename or os.path.basename(abs_path) + meta.file_path = abs_path + _, ext = os.path.splitext(abs_path) + meta.format = ext.lstrip(".").lower() if ext else "" + + mime_type, _ = mimetypes.guess_type(abs_path) + meta.content_type = mime_type + + # Size from stat + if stat_result is None: + try: + stat_result = os.stat(abs_path, follow_symlinks=True) + except OSError: + pass + + if stat_result: + meta.content_length = stat_result.st_size + + # Tier 2: Safetensors header (if applicable and enabled) + if ext.lower() in SAFETENSORS_EXTENSIONS: + header = _read_safetensors_header(abs_path) + if header: + try: + _extract_safetensors_metadata(header, meta) + except Exception as e: + logging.debug("Safetensors meta extract failed %s: %s", abs_path, e) + + return meta diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py new file mode 100644 index 000000000..f5dd7f7fd --- /dev/null +++ b/app/assets/services/path_utils.py @@ -0,0 +1,167 @@ +import os +from pathlib import Path +from typing import Literal + +import folder_paths +from app.assets.helpers import normalize_tags + + +_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"}) + + +def get_comfy_models_folders() -> list[tuple[str, list[str]]]: + """Build list of (folder_name, base_paths[]) for all model locations. + + Includes every category registered in folder_names_and_paths, + regardless of whether its paths are under the main models_dir, + but excludes non-model entries like custom_nodes. + """ + targets: list[tuple[str, list[str]]] = [] + for name, values in folder_paths.folder_names_and_paths.items(): + if name in _NON_MODEL_FOLDER_NAMES: + continue + paths, _exts = values[0], values[1] + if paths: + targets.append((name, paths)) + return targets + + +def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: + """Validates and maps tags -> (base_dir, subdirs_for_fs)""" + if not tags: + raise ValueError("tags must not be empty") + root = tags[0].lower() + if root == "models": + if len(tags) < 2: + raise ValueError("at least two tags required for model asset") + try: + bases = folder_paths.folder_names_and_paths[tags[1]][0] + except KeyError: + raise ValueError(f"unknown model category '{tags[1]}'") + if not bases: + raise ValueError(f"no base path configured for category '{tags[1]}'") + base_dir = os.path.abspath(bases[0]) + raw_subdirs = tags[2:] + elif root == "input": + base_dir = os.path.abspath(folder_paths.get_input_directory()) + raw_subdirs = tags[1:] + elif root == "output": + base_dir = os.path.abspath(folder_paths.get_output_directory()) + raw_subdirs = tags[1:] + else: + raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'") + _sep_chars = frozenset(("/", "\\", os.sep)) + for i in raw_subdirs: + if i in (".", "..") or _sep_chars & set(i): + raise ValueError("invalid path component in tags") + + return base_dir, raw_subdirs if raw_subdirs else [] + + +def validate_path_within_base(candidate: str, base: str) -> None: + cand_abs = Path(os.path.abspath(candidate)) + base_abs = Path(os.path.abspath(base)) + if not cand_abs.is_relative_to(base_abs): + raise ValueError("destination escapes base directory") + + +def compute_relative_filename(file_path: str) -> str | None: + """ + Return the model's path relative to the last well-known folder (the model category), + using forward slashes, eg: + /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" + /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + + For non-model paths, returns None. + """ + try: + root_category, rel_path = get_asset_category_and_relative_path(file_path) + except ValueError: + return None + + p = Path(rel_path) + parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + if not parts: + return None + + if root_category == "models": + # parts[0] is the category ("checkpoints", "vae", etc) – drop it + inside = parts[1:] if len(parts) > 1 else [parts[0]] + return "/".join(inside) + return "/".join(parts) # input/output: keep all parts + + +def get_asset_category_and_relative_path( + file_path: str, +) -> tuple[Literal["input", "output", "models"], str]: + """Determine which root category a file path belongs to. + + Categories: + - 'input': under folder_paths.get_input_directory() + - 'output': under folder_paths.get_output_directory() + - 'models': under any base path from get_comfy_models_folders() + + Returns: + (root_category, relative_path_inside_that_root) + + Raises: + ValueError: path does not belong to any known root. + """ + fp_abs = os.path.abspath(file_path) + + def _check_is_within(child: str, parent: str) -> bool: + return Path(child).is_relative_to(parent) + + def _compute_relative(child: str, parent: str) -> str: + # Normalize relative path, stripping any leading ".." components + # by anchoring to root (os.sep) then computing relpath back from it. + return os.path.relpath( + os.path.join(os.sep, os.path.relpath(child, parent)), os.sep + ) + + # 1) input + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _check_is_within(fp_abs, input_base): + return "input", _compute_relative(fp_abs, input_base) + + # 2) output + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _check_is_within(fp_abs, output_base): + return "output", _compute_relative(fp_abs, output_base) + + # 3) models (check deepest matching base to avoid ambiguity) + best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) + for bucket, bases in get_comfy_models_folders(): + for b in bases: + base_abs = os.path.abspath(b) + if not _check_is_within(fp_abs, base_abs): + continue + cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs)) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, bucket, rel_inside = best + combined = os.path.join(bucket, rel_inside) + return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + + raise ValueError( + f"Path is not within input, output, or configured model bases: {file_path}" + ) + + +def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: + """Return (name, tags) derived from a filesystem path. + + - name: base filename with extension + - tags: [root_category] + parent folder names in order + + Raises: + ValueError: path does not belong to any known root. + """ + root_category, some_path = get_asset_category_and_relative_path(file_path) + p = Path(some_path) + parent_parts = [ + part for part in p.parent.parts if part not in (".", "..", p.anchor) + ] + return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py new file mode 100644 index 000000000..8b1f1f4dc --- /dev/null +++ b/app/assets/services/schemas.py @@ -0,0 +1,109 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Any, NamedTuple + +from app.assets.database.models import Asset, AssetReference + +UserMetadata = dict[str, Any] | None + + +@dataclass(frozen=True) +class AssetData: + hash: str | None + size_bytes: int | None + mime_type: str | None + + +@dataclass(frozen=True) +class ReferenceData: + """Data transfer object for AssetReference.""" + + id: str + name: str + file_path: str | None + user_metadata: UserMetadata + preview_id: str | None + created_at: datetime + updated_at: datetime + last_access_time: datetime | None + + +@dataclass(frozen=True) +class AssetDetailResult: + ref: ReferenceData + asset: AssetData | None + tags: list[str] + + +@dataclass(frozen=True) +class RegisterAssetResult: + ref: ReferenceData + asset: AssetData + tags: list[str] + created: bool + + +@dataclass(frozen=True) +class IngestResult: + asset_created: bool + asset_updated: bool + ref_created: bool + ref_updated: bool + reference_id: str | None + + +class TagUsage(NamedTuple): + name: str + tag_type: str + count: int + + +@dataclass(frozen=True) +class AssetSummaryData: + ref: ReferenceData + asset: AssetData | None + tags: list[str] + + +@dataclass(frozen=True) +class ListAssetsResult: + items: list[AssetSummaryData] + total: int + + +@dataclass(frozen=True) +class DownloadResolutionResult: + abs_path: str + content_type: str + download_name: str + + +@dataclass(frozen=True) +class UploadResult: + ref: ReferenceData + asset: AssetData + tags: list[str] + created_new: bool + + +def extract_reference_data(ref: AssetReference) -> ReferenceData: + return ReferenceData( + id=ref.id, + name=ref.name, + file_path=ref.file_path, + user_metadata=ref.user_metadata, + preview_id=ref.preview_id, + created_at=ref.created_at, + updated_at=ref.updated_at, + last_access_time=ref.last_access_time, + ) + + +def extract_asset_data(asset: Asset | None) -> AssetData | None: + if asset is None: + return None + return AssetData( + hash=asset.hash, + size_bytes=asset.size_bytes, + mime_type=asset.mime_type, + ) diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py new file mode 100644 index 000000000..28900464d --- /dev/null +++ b/app/assets/services/tagging.py @@ -0,0 +1,75 @@ +from app.assets.database.queries import ( + AddTagsResult, + RemoveTagsResult, + add_tags_to_reference, + get_reference_with_owner_check, + list_tags_with_usage, + remove_tags_from_reference, +) +from app.assets.services.schemas import TagUsage +from app.database.db import create_session + + +def apply_tags( + reference_id: str, + tags: list[str], + origin: str = "manual", + owner_id: str = "", +) -> AddTagsResult: + with create_session() as session: + ref_row = get_reference_with_owner_check(session, reference_id, owner_id) + + result = add_tags_to_reference( + session, + reference_id=reference_id, + tags=tags, + origin=origin, + create_if_missing=True, + reference_row=ref_row, + ) + session.commit() + + return result + + +def remove_tags( + reference_id: str, + tags: list[str], + owner_id: str = "", +) -> RemoveTagsResult: + with create_session() as session: + get_reference_with_owner_check(session, reference_id, owner_id) + + result = remove_tags_from_reference( + session, + reference_id=reference_id, + tags=tags, + ) + session.commit() + + return result + + +def list_tags( + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + order: str = "count_desc", + include_zero: bool = True, + owner_id: str = "", +) -> tuple[list[TagUsage], int]: + limit = max(1, min(1000, limit)) + offset = max(0, offset) + + with create_session() as session: + rows, total = list_tags_with_usage( + session, + prefix=prefix, + limit=limit, + offset=offset, + include_zero=include_zero, + order=order, + owner_id=owner_id, + ) + + return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total diff --git a/app/database/db.py b/app/database/db.py index 1de8b80ed..0aab09a49 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -3,6 +3,7 @@ import os import shutil from app.logger import log_startup_warning from utils.install_util import get_missing_requirements_message +from filelock import FileLock, Timeout from comfy.cli_args import args _DB_AVAILABLE = False @@ -14,8 +15,12 @@ try: from alembic.config import Config from alembic.runtime.migration import MigrationContext from alembic.script import ScriptDirectory - from sqlalchemy import create_engine + from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker + from sqlalchemy.pool import StaticPool + + from app.database.models import Base + import app.assets.database.models # noqa: F401 — register models with Base.metadata _DB_AVAILABLE = True except ImportError as e: @@ -65,9 +70,69 @@ def get_db_path(): raise ValueError(f"Unsupported database URL '{url}'.") +_db_lock = None + +def _acquire_file_lock(db_path): + """Acquire an OS-level file lock to prevent multi-process access. + + Uses filelock for cross-platform support (macOS, Linux, Windows). + The OS automatically releases the lock when the process exits, even on crashes. + """ + global _db_lock + lock_path = db_path + ".lock" + _db_lock = FileLock(lock_path) + try: + _db_lock.acquire(timeout=0) + except Timeout: + raise RuntimeError( + f"Could not acquire lock on database '{db_path}'. " + "Another ComfyUI process may already be using it. " + "Use --database-url to specify a separate database file." + ) + + +def _is_memory_db(db_url): + """Check if the database URL refers to an in-memory SQLite database.""" + return db_url in ("sqlite:///:memory:", "sqlite://") + + def init_db(): db_url = args.database_url logging.debug(f"Database URL: {db_url}") + + if _is_memory_db(db_url): + _init_memory_db(db_url) + else: + _init_file_db(db_url) + + +def _init_memory_db(db_url): + """Initialize an in-memory SQLite database using metadata.create_all. + + Alembic migrations don't work with in-memory SQLite because each + connection gets its own separate database — tables created by Alembic's + internal connection are lost immediately. + """ + engine = create_engine( + db_url, + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + Base.metadata.create_all(engine) + + global Session + Session = sessionmaker(bind=engine) + + +def _init_file_db(db_url): + """Initialize a file-backed SQLite database using Alembic migrations.""" db_path = get_db_path() db_exists = os.path.exists(db_path) @@ -75,6 +140,14 @@ def init_db(): # Check if we need to upgrade engine = create_engine(db_url) + + # Enable foreign key enforcement for SQLite + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + conn = engine.connect() context = MigrationContext.configure(conn) @@ -104,6 +177,12 @@ def init_db(): logging.exception("Error upgrading database: ") raise e + # Acquire an OS-level file lock after migrations are complete. + # Alembic uses its own connection, so we must wait until it's done + # before locking — otherwise our own lock blocks the migration. + conn.close() + _acquire_file_lock(db_path) + global Session Session = sessionmaker(bind=engine) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 13079c7bc..e9832acaf 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -232,7 +232,7 @@ database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") -parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") +parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index a90a5ca40..9f6918315 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -15,6 +15,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = { "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, "node_replacements": True, + "assets": args.enable_assets, } diff --git a/main.py b/main.py index 0f58d57b8..a8fc1a28d 100644 --- a/main.py +++ b/main.py @@ -7,14 +7,16 @@ import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger -from app.assets.scanner import seed_assets +from app.assets.seeder import asset_seeder import itertools import utils.extra_config +from utils.mime_types import init_mime_types import logging import sys from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags +from app.database.db import init_db, dependencies_available if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. @@ -161,6 +163,7 @@ def execute_prestartup_script(): logging.info("") apply_custom_paths() +init_mime_types() if args.enable_manager: comfyui_manager.prestartup() @@ -258,6 +261,7 @@ def prompt_worker(q, server_instance): for k in sensitive: extra_data[k] = sensitive[k] + asset_seeder.pause() e.execute(item[2], prompt_id, extra_data, item[4]) need_gc = True @@ -302,6 +306,7 @@ def prompt_worker(q, server_instance): last_gc_collect = current_time need_gc = False hook_breaker_ac10a0.restore_functions() + asset_seeder.resume() async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): @@ -352,12 +357,29 @@ def cleanup_temp(): def setup_database(): try: - from app.database.db import init_db, dependencies_available if dependencies_available(): init_db() - if not args.disable_assets_autoscan: - seed_assets(["models"], enable_logging=True) + if args.enable_assets: + if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True): + logging.info("Background asset scan initiated for models, input, output") except Exception as e: + if "database is locked" in str(e): + logging.error( + "Database is locked. Another ComfyUI process is already using this database.\n" + "To resolve this, specify a separate database file for this instance:\n" + " --database-url sqlite:///path/to/another.db" + ) + sys.exit(1) + if args.enable_assets: + logging.error( + f"Failed to initialize database: {e}\n" + "The --enable-assets flag requires a working database connection.\n" + "To resolve this, try one of the following:\n" + " 1. Install the latest requirements: pip install -r requirements.txt\n" + " 2. Specify an alternative database URL: --database-url sqlite:///path/to/your.db\n" + " 3. Use an in-memory database: --database-url sqlite:///:memory:" + ) + sys.exit(1) logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") @@ -440,5 +462,6 @@ if __name__ == "__main__": event_loop.run_until_complete(x) except KeyboardInterrupt: logging.info("\nStopped server") - - cleanup_temp() + finally: + asset_seeder.shutdown() + cleanup_temp() diff --git a/requirements.txt b/requirements.txt index dc9a9ded0..9527135ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,11 +20,13 @@ tqdm psutil alembic SQLAlchemy +filelock av>=14.2.0 comfy-kitchen>=0.2.7 comfy-aimdo>=0.2.7 requests simpleeval>=1.0.0 +blake3 #non essential dependencies: kornia>=0.7.1 diff --git a/server.py b/server.py index 275bce5a7..76904ebc9 100644 --- a/server.py +++ b/server.py @@ -33,8 +33,8 @@ import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal -from app.assets.scanner import seed_assets -from app.assets.api.routes import register_assets_system +from app.assets.seeder import asset_seeder +from app.assets.api.routes import register_assets_routes from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -197,10 +197,6 @@ class PromptServer(): def __init__(self, loop): PromptServer.instance = self - mimetypes.init() - mimetypes.add_type('application/javascript; charset=utf-8', '.js') - mimetypes.add_type('image/webp', '.webp') - self.user_manager = UserManager() self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() @@ -239,7 +235,11 @@ class PromptServer(): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") - register_assets_system(self.app, self.user_manager) + if args.enable_assets: + register_assets_routes(self.app, self.user_manager) + else: + register_assets_routes(self.app) + asset_seeder.disable() routes = web.RouteTableDef() self.routes = routes self.last_node_id = None @@ -697,10 +697,7 @@ class PromptServer(): @routes.get("/object_info") async def get_object_info(request): - try: - seed_assets(["models"]) - except Exception as e: - logging.error(f"Failed to seed assets: {e}") + asset_seeder.start(roots=("models", "input", "output")) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: diff --git a/tests-unit/assets_test/conftest.py b/tests-unit/assets_test/conftest.py index 0a57dd7b5..6c5c56113 100644 --- a/tests-unit/assets_test/conftest.py +++ b/tests-unit/assets_test/conftest.py @@ -108,7 +108,7 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest) "main.py", f"--base-directory={str(comfy_tmp_base_dir)}", f"--database-url={db_url}", - "--disable-assets-autoscan", + "--enable-assets", "--listen", "127.0.0.1", "--port", @@ -212,7 +212,7 @@ def asset_factory(http: requests.Session, api_base: str): for aid in created: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}", timeout=30) + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) @pytest.fixture @@ -258,14 +258,4 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str): break for aid in ids: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}", timeout=30) - - -def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: - """Force a fast sync/seed pass by calling the seed endpoint.""" - session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30) - time.sleep(0.2) - - -def get_asset_filename(asset_hash: str, extension: str) -> str: - return asset_hash.removeprefix("blake3:") + extension + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) diff --git a/tests-unit/assets_test/helpers.py b/tests-unit/assets_test/helpers.py new file mode 100644 index 000000000..770e011f4 --- /dev/null +++ b/tests-unit/assets_test/helpers.py @@ -0,0 +1,28 @@ +"""Helper functions for assets integration tests.""" +import time + +import requests + + +def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: + """Force a synchronous sync/seed pass by calling the seed endpoint with wait=true. + + Retries on 409 (already running) until the previous scan finishes. + """ + deadline = time.monotonic() + 60 + while True: + r = session.post( + base_url + "/api/assets/seed?wait=true", + json={"roots": ["models", "input", "output"]}, + timeout=60, + ) + if r.status_code != 409: + assert r.status_code == 200, f"seed endpoint returned {r.status_code}: {r.text}" + return + if time.monotonic() > deadline: + raise TimeoutError("seed endpoint stuck in 409 (already running)") + time.sleep(0.25) + + +def get_asset_filename(asset_hash: str, extension: str) -> str: + return asset_hash.removeprefix("blake3:") + extension diff --git a/tests-unit/assets_test/queries/conftest.py b/tests-unit/assets_test/queries/conftest.py new file mode 100644 index 000000000..4ca0e86a9 --- /dev/null +++ b/tests-unit/assets_test/queries/conftest.py @@ -0,0 +1,20 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import Base + + +@pytest.fixture +def session(): + """In-memory SQLite session for fast unit tests.""" + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + with Session(engine) as sess: + yield sess + + +@pytest.fixture(autouse=True) +def autoclean_unit_test_assets(): + """Override parent autouse fixture - query tests don't need server cleanup.""" + yield diff --git a/tests-unit/assets_test/queries/test_asset.py b/tests-unit/assets_test/queries/test_asset.py new file mode 100644 index 000000000..08f84cd11 --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset.py @@ -0,0 +1,144 @@ +import uuid + +import pytest +from sqlalchemy.orm import Session + +from app.assets.helpers import get_utc_now +from app.assets.database.models import Asset +from app.assets.database.queries import ( + asset_exists_by_hash, + get_asset_by_hash, + upsert_asset, + bulk_insert_assets, +) + + +class TestAssetExistsByHash: + @pytest.mark.parametrize( + "setup_hash,query_hash,expected", + [ + (None, "nonexistent", False), # No asset exists + ("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash + (None, "", False), # Null hash in DB doesn't match empty string + ], + ids=["nonexistent", "existing", "null_hash_no_match"], + ) + def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected): + if setup_hash is not None or query_hash == "": + asset = Asset(hash=setup_hash, size_bytes=100) + session.add(asset) + session.commit() + + assert asset_exists_by_hash(session, asset_hash=query_hash) is expected + + +class TestGetAssetByHash: + @pytest.mark.parametrize( + "setup_hash,query_hash,should_find", + [ + (None, "nonexistent", False), + ("blake3:def456", "blake3:def456", True), + ], + ids=["nonexistent", "existing"], + ) + def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find): + if setup_hash is not None: + asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png") + session.add(asset) + session.commit() + + result = get_asset_by_hash(session, asset_hash=query_hash) + if should_find: + assert result is not None + assert result.size_bytes == 200 + assert result.mime_type == "image/png" + else: + assert result is None + + +class TestUpsertAsset: + @pytest.mark.parametrize( + "first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime", + [ + # New asset creation + (None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"), + # Existing asset, same values - no update + (500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"), + # Existing asset with size 0, update with new values + (0, None, 2048, "image/png", False, True, 2048, "image/png"), + # Existing asset, second call with size 0 - no update + (1000, None, 0, None, False, False, 1000, None), + ], + ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"], + ) + def test_upsert_scenarios( + self, + session: Session, + first_size, + first_mime, + second_size, + second_mime, + expect_created, + expect_updated, + final_size, + final_mime, + ): + asset_hash = f"blake3:test_{first_size}_{second_size}" + + # First upsert (if first_size is not None, we're testing the second call) + if first_size is not None: + upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=first_size, + mime_type=first_mime, + ) + session.commit() + + # The upsert call we're testing + asset, created, updated = upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=second_size, + mime_type=second_mime, + ) + session.commit() + + assert created is expect_created + assert updated is expect_updated + assert asset.size_bytes == final_size + assert asset.mime_type == final_mime + + +class TestBulkInsertAssets: + def test_inserts_multiple_assets(self, session: Session): + now = get_utc_now() + rows = [ + {"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": now}, + {"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": now}, + {"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": now}, + ] + bulk_insert_assets(session, rows) + session.commit() + + assets = session.query(Asset).all() + assert len(assets) == 3 + hashes = {a.hash for a in assets} + assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"} + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_assets(session, []) + session.commit() + assert session.query(Asset).count() == 0 + + def test_handles_large_batch(self, session: Session): + """Test chunking logic with more rows than MAX_BIND_PARAMS allows.""" + now = get_utc_now() + rows = [ + {"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": now} + for i in range(200) + ] + bulk_insert_assets(session, rows) + session.commit() + + assert session.query(Asset).count() == 200 diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py new file mode 100644 index 000000000..8f6c7fcdb --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -0,0 +1,517 @@ +import time +import uuid +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta +from app.assets.database.queries import ( + reference_exists_for_asset_id, + get_reference_by_id, + insert_reference, + get_or_create_reference, + update_reference_timestamps, + list_references_page, + fetch_reference_asset_and_tags, + fetch_reference_and_asset, + update_reference_access_time, + set_reference_metadata, + delete_reference_by_id, + set_reference_preview, + bulk_insert_references_ignore_conflicts, + get_reference_ids_by_ids, + ensure_tags_exist, + add_tags_to_reference, +) +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream") + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestReferenceExistsForAssetId: + def test_returns_false_when_no_reference(self, session: Session): + asset = _make_asset(session, "hash1") + assert reference_exists_for_asset_id(session, asset_id=asset.id) is False + + def test_returns_true_when_reference_exists(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset) + assert reference_exists_for_asset_id(session, asset_id=asset.id) is True + + +class TestGetReferenceById: + def test_returns_none_for_nonexistent(self, session: Session): + assert get_reference_by_id(session, reference_id="nonexistent") is None + + def test_returns_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="myfile.txt") + + result = get_reference_by_id(session, reference_id=ref.id) + assert result is not None + assert result.name == "myfile.txt" + + +class TestListReferencesPage: + def test_empty_db(self, session: Session): + refs, tag_map, total = list_references_page(session) + assert refs == [] + assert tag_map == {} + assert total == 0 + + def test_returns_references_with_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["alpha", "beta"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"]) + session.commit() + + refs, tag_map, total = list_references_page(session) + assert len(refs) == 1 + assert refs[0].id == ref.id + assert set(tag_map[ref.id]) == {"alpha", "beta"} + assert total == 1 + + def test_name_contains_filter(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="model_v1.safetensors") + _make_reference(session, asset, name="config.json") + session.commit() + + refs, _, total = list_references_page(session, name_contains="model") + assert total == 1 + assert refs[0].name == "model_v1.safetensors" + + def test_owner_visibility(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="public", owner_id="") + _make_reference(session, asset, name="private", owner_id="user1") + session.commit() + + # Empty owner sees only public + refs, _, total = list_references_page(session, owner_id="") + assert total == 1 + assert refs[0].name == "public" + + # Owner sees both + refs, _, total = list_references_page(session, owner_id="user1") + assert total == 2 + + def test_include_tags_filter(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="tagged") + _make_reference(session, asset, name="untagged") + ensure_tags_exist(session, ["wanted"]) + add_tags_to_reference(session, reference_id=ref1.id, tags=["wanted"]) + session.commit() + + refs, _, total = list_references_page(session, include_tags=["wanted"]) + assert total == 1 + assert refs[0].name == "tagged" + + def test_exclude_tags_filter(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="keep") + ref_exclude = _make_reference(session, asset, name="exclude") + ensure_tags_exist(session, ["bad"]) + add_tags_to_reference(session, reference_id=ref_exclude.id, tags=["bad"]) + session.commit() + + refs, _, total = list_references_page(session, exclude_tags=["bad"]) + assert total == 1 + assert refs[0].name == "keep" + + def test_sorting(self, session: Session): + asset = _make_asset(session, "hash1", size=100) + asset2 = _make_asset(session, "hash2", size=500) + _make_reference(session, asset, name="small") + _make_reference(session, asset2, name="large") + session.commit() + + refs, _, _ = list_references_page(session, sort="size", order="desc") + assert refs[0].name == "large" + + refs, _, _ = list_references_page(session, sort="name", order="asc") + assert refs[0].name == "large" + + +class TestFetchReferenceAssetAndTags: + def test_returns_none_for_nonexistent(self, session: Session): + result = fetch_reference_asset_and_tags(session, "nonexistent") + assert result is None + + def test_returns_tuple(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["tag1"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["tag1"]) + session.commit() + + result = fetch_reference_asset_and_tags(session, ref.id) + assert result is not None + ret_ref, ret_asset, ret_tags = result + assert ret_ref.id == ref.id + assert ret_asset.id == asset.id + assert ret_tags == ["tag1"] + + +class TestFetchReferenceAndAsset: + def test_returns_none_for_nonexistent(self, session: Session): + result = fetch_reference_and_asset(session, reference_id="nonexistent") + assert result is None + + def test_returns_tuple(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + result = fetch_reference_and_asset(session, reference_id=ref.id) + assert result is not None + ret_ref, ret_asset = result + assert ret_ref.id == ref.id + assert ret_asset.id == asset.id + + +class TestUpdateReferenceAccessTime: + def test_updates_last_access_time(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + original_time = ref.last_access_time + session.commit() + + import time + time.sleep(0.01) + + update_reference_access_time(session, reference_id=ref.id) + session.commit() + + session.refresh(ref) + assert ref.last_access_time > original_time + + +class TestDeleteReferenceById: + def test_deletes_existing(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + result = delete_reference_by_id(session, reference_id=ref.id, owner_id="") + assert result is True + assert get_reference_by_id(session, reference_id=ref.id) is None + + def test_returns_false_for_nonexistent(self, session: Session): + result = delete_reference_by_id(session, reference_id="nonexistent", owner_id="") + assert result is False + + def test_respects_owner_visibility(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + result = delete_reference_by_id(session, reference_id=ref.id, owner_id="user2") + assert result is False + assert get_reference_by_id(session, reference_id=ref.id) is not None + + +class TestSetReferencePreview: + def test_sets_preview(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + session.commit() + + set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id) + session.commit() + + session.refresh(ref) + assert ref.preview_id == preview_asset.id + + def test_clears_preview(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + ref.preview_id = preview_asset.id + session.commit() + + set_reference_preview(session, reference_id=ref.id, preview_asset_id=None) + session.commit() + + session.refresh(ref) + assert ref.preview_id is None + + def test_raises_for_nonexistent_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None) + + def test_raises_for_nonexistent_preview(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + with pytest.raises(ValueError, match="Preview Asset"): + set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent") + + +class TestInsertReference: + def test_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = insert_reference( + session, asset_id=asset.id, owner_id="user1", name="test.bin" + ) + session.commit() + + assert ref is not None + assert ref.name == "test.bin" + assert ref.owner_id == "user1" + + def test_allows_duplicate_names(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = insert_reference(session, asset_id=asset.id, owner_id="user1", name="dup.bin") + session.commit() + + # Duplicate names are now allowed + ref2 = insert_reference( + session, asset_id=asset.id, owner_id="user1", name="dup.bin" + ) + session.commit() + + assert ref1 is not None + assert ref2 is not None + assert ref1.id != ref2.id + + +class TestGetOrCreateReference: + def test_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref, created = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="new.bin" + ) + session.commit() + + assert created is True + assert ref.name == "new.bin" + + def test_always_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref1, created1 = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="existing.bin" + ) + session.commit() + + # Duplicate names are allowed, so always creates new + ref2, created2 = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="existing.bin" + ) + session.commit() + + assert created1 is True + assert created2 is True + assert ref1.id != ref2.id + + +class TestUpdateReferenceTimestamps: + def test_updates_timestamps(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + original_updated_at = ref.updated_at + session.commit() + + time.sleep(0.01) + update_reference_timestamps(session, ref) + session.commit() + + session.refresh(ref) + assert ref.updated_at > original_updated_at + + def test_updates_preview_id(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + session.commit() + + update_reference_timestamps(session, ref, preview_id=preview_asset.id) + session.commit() + + session.refresh(ref) + assert ref.preview_id == preview_asset.id + + +class TestSetReferenceMetadata: + def test_sets_metadata(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"key": "value"} + ) + session.commit() + + session.refresh(ref) + assert ref.user_metadata == {"key": "value"} + # Check metadata table + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "key" + assert meta[0].val_str == "value" + + def test_replaces_existing_metadata(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"old": "data"} + ) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"new": "data"} + ) + session.commit() + + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "new" + + def test_clears_metadata_with_empty_dict(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"key": "value"} + ) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={} + ) + session.commit() + + session.refresh(ref) + assert ref.user_metadata == {} + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 0 + + def test_raises_for_nonexistent(self, session: Session): + with pytest.raises(ValueError, match="not found"): + set_reference_metadata( + session, reference_id="nonexistent", user_metadata={"key": "value"} + ) + + +class TestBulkInsertReferencesIgnoreConflicts: + def test_inserts_multiple_references(self, session: Session): + asset = _make_asset(session, "hash1") + now = get_utc_now() + rows = [ + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "bulk1.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "bulk2.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + refs = session.query(AssetReference).all() + assert len(refs) == 2 + + def test_allows_duplicate_names(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="existing.bin", owner_id="") + session.commit() + + now = get_utc_now() + rows = [ + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "existing.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "new.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + # Duplicate names allowed, so all 3 rows exist + refs = session.query(AssetReference).all() + assert len(refs) == 3 + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_references_ignore_conflicts(session, []) + assert session.query(AssetReference).count() == 0 + + +class TestGetReferenceIdsByIds: + def test_returns_existing_ids(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="a.bin") + ref2 = _make_reference(session, asset, name="b.bin") + session.commit() + + found = get_reference_ids_by_ids(session, [ref1.id, ref2.id, "nonexistent"]) + + assert found == {ref1.id, ref2.id} + + def test_empty_list_returns_empty(self, session: Session): + found = get_reference_ids_by_ids(session, []) + assert found == set() diff --git a/tests-unit/assets_test/queries/test_cache_state.py b/tests-unit/assets_test/queries/test_cache_state.py new file mode 100644 index 000000000..ead60e570 --- /dev/null +++ b/tests-unit/assets_test/queries/test_cache_state.py @@ -0,0 +1,499 @@ +"""Tests for cache_state (AssetReference file path) query functions.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ( + list_references_by_asset_id, + upsert_reference, + get_unreferenced_unhashed_asset_ids, + delete_assets_by_ids, + get_references_for_prefixes, + bulk_update_needs_verify, + delete_references_by_ids, + delete_orphaned_seed_asset, + bulk_insert_references_ignore_conflicts, + get_references_by_paths_and_asset_ids, + mark_references_missing_outside_prefixes, + restore_references_by_paths, +) +from app.assets.helpers import select_best_live_path, get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + file_path: str, + name: str = "test", + mtime_ns: int | None = None, + needs_verify: bool = False, +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + asset_id=asset.id, + file_path=file_path, + name=name, + mtime_ns=mtime_ns, + needs_verify=needs_verify, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestListReferencesByAssetId: + def test_returns_empty_for_no_references(self, session: Session): + asset = _make_asset(session, "hash1") + refs = list_references_by_asset_id(session, asset_id=asset.id) + assert list(refs) == [] + + def test_returns_references_for_asset(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/path/a.bin", name="a") + _make_reference(session, asset, "/path/b.bin", name="b") + session.commit() + + refs = list_references_by_asset_id(session, asset_id=asset.id) + paths = [r.file_path for r in refs] + assert set(paths) == {"/path/a.bin", "/path/b.bin"} + + def test_does_not_return_other_assets_references(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + _make_reference(session, asset1, "/path/asset1.bin", name="a1") + _make_reference(session, asset2, "/path/asset2.bin", name="a2") + session.commit() + + refs = list_references_by_asset_id(session, asset_id=asset1.id) + paths = [r.file_path for r in refs] + assert paths == ["/path/asset1.bin"] + + +class TestSelectBestLivePath: + def test_returns_empty_for_empty_list(self): + result = select_best_live_path([]) + assert result == "" + + def test_returns_empty_when_no_files_exist(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, "/nonexistent/path.bin") + session.commit() + + result = select_best_live_path([ref]) + assert result == "" + + def test_prefers_verified_path(self, session: Session, tmp_path): + """needs_verify=False should be preferred.""" + asset = _make_asset(session, "hash1") + + verified_file = tmp_path / "verified.bin" + verified_file.write_bytes(b"data") + + unverified_file = tmp_path / "unverified.bin" + unverified_file.write_bytes(b"data") + + ref_verified = _make_reference( + session, asset, str(verified_file), name="verified", needs_verify=False + ) + ref_unverified = _make_reference( + session, asset, str(unverified_file), name="unverified", needs_verify=True + ) + session.commit() + + refs = [ref_unverified, ref_verified] + result = select_best_live_path(refs) + assert result == str(verified_file) + + def test_falls_back_to_existing_unverified(self, session: Session, tmp_path): + """If all references need verification, return first existing path.""" + asset = _make_asset(session, "hash1") + + existing_file = tmp_path / "exists.bin" + existing_file.write_bytes(b"data") + + ref = _make_reference(session, asset, str(existing_file), needs_verify=True) + session.commit() + + result = select_best_live_path([ref]) + assert result == str(existing_file) + + +class TestSelectBestLivePathWithMocking: + def test_handles_missing_file_path_attr(self): + """Gracefully handle references with None file_path.""" + + class MockRef: + file_path = None + needs_verify = False + + result = select_best_live_path([MockRef()]) + assert result == "" + + +class TestUpsertReference: + @pytest.mark.parametrize( + "initial_mtime,second_mtime,expect_created,expect_updated,final_mtime", + [ + # New reference creation + (None, 12345, True, False, 12345), + # Existing reference, same mtime - no update + (100, 100, False, False, 100), + # Existing reference, different mtime - update + (100, 200, False, True, 200), + ], + ids=["new_reference", "existing_no_change", "existing_update_mtime"], + ) + def test_upsert_scenarios( + self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime + ): + asset = _make_asset(session, "hash1") + file_path = f"/path_{initial_mtime}_{second_mtime}.bin" + name = f"file_{initial_mtime}_{second_mtime}" + + # Create initial reference if needed + if initial_mtime is not None: + upsert_reference(session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=initial_mtime) + session.commit() + + # The upsert call we're testing + created, updated = upsert_reference( + session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=second_mtime + ) + session.commit() + + assert created is expect_created + assert updated is expect_updated + ref = session.query(AssetReference).filter_by(file_path=file_path).one() + assert ref.mtime_ns == final_mtime + + def test_upsert_restores_missing_reference(self, session: Session): + """Upserting a reference that was marked missing should restore it.""" + asset = _make_asset(session, "hash1") + file_path = "/restored/file.bin" + + ref = _make_reference(session, asset, file_path, mtime_ns=100) + ref.is_missing = True + session.commit() + + created, updated = upsert_reference( + session, asset_id=asset.id, file_path=file_path, name="restored", mtime_ns=100 + ) + session.commit() + + assert created is False + assert updated is True + restored_ref = session.query(AssetReference).filter_by(file_path=file_path).one() + assert restored_ref.is_missing is False + + +class TestRestoreReferencesByPaths: + def test_restores_missing_references(self, session: Session): + asset = _make_asset(session, "hash1") + missing_path = "/missing/file.bin" + active_path = "/active/file.bin" + + missing_ref = _make_reference(session, asset, missing_path, name="missing") + missing_ref.is_missing = True + _make_reference(session, asset, active_path, name="active") + session.commit() + + restored = restore_references_by_paths(session, [missing_path]) + session.commit() + + assert restored == 1 + ref = session.query(AssetReference).filter_by(file_path=missing_path).one() + assert ref.is_missing is False + + def test_empty_list_restores_nothing(self, session: Session): + restored = restore_references_by_paths(session, []) + assert restored == 0 + + +class TestMarkReferencesMissingOutsidePrefixes: + def test_marks_references_missing_outside_prefixes(self, session: Session, tmp_path): + asset = _make_asset(session, "hash1") + valid_dir = tmp_path / "valid" + valid_dir.mkdir() + invalid_dir = tmp_path / "invalid" + invalid_dir.mkdir() + + valid_path = str(valid_dir / "file.bin") + invalid_path = str(invalid_dir / "file.bin") + + _make_reference(session, asset, valid_path, name="valid") + _make_reference(session, asset, invalid_path, name="invalid") + session.commit() + + marked = mark_references_missing_outside_prefixes(session, [str(valid_dir)]) + session.commit() + + assert marked == 1 + all_refs = session.query(AssetReference).all() + assert len(all_refs) == 2 + + valid_ref = next(r for r in all_refs if r.file_path == valid_path) + invalid_ref = next(r for r in all_refs if r.file_path == invalid_path) + assert valid_ref.is_missing is False + assert invalid_ref.is_missing is True + + def test_empty_prefixes_marks_nothing(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/some/path.bin") + session.commit() + + marked = mark_references_missing_outside_prefixes(session, []) + + assert marked == 0 + + +class TestGetUnreferencedUnhashedAssetIds: + def test_returns_unreferenced_unhashed_assets(self, session: Session): + # Unhashed asset (hash=None) with no references (no file_path) + no_refs = _make_asset(session, hash_val=None) + # Unhashed asset with active reference (not unreferenced) + with_active_ref = _make_asset(session, hash_val=None) + _make_reference(session, with_active_ref, "/has/ref.bin", name="has_ref") + # Unhashed asset with only missing reference (should be unreferenced) + with_missing_ref = _make_asset(session, hash_val=None) + missing_ref = _make_reference(session, with_missing_ref, "/missing/ref.bin", name="missing_ref") + missing_ref.is_missing = True + # Regular asset (hash not None) - should not be returned + _make_asset(session, hash_val="blake3:regular") + session.commit() + + unreferenced = get_unreferenced_unhashed_asset_ids(session) + + assert no_refs.id in unreferenced + assert with_missing_ref.id in unreferenced + assert with_active_ref.id not in unreferenced + + +class TestDeleteAssetsByIds: + def test_deletes_assets_and_references(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/test/path.bin", name="test") + session.commit() + + deleted = delete_assets_by_ids(session, [asset.id]) + session.commit() + + assert deleted == 1 + assert session.query(Asset).count() == 0 + assert session.query(AssetReference).count() == 0 + + def test_empty_list_deletes_nothing(self, session: Session): + _make_asset(session, "hash1") + session.commit() + + deleted = delete_assets_by_ids(session, []) + + assert deleted == 0 + assert session.query(Asset).count() == 1 + + +class TestGetReferencesForPrefixes: + def test_returns_references_matching_prefix(self, session: Session, tmp_path): + asset = _make_asset(session, "hash1") + dir1 = tmp_path / "dir1" + dir1.mkdir() + dir2 = tmp_path / "dir2" + dir2.mkdir() + + path1 = str(dir1 / "file.bin") + path2 = str(dir2 / "file.bin") + + _make_reference(session, asset, path1, name="file1", mtime_ns=100) + _make_reference(session, asset, path2, name="file2", mtime_ns=200) + session.commit() + + rows = get_references_for_prefixes(session, [str(dir1)]) + + assert len(rows) == 1 + assert rows[0].file_path == path1 + + def test_empty_prefixes_returns_empty(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/some/path.bin") + session.commit() + + rows = get_references_for_prefixes(session, []) + + assert rows == [] + + +class TestBulkSetNeedsVerify: + def test_sets_needs_verify_flag(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, "/path1.bin", needs_verify=False) + ref2 = _make_reference(session, asset, "/path2.bin", needs_verify=False) + session.commit() + + updated = bulk_update_needs_verify(session, [ref1.id, ref2.id], True) + session.commit() + + assert updated == 2 + session.refresh(ref1) + session.refresh(ref2) + assert ref1.needs_verify is True + assert ref2.needs_verify is True + + def test_empty_list_updates_nothing(self, session: Session): + updated = bulk_update_needs_verify(session, [], True) + assert updated == 0 + + +class TestDeleteReferencesByIds: + def test_deletes_references_by_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, "/path1.bin") + _make_reference(session, asset, "/path2.bin") + session.commit() + + deleted = delete_references_by_ids(session, [ref1.id]) + session.commit() + + assert deleted == 1 + assert session.query(AssetReference).count() == 1 + + def test_empty_list_deletes_nothing(self, session: Session): + deleted = delete_references_by_ids(session, []) + assert deleted == 0 + + +class TestDeleteOrphanedSeedAsset: + @pytest.mark.parametrize( + "create_asset,expected_deleted,expected_count", + [ + (True, True, 0), # Existing asset gets deleted + (False, False, 0), # Nonexistent returns False + ], + ids=["deletes_existing", "nonexistent_returns_false"], + ) + def test_delete_orphaned_seed_asset( + self, session: Session, create_asset, expected_deleted, expected_count + ): + asset_id = "nonexistent-id" + if create_asset: + asset = _make_asset(session, hash_val=None) + asset_id = asset.id + _make_reference(session, asset, "/test/path.bin", name="test") + session.commit() + + deleted = delete_orphaned_seed_asset(session, asset_id) + if create_asset: + session.commit() + + assert deleted is expected_deleted + assert session.query(Asset).count() == expected_count + + +class TestBulkInsertReferencesIgnoreConflicts: + def test_inserts_multiple_references(self, session: Session): + asset = _make_asset(session, "hash1") + now = get_utc_now() + rows = [ + { + "asset_id": asset.id, + "file_path": "/bulk1.bin", + "name": "bulk1", + "mtime_ns": 100, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "asset_id": asset.id, + "file_path": "/bulk2.bin", + "name": "bulk2", + "mtime_ns": 200, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + assert session.query(AssetReference).count() == 2 + + def test_ignores_conflicts(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/existing.bin", mtime_ns=100) + session.commit() + + now = get_utc_now() + rows = [ + { + "asset_id": asset.id, + "file_path": "/existing.bin", + "name": "existing", + "mtime_ns": 999, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "asset_id": asset.id, + "file_path": "/new.bin", + "name": "new", + "mtime_ns": 200, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + assert session.query(AssetReference).count() == 2 + existing = session.query(AssetReference).filter_by(file_path="/existing.bin").one() + assert existing.mtime_ns == 100 # Original value preserved + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_references_ignore_conflicts(session, []) + assert session.query(AssetReference).count() == 0 + + +class TestGetReferencesByPathsAndAssetIds: + def test_returns_matching_paths(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + + _make_reference(session, asset1, "/path1.bin") + _make_reference(session, asset2, "/path2.bin") + session.commit() + + path_to_asset = { + "/path1.bin": asset1.id, + "/path2.bin": asset2.id, + } + winners = get_references_by_paths_and_asset_ids(session, path_to_asset) + + assert winners == {"/path1.bin", "/path2.bin"} + + def test_excludes_non_matching_asset_ids(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + + _make_reference(session, asset1, "/path1.bin") + session.commit() + + # Path exists but with different asset_id + path_to_asset = {"/path1.bin": asset2.id} + winners = get_references_by_paths_and_asset_ids(session, path_to_asset) + + assert winners == set() + + def test_empty_dict_returns_empty(self, session: Session): + winners = get_references_by_paths_and_asset_ids(session, {}) + assert winners == set() diff --git a/tests-unit/assets_test/queries/test_metadata.py b/tests-unit/assets_test/queries/test_metadata.py new file mode 100644 index 000000000..6a545e819 --- /dev/null +++ b/tests-unit/assets_test/queries/test_metadata.py @@ -0,0 +1,184 @@ +"""Tests for metadata filtering logic in asset_reference queries.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta +from app.assets.database.queries import list_references_page +from app.assets.database.queries.asset_reference import convert_metadata_to_rows +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str) -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str, + metadata: dict | None = None, +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id="", + name=name, + asset_id=asset.id, + user_metadata=metadata, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + + if metadata: + for key, val in metadata.items(): + for row in convert_metadata_to_rows(key, val): + meta_row = AssetReferenceMeta( + asset_reference_id=ref.id, + key=row["key"], + ordinal=row.get("ordinal", 0), + val_str=row.get("val_str"), + val_num=row.get("val_num"), + val_bool=row.get("val_bool"), + val_json=row.get("val_json"), + ) + session.add(meta_row) + session.flush() + + return ref + + +class TestMetadataFilterByType: + """Table-driven tests for metadata filtering by different value types.""" + + @pytest.mark.parametrize( + "match_meta,nomatch_meta,filter_key,filter_val", + [ + # String matching + ({"category": "models"}, {"category": "images"}, "category", "models"), + # Integer matching + ({"epoch": 5}, {"epoch": 10}, "epoch", 5), + # Float matching + ({"score": 0.95}, {"score": 0.5}, "score", 0.95), + # Boolean True matching + ({"enabled": True}, {"enabled": False}, "enabled", True), + # Boolean False matching + ({"enabled": False}, {"enabled": True}, "enabled", False), + ], + ids=["string", "int", "float", "bool_true", "bool_false"], + ) + def test_filter_matches_correct_value( + self, session: Session, match_meta, nomatch_meta, filter_key, filter_val + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "match", match_meta) + _make_reference(session, asset, "nomatch", nomatch_meta) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={filter_key: filter_val} + ) + assert total == 1 + assert refs[0].name == "match" + + @pytest.mark.parametrize( + "stored_meta,filter_key,filter_val", + [ + # String no match + ({"category": "models"}, "category", "other"), + # Int no match + ({"epoch": 5}, "epoch", 99), + # Float no match + ({"score": 0.5}, "score", 0.99), + ], + ids=["string_no_match", "int_no_match", "float_no_match"], + ) + def test_filter_returns_empty_when_no_match( + self, session: Session, stored_meta, filter_key, filter_val + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "item", stored_meta) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={filter_key: filter_val} + ) + assert total == 0 + + +class TestMetadataFilterNull: + """Tests for null/missing key filtering.""" + + @pytest.mark.parametrize( + "match_name,match_meta,nomatch_name,nomatch_meta,filter_key", + [ + # Null matches missing key + ("missing_key", {}, "has_key", {"optional": "value"}, "optional"), + # Null matches explicit null + ("explicit_null", {"nullable": None}, "has_value", {"nullable": "present"}, "nullable"), + ], + ids=["missing_key", "explicit_null"], + ) + def test_null_filter_matches( + self, session: Session, match_name, match_meta, nomatch_name, nomatch_meta, filter_key + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, match_name, match_meta) + _make_reference(session, asset, nomatch_name, nomatch_meta) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={filter_key: None}) + assert total == 1 + assert refs[0].name == match_name + + +class TestMetadataFilterList: + """Tests for list-based (OR) filtering.""" + + def test_filter_by_list_matches_any(self, session: Session): + """List values should match ANY of the values (OR).""" + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "cat_a", {"category": "a"}) + _make_reference(session, asset, "cat_b", {"category": "b"}) + _make_reference(session, asset, "cat_c", {"category": "c"}) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={"category": ["a", "b"]}) + assert total == 2 + names = {r.name for r in refs} + assert names == {"cat_a", "cat_b"} + + +class TestMetadataFilterMultipleKeys: + """Tests for multiple filter keys (AND semantics).""" + + def test_multiple_keys_must_all_match(self, session: Session): + """Multiple keys should ALL match (AND).""" + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "match", {"type": "model", "version": 2}) + _make_reference(session, asset, "wrong_type", {"type": "config", "version": 2}) + _make_reference(session, asset, "wrong_version", {"type": "model", "version": 1}) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={"type": "model", "version": 2} + ) + assert total == 1 + assert refs[0].name == "match" + + +class TestMetadataFilterEmptyDict: + """Tests for empty filter behavior.""" + + def test_empty_filter_returns_all(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "a", {"key": "val"}) + _make_reference(session, asset, "b", {}) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={}) + assert total == 2 diff --git a/tests-unit/assets_test/queries/test_tags.py b/tests-unit/assets_test/queries/test_tags.py new file mode 100644 index 000000000..4ed99aa37 --- /dev/null +++ b/tests-unit/assets_test/queries/test_tags.py @@ -0,0 +1,366 @@ +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, AssetReferenceMeta, Tag +from app.assets.database.queries import ( + ensure_tags_exist, + get_reference_tags, + set_reference_tags, + add_tags_to_reference, + remove_tags_from_reference, + add_missing_tag_for_asset_id, + remove_missing_tag_for_asset_id, + list_tags_with_usage, + bulk_insert_tags_and_meta, +) +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None) -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestEnsureTagsExist: + def test_creates_new_tags(self, session: Session): + ensure_tags_exist(session, ["alpha", "beta"], tag_type="user") + session.commit() + + tags = session.query(Tag).all() + assert {t.name for t in tags} == {"alpha", "beta"} + + def test_is_idempotent(self, session: Session): + ensure_tags_exist(session, ["alpha"], tag_type="user") + ensure_tags_exist(session, ["alpha"], tag_type="user") + session.commit() + + assert session.query(Tag).count() == 1 + + def test_normalizes_tags(self, session: Session): + ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"]) + session.commit() + + tags = session.query(Tag).all() + assert {t.name for t in tags} == {"alpha", "beta"} + + def test_empty_list_is_noop(self, session: Session): + ensure_tags_exist(session, []) + session.commit() + assert session.query(Tag).count() == 0 + + def test_tag_type_is_set(self, session: Session): + ensure_tags_exist(session, ["system-tag"], tag_type="system") + session.commit() + + tag = session.query(Tag).filter_by(name="system-tag").one() + assert tag.tag_type == "system" + + +class TestGetReferenceTags: + def test_returns_empty_for_no_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + tags = get_reference_tags(session, reference_id=ref.id) + assert tags == [] + + def test_returns_tags_for_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + ensure_tags_exist(session, ["tag1", "tag2"]) + session.add_all([ + AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag1", origin="manual", added_at=get_utc_now()), + AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag2", origin="manual", added_at=get_utc_now()), + ]) + session.flush() + + tags = get_reference_tags(session, reference_id=ref.id) + assert set(tags) == {"tag1", "tag2"} + + +class TestSetReferenceTags: + def test_adds_new_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + result = set_reference_tags(session, reference_id=ref.id, tags=["a", "b"]) + session.commit() + + assert set(result.added) == {"a", "b"} + assert result.removed == [] + assert set(result.total) == {"a", "b"} + + def test_removes_old_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + set_reference_tags(session, reference_id=ref.id, tags=["a", "b", "c"]) + result = set_reference_tags(session, reference_id=ref.id, tags=["a"]) + session.commit() + + assert result.added == [] + assert set(result.removed) == {"b", "c"} + assert result.total == ["a"] + + def test_replaces_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + set_reference_tags(session, reference_id=ref.id, tags=["a", "b"]) + result = set_reference_tags(session, reference_id=ref.id, tags=["b", "c"]) + session.commit() + + assert result.added == ["c"] + assert result.removed == ["a"] + assert set(result.total) == {"b", "c"} + + +class TestAddTagsToReference: + def test_adds_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"]) + session.commit() + + assert set(result.added) == {"x", "y"} + assert result.already_present == [] + + def test_reports_already_present(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["x"]) + result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"]) + session.commit() + + assert result.added == ["y"] + assert result.already_present == ["x"] + + def test_raises_for_missing_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + add_tags_to_reference(session, reference_id="nonexistent", tags=["x"]) + + +class TestRemoveTagsFromReference: + def test_removes_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["a", "b", "c"]) + result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "b"]) + session.commit() + + assert set(result.removed) == {"a", "b"} + assert result.not_present == [] + assert result.total_tags == ["c"] + + def test_reports_not_present(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["a"]) + result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "x"]) + session.commit() + + assert result.removed == ["a"] + assert result.not_present == ["x"] + + def test_raises_for_missing_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + remove_tags_from_reference(session, reference_id="nonexistent", tags=["x"]) + + +class TestMissingTagFunctions: + def test_add_missing_tag_for_asset_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + + add_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert "missing" in tags + + def test_add_missing_tag_is_idempotent(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + + add_missing_tag_for_asset_id(session, asset_id=asset.id) + add_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="missing").all() + assert len(links) == 1 + + def test_remove_missing_tag_for_asset_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + add_missing_tag_for_asset_id(session, asset_id=asset.id) + + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert "missing" not in tags + + +class TestListTagsWithUsage: + def test_returns_tags_with_counts(self, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags_with_usage(session) + + tag_dict = {name: count for name, _, count in rows} + assert tag_dict["used"] == 1 + assert tag_dict["unused"] == 0 + assert total == 2 + + def test_exclude_zero_counts(self, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags_with_usage(session, include_zero=False) + + tag_names = {name for name, _, _ in rows} + assert "used" in tag_names + assert "unused" not in tag_names + + def test_prefix_filter(self, session: Session): + ensure_tags_exist(session, ["alpha", "beta", "alphabet"]) + session.commit() + + rows, total = list_tags_with_usage(session, prefix="alph") + + tag_names = {name for name, _, _ in rows} + assert tag_names == {"alpha", "alphabet"} + + def test_order_by_name(self, session: Session): + ensure_tags_exist(session, ["zebra", "alpha", "middle"]) + session.commit() + + rows, _ = list_tags_with_usage(session, order="name_asc") + + names = [name for name, _, _ in rows] + assert names == ["alpha", "middle", "zebra"] + + def test_owner_visibility(self, session: Session): + ensure_tags_exist(session, ["shared-tag", "owner-tag"]) + + asset = _make_asset(session, "hash1") + shared_ref = _make_reference(session, asset, name="shared", owner_id="") + owner_ref = _make_reference(session, asset, name="owned", owner_id="user1") + + add_tags_to_reference(session, reference_id=shared_ref.id, tags=["shared-tag"]) + add_tags_to_reference(session, reference_id=owner_ref.id, tags=["owner-tag"]) + session.commit() + + # Empty owner sees only shared + rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False) + tag_dict = {name: count for name, _, count in rows} + assert tag_dict.get("shared-tag", 0) == 1 + assert tag_dict.get("owner-tag", 0) == 0 + + # User1 sees both + rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False) + tag_dict = {name: count for name, _, count in rows} + assert tag_dict.get("shared-tag", 0) == 1 + assert tag_dict.get("owner-tag", 0) == 1 + + +class TestBulkInsertTagsAndMeta: + def test_inserts_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["bulk-tag1", "bulk-tag2"]) + session.commit() + + now = get_utc_now() + tag_rows = [ + {"asset_reference_id": ref.id, "tag_name": "bulk-tag1", "origin": "manual", "added_at": now}, + {"asset_reference_id": ref.id, "tag_name": "bulk-tag2", "origin": "manual", "added_at": now}, + ] + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[]) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert set(tags) == {"bulk-tag1", "bulk-tag2"} + + def test_inserts_meta(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + meta_rows = [ + { + "asset_reference_id": ref.id, + "key": "meta-key", + "ordinal": 0, + "val_str": "meta-value", + "val_num": None, + "val_bool": None, + "val_json": None, + }, + ] + bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=meta_rows) + session.commit() + + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "meta-key" + assert meta[0].val_str == "meta-value" + + def test_ignores_conflicts(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["existing-tag"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["existing-tag"]) + session.commit() + + now = get_utc_now() + tag_rows = [ + {"asset_reference_id": ref.id, "tag_name": "existing-tag", "origin": "duplicate", "added_at": now}, + ] + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[]) + session.commit() + + # Should still have only one tag link + links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="existing-tag").all() + assert len(links) == 1 + # Origin should be original, not overwritten + assert links[0].origin == "manual" + + def test_empty_lists_is_noop(self, session: Session): + bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=[]) + assert session.query(AssetReferenceTag).count() == 0 + assert session.query(AssetReferenceMeta).count() == 0 diff --git a/tests-unit/assets_test/services/__init__.py b/tests-unit/assets_test/services/__init__.py new file mode 100644 index 000000000..d0213422e --- /dev/null +++ b/tests-unit/assets_test/services/__init__.py @@ -0,0 +1 @@ +# Service layer tests diff --git a/tests-unit/assets_test/services/conftest.py b/tests-unit/assets_test/services/conftest.py new file mode 100644 index 000000000..31c763d48 --- /dev/null +++ b/tests-unit/assets_test/services/conftest.py @@ -0,0 +1,54 @@ +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import Base + + +@pytest.fixture(autouse=True) +def autoclean_unit_test_assets(): + """Override parent autouse fixture - service unit tests don't need server cleanup.""" + yield + + +@pytest.fixture +def db_engine(): + """In-memory SQLite engine for fast unit tests.""" + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def session(db_engine): + """Session fixture for tests that need direct DB access.""" + with Session(db_engine) as sess: + yield sess + + +@pytest.fixture +def mock_create_session(db_engine): + """Patch create_session to use our in-memory database.""" + from contextlib import contextmanager + from sqlalchemy.orm import Session as SASession + + @contextmanager + def _create_session(): + with SASession(db_engine) as sess: + yield sess + + with patch("app.assets.services.ingest.create_session", _create_session), \ + patch("app.assets.services.asset_management.create_session", _create_session), \ + patch("app.assets.services.tagging.create_session", _create_session): + yield _create_session + + +@pytest.fixture +def temp_dir(): + """Temporary directory for file operations.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) diff --git a/tests-unit/assets_test/services/test_asset_management.py b/tests-unit/assets_test/services/test_asset_management.py new file mode 100644 index 000000000..101ef7292 --- /dev/null +++ b/tests-unit/assets_test/services/test_asset_management.py @@ -0,0 +1,268 @@ +"""Tests for asset_management services.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference +from app.assets.helpers import get_utc_now +from app.assets.services import ( + get_asset_detail, + update_asset_metadata, + delete_asset_reference, + set_asset_preview, +) + + +def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream") + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestGetAssetDetail: + def test_returns_none_for_nonexistent(self, mock_create_session): + result = get_asset_detail(reference_id="nonexistent") + assert result is None + + def test_returns_asset_with_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["alpha", "beta"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"]) + session.commit() + + result = get_asset_detail(reference_id=ref.id) + + assert result is not None + assert result.ref.id == ref.id + assert result.asset.hash == asset.hash + assert set(result.tags) == {"alpha", "beta"} + + def test_respects_owner_visibility(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + # Wrong owner cannot see + result = get_asset_detail(reference_id=ref.id, owner_id="user2") + assert result is None + + # Correct owner can see + result = get_asset_detail(reference_id=ref.id, owner_id="user1") + assert result is not None + + +class TestUpdateAssetMetadata: + def test_updates_name(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, name="old_name.bin") + ref_id = ref.id + session.commit() + + update_asset_metadata( + reference_id=ref_id, + name="new_name.bin", + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.name == "new_name.bin" + + def test_updates_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["old"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["old"]) + session.commit() + + result = update_asset_metadata( + reference_id=ref.id, + tags=["new1", "new2"], + ) + + assert set(result.tags) == {"new1", "new2"} + assert "old" not in result.tags + + def test_updates_user_metadata(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ref_id = ref.id + session.commit() + + update_asset_metadata( + reference_id=ref_id, + user_metadata={"key": "value", "num": 42}, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.user_metadata["key"] == "value" + assert updated_ref.user_metadata["num"] == 42 + + def test_raises_for_nonexistent(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + update_asset_metadata(reference_id="nonexistent", name="fail") + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + update_asset_metadata( + reference_id=ref.id, + name="new", + owner_id="user2", + ) + + +class TestDeleteAssetReference: + def test_soft_deletes_reference(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ref_id = ref.id + session.commit() + + result = delete_asset_reference( + reference_id=ref_id, + owner_id="", + delete_content_if_orphan=False, + ) + + assert result is True + # Row still exists but is marked as soft-deleted + session.expire_all() + row = session.get(AssetReference, ref_id) + assert row is not None + assert row.deleted_at is not None + + def test_returns_false_for_nonexistent(self, mock_create_session): + result = delete_asset_reference( + reference_id="nonexistent", + owner_id="", + ) + assert result is False + + def test_returns_false_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + ref_id = ref.id + session.commit() + + result = delete_asset_reference( + reference_id=ref_id, + owner_id="user2", + ) + + assert result is False + assert session.get(AssetReference, ref_id) is not None + + def test_keeps_asset_if_other_references_exist(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref1 = _make_reference(session, asset, name="ref1") + _make_reference(session, asset, name="ref2") # Second ref keeps asset alive + asset_id = asset.id + session.commit() + + delete_asset_reference( + reference_id=ref1.id, + owner_id="", + delete_content_if_orphan=True, + ) + + # Asset should still exist + assert session.get(Asset, asset_id) is not None + + def test_deletes_orphaned_asset(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + asset_id = asset.id + ref_id = ref.id + session.commit() + + delete_asset_reference( + reference_id=ref_id, + owner_id="", + delete_content_if_orphan=True, + ) + + # Both ref and asset should be gone + assert session.get(AssetReference, ref_id) is None + assert session.get(Asset, asset_id) is None + + +class TestSetAssetPreview: + def test_sets_preview(self, mock_create_session, session: Session): + asset = _make_asset(session, hash_val="blake3:main") + preview_asset = _make_asset(session, hash_val="blake3:preview") + ref = _make_reference(session, asset) + ref_id = ref.id + preview_id = preview_asset.id + session.commit() + + set_asset_preview( + reference_id=ref_id, + preview_asset_id=preview_id, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.preview_id == preview_id + + def test_clears_preview(self, mock_create_session, session: Session): + asset = _make_asset(session) + preview_asset = _make_asset(session, hash_val="blake3:preview") + ref = _make_reference(session, asset) + ref.preview_id = preview_asset.id + ref_id = ref.id + session.commit() + + set_asset_preview( + reference_id=ref_id, + preview_asset_id=None, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.preview_id is None + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + set_asset_preview(reference_id="nonexistent") + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + set_asset_preview( + reference_id=ref.id, + preview_asset_id=None, + owner_id="user2", + ) diff --git a/tests-unit/assets_test/services/test_bulk_ingest.py b/tests-unit/assets_test/services/test_bulk_ingest.py new file mode 100644 index 000000000..26e22a01d --- /dev/null +++ b/tests-unit/assets_test/services/test_bulk_ingest.py @@ -0,0 +1,137 @@ +"""Tests for bulk ingest services.""" + +from pathlib import Path + +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets + + +class TestBatchInsertSeedAssets: + def test_populates_mime_type_for_model_files(self, session: Session, temp_dir: Path): + """Verify mime_type is stored in the Asset table for model files.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"fake safetensors content") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 24, + "mtime_ns": 1234567890000000000, + "info_name": "Test Model", + "tags": ["models"], + "fname": "model.safetensors", + "metadata": None, + "hash": None, + "mime_type": "application/safetensors", + } + ] + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == 1 + + # Verify Asset has mime_type populated + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].mime_type == "application/safetensors" + + def test_mime_type_none_when_not_provided(self, session: Session, temp_dir: Path): + """Verify mime_type is None when not provided in spec.""" + file_path = temp_dir / "unknown.bin" + file_path.write_bytes(b"binary data") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 11, + "mtime_ns": 1234567890000000000, + "info_name": "Unknown File", + "tags": [], + "fname": "unknown.bin", + "metadata": None, + "hash": None, + "mime_type": None, + } + ] + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == 1 + + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].mime_type is None + + def test_various_model_mime_types(self, session: Session, temp_dir: Path): + """Verify various model file types get correct mime_type.""" + test_cases = [ + ("model.safetensors", "application/safetensors"), + ("model.pt", "application/pytorch"), + ("model.ckpt", "application/pickle"), + ("model.gguf", "application/gguf"), + ] + + specs: list[SeedAssetSpec] = [] + for filename, mime_type in test_cases: + file_path = temp_dir / filename + file_path.write_bytes(b"content") + specs.append( + { + "abs_path": str(file_path), + "size_bytes": 7, + "mtime_ns": 1234567890000000000, + "info_name": filename, + "tags": [], + "fname": filename, + "metadata": None, + "hash": None, + "mime_type": mime_type, + } + ) + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == len(test_cases) + + for filename, expected_mime in test_cases: + ref = session.query(AssetReference).filter_by(name=filename).first() + assert ref is not None + asset = session.query(Asset).filter_by(id=ref.asset_id).first() + assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}" + + +class TestMetadataExtraction: + def test_extracts_mime_type_for_model_files(self, temp_dir: Path): + """Verify metadata extraction returns correct mime_type for model files.""" + from app.assets.services.metadata_extract import extract_file_metadata + + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"fake safetensors content") + + meta = extract_file_metadata(str(file_path)) + + assert meta.content_type == "application/safetensors" + + def test_mime_type_for_various_model_formats(self, temp_dir: Path): + """Verify various model file types get correct mime_type from metadata.""" + from app.assets.services.metadata_extract import extract_file_metadata + + test_cases = [ + ("model.safetensors", "application/safetensors"), + ("model.sft", "application/safetensors"), + ("model.pt", "application/pytorch"), + ("model.pth", "application/pytorch"), + ("model.ckpt", "application/pickle"), + ("model.pkl", "application/pickle"), + ("model.gguf", "application/gguf"), + ] + + for filename, expected_mime in test_cases: + file_path = temp_dir / filename + file_path.write_bytes(b"content") + + meta = extract_file_metadata(str(file_path)) + + assert meta.content_type == expected_mime, f"Expected {expected_mime} for {filename}, got {meta.content_type}" diff --git a/tests-unit/assets_test/services/test_enrich.py b/tests-unit/assets_test/services/test_enrich.py new file mode 100644 index 000000000..2bd79a01a --- /dev/null +++ b/tests-unit/assets_test/services/test_enrich.py @@ -0,0 +1,207 @@ +"""Tests for asset enrichment (mime_type and hash population).""" +from pathlib import Path + +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.scanner import ( + ENRICHMENT_HASHED, + ENRICHMENT_METADATA, + ENRICHMENT_STUB, + enrich_asset, +) + + +def _create_stub_asset( + session: Session, + file_path: str, + asset_id: str = "test-asset-id", + reference_id: str = "test-ref-id", + name: str | None = None, +) -> tuple[Asset, AssetReference]: + """Create a stub asset with reference for testing enrichment.""" + asset = Asset( + id=asset_id, + hash=None, + size_bytes=100, + mime_type=None, + ) + session.add(asset) + session.flush() + + ref = AssetReference( + id=reference_id, + asset_id=asset_id, + name=name or f"test-asset-{asset_id}", + owner_id="system", + file_path=file_path, + mtime_ns=1234567890000000000, + enrichment_level=ENRICHMENT_STUB, + ) + session.add(ref) + session.flush() + + return asset, ref + + +class TestEnrichAsset: + def test_extracts_mime_type_and_updates_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify mime_type is written to the Asset table during enrichment.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"\x00" * 100) + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-1", "ref-1" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=False, + ) + + assert new_level == ENRICHMENT_METADATA + + session.expire_all() + updated_asset = session.get(Asset, "asset-1") + assert updated_asset is not None + assert updated_asset.mime_type == "application/safetensors" + + def test_computes_hash_and_updates_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify hash is written to the Asset table during enrichment.""" + file_path = temp_dir / "data.bin" + file_path.write_bytes(b"test content for hashing") + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-2", "ref-2" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + assert new_level == ENRICHMENT_HASHED + + session.expire_all() + updated_asset = session.get(Asset, "asset-2") + assert updated_asset is not None + assert updated_asset.hash is not None + assert updated_asset.hash.startswith("blake3:") + + def test_enrichment_updates_both_mime_and_hash( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify both mime_type and hash are set when full enrichment runs.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"\x00" * 50) + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-3", "ref-3" + ) + session.commit() + + enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + session.expire_all() + updated_asset = session.get(Asset, "asset-3") + assert updated_asset is not None + assert updated_asset.mime_type == "application/safetensors" + assert updated_asset.hash is not None + assert updated_asset.hash.startswith("blake3:") + + def test_missing_file_returns_stub_level( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify missing files don't cause errors and return STUB level.""" + file_path = temp_dir / "nonexistent.bin" + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-4", "ref-4" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + assert new_level == ENRICHMENT_STUB + + session.expire_all() + updated_asset = session.get(Asset, "asset-4") + assert updated_asset.mime_type is None + assert updated_asset.hash is None + + def test_duplicate_hash_merges_into_existing_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify duplicate files merge into existing asset instead of failing.""" + file_path_1 = temp_dir / "file1.bin" + file_path_2 = temp_dir / "file2.bin" + content = b"identical content" + file_path_1.write_bytes(content) + file_path_2.write_bytes(content) + + asset1, ref1 = _create_stub_asset( + session, str(file_path_1), "asset-dup-1", "ref-dup-1" + ) + asset2, ref2 = _create_stub_asset( + session, str(file_path_2), "asset-dup-2", "ref-dup-2" + ) + session.commit() + + enrich_asset( + session, + file_path=str(file_path_1), + reference_id=ref1.id, + asset_id=asset1.id, + extract_metadata=True, + compute_hash=True, + ) + + enrich_asset( + session, + file_path=str(file_path_2), + reference_id=ref2.id, + asset_id=asset2.id, + extract_metadata=True, + compute_hash=True, + ) + + session.expire_all() + + updated_asset1 = session.get(Asset, "asset-dup-1") + assert updated_asset1 is not None + assert updated_asset1.hash is not None + + updated_asset2 = session.get(Asset, "asset-dup-2") + assert updated_asset2 is None + + updated_ref2 = session.get(AssetReference, "ref-dup-2") + assert updated_ref2 is not None + assert updated_ref2.asset_id == "asset-dup-1" diff --git a/tests-unit/assets_test/services/test_ingest.py b/tests-unit/assets_test/services/test_ingest.py new file mode 100644 index 000000000..367bc7721 --- /dev/null +++ b/tests-unit/assets_test/services/test_ingest.py @@ -0,0 +1,229 @@ +"""Tests for ingest services.""" +from pathlib import Path + +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, Tag +from app.assets.database.queries import get_reference_tags +from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset + + +class TestIngestFileFromPath: + def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "test_file.bin" + file_path.write_bytes(b"test content") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:abc123", + size_bytes=12, + mtime_ns=1234567890000000000, + mime_type="application/octet-stream", + ) + + assert result.asset_created is True + assert result.ref_created is True + assert result.reference_id is not None + + # Verify DB state + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].hash == "blake3:abc123" + + refs = session.query(AssetReference).all() + assert len(refs) == 1 + assert refs[0].file_path == str(file_path) + + def test_creates_reference_when_name_provided(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"model data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:def456", + size_bytes=10, + mtime_ns=1234567890000000000, + mime_type="application/octet-stream", + info_name="My Model", + owner_id="user1", + ) + + assert result.asset_created is True + assert result.reference_id is not None + + ref = session.query(AssetReference).first() + assert ref is not None + assert ref.name == "My Model" + assert ref.owner_id == "user1" + + def test_creates_tags_when_provided(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "tagged.bin" + file_path.write_bytes(b"data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:ghi789", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="Tagged Asset", + tags=["models", "checkpoints"], + ) + + assert result.reference_id is not None + + # Verify tags were created and linked + tags = session.query(Tag).all() + tag_names = {t.name for t in tags} + assert "models" in tag_names + assert "checkpoints" in tag_names + + ref_tags = get_reference_tags(session, reference_id=result.reference_id) + assert set(ref_tags) == {"models", "checkpoints"} + + def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "dup.bin" + file_path.write_bytes(b"content") + + # First ingest + r1 = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:repeat", + size_bytes=7, + mtime_ns=1234567890000000000, + ) + assert r1.asset_created is True + + # Second ingest with same hash - should update, not create + r2 = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:repeat", + size_bytes=7, + mtime_ns=1234567890000000001, # different mtime + ) + assert r2.asset_created is False + assert r2.ref_created is False + assert r2.ref_updated is True + + # Still only one asset + assets = session.query(Asset).all() + assert len(assets) == 1 + + def test_validates_preview_id(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "with_preview.bin" + file_path.write_bytes(b"data") + + # Create a preview asset first + preview_asset = Asset(hash="blake3:preview", size_bytes=100) + session.add(preview_asset) + session.commit() + preview_id = preview_asset.id + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:main", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="With Preview", + preview_id=preview_id, + ) + + assert result.reference_id is not None + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.preview_id == preview_id + + def test_invalid_preview_id_is_cleared(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "bad_preview.bin" + file_path.write_bytes(b"data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:badpreview", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="Bad Preview", + preview_id="nonexistent-uuid", + ) + + assert result.reference_id is not None + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.preview_id is None + + +class TestRegisterExistingAsset: + def test_creates_reference_for_existing_asset(self, mock_create_session, session: Session): + # Create existing asset + asset = Asset(hash="blake3:existing", size_bytes=1024, mime_type="image/png") + session.add(asset) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:existing", + name="Registered Asset", + user_metadata={"key": "value"}, + tags=["models"], + ) + + assert result.created is True + assert "models" in result.tags + + # Verify by re-fetching from DB + session.expire_all() + refs = session.query(AssetReference).filter_by(name="Registered Asset").all() + assert len(refs) == 1 + + def test_creates_new_reference_even_with_same_name(self, mock_create_session, session: Session): + # Create asset and reference + asset = Asset(hash="blake3:withref", size_bytes=512) + session.add(asset) + session.flush() + + from app.assets.helpers import get_utc_now + ref = AssetReference( + owner_id="", + name="Existing Ref", + asset_id=asset.id, + created_at=get_utc_now(), + updated_at=get_utc_now(), + last_access_time=get_utc_now(), + ) + session.add(ref) + session.flush() + ref_id = ref.id + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:withref", + name="Existing Ref", + owner_id="", + ) + + # Multiple files with same name are allowed + assert result.created is True + + # Verify two AssetReferences exist for this name + session.expire_all() + refs = session.query(AssetReference).filter_by(name="Existing Ref").all() + assert len(refs) == 2 + assert ref_id in [r.id for r in refs] + + def test_raises_for_nonexistent_hash(self, mock_create_session): + with pytest.raises(ValueError, match="No asset with hash"): + _register_existing_asset( + asset_hash="blake3:doesnotexist", + name="Fail", + ) + + def test_applies_tags_to_new_reference(self, mock_create_session, session: Session): + asset = Asset(hash="blake3:tagged", size_bytes=256) + session.add(asset) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:tagged", + name="Tagged Ref", + tags=["alpha", "beta"], + ) + + assert result.created is True + assert set(result.tags) == {"alpha", "beta"} diff --git a/tests-unit/assets_test/services/test_tagging.py b/tests-unit/assets_test/services/test_tagging.py new file mode 100644 index 000000000..ab69e5dc1 --- /dev/null +++ b/tests-unit/assets_test/services/test_tagging.py @@ -0,0 +1,197 @@ +"""Tests for tagging services.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference +from app.assets.helpers import get_utc_now +from app.assets.services import apply_tags, remove_tags, list_tags + + +def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestApplyTags: + def test_adds_new_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + session.commit() + + result = apply_tags( + reference_id=ref.id, + tags=["alpha", "beta"], + ) + + assert set(result.added) == {"alpha", "beta"} + assert result.already_present == [] + assert set(result.total_tags) == {"alpha", "beta"} + + def test_reports_already_present(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["existing"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["existing"]) + session.commit() + + result = apply_tags( + reference_id=ref.id, + tags=["existing", "new"], + ) + + assert result.added == ["new"] + assert result.already_present == ["existing"] + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + apply_tags(reference_id="nonexistent", tags=["x"]) + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + apply_tags( + reference_id=ref.id, + tags=["new"], + owner_id="user2", + ) + + +class TestRemoveTags: + def test_removes_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["a", "b", "c"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["a", "b", "c"]) + session.commit() + + result = remove_tags( + reference_id=ref.id, + tags=["a", "b"], + ) + + assert set(result.removed) == {"a", "b"} + assert result.not_present == [] + assert result.total_tags == ["c"] + + def test_reports_not_present(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["present"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["present"]) + session.commit() + + result = remove_tags( + reference_id=ref.id, + tags=["present", "absent"], + ) + + assert result.removed == ["present"] + assert result.not_present == ["absent"] + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + remove_tags(reference_id="nonexistent", tags=["x"]) + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + remove_tags( + reference_id=ref.id, + tags=["x"], + owner_id="user2", + ) + + +class TestListTags: + def test_returns_tags_with_counts(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + asset = _make_asset(session) + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags() + + tag_dict = {name: count for name, _, count in rows} + assert tag_dict["used"] == 1 + assert tag_dict["unused"] == 0 + assert total == 2 + + def test_excludes_zero_counts(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + asset = _make_asset(session) + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags(include_zero=False) + + tag_names = {name for name, _, _ in rows} + assert "used" in tag_names + assert "unused" not in tag_names + + def test_prefix_filter(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["alpha", "beta", "alphabet"]) + session.commit() + + rows, _ = list_tags(prefix="alph") + + tag_names = {name for name, _, _ in rows} + assert tag_names == {"alpha", "alphabet"} + + def test_order_by_name(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["zebra", "alpha", "middle"]) + session.commit() + + rows, _ = list_tags(order="name_asc") + + names = [name for name, _, _ in rows] + assert names == ["alpha", "middle", "zebra"] + + def test_pagination(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["a", "b", "c", "d", "e"]) + session.commit() + + rows, total = list_tags(limit=2, offset=1, order="name_asc") + + assert total == 5 + assert len(rows) == 2 + names = [name for name, _, _ in rows] + assert names == ["b", "c"] + + def test_clamps_limit(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["a"]) + session.commit() + + # Service should clamp limit to max 1000 + rows, _ = list_tags(limit=2000) + assert len(rows) <= 1000 diff --git a/tests-unit/assets_test/test_assets_missing_sync.py b/tests-unit/assets_test/test_assets_missing_sync.py index 78fa7b404..47dc130cb 100644 --- a/tests-unit/assets_test/test_assets_missing_sync.py +++ b/tests-unit/assets_test/test_assets_missing_sync.py @@ -4,7 +4,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets diff --git a/tests-unit/assets_test/test_crud.py b/tests-unit/assets_test/test_crud.py index d2b69f475..07310223e 100644 --- a/tests-unit/assets_test/test_crud.py +++ b/tests-unit/assets_test/test_crud.py @@ -4,7 +4,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets def test_create_from_hash_success( @@ -24,11 +24,11 @@ def test_create_from_hash_success( assert b1["created_new"] is False aid = b1["id"] - # Calling again with the same name should return the same AssetInfo id + # Calling again with the same name creates a new AssetReference (duplicates allowed) r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) b2 = r2.json() assert r2.status_code == 201, b2 - assert b2["id"] == aid + assert b2["id"] != aid # new reference, not the same one def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asset: dict): @@ -42,8 +42,8 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse assert "user_metadata" in detail assert "filename" in detail["user_metadata"] - # DELETE - rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + # DELETE (hard delete to also remove underlying asset and file) + rd = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) assert rd.status_code == 204 # GET again -> 404 @@ -53,6 +53,35 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse assert body["error"]["code"] == "ASSET_NOT_FOUND" +def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + asset_hash = seeded_asset["asset_hash"] + + # Soft-delete (default, no delete_content param) + rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + assert rd.status_code == 204 + + # GET by reference ID -> 404 (soft-deleted references are hidden) + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + assert rg.status_code == 404 + + # Asset identity is preserved (underlying content still exists) + rh = http.head(f"{api_base}/api/assets/hash/{asset_hash}", timeout=120) + assert rh.status_code == 200 + + # Soft-deleted reference should not appear in listings + rl = http.get( + f"{api_base}/api/assets", + params={"include_tags": "unit-tests", "limit": "500"}, + timeout=120, + ) + ids = [a["id"] for a in rl.json().get("assets", [])] + assert aid not in ids + + # Clean up: hard-delete the soft-deleted reference and orphaned asset + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) + + def test_delete_upon_reference_count( http: requests.Session, api_base: str, seeded_asset: dict ): @@ -70,21 +99,32 @@ def test_delete_upon_reference_count( assert copy["asset_hash"] == src_hash assert copy["created_new"] is False - # Delete original reference -> asset identity must remain + # Soft-delete original reference (default) -> asset identity must remain aid1 = seeded_asset["id"] rd1 = http.delete(f"{api_base}/api/assets/{aid1}", timeout=120) assert rd1.status_code == 204 rh1 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) - assert rh1.status_code == 200 # identity still present + assert rh1.status_code == 200 # identity still present (second ref exists) - # Delete the last reference with default semantics -> identity and cached files removed + # Soft-delete the last reference -> asset identity preserved (no hard delete) aid2 = copy["id"] rd2 = http.delete(f"{api_base}/api/assets/{aid2}", timeout=120) assert rd2.status_code == 204 rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) - assert rh2.status_code == 404 # orphan content removed + assert rh2.status_code == 200 # asset identity preserved (soft delete) + + # Re-associate via from-hash, then hard-delete -> orphan content removed + r3 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + assert r3.status_code == 201, r3.json() + aid3 = r3.json()["id"] + + rd3 = http.delete(f"{api_base}/api/assets/{aid3}?delete_content=true", timeout=120) + assert rd3.status_code == 204 + + rh3 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) + assert rh3.status_code == 404 # orphan content removed def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict): @@ -126,42 +166,52 @@ def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api assert body == b"" -def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str): - bogus = str(uuid.uuid4()) - r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120) +@pytest.mark.parametrize( + "method,endpoint_template,payload,expected_status,error_code", + [ + # Delete nonexistent asset + ("delete", "/api/assets/{uuid}", None, 404, "ASSET_NOT_FOUND"), + # Bad hash algorithm in from-hash + ( + "post", + "/api/assets/from-hash", + {"hash": "sha256:" + "0" * 64, "name": "x.bin", "tags": ["models", "checkpoints", "unit-tests"]}, + 400, + "INVALID_BODY", + ), + # Get with bad UUID format + ("get", "/api/assets/not-a-uuid", None, 404, None), + # Get content with bad UUID format + ("get", "/api/assets/not-a-uuid/content", None, 404, None), + ], + ids=["delete_nonexistent", "bad_hash_algorithm", "get_bad_uuid", "content_bad_uuid"], +) +def test_error_responses( + http: requests.Session, api_base: str, method, endpoint_template, payload, expected_status, error_code +): + # Replace {uuid} placeholder with a random UUID for delete test + endpoint = endpoint_template.replace("{uuid}", str(uuid.uuid4())) + url = f"{api_base}{endpoint}" + + if method == "get": + r = http.get(url, timeout=120) + elif method == "post": + r = http.post(url, json=payload, timeout=120) + elif method == "delete": + r = http.delete(url, timeout=120) + + assert r.status_code == expected_status + if error_code: + body = r.json() + assert body["error"]["code"] == error_code + + +def test_create_from_hash_invalid_json(http: requests.Session, api_base: str): + """Invalid JSON body requires special handling (data= instead of json=).""" + r = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120) body = r.json() - assert r.status_code == 404 - assert body["error"]["code"] == "ASSET_NOT_FOUND" - - -def test_create_from_hash_invalids(http: requests.Session, api_base: str): - # Bad hash algorithm - bad = { - "hash": "sha256:" + "0" * 64, - "name": "x.bin", - "tags": ["models", "checkpoints", "unit-tests"], - } - r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_BODY" - - # Invalid JSON body - r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_JSON" - - -def test_get_update_download_bad_ids(http: requests.Session, api_base: str): - # All endpoints should be not found, as we UUID regex directly in the route definition. - bad_id = "not-a-uuid" - - r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120) - assert r1.status_code == 404 - - r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120) - assert r3.status_code == 404 + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_JSON" def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict): diff --git a/tests-unit/assets_test/test_downloads.py b/tests-unit/assets_test/test_downloads.py index cdebf9082..672ba9728 100644 --- a/tests-unit/assets_test/test_downloads.py +++ b/tests-unit/assets_test/test_downloads.py @@ -6,7 +6,7 @@ from typing import Optional import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict): @@ -117,7 +117,7 @@ def test_download_missing_file_returns_404( assert body["error"]["code"] == "FILE_NOT_FOUND" finally: # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. - dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + dr = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) dr.content diff --git a/tests-unit/assets_test/test_file_utils.py b/tests-unit/assets_test/test_file_utils.py new file mode 100644 index 000000000..e3591d49b --- /dev/null +++ b/tests-unit/assets_test/test_file_utils.py @@ -0,0 +1,121 @@ +import os +import sys + +import pytest + +from app.assets.services.file_utils import is_visible, list_files_recursively + + +class TestIsVisible: + def test_visible_file(self): + assert is_visible("file.txt") is True + + def test_hidden_file(self): + assert is_visible(".hidden") is False + + def test_hidden_directory(self): + assert is_visible(".git") is False + + def test_visible_directory(self): + assert is_visible("src") is True + + def test_dotdot_is_hidden(self): + assert is_visible("..") is False + + def test_dot_is_hidden(self): + assert is_visible(".") is False + + +class TestListFilesRecursively: + def test_skips_hidden_files(self, tmp_path): + (tmp_path / "visible.txt").write_text("a") + (tmp_path / ".hidden").write_text("b") + + result = list_files_recursively(str(tmp_path)) + + assert len(result) == 1 + assert result[0].endswith("visible.txt") + + def test_skips_hidden_directories(self, tmp_path): + hidden_dir = tmp_path / ".hidden_dir" + hidden_dir.mkdir() + (hidden_dir / "file.txt").write_text("a") + + visible_dir = tmp_path / "visible_dir" + visible_dir.mkdir() + (visible_dir / "file.txt").write_text("b") + + result = list_files_recursively(str(tmp_path)) + + assert len(result) == 1 + assert "visible_dir" in result[0] + assert ".hidden_dir" not in result[0] + + def test_empty_directory(self, tmp_path): + result = list_files_recursively(str(tmp_path)) + assert result == [] + + def test_nonexistent_directory(self, tmp_path): + result = list_files_recursively(str(tmp_path / "nonexistent")) + assert result == [] + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_follows_symlinked_directories(self, tmp_path): + target = tmp_path / "real_dir" + target.mkdir() + (target / "model.safetensors").write_text("data") + + root = tmp_path / "root" + root.mkdir() + (root / "link").symlink_to(target) + + result = list_files_recursively(str(root)) + + assert len(result) == 1 + assert result[0].endswith("model.safetensors") + assert "link" in result[0] + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_follows_symlinked_files(self, tmp_path): + real_file = tmp_path / "real.txt" + real_file.write_text("content") + + root = tmp_path / "root" + root.mkdir() + (root / "link.txt").symlink_to(real_file) + + result = list_files_recursively(str(root)) + + assert len(result) == 1 + assert result[0].endswith("link.txt") + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_circular_symlinks_do_not_loop(self, tmp_path): + dir_a = tmp_path / "a" + dir_a.mkdir() + (dir_a / "file.txt").write_text("a") + # a/b -> a (circular) + (dir_a / "b").symlink_to(dir_a) + + result = list_files_recursively(str(dir_a)) + + assert len(result) == 1 + assert result[0].endswith("file.txt") + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_mutual_circular_symlinks(self, tmp_path): + dir_a = tmp_path / "a" + dir_b = tmp_path / "b" + dir_a.mkdir() + dir_b.mkdir() + (dir_a / "file_a.txt").write_text("a") + (dir_b / "file_b.txt").write_text("b") + # a/link_b -> b and b/link_a -> a + (dir_a / "link_b").symlink_to(dir_b) + (dir_b / "link_a").symlink_to(dir_a) + + result = list_files_recursively(str(dir_a)) + basenames = sorted(os.path.basename(p) for p in result) + + assert "file_a.txt" in basenames + assert "file_b.txt" in basenames diff --git a/tests-unit/assets_test/test_list_filter.py b/tests-unit/assets_test/test_list_filter.py index 82e109832..dcb7a73ca 100644 --- a/tests-unit/assets_test/test_list_filter.py +++ b/tests-unit/assets_test/test_list_filter.py @@ -1,6 +1,7 @@ import time import uuid +import pytest import requests @@ -283,30 +284,21 @@ def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asse assert b2["has_more"] is False -def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base): - r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_QUERY" - - r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_QUERY" - - -def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str): - # limit too small - r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_QUERY" - - # bad metadata JSON - r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_QUERY" +@pytest.mark.parametrize( + "params,error_code", + [ + ({"offset": "-1"}, "INVALID_QUERY"), + ({"limit": "abc"}, "INVALID_QUERY"), + ({"limit": "0"}, "INVALID_QUERY"), + ({"metadata_filter": "{not json"}, "INVALID_QUERY"), + ], + ids=["negative_offset", "non_int_limit", "zero_limit", "invalid_metadata_json"], +) +def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str, params, error_code): + r = http.get(api_base + "/api/assets", params=params, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == error_code def test_list_assets_name_contains_literal_underscore( diff --git a/tests-unit/assets_test/test_prune_orphaned_assets.py b/tests-unit/assets_test/test_prune_orphaned_assets.py index f602e5a77..1fbd4d4e2 100644 --- a/tests-unit/assets_test/test_prune_orphaned_assets.py +++ b/tests-unit/assets_test/test_prune_orphaned_assets.py @@ -3,7 +3,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets @pytest.fixture diff --git a/tests-unit/assets_test/test_sync_references.py b/tests-unit/assets_test/test_sync_references.py new file mode 100644 index 000000000..94cc255bc --- /dev/null +++ b/tests-unit/assets_test/test_sync_references.py @@ -0,0 +1,482 @@ +"""Tests for sync_references_with_filesystem in scanner.py.""" + +import os +import tempfile +from datetime import datetime +from pathlib import Path +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import ( + Asset, + AssetReference, + AssetReferenceTag, + Base, + Tag, +) +from app.assets.database.queries.asset_reference import ( + bulk_insert_references_ignore_conflicts, + get_references_for_prefixes, + get_unenriched_references, + restore_references_by_paths, +) +from app.assets.scanner import sync_references_with_filesystem +from app.assets.services.file_utils import get_mtime_ns + + +@pytest.fixture +def db_engine(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def session(db_engine): + with Session(db_engine) as sess: + yield sess + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +def _create_file(temp_dir: Path, name: str, content: bytes = b"\x00" * 100) -> str: + """Create a file and return its absolute path (no symlink resolution).""" + p = temp_dir / name + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(content) + return os.path.abspath(str(p)) + + +def _stat_mtime_ns(path: str) -> int: + return get_mtime_ns(os.stat(path, follow_symlinks=True)) + + +def _make_asset( + session: Session, + asset_id: str, + file_path: str, + ref_id: str, + *, + asset_hash: str | None = None, + size_bytes: int = 100, + mtime_ns: int | None = None, + needs_verify: bool = False, + is_missing: bool = False, +) -> tuple[Asset, AssetReference]: + """Insert an Asset + AssetReference and flush.""" + asset = session.get(Asset, asset_id) + if asset is None: + asset = Asset(id=asset_id, hash=asset_hash, size_bytes=size_bytes) + session.add(asset) + session.flush() + + ref = AssetReference( + id=ref_id, + asset_id=asset_id, + name=f"test-{ref_id}", + owner_id="system", + file_path=file_path, + mtime_ns=mtime_ns, + needs_verify=needs_verify, + is_missing=is_missing, + ) + session.add(ref) + session.flush() + return asset, ref + + +def _ensure_missing_tag(session: Session): + """Ensure the 'missing' tag exists.""" + if not session.get(Tag, "missing"): + session.add(Tag(name="missing", tag_type="system")) + session.flush() + + +class _VerifyCase: + def __init__(self, id, stat_unchanged, needs_verify_before, expect_needs_verify): + self.id = id + self.stat_unchanged = stat_unchanged + self.needs_verify_before = needs_verify_before + self.expect_needs_verify = expect_needs_verify + + +VERIFY_CASES = [ + _VerifyCase( + id="unchanged_clears_verify", + stat_unchanged=True, + needs_verify_before=True, + expect_needs_verify=False, + ), + _VerifyCase( + id="unchanged_keeps_clear", + stat_unchanged=True, + needs_verify_before=False, + expect_needs_verify=False, + ), + _VerifyCase( + id="changed_sets_verify", + stat_unchanged=False, + needs_verify_before=False, + expect_needs_verify=True, + ), + _VerifyCase( + id="changed_keeps_verify", + stat_unchanged=False, + needs_verify_before=True, + expect_needs_verify=True, + ), +] + + +@pytest.mark.parametrize("case", VERIFY_CASES, ids=lambda c: c.id) +def test_needs_verify_toggling(session, temp_dir, case): + """needs_verify is set/cleared based on mtime+size match.""" + fp = _create_file(temp_dir, "model.bin") + real_mtime = _stat_mtime_ns(fp) + + mtime_for_db = real_mtime if case.stat_unchanged else real_mtime + 1 + _make_asset( + session, "a1", fp, "r1", + asset_hash="blake3:abc", + mtime_ns=mtime_for_db, + needs_verify=case.needs_verify_before, + ) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.needs_verify is case.expect_needs_verify + + +class _MissingCase: + def __init__(self, id, file_exists, expect_is_missing): + self.id = id + self.file_exists = file_exists + self.expect_is_missing = expect_is_missing + + +MISSING_CASES = [ + _MissingCase(id="existing_file_not_missing", file_exists=True, expect_is_missing=False), + _MissingCase(id="missing_file_marked_missing", file_exists=False, expect_is_missing=True), +] + + +@pytest.mark.parametrize("case", MISSING_CASES, ids=lambda c: c.id) +def test_is_missing_flag(session, temp_dir, case): + """is_missing reflects whether the file exists on disk.""" + if case.file_exists: + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + else: + fp = str(temp_dir / "gone.bin") + mtime = 999 + + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.is_missing is case.expect_is_missing + + +def test_seed_asset_all_missing_deletes_asset(session, temp_dir): + """Seed asset with all refs missing gets deleted entirely.""" + fp = str(temp_dir / "gone.bin") + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + assert session.get(Asset, "seed1") is None + assert session.get(AssetReference, "r1") is None + + +def test_seed_asset_some_exist_returns_survivors(session, temp_dir): + """Seed asset with at least one existing ref survives and is returned.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + survivors = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + session.commit() + + assert session.get(Asset, "seed1") is not None + assert os.path.abspath(fp) in survivors + + +def test_hashed_asset_prunes_missing_refs_when_one_is_ok(session, temp_dir): + """Hashed asset with one stat-unchanged ref deletes missing refs.""" + fp_ok = _create_file(temp_dir, "good.bin") + fp_gone = str(temp_dir / "gone.bin") + mtime = _stat_mtime_ns(fp_ok) + + _make_asset(session, "h1", fp_ok, "r_ok", asset_hash="blake3:aaa", mtime_ns=mtime) + # Second ref on same asset, file missing + ref_gone = AssetReference( + id="r_gone", asset_id="h1", name="gone", + owner_id="system", file_path=fp_gone, mtime_ns=999, + ) + session.add(ref_gone) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + assert session.get(AssetReference, "r_ok") is not None + assert session.get(AssetReference, "r_gone") is None + + +def test_hashed_asset_all_missing_keeps_refs(session, temp_dir): + """Hashed asset with all refs missing keeps refs (no pruning).""" + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + assert session.get(AssetReference, "r1") is not None + ref = session.get(AssetReference, "r1") + assert ref.is_missing is True + + +def test_missing_tag_added_when_all_refs_gone(session, temp_dir): + """Missing tag is added to hashed asset when all refs are missing.""" + _ensure_missing_tag(session) + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=True, + ) + session.commit() + + session.expire_all() + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is not None + + +def test_missing_tag_removed_when_ref_ok(session, temp_dir): + """Missing tag is removed from hashed asset when a ref is stat-unchanged.""" + _ensure_missing_tag(session) + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=mtime) + # Pre-add a stale missing tag + session.add(AssetReferenceTag( + asset_reference_id="r1", tag_name="missing", origin="automatic", + )) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=True, + ) + session.commit() + + session.expire_all() + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is None + + +def test_missing_tags_not_touched_when_flag_false(session, temp_dir): + """Missing tags are not modified when update_missing_tags=False.""" + _ensure_missing_tag(session) + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=False, + ) + session.commit() + + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is None # tag was never added + + +def test_returns_none_when_collect_false(session, temp_dir): + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + result = sync_references_with_filesystem( + session, "models", collect_existing_paths=False, + ) + + assert result is None + + +def test_returns_empty_set_for_no_prefixes(session): + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[]): + result = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + + assert result == set() + + +def test_no_references_is_noop(session, temp_dir): + """No crash and no side effects when there are no references.""" + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + survivors = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + session.commit() + + assert survivors == set() + + +# --------------------------------------------------------------------------- +# Soft-delete persistence across scanner operations +# --------------------------------------------------------------------------- + +def _soft_delete_ref(session: Session, ref_id: str) -> None: + """Mark a reference as soft-deleted (mimics the API DELETE behaviour).""" + ref = session.get(AssetReference, ref_id) + ref.deleted_at = datetime(2025, 1, 1) + session.flush() + + +def test_soft_deleted_ref_excluded_from_get_references_for_prefixes(session, temp_dir): + """get_references_for_prefixes skips soft-deleted references.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + rows = get_references_for_prefixes(session, [str(temp_dir)], include_missing=True) + assert len(rows) == 0 + + +def test_sync_does_not_resurrect_soft_deleted_ref(session, temp_dir): + """Scanner sync leaves soft-deleted refs untouched even when file exists on disk.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.deleted_at is not None, "soft-deleted ref must stay deleted after sync" + + +def test_bulk_insert_does_not_overwrite_soft_deleted_ref(session, temp_dir): + """bulk_insert_references_ignore_conflicts cannot replace a soft-deleted row.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + now = datetime.now(tz=None) + bulk_insert_references_ignore_conflicts(session, [ + { + "id": "r_new", + "asset_id": "a1", + "file_path": fp, + "name": "model.bin", + "owner_id": "", + "mtime_ns": mtime, + "preview_id": None, + "user_metadata": None, + "created_at": now, + "updated_at": now, + "last_access_time": now, + } + ]) + session.commit() + + session.expire_all() + # Original row is still the soft-deleted one + ref = session.get(AssetReference, "r1") + assert ref is not None + assert ref.deleted_at is not None + # The new row was not inserted (conflict on file_path) + assert session.get(AssetReference, "r_new") is None + + +def test_restore_references_by_paths_skips_soft_deleted(session, temp_dir): + """restore_references_by_paths does not clear is_missing on soft-deleted refs.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset( + session, "a1", fp, "r1", + asset_hash="blake3:abc", mtime_ns=mtime, is_missing=True, + ) + _soft_delete_ref(session, "r1") + session.commit() + + restored = restore_references_by_paths(session, [fp]) + session.commit() + + assert restored == 0 + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.is_missing is True, "is_missing must not be cleared on soft-deleted ref" + assert ref.deleted_at is not None + + +def test_get_unenriched_references_excludes_soft_deleted(session, temp_dir): + """Enrichment queries do not pick up soft-deleted references.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + rows = get_unenriched_references(session, [str(temp_dir)], max_level=2) + assert len(rows) == 0 + + +def test_sync_ignores_soft_deleted_seed_asset(session, temp_dir): + """Soft-deleted seed ref is not garbage-collected even when file is missing.""" + fp = str(temp_dir / "gone.bin") # file does not exist + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999) + _soft_delete_ref(session, "r1") + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + # Asset and ref must still exist — scanner did not see the soft-deleted row + assert session.get(Asset, "seed1") is not None + assert session.get(AssetReference, "r1") is not None diff --git a/tests-unit/assets_test/test_tags.py b/tests-unit/assets_test/test_tags_api.py similarity index 98% rename from tests-unit/assets_test/test_tags.py rename to tests-unit/assets_test/test_tags_api.py index 6b1047802..595bf29c6 100644 --- a/tests-unit/assets_test/test_tags.py +++ b/tests-unit/assets_test/test_tags_api.py @@ -69,8 +69,8 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, used_names = [t["name"] for t in body2["tags"]] assert custom_tag in used_names - # Delete the asset so the tag usage drops to zero - rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120) + # Hard-delete the asset so the tag usage drops to zero + rd = http.delete(f"{api_base}/api/assets/{_asset['id']}?delete_content=true", timeout=120) assert rd.status_code == 204 # Now the custom tag must not be returned when include_zero=false diff --git a/tests-unit/assets_test/test_uploads.py b/tests-unit/assets_test/test_uploads.py index 137d7391a..d68e5b5d7 100644 --- a/tests-unit/assets_test/test_uploads.py +++ b/tests-unit/assets_test/test_uploads.py @@ -18,25 +18,25 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma assert r1.status_code == 201, a1 assert a1["created_new"] is True - # Second upload with the same data and name should return created_new == False and the same asset + # Second upload with the same data and name creates a new AssetReference (duplicates allowed) + # Returns 200 because Asset already exists, but a new AssetReference is created files = {"file": (name, data, "application/octet-stream")} form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) a2 = r2.json() - assert r2.status_code == 200, a2 - assert a2["created_new"] is False + assert r2.status_code in (200, 201), a2 assert a2["asset_hash"] == a1["asset_hash"] - assert a2["id"] == a1["id"] # old reference + assert a2["id"] != a1["id"] # new reference with same content - # Third upload with the same data but new name should return created_new == False and the new AssetReference + # Third upload with the same data but different name also creates new AssetReference files = {"file": (name, data, "application/octet-stream")} form = {"tags": json.dumps(tags), "name": name + "_d", "user_metadata": json.dumps(meta)} - r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) - a3 = r2.json() - assert r2.status_code == 200, a3 - assert a3["created_new"] is False + r3 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a3 = r3.json() + assert r3.status_code in (200, 201), a3 assert a3["asset_hash"] == a1["asset_hash"] - assert a3["id"] != a1["id"] # old reference + assert a3["id"] != a1["id"] + assert a3["id"] != a2["id"] def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str): @@ -116,7 +116,7 @@ def test_concurrent_upload_identical_bytes_different_names( ): """ Two concurrent uploads of identical bytes but different names. - Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True. + Expect a single Asset (same hash), two AssetReference rows, and exactly one created_new=True. """ scope = f"concupload-{uuid.uuid4().hex[:6]}" name1, name2 = "cu_a.bin", "cu_b.bin" diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index 2355b8000..3a6790ee0 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -2,4 +2,3 @@ pytest>=7.8.0 pytest-aiohttp pytest-asyncio websocket-client -blake3 diff --git a/tests-unit/seeder_test/test_seeder.py b/tests-unit/seeder_test/test_seeder.py new file mode 100644 index 000000000..db3795e48 --- /dev/null +++ b/tests-unit/seeder_test/test_seeder.py @@ -0,0 +1,900 @@ +"""Unit tests for the _AssetSeeder background scanning class.""" + +import threading +from unittest.mock import patch + +import pytest + +from app.assets.database.queries.asset_reference import UnenrichedReferenceRow +from app.assets.seeder import _AssetSeeder, Progress, ScanInProgressError, ScanPhase, State + + +@pytest.fixture +def fresh_seeder(): + """Create a fresh _AssetSeeder instance for testing.""" + seeder = _AssetSeeder() + yield seeder + seeder.shutdown(timeout=1.0) + + +@pytest.fixture +def mock_dependencies(): + """Mock all external dependencies for isolated testing.""" + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + yield + + +class TestSeederStateTransitions: + """Test state machine transitions.""" + + def test_initial_state_is_idle(self, fresh_seeder: _AssetSeeder): + assert fresh_seeder.get_status().state == State.IDLE + + def test_start_transitions_to_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + started = fresh_seeder.start(roots=("models",)) + assert started is True + assert reached.wait(timeout=2.0) + assert fresh_seeder.get_status().state == State.RUNNING + + barrier.set() + + def test_start_while_running_returns_false( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + second_start = fresh_seeder.start(roots=("models",)) + assert second_start is False + + barrier.set() + + def test_cancel_transitions_to_cancelling( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + cancelled = fresh_seeder.cancel() + assert cancelled is True + assert fresh_seeder.get_status().state == State.CANCELLING + + barrier.set() + + def test_cancel_when_idle_returns_false(self, fresh_seeder: _AssetSeeder): + cancelled = fresh_seeder.cancel() + assert cancelled is False + + def test_state_returns_to_idle_after_completion( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=5.0) + assert completed is True + assert fresh_seeder.get_status().state == State.IDLE + + +class TestSeederWait: + """Test wait() behavior.""" + + def test_wait_blocks_until_complete( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=5.0) + assert completed is True + assert fresh_seeder.get_status().state == State.IDLE + + def test_wait_returns_false_on_timeout( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=10.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=0.1) + assert completed is False + + barrier.set() + + def test_wait_when_idle_returns_true(self, fresh_seeder: _AssetSeeder): + completed = fresh_seeder.wait(timeout=1.0) + assert completed is True + + +class TestSeederProgress: + """Test progress tracking.""" + + def test_get_status_returns_progress_during_scan( + self, fresh_seeder: _AssetSeeder + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_build(*args, **kwargs): + reached.set() + barrier.wait(timeout=5.0) + return ([], set(), 0) + + paths = ["/path/file1.safetensors", "/path/file2.safetensors"] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), + patch("app.assets.seeder.build_asset_specs", side_effect=slow_build), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + status = fresh_seeder.get_status() + assert status.state == State.RUNNING + assert status.progress is not None + assert status.progress.total == 2 + + barrier.set() + + def test_progress_callback_is_invoked( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + progress_updates: list[Progress] = [] + + def callback(p: Progress): + progress_updates.append(p) + + with patch( + "app.assets.seeder.collect_paths_for_roots", + return_value=[f"/path/file{i}.safetensors" for i in range(10)], + ): + fresh_seeder.start(roots=("models",), progress_callback=callback) + fresh_seeder.wait(timeout=5.0) + + assert len(progress_updates) > 0 + + +class TestSeederCancellation: + """Test cancellation behavior.""" + + def test_scan_commits_partial_progress_on_cancellation( + self, fresh_seeder: _AssetSeeder + ): + insert_count = 0 + barrier = threading.Event() + first_insert_done = threading.Event() + + def slow_insert(specs, tags): + nonlocal insert_count + insert_count += 1 + if insert_count == 1: + first_insert_done.set() + if insert_count >= 2: + barrier.wait(timeout=5.0) + return len(specs) + + paths = [f"/path/file{i}.safetensors" for i in range(1500)] + specs = [ + { + "abs_path": p, + "size_bytes": 100, + "mtime_ns": 0, + "info_name": f"file{i}", + "tags": [], + "fname": f"file{i}", + "metadata": None, + "hash": None, + "mime_type": None, + } + for i, p in enumerate(paths) + ] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), + patch( + "app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0) + ), + patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + assert first_insert_done.wait(timeout=2.0) + + fresh_seeder.cancel() + barrier.set() + fresh_seeder.wait(timeout=5.0) + + assert 1 <= insert_count < 3 # 1500 paths / 500 batch = 3; cancel stopped early + + +class TestSeederErrorHandling: + """Test error handling behavior.""" + + def test_database_errors_captured_in_status(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch( + "app.assets.seeder.collect_paths_for_roots", + return_value=["/path/file.safetensors"], + ), + patch( + "app.assets.seeder.build_asset_specs", + return_value=( + [ + { + "abs_path": "/path/file.safetensors", + "size_bytes": 100, + "mtime_ns": 0, + "info_name": "file", + "tags": [], + "fname": "file", + "metadata": None, + "hash": None, + "mime_type": None, + } + ], + set(), + 0, + ), + ), + patch( + "app.assets.seeder.insert_asset_specs", + side_effect=Exception("DB connection failed"), + ), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert len(status.errors) > 0 + assert "DB connection failed" in status.errors[0] + + def test_dependencies_unavailable_captured_in_errors( + self, fresh_seeder: _AssetSeeder + ): + with patch("app.assets.seeder.dependencies_available", return_value=False): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert len(status.errors) > 0 + assert "dependencies" in status.errors[0].lower() + + def test_thread_crash_resets_state_to_idle(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch( + "app.assets.seeder.sync_root_safely", + side_effect=RuntimeError("Unexpected crash"), + ), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert status.state == State.IDLE + assert len(status.errors) > 0 + + +class TestSeederThreadSafety: + """Test thread safety of concurrent operations.""" + + def test_concurrent_start_calls_spawn_only_one_thread( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + results = [] + + def try_start(): + results.append(fresh_seeder.start(roots=("models",))) + + threads = [threading.Thread(target=try_start) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + barrier.set() + + assert sum(results) == 1 + + def test_get_status_safe_during_scan( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + statuses = [] + for _ in range(100): + statuses.append(fresh_seeder.get_status()) + + barrier.set() + + assert all( + s.state in (State.RUNNING, State.IDLE, State.CANCELLING) + for s in statuses + ) + + +class TestSeederMarkMissing: + """Test mark_missing_outside_prefixes behavior.""" + + def test_mark_missing_when_idle(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch( + "app.assets.seeder.get_all_known_prefixes", + return_value=["/models", "/input", "/output"], + ), + patch( + "app.assets.seeder.mark_missing_outside_prefixes_safely", return_value=5 + ) as mock_mark, + ): + result = fresh_seeder.mark_missing_outside_prefixes() + assert result == 5 + mock_mark.assert_called_once_with(["/models", "/input", "/output"]) + + def test_mark_missing_raises_when_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + with pytest.raises(ScanInProgressError): + fresh_seeder.mark_missing_outside_prefixes() + + barrier.set() + + def test_mark_missing_returns_zero_when_dependencies_unavailable( + self, fresh_seeder: _AssetSeeder + ): + with patch("app.assets.seeder.dependencies_available", return_value=False): + result = fresh_seeder.mark_missing_outside_prefixes() + assert result == 0 + + def test_prune_first_flag_triggers_mark_missing_before_scan( + self, fresh_seeder: _AssetSeeder + ): + call_order = [] + + def track_mark(prefixes): + call_order.append("mark_missing") + return 3 + + def track_sync(root): + call_order.append(f"sync_{root}") + return set() + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]), + patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark), + patch("app.assets.seeder.sync_root_safely", side_effect=track_sync), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",), prune_first=True) + fresh_seeder.wait(timeout=5.0) + + assert call_order[0] == "mark_missing" + assert "sync_models" in call_order + + +class TestSeederPhases: + """Test phased scanning behavior.""" + + def test_start_fast_only_runs_fast_phase(self, fresh_seeder: _AssetSeeder): + """Verify start_fast only runs the fast phase.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start_fast(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 1 + assert len(enrich_called) == 0 + + def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: _AssetSeeder): + """Verify start_enrich only runs the enrich phase.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start_enrich(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 0 + assert len(enrich_called) == 1 + + def test_full_scan_runs_both_phases(self, fresh_seeder: _AssetSeeder): + """Verify full scan runs both fast and enrich phases.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.FULL) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 1 + assert len(enrich_called) == 1 + + +class TestSeederPauseResume: + """Test pause/resume behavior.""" + + def test_pause_transitions_to_paused( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + paused = fresh_seeder.pause() + assert paused is True + assert fresh_seeder.get_status().state == State.PAUSED + + barrier.set() + + def test_pause_when_idle_returns_false(self, fresh_seeder: _AssetSeeder): + paused = fresh_seeder.pause() + assert paused is False + + def test_resume_returns_to_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + fresh_seeder.pause() + assert fresh_seeder.get_status().state == State.PAUSED + + resumed = fresh_seeder.resume() + assert resumed is True + assert fresh_seeder.get_status().state == State.RUNNING + + barrier.set() + + def test_resume_when_not_paused_returns_false( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + resumed = fresh_seeder.resume() + assert resumed is False + + barrier.set() + + def test_cancel_while_paused_works( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached_checkpoint = threading.Event() + + def slow_collect(*args): + reached_checkpoint.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached_checkpoint.wait(timeout=2.0) + + fresh_seeder.pause() + assert fresh_seeder.get_status().state == State.PAUSED + + cancelled = fresh_seeder.cancel() + assert cancelled is True + + barrier.set() + fresh_seeder.wait(timeout=5.0) + assert fresh_seeder.get_status().state == State.IDLE + +class TestSeederStopRestart: + """Test stop and restart behavior.""" + + def test_stop_is_alias_for_cancel( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + stopped = fresh_seeder.stop() + assert stopped is True + assert fresh_seeder.get_status().state == State.CANCELLING + + barrier.set() + + def test_restart_cancels_and_starts_new_scan( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + start_count = 0 + + def slow_collect(*args): + nonlocal start_count + start_count += 1 + if start_count == 1: + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + barrier.set() + restarted = fresh_seeder.restart() + assert restarted is True + + fresh_seeder.wait(timeout=5.0) + assert start_count == 2 + + def test_restart_preserves_previous_params(self, fresh_seeder: _AssetSeeder): + """Verify restart uses previous params when not overridden.""" + collected_roots = [] + + def track_collect(roots): + collected_roots.append(roots) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("input", "output")) + fresh_seeder.wait(timeout=5.0) + + fresh_seeder.restart() + fresh_seeder.wait(timeout=5.0) + + assert len(collected_roots) == 2 + assert collected_roots[0] == ("input", "output") + assert collected_roots[1] == ("input", "output") + + def test_restart_can_override_params(self, fresh_seeder: _AssetSeeder): + """Verify restart can override previous params.""" + collected_roots = [] + + def track_collect(roots): + collected_roots.append(roots) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + fresh_seeder.restart(roots=("input",)) + fresh_seeder.wait(timeout=5.0) + + assert len(collected_roots) == 2 + assert collected_roots[0] == ("models",) + assert collected_roots[1] == ("input",) + + +def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow: + return UnenrichedReferenceRow( + reference_id=ref_id, asset_id=asset_id, + file_path=f"/fake/{ref_id}.bin", enrichment_level=0, + ) + + +class TestEnrichPhaseDefensiveLogic: + """Test skip_ids filtering and consecutive_empty termination.""" + + def test_failed_refs_are_skipped_on_subsequent_batches( + self, fresh_seeder: _AssetSeeder, + ): + """References that fail enrichment are filtered out of future batches.""" + row_a = _make_row("r1") + row_b = _make_row("r2") + call_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 2: + return [row_a, row_b] + return [] + + enriched_refs: list[list[str]] = [] + + def fake_enrich(rows, **kwargs): + ref_ids = [r.reference_id for r in rows] + enriched_refs.append(ref_ids) + # r1 always fails, r2 succeeds + failed = [r.reference_id for r in rows if r.reference_id == "r1"] + enriched = len(rows) - len(failed) + return enriched, failed + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # First batch: both refs attempted + assert "r1" in enriched_refs[0] + assert "r2" in enriched_refs[0] + # Second batch: r1 filtered out + assert "r1" not in enriched_refs[1] + assert "r2" in enriched_refs[1] + + def test_stops_after_consecutive_empty_batches( + self, fresh_seeder: _AssetSeeder, + ): + """Enrich phase terminates after 3 consecutive batches with zero progress.""" + row = _make_row("r1") + batch_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal batch_count + batch_count += 1 + # Always return the same row (simulating a permanently failing ref) + return [row] + + def fake_enrich(rows, **kwargs): + # Always fail — zero enriched, all failed + return 0, [r.reference_id for r in rows] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # Should stop after exactly 3 consecutive empty batches + # Batch 1: returns row, enrich fails → filtered out in batch 2+ + # But get_unenriched keeps returning it, filter removes it → empty → break + # Actually: batch 1 has row, fails. Batch 2 get_unenriched returns [row], + # skip_ids filters it → empty list → breaks via `if not unenriched: break` + # So it terminates in 2 calls to get_unenriched. + assert batch_count == 2 + + def test_consecutive_empty_counter_resets_on_success( + self, fresh_seeder: _AssetSeeder, + ): + """A successful batch resets the consecutive empty counter.""" + call_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 6: + return [_make_row(f"r{call_count}", f"a{call_count}")] + return [] + + def fake_enrich(rows, **kwargs): + ref_id = rows[0].reference_id + # Fail batches 1-2, succeed batch 3, fail batches 4-5, succeed batch 6 + if ref_id in ("r1", "r2", "r4", "r5"): + return 0, [ref_id] + return 1, [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # All 6 batches should run + 1 final call returning empty + assert call_count == 7 + status = fresh_seeder.get_status() + assert status.state == State.IDLE diff --git a/utils/mime_types.py b/utils/mime_types.py new file mode 100644 index 000000000..916e963c5 --- /dev/null +++ b/utils/mime_types.py @@ -0,0 +1,37 @@ +"""Centralized MIME type initialization. + +Call init_mime_types() once at startup to initialize the MIME type database +and register all custom types used across ComfyUI. +""" + +import mimetypes + +_initialized = False + + +def init_mime_types(): + """Initialize the MIME type database and register all custom types. + + Safe to call multiple times; only runs once. + """ + global _initialized + if _initialized: + return + _initialized = True + + mimetypes.init() + + # Web types (used by server.py for static file serving) + mimetypes.add_type('application/javascript; charset=utf-8', '.js') + mimetypes.add_type('image/webp', '.webp') + + # Model and data file types (used by asset scanning / metadata extraction) + mimetypes.add_type("application/safetensors", ".safetensors") + mimetypes.add_type("application/safetensors", ".sft") + mimetypes.add_type("application/pytorch", ".pt") + mimetypes.add_type("application/pytorch", ".pth") + mimetypes.add_type("application/pickle", ".ckpt") + mimetypes.add_type("application/pickle", ".pkl") + mimetypes.add_type("application/gguf", ".gguf") + mimetypes.add_type("application/yaml", ".yaml") + mimetypes.add_type("application/yaml", ".yml") From 7723f20bbe010a3ea4eac602f77b0ff496f123c4 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 8 Mar 2026 13:17:40 -0700 Subject: [PATCH 80/81] comfy-aimdo 0.2.9 (#12840) Comfy-aimdo 0.2.9 fixes a context issue where if a non-main thread does a spurious garbage collection, cudaFrees are attempted with bad context. Some new APIs for displaying aimdo stats in UI widgets are also added. These are purely additive getters that dont touch cuda APIs. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9527135ec..b1db1cf24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy filelock av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.7 +comfy-aimdo>=0.2.9 requests simpleeval>=1.0.0 blake3 From e4b0bb8305a4069ef7ff8396bfc6057c736ab95b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 8 Mar 2026 13:25:30 -0700 Subject: [PATCH 81/81] Import assets seeder later, print some package versions. (#12841) --- app/assets/services/hashing.py | 6 +++++- main.py | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/app/assets/services/hashing.py b/app/assets/services/hashing.py index 92aee6402..41d8b4615 100644 --- a/app/assets/services/hashing.py +++ b/app/assets/services/hashing.py @@ -3,8 +3,12 @@ import os from contextlib import contextmanager from dataclasses import dataclass from typing import IO, Any, Callable, Iterator +import logging -from blake3 import blake3 +try: + from blake3 import blake3 +except ModuleNotFoundError: + logging.warning("WARNING: blake3 package not installed") DEFAULT_CHUNK = 8 * 1024 * 1024 diff --git a/main.py b/main.py index a8fc1a28d..1977f9362 100644 --- a/main.py +++ b/main.py @@ -3,11 +3,11 @@ comfy.options.enable_args_parsing() import os import importlib.util +import importlib.metadata import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger -from app.assets.seeder import asset_seeder import itertools import utils.extra_config from utils.mime_types import init_mime_types @@ -182,6 +182,7 @@ if 'torch' in sys.modules: import comfy.utils +from app.assets.seeder import asset_seeder import execution import server @@ -451,6 +452,11 @@ if __name__ == "__main__": # Running directly, just start ComfyUI. logging.info("Python version: {}".format(sys.version)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + for package in ("comfy-aimdo", "comfy-kitchen"): + try: + logging.info("{} version: {}".format(package, importlib.metadata.version(package))) + except: + pass if sys.version_info.major == 3 and sys.version_info.minor < 10: logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")